In [33]:
import numpy as np
import pandas as pd
import datetime
import ast
from clickbait import *

### Search for Clickbait Sessions

In [34]:
# Where to look for clickbait datasets
data_dir = f'A:/Clickbait/'

# Get session and filenames for directories with .avi files greater than 1gb
datasets, sessions, files = scan_dataset(data_dir, min_size_bytes=1e9, filetype='.avi')

In [35]:
print(f"Located {len(datasets)} sessions.")
for ii in range(len(datasets)):
    print(f"{datasets[ii]}/{sessions[ii]}/{files[ii]}")


Located 51 sessions.
1003/diet1/11112024_1003_diet1
1003/diet2/11122024_1003_diet2
1003/diet3/11132024_1003_diet3
1003/diet4/11142024_1003_diet4
1003/full1/11152024_1003_full1
1003/full2/11182024_1003_full2
1003/full3/11192024_1003_full3
1003/full4/11212024_1003_full4
1003/full5/11222024_1003_full5
1003/full6/11252024_1003_full6
1003/full7/11272024_1003_full7
1003/full8/11282024_1003_full8
1003/full9/11282024_1003_full9
1004/diet1/11112024_1004_diet1
1004/diet2/11122024_1004_diet2
1004/diet3/11132024_1004_diet3
1004/diet4/11142024_1004_diet4
1004/diet5/11152024_1004_diet5
1004/diet6/11182024_1004_diet6
1004/diet7/11192024_1004_diet7
1004/full1/11212024_1004_full1
1004/full2/11222024_1004_full2
1004/full3/11252024_1004_full3
1004/full4/11272024_1004_full4
1004/full5/11282024_1004_full5
1004/full6/11282024_1004_full6
1005/diet1/11112024_1005_diet1
1005/diet2/11122024_1005_diet2
1005/diet3/11132024_1005_diet3
1005/diet4/11142024_1005_diet4
1005/diet5/11152024_1005_diet5
1005/diet6/1118202

### Load data for a single session.

In [36]:
# Set session index
idx = 31

# Set path and filename prefix
data_path = f"{data_dir}{datasets[idx]}/{sessions[idx]}/{files[idx]}"

# Get video filename
video_filename = f"{data_path}.avi"

# Load video timestamps
video_ts = pd.read_csv(f"{data_path}_video_timestamp.csv")
video_ts.columns = ['timestamp']

# Load events .csv part A
col_names_a = ['trial_number', 'timestamp', 'poke_left', 'poke_right', 'centroid_x', 'centroid_y', 'target_cell']
event_data_a = pd.read_csv(f"{data_path}_eventsA.csv")
event_data_a.columns = col_names_a
pd.to_datetime(event_data_a['timestamp'])

# Load events .csv part B
col_names_b = ['iti', 'reward_state', 'water_left', 'water_right', 'click']
event_data_b = pd.read_csv(f"{data_path}_eventsB.csv")
event_data_b.columns = col_names_b

# Concatenate eventsA abd eventsB dataframes
if len(event_data_a) == len(event_data_b):
    event_data = pd.concat([event_data_a, event_data_b], axis=1)
else:
    print("Event dataframes must contain same number of rows")
    min_length = min(len(event_data_a), len(event_data_b))
    max_length = max(len(event_data_a), len(event_data_b))
    print(f"Trimmed long dataframe by {max_length-min_length} rows.")
    event_data_a = event_data_a.iloc[:min_length]
    event_data_b = event_data_b.iloc[:min_length]
    event_data = pd.concat([event_data_a, event_data_b], axis=1)

#Set types for each column in the dataframes
video_ts = video_ts.astype({'timestamp': 'datetime64[ns]'})
event_data = event_data.astype({
    'trial_number': 'uint8',
    'timestamp': 'datetime64[ns]',
    'poke_left': 'bool',
    'poke_right': 'bool',
    'centroid_x': 'uint16',
    'centroid_y': 'uint16',
    'target_cell': 'str',
    'iti': 'bool',
    'water_left': 'bool',
    'water_right': 'bool',
    'reward_state': 'bool',
    'click': 'bool'})

# Convert string representations of lists to actual lists
event_data['target_cell'] = event_data['target_cell'].apply(ast.literal_eval)

# Check Lengths of video and events dataframe
print(f"Mouse: {datasets[idx]} Session: {sessions[idx]}")
print(f"Video length: {len(video_ts)} frames")
print(f"Events Data Length: {len(event_data)} rows")

Mouse: 1005 Session: diet6
Video length: 122475 frames
Events Data Length: 276609 rows


### Resample to synch event and video timestamps

In [37]:
if len(video_ts) < len(event_data):
    # Synchronize events data with video, using nearest matches in the 'timestamp' column 
    video_ts['timestamp'] = pd.to_datetime(video_ts['timestamp'])
    event_data['timestamp'] = pd.to_datetime(event_data['timestamp'])

    # Set 'timestamp' as the index of each dataframe
    video_ts = video_ts.set_index('timestamp')
    event_data = event_data.set_index('timestamp')

    # Map event_data onto the video_data timestamps, using nearest matches between timestamps
    event_data = event_data.reindex(video_ts.index, method='nearest')
    video_ts = video_ts.reset_index()
    event_data = event_data.reset_index()

    # Check that event dataframe is the same length as video frames
    print(f"Event data resampled to match video length:")
    print(f"Video length: {len(video_ts)} frames")
    print(f"Events Data Length: {len(event_data)} rows")

elif len(video_ts) > len(event_data):
    # If video is longer than event data, slice excess timestamps out of the beginning of timestamp list
    video_ts, frame_idx = slice_video_timestamp(video_ts, event_data)
    
    # Check that event dataframe is the same length as video frames
    print(f"Video timestamps trimmed by {frame_idx + 1} to match event data length:")
    print(f"Video length: {len(video_ts)} frames")
    print(f"Events Data Length: {len(event_data)} rows")

print(f"Video length at 50.6 FPS: {len(video_ts)/50.6/60:.2f} minutes")

Event data resampled to match video length:
Video length: 122475 frames
Events Data Length: 122475 rows
Video length at 50.6 FPS: 40.34 minutes


### Calculate additional columns (add to library later)

In [38]:
# Add a column that is tracking distance between conssecutive centroids
event_data['distance'] = np.sqrt(
    (event_data['centroid_x'] - event_data['centroid_x'].shift(1))**2 + 
    (event_data['centroid_y'] - event_data['centroid_y'].shift(1))**2)

# Add a column that is tracking time between timestamps in milliseconds
event_data['frame_ms'] = event_data['timestamp'].diff().dt.total_seconds() * 1000

# Add a column that tags gaps if distance exceeds a threshold
gap_tresh = 100
event_data['gap'] = (event_data['distance'] >= gap_tresh).astype(np.uint8)

### Check for frame drops

In [39]:
gaps_total = []

for ii in range(np.max(event_data['trial_number'])):
    trial = event_data[event_data['trial_number'] == ii]
    gaps_trial = int(np.sum(trial['gap']))
    gaps_total.append(gaps_trial)

clean_trials = [x for x in gaps_total if x == 0]
gap_trials = [x for x in gaps_total if x > 0]

# Check if gaps are distributed toward the beginning or end of session
gap_idx = []
for ii, x in enumerate(gaps_total):
    if x > 0:
        gap_idx.append(ii) 

print(f"Trials with gaps: {len(gap_trials)}")
print(f"Trials without gaps: {len(clean_trials)}")
print(f"Maximum gaps in a single trial: {np.max(gaps_total)}")
print(f"Mean index of gap trials: {np.mean(gap_idx):.0f}")

Trials with gaps: 107
Trials without gaps: 9
Maximum gaps in a single trial: 13
Mean index of gap trials: 55


### Preview Events Dataframe

In [40]:
event_data

Unnamed: 0,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click,distance,frame_ms,gap
0,2024-11-18 12:27:01.515251200,0,False,False,737,1742,[0],False,False,False,False,False,,,0
1,2024-11-18 12:27:01.531225600,0,False,False,737,1742,[0],False,False,False,False,False,0.000000,15.9744,0
2,2024-11-18 12:27:01.554060800,0,False,False,737,1742,[0],False,False,False,False,False,0.000000,22.8352,0
3,2024-11-18 12:27:01.570547200,0,False,False,737,1742,[0],False,False,False,False,False,0.000000,16.4864,0
4,2024-11-18 12:27:01.589452800,0,False,False,737,1742,[0],False,False,False,False,False,0.000000,18.9056,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
122470,2024-11-18 13:12:30.049164800,116,False,False,657,993,[],False,True,False,False,False,2.236068,19.4816,0
122471,2024-11-18 13:12:30.074419200,116,False,False,659,990,[],False,True,False,False,False,3.605551,25.2544,0
122472,2024-11-18 13:12:30.092096000,116,False,False,662,985,[],False,True,False,False,False,5.830952,17.6768,0
122473,2024-11-18 13:12:30.115468800,116,False,False,663,980,[],False,True,False,False,False,5.099020,23.3728,0


In [41]:
# Check to see what the largest distance between observations is, and get the index
max_distance_idx = event_data['distance'].idxmax()
print(np.max(event_data['distance']))
event_data.loc[max_distance_idx-1:max_distance_idx+1]

355.36038045904894


Unnamed: 0,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click,distance,frame_ms,gap
81186,2024-11-18 12:57:16.485376000,92,False,False,449,1482,[13],False,False,False,False,False,25.632011,20.4544,0
81187,2024-11-18 12:57:16.534822400,92,False,False,549,1141,[13],False,False,False,False,False,355.36038,49.4464,1
81188,2024-11-18 12:57:16.553638400,92,False,False,561,1125,[13],False,False,False,False,False,20.0,18.816,0


### Visualize Trajectories

In [42]:
# Extract list of trials from session
trials_list = event_data['trial_number'].unique().tolist()
print(trials_list)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116]


In [43]:
visualize_trial_trajectory(event_data, trial_number=[trials_list[-1]], color_code="frame_number", target_frame=True, opacity=.5)


Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. Call result.infer_objects(copy=False) instead. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`



In [44]:
trial_data = event_data[event_data['trial_number'] == trials_list[-1]]
trial_data = trial_data.reset_index()
trial_data[20:30]

Unnamed: 0,index,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click,distance,frame_ms,gap
20,113292,2024-11-18 13:09:05.076966400,116,False,False,182,1052,[8],False,False,False,False,False,1.0,22.1056,0
21,113293,2024-11-18 13:09:05.100019200,116,False,False,182,1052,[8],False,False,False,False,False,0.0,23.0528,0
22,113294,2024-11-18 13:09:05.122291200,116,False,False,182,1052,[8],False,False,False,False,False,0.0,22.272,0
23,113295,2024-11-18 13:09:05.140416000,116,False,False,183,1052,[8],False,False,False,False,False,1.0,18.1248,0
24,113296,2024-11-18 13:09:05.164121600,116,False,False,183,1052,[8],False,False,False,False,False,0.0,23.7056,0
25,113297,2024-11-18 13:09:05.192678400,116,False,False,183,1052,[8],False,False,False,False,False,0.0,28.5568,0
26,113298,2024-11-18 13:09:05.216691200,116,False,False,182,1051,[8],False,False,False,False,False,1.414214,24.0128,0
27,113299,2024-11-18 13:09:05.233945600,116,False,False,182,1052,[8],False,False,False,False,False,1.0,17.2544,0
28,113300,2024-11-18 13:09:05.256844800,116,False,False,181,1052,[8],False,False,False,False,False,1.0,22.8992,0
29,113301,2024-11-18 13:09:05.281510400,116,False,False,180,1052,[8],False,False,False,False,False,1.0,24.6656,0


In [45]:
max_distance_idx = trial_data['distance'].idxmax()
trial_data.loc[max_distance_idx-1:max_distance_idx+1]

Unnamed: 0,index,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click,distance,frame_ms,gap
1801,115073,2024-11-18 13:09:44.867353600,116,False,False,578,1516,[8],False,False,False,False,False,18.027756,19.264,0
1802,115074,2024-11-18 13:09:44.911168000,116,False,False,670,1393,[8],False,False,False,False,False,153.60013,43.8144,1
1803,115075,2024-11-18 13:09:44.936320000,116,False,False,675,1381,[8],False,False,False,False,False,13.0,25.152,0


In [46]:
# Set video dimensions in pixels
scale = 2 # Reduction factor
dim_x = 894//scale
dim_y = 1952//scale

# Create GidMaze object
grid = GridMaze((dim_y,dim_x),(9,4), border=True)

In [47]:
# Blank white canvas to draw on
canvas = np.ones(grid.shape)
# Draw grid in black
grid_img = grid.draw_grid(canvas, color=(0,0,0)).astype(np.uint8)
# Convert grayscale image to BGR
grid_img = cv2.cvtColor(grid_img, cv2.COLOR_GRAY2RGB)

In [48]:
test_data = event_data.loc[event_data['trial_number'].isin([12])].copy()
test_data

Unnamed: 0,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click,distance,frame_ms,gap
12098,2024-11-18 12:31:34.161817600,12,False,False,164,1063,[0],False,False,False,False,False,4.000000,22.9504,0
12099,2024-11-18 12:31:34.183513600,12,False,False,170,1065,[0],False,False,False,False,False,6.324555,21.6960,0
12100,2024-11-18 12:31:34.204300800,12,False,False,173,1067,[0],False,False,False,False,False,3.605551,20.7872,0
12101,2024-11-18 12:31:34.225241600,12,False,False,175,1068,[0],False,False,False,False,False,2.236068,20.9408,0
12102,2024-11-18 12:31:34.245337600,12,False,False,175,1069,[0],False,False,False,False,False,1.000000,20.0960,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12399,2024-11-18 12:31:40.942028800,12,False,False,242,427,[],True,False,False,False,False,10.630146,21.5040,0
12400,2024-11-18 12:31:40.963648000,12,False,False,251,423,[],True,False,False,False,False,9.848858,21.6192,0
12401,2024-11-18 12:31:40.985280000,12,False,False,260,421,[],True,False,False,False,False,9.219544,21.6320,0
12402,2024-11-18 12:31:41.009177600,12,False,False,269,420,[],True,False,False,False,False,9.055385,23.8976,0


In [49]:
# test_data.index[0]

In [50]:
video_filename

'A:/Clickbait/1005/diet6/11182024_1005_diet6.avi'

In [51]:
single_trial = False
display = True
write = False
loop = False

if single_trial:
    test_data = event_data.loc[event_data['trial_number'].isin([20])].copy()
    ii = test_data.index[0]

else:
    test_data = event_data.copy()
    ii = 0  # Start from 0

# Load video
video = cv2.VideoCapture(video_filename)
video.set(cv2.CAP_PROP_POS_FRAMES, ii)

# Display video with event overlay
while True:
    # Load video frame
    ret, frame = video.read()
    frame = cv2.resize(frame, (frame.shape[1]//scale, frame.shape[0]//scale)).astype(np.uint8)
    current_frame = int(video.get(cv2.CAP_PROP_POS_FRAMES))

    # Draw Grid
    frame = grid.draw_grid(frame, color=(0,0,0), opacity=.75).astype(np.uint8)

    # Draw target
    target = grid.get_target_cell(test_data['target_cell'][ii])
    frame = cv2.rectangle(frame, target[0], target[1], (0,0,0), -1)

    # Get mouse centroid
    pt_x = test_data['centroid_x'][ii]//scale
    pt_y = test_data['centroid_y'][ii]//scale

    # Get mouse cell and draw
    cell_i, cell_j = grid.get_mouse_cell(pt_x, pt_y)
    frame = grid.draw_cell(frame, cell_i, cell_j, (128,128,255), -1, opacity=.25)

    # Draw mouse centroid
    frame = cv2.circle(frame, (pt_x, pt_y), 5, (255,255,255), -1, cv2.LINE_AA)

    # Set state color for trial number
    if test_data['iti'][ii] == True:
        state_color = (255,128,128)
    elif test_data['reward_state'][ii] == True:
        state_color = (128,255,128)
    else:
        state_color = (128,128,255)

    # Rotate frame 90 degrees
    frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
    #frame = cv2.rotate(frame, cv2.ROTATE_180)
    
    ## Draw trial number in state color
    frame = cv2.putText(frame, str(test_data['trial_number'][ii]+1), (20,55), cv2.FONT_HERSHEY_SIMPLEX, 1.5, state_color, 2, cv2.LINE_AA)
    ## Draw video frame number
    #frame_count = f"Video frame: {current_frame}, DataFrame index: {ii}"
    #frame = cv2.putText(frame, frame_count, (20,60), cv2.FONT_HERSHEY_SIMPLEX, .5, (255,255,255), 1, cv2.LINE_AA)

    # Display frame
    if display:
        cv2.imshow("clickbait", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    if write:
        cv2.imwrite(f"output/frame_{ii}.png", frame)

    # Increase counter
    ii += 1

    # Loop counter if needed
    if loop:
        if ii >= test_data.index[-1]:
            ii = test_data.index[0]
            video.set(cv2.CAP_PROP_POS_FRAMES, ii) 
    else:
        if ii >= test_data.index[-1]:
            break

cv2.destroyAllWindows()

In [52]:
reward_state = event_data[event_data['reward_state'] == True]
search_state = event_data[(event_data['reward_state'] == False) & (test_data['poke_left'] | test_data['poke_right'] == False)]

In [53]:
print(reward_state['distance'].mean())
print(search_state['distance'].mean())

7.835663274190006
6.982436548268861
