In [1]:
import numpy as np
import pandas as pd
import datetime
import ast
from clickbait import *

### Search for Clickbait Sessions

In [2]:
# Where to look for clickbait datasets
data_dir = f'A:/Clickbait/'

# Get session and filenames for directories with .avi files greater than 1gb
datasets, sessions, files = scan_dataset(data_dir, min_size_bytes=1e9, filetype='.avi')

In [3]:
print(f"Located {len(datasets)} sessions.")
for ii in range(len(datasets)):
    print(f"{datasets[ii]}/{sessions[ii]}/{files[ii]}")


Located 26 sessions.
1003/diet1/11112024_1003_diet1
1003/diet2/11122024_1003_diet2
1003/diet3/11132024_1003_diet3
1003/diet4/11142024_1003_diet4
1003/full1/11152024_1003_full1
1003/full2/11182024_1003_full2
1003/full3/11192024_1003_full3
1003/full4/11212024_1003_full4
1003/full5/11222024_1003_full5
1003/full6/11252024_1003_full6
1003/full7/11272024_1003_full7
1003/full8/11282024_1003_full8
1003/full9/11282024_1003_full9
1006/diet1/11112024_1006_diet1
1006/diet2/11122024_1006_diet2
1006/diet3/11132024_1006_diet3
1006/diet4/11142024_1006_diet4
1006/full1/11152024_1006_full1
1006/full2/11182024_1006_full2
1006/full3/11192024_1006_full3
1006/full4/11212024_1006_full4
1006/full5/11222024_1006_full5
1006/full6/11252024_1006_full6
1006/full7/11272024_1006_full7
1006/full8/11282024_1006_full8
1006/full9/11282024_1006_full9


### Load data for a single session.

In [4]:
# Set session index
idx = 0

# Set path and filename prefix
data_path = f"{data_dir}{datasets[idx]}/{sessions[idx]}/{files[idx]}"

# Get video filename
video_filename = f"{data_path}.avi"

# Load video timestamps
video_ts = pd.read_csv(f"{data_path}_video_timestamp.csv")
video_ts.columns = ['timestamp']

# Load events .csv part A
col_names_a = ['trial_number', 'timestamp', 'poke_left', 'poke_right', 'centroid_x', 'centroid_y', 'target_cell']
event_data_a = pd.read_csv(f"{data_path}_eventsA.csv")
event_data_a.columns = col_names_a
pd.to_datetime(event_data_a['timestamp'])

# Load events .csv part B
col_names_b = ['iti', 'reward_state', 'water_left', 'water_right', 'click']
event_data_b = pd.read_csv(f"{data_path}_eventsB.csv")
event_data_b.columns = col_names_b

# Concatenate eventsA abd eventsB dataframes
if len(event_data_a) == len(event_data_b):
    event_data = pd.concat([event_data_a, event_data_b], axis=1)
else:
    print("Event dataframes must contain same number of rows")
    min_length = min(len(event_data_a), len(event_data_b))
    max_length = max(len(event_data_a), len(event_data_b))
    print(f"Trimmed long dataframe by {max_length-min_length} rows.")
    event_data_a = event_data_a.iloc[:min_length]
    event_data_b = event_data_b.iloc[:min_length]
    event_data = pd.concat([event_data_a, event_data_b], axis=1)

#Set types for each column in the dataframes
video_ts = video_ts.astype({'timestamp': 'datetime64[ns]'})
event_data = event_data.astype({
    'trial_number': 'uint8',
    'timestamp': 'datetime64[ns]',
    'poke_left': 'bool',
    'poke_right': 'bool',
    'centroid_x': 'uint16',
    'centroid_y': 'uint16',
    'target_cell': 'str',
    'iti': 'bool',
    'water_left': 'bool',
    'water_right': 'bool',
    'reward_state': 'bool',
    'click': 'bool'})

# Convert string representations of lists to actual lists
event_data['target_cell'] = event_data['target_cell'].apply(ast.literal_eval)

# Check Lengths of video and events dataframe
print(f"Mouse: {datasets[idx]} Session: {sessions[idx]}")
print(f"Video length: {len(video_ts)} frames")
print(f"Events Data Length: {len(event_data)} rows")

Mouse: 1003 Session: diet1
Video length: 138106 frames
Events Data Length: 296855 rows


### Resample to synch event and video timestamps

In [5]:
if len(video_ts) < len(event_data):
    # Synchronize events data with video, using nearest matches in the 'timestamp' column 
    video_ts['timestamp'] = pd.to_datetime(video_ts['timestamp'])
    event_data['timestamp'] = pd.to_datetime(event_data['timestamp'])

    # Set 'timestamp' as the index of each dataframe
    video_ts = video_ts.set_index('timestamp')
    event_data = event_data.set_index('timestamp')

    # Map event_data onto the video_data timestamps, using nearest matches between timestamps
    event_data = event_data.reindex(video_ts.index, method='nearest')
    video_ts = video_ts.reset_index()
    event_data = event_data.reset_index()

    # Check that event dataframe is the same length as video frames
    print(f"Event data resampled to match video length:")
    print(f"Video length: {len(video_ts)} frames")
    print(f"Events Data Length: {len(event_data)} rows")

elif len(video_ts) > len(event_data):
    # If video is longer than event data, slice excess timestamps out of the beginning of timestamp list
    video_ts, frame_idx = slice_video_timestamp(video_ts, event_data)
    
    # Check that event dataframe is the same length as video frames
    print(f"Video timestamps trimmed by {frame_idx + 1} to match event data length:")
    print(f"Video length: {len(video_ts)} frames")
    print(f"Events Data Length: {len(event_data)} rows")

print(f"Video length at 50.6 FPS: {len(video_ts)/50.6/60:.2f} minutes")

Event data resampled to match video length:
Video length: 138106 frames
Events Data Length: 138106 rows
Video length at 50.6 FPS: 45.49 minutes


### Calculate additional columns (add to library later)

In [6]:
# Add a column that is tracking distance between conssecutive centroids
event_data['distance'] = np.sqrt(
    (event_data['centroid_x'] - event_data['centroid_x'].shift(1))**2 + 
    (event_data['centroid_y'] - event_data['centroid_y'].shift(1))**2)

# Add a column that is tracking time between timestamps in milliseconds
event_data['frame_ms'] = event_data['timestamp'].diff().dt.total_seconds() * 1000

# Add a column that tags gaps if distance exceeds a threshold
gap_tresh = 100
event_data['gap'] = (event_data['distance'] >= gap_tresh).astype(np.uint8)

### Check for frame drops

In [7]:
gaps_total = []

for ii in range(np.max(event_data['trial_number'])):
    trial = event_data[event_data['trial_number'] == ii]
    gaps_trial = int(np.sum(trial['gap']))
    gaps_total.append(gaps_trial)

clean_trials = [x for x in gaps_total if x == 0]
gap_trials = [x for x in gaps_total if x > 0]

# Check if gaps are distributed toward the beginning or end of session
gap_idx = []
for ii, x in enumerate(gaps_total):
    if x > 0:
        gap_idx.append(ii) 

print(f"Trials with gaps: {len(gap_trials)}")
print(f"Trials without gaps: {len(clean_trials)}")
print(f"Maximum gaps in a single trial: {np.max(gaps_total)}")
print(f"Mean index of gap trials: {np.mean(gap_idx):.0f}")

Trials with gaps: 45
Trials without gaps: 27
Maximum gaps in a single trial: 23
Mean index of gap trials: 35


### Preview Events Dataframe

In [8]:
event_data

Unnamed: 0,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click,distance,frame_ms,gap
0,2024-11-11 12:55:21.913996800,0,False,False,676,1596,[9],False,False,False,False,False,,,0
1,2024-11-11 12:55:21.927244800,0,False,False,676,1596,[9],False,False,False,False,False,0.000000,13.2480,0
2,2024-11-11 12:55:21.945113600,0,False,False,676,1596,[9],False,False,False,False,False,0.000000,17.8688,0
3,2024-11-11 12:55:21.959296000,0,False,False,676,1596,[9],False,False,False,False,False,0.000000,14.1824,0
4,2024-11-11 12:55:21.972569600,0,False,False,676,1596,[9],False,False,False,False,False,0.000000,13.2736,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
138101,2024-11-11 13:40:50.363366400,72,False,False,730,1240,[9],False,False,False,False,False,1.000000,20.6976,0
138102,2024-11-11 13:40:50.383347200,72,False,False,730,1240,[9],False,False,False,False,False,0.000000,19.9808,0
138103,2024-11-11 13:40:50.455564800,72,False,False,730,1239,[9],False,False,False,False,False,1.000000,72.2176,0
138104,2024-11-11 13:40:50.522342400,72,False,False,730,1235,[9],False,False,False,False,False,4.000000,66.7776,0


In [9]:
# Check to see what the largest distance between observations is, and get the index
max_distance_idx = event_data['distance'].idxmax()
print(np.max(event_data['distance']))
event_data.loc[max_distance_idx-1:max_distance_idx+1]

1604.4500615475697


Unnamed: 0,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click,distance,frame_ms,gap
115690,2024-11-11 13:33:27.631129600,62,False,False,739,452,[],False,True,False,False,False,1.0,22.6048,0
115691,2024-11-11 13:33:27.650316800,62,False,False,87,1918,[],False,True,False,False,False,1604.450062,19.1872,1
115692,2024-11-11 13:33:27.669350400,62,False,False,87,1918,[],False,True,False,False,False,0.0,19.0336,0


### Visualize Trajectories

In [10]:
# Extract list of trials from session
trials_list = event_data['trial_number'].unique().tolist()
print(trials_list)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72]


In [11]:
visualize_trial_trajectory(event_data, trial_number=[trials_list[-1]], color_code="frame_number", target_frame=True, opacity=.5)


Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. Call result.infer_objects(copy=False) instead. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`



Index([], dtype='int64')


In [12]:
trial_data = event_data[event_data['trial_number'] == trials_list[-1]]
trial_data = trial_data.reset_index()
trial_data[20:30]

Unnamed: 0,index,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click,distance,frame_ms,gap
20,134927,2024-11-11 13:39:47.679846400,72,False,False,266,720,[9],False,False,False,False,False,2.0,18.3552,0
21,134928,2024-11-11 13:39:47.700121600,72,False,False,267,718,[9],False,False,False,False,False,2.236068,20.2752,0
22,134929,2024-11-11 13:39:47.720460800,72,False,False,266,714,[9],False,False,False,False,False,4.123106,20.3392,0
23,134930,2024-11-11 13:39:47.740198400,72,False,False,266,713,[9],False,False,False,False,False,1.0,19.7376,0
24,134931,2024-11-11 13:39:47.760384000,72,False,False,269,714,[9],False,False,False,False,False,3.162278,20.1856,0
25,134932,2024-11-11 13:39:47.783936000,72,False,False,272,716,[9],False,False,False,False,False,3.605551,23.552,0
26,134933,2024-11-11 13:39:47.801049600,72,False,False,275,717,[9],False,False,False,False,False,3.162278,17.1136,0
27,134934,2024-11-11 13:39:47.823449600,72,False,False,276,720,[9],False,False,False,False,False,3.162278,22.4,0
28,134935,2024-11-11 13:39:47.846988800,72,False,False,277,723,[9],False,False,False,False,False,3.162278,23.5392,0
29,134936,2024-11-11 13:39:47.866060800,72,False,False,277,727,[9],False,False,False,False,False,4.0,19.072,0


In [13]:
max_distance_idx = trial_data['distance'].idxmax()
trial_data.loc[max_distance_idx-1:max_distance_idx+1]

Unnamed: 0,index,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click,distance,frame_ms,gap
351,135258,2024-11-11 13:39:54.214233600,72,False,False,180,163,[9],False,False,False,False,False,9.848858,25.7664,0
352,135259,2024-11-11 13:39:54.231129600,72,False,False,197,250,[9],False,False,False,False,False,88.645361,16.896,0
353,135260,2024-11-11 13:39:54.251008000,72,False,False,205,250,[9],False,False,False,False,False,8.0,19.8784,0


In [14]:
# Set video dimensions in pixels
scale = 2 # Reduction factor
dim_x = 894//scale
dim_y = 1952//scale

# Create GidMaze object
grid = GridMaze((dim_y,dim_x),(9,4), border=True)

In [15]:
# Blank white canvas to draw on
canvas = np.ones(grid.shape)
# Draw grid in black
grid_img = grid.draw_grid(canvas, color=(0,0,0)).astype(np.uint8)
# Convert grayscale image to BGR
grid_img = cv2.cvtColor(grid_img, cv2.COLOR_GRAY2RGB)

In [16]:
test_data = event_data.loc[event_data['trial_number'].isin([12])].copy()
test_data

Unnamed: 0,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click,distance,frame_ms,gap
25018,2024-11-11 13:03:36.145625600,12,False,False,616,761,[10],False,False,False,False,False,3.000000,18.0480,0
25019,2024-11-11 13:03:36.163852800,12,False,False,615,757,[10],False,False,False,False,False,4.123106,18.2272,0
25020,2024-11-11 13:03:36.183283200,12,False,False,614,753,[10],False,False,False,False,False,4.123106,19.4304,0
25021,2024-11-11 13:03:36.203750400,12,False,False,612,748,[10],False,False,False,False,False,5.385165,20.4672,0
25022,2024-11-11 13:03:36.224012800,12,False,False,611,743,[10],False,False,False,False,False,5.099020,20.2624,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29743,2024-11-11 13:05:09.487347200,12,False,False,150,1055,[],True,False,False,False,False,2.828427,20.5312,0
29744,2024-11-11 13:05:09.507929600,12,False,False,151,1057,[],True,False,False,False,False,2.236068,20.5824,0
29745,2024-11-11 13:05:09.528806400,12,False,False,152,1057,[],True,False,False,False,False,1.000000,20.8768,0
29746,2024-11-11 13:05:09.549286400,12,False,False,153,1056,[],True,False,False,False,False,1.414214,20.4800,0


In [17]:
# test_data.index[0]

In [18]:
video_filename

'A:/Clickbait/1003/diet1/11112024_1003_diet1.avi'

In [19]:
single_trial = False
display = True
write = False
loop = False

if single_trial:
    test_data = event_data.loc[event_data['trial_number'].isin([20])].copy()
    ii = test_data.index[0]

else:
    test_data = event_data.copy()
    ii = 0  # Start from 0

# Load video
video = cv2.VideoCapture(video_filename)
video.set(cv2.CAP_PROP_POS_FRAMES, ii)

# Display video with event overlay
while True:
    # Load video frame

    ret, frame = video.read()
    frame = cv2.resize(frame, (frame.shape[1]//scale, frame.shape[0]//scale)).astype(np.uint8)
    
    current_frame = int(video.get(cv2.CAP_PROP_POS_FRAMES))

    # Draw Grid
    frame = grid.draw_grid(frame, color=(0,0,0), opacity=.25).astype(np.uint8)

    # Draw target
    target = grid.get_target_cell(test_data['target_cell'][ii])
    frame = cv2.rectangle(frame, target[0], target[1], (0,0,0), -1)

    # Get mouse centroid
    pt_x = test_data['centroid_x'][ii]//scale
    pt_y = test_data['centroid_y'][ii]//scale

    # Get mouse cell and draw
    cell_i, cell_j = grid.get_mouse_cell(pt_x, pt_y)
    frame = grid.draw_cell(frame, cell_i, cell_j, (128,128,255), -1, opacity=.25)

    # Draw mouse centroid
    frame = cv2.circle(frame, (pt_x, pt_y), 5, (255,255,255), -1, cv2.LINE_AA)

    # Set state color for trial number
    if test_data['iti'][ii] == True:
        state_color = (255,128,128)
    elif test_data['reward_state'][ii] == True:
        state_color = (128,255,128)
    else:
        state_color = (128,128,255)

    # Draw trial number in state color
    frame = cv2.putText(frame, str(test_data['trial_number'][ii]), (20,40), cv2.FONT_HERSHEY_SIMPLEX, 1, state_color, 1, cv2.LINE_AA)
    # Draw video frame number
    frame_count = f"Video frame: {current_frame}, DataFrame index: {ii}"
    frame = cv2.putText(frame, frame_count, (20,60), cv2.FONT_HERSHEY_SIMPLEX, .5, (255,255,255), 1, cv2.LINE_AA)
    # Draw session filename frame number
    frame = cv2.putText(frame, video_filename[29:], (20,dim_y-30), cv2.FONT_HERSHEY_SIMPLEX, .5, (255,255,255), 1, cv2.LINE_AA)


    # Display frame
    if display:
        cv2.imshow("clickbait", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    if write:
        cv2.imwrite(f"output/frame_{ii}.jpg", frame)

    # Increase counter
    ii += 1

    # Loop counter if needed
    if loop:
        if ii >= test_data.index[-1]:
            ii = test_data.index[0]
            video.set(cv2.CAP_PROP_POS_FRAMES, ii) 
    else:
        if ii >= test_data.index[-1]:
            break

cv2.destroyAllWindows()