In [1]:
import numpy as np
import pandas as pd
import datetime
import ast
from clickbait import *

### Search for Clickbait Sessions

In [5]:
# Where to look for clickbait datasets
data_dir = f'A:/Clickbait/'

# Get session and filenames for directories with .avi files greater than 1gb
datasets, sessions, files = scan_dataset(data_dir, min_size_bytes=1e9, filetype='.avi')

### Load data for a single session.

In [6]:
# Set session index
idx = 48

# Set path and filename prefix
data_path = f"{data_dir}{datasets[idx]}/{sessions[idx]}/{files[idx]}"

# Get video filename
video_filename = f"{data_path}.avi"

# Load video timestamps
video_ts = pd.read_csv(f"{data_path}_video_timestamp.csv")
video_ts.columns = ['timestamp']

# Load events .csv part A
col_names_a = ['trial_number', 'timestamp', 'poke_left', 'poke_right', 'centroid_x', 'centroid_y', 'target_cell']
event_data_a = pd.read_csv(f"{data_path}_eventsA.csv")
event_data_a.columns = col_names_a
pd.to_datetime(event_data_a['timestamp'])

# Load events .csv part B
col_names_b = ['iti', 'reward_state', 'water_left', 'water_right', 'click']
event_data_b = pd.read_csv(f"{data_path}_eventsB.csv")
event_data_b.columns = col_names_b

# Concatenate eventsA abd eventsB dataframes
if len(event_data_a) == len(event_data_b):
    event_data = pd.concat([event_data_a, event_data_b], axis=1)
else:
    print("Event dataframes must contain same number of rows")
    min_length = min(len(event_data_a), len(event_data_b))
    max_length = max(len(event_data_a), len(event_data_b))
    print(f"Trimmed long dataframe by {max_length-min_length} rows.")
    event_data_a = event_data_a.iloc[:min_length]
    event_data_b = event_data_b.iloc[:min_length]
    event_data = pd.concat([event_data_a, event_data_b], axis=1)

#Set types for each column in the dataframe
event_data = event_data.astype({
    'trial_number': 'uint8',
    'timestamp': 'datetime64[ns]',
    'poke_left': 'bool',
    'poke_right': 'bool',
    'centroid_x': 'uint16',
    'centroid_y': 'uint16',
    'target_cell': 'str',
    'iti': 'bool',
    'water_left': 'bool',
    'water_right': 'bool',
    'reward_state': 'bool',
    'click': 'bool'})

# Convert string representations of lists to actual lists
event_data['target_cell'] = event_data['target_cell'].apply(ast.literal_eval)

# Check Lengths of video and events dataframe
print(f"Video length: {len(video_ts)} frames")
print(f"Events Data Length: {len(event_data)} rows")

Video length: 138110 frames
Events Data Length: 138083 rows


### Resample to synch event and video timestamps

In [7]:
# Synchronize events data with video, using nearest matches in the 'timestamp' column 
video_ts['timestamp'] = pd.to_datetime(video_ts['timestamp'])
event_data['timestamp'] = pd.to_datetime(event_data['timestamp'])

# Set 'timestamp' as the index of each dataframe
video_ts = video_ts.set_index('timestamp')
event_data = event_data.set_index('timestamp')

# Map event_data onto the video_data timestamps, using nearest matches between timestamps
event_data = event_data.reindex(video_ts.index, method='nearest')
video_ts = video_ts.reset_index()
event_data = event_data.reset_index()

# Check that event dataframe is the same length as video frames
print(f"Data resampled to match video length:")
print(f"Video length: {len(video_ts)} frames")
print(f"Events Data Length: {len(event_data)} rows")

Data resampled to match video length:
Video length: 138110 frames
Events Data Length: 138110 rows


### Preview Events Dataframe

In [8]:
event_data

Unnamed: 0,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click
0,2024-11-27 14:45:13.180288000,0,False,False,716,168,[7],False,False,False,False,False
1,2024-11-27 14:45:13.197568000,0,False,False,716,168,[7],False,False,False,False,False
2,2024-11-27 14:45:13.213875200,0,False,False,716,168,[7],False,False,False,False,False
3,2024-11-27 14:45:13.233971200,0,False,False,716,168,[7],False,False,False,False,False
4,2024-11-27 14:45:13.253888000,0,False,False,716,168,[7],False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...
138105,2024-11-27 15:30:41.726604800,106,False,True,743,1061,[],False,False,False,False,False
138106,2024-11-27 15:30:41.743628800,106,False,True,741,1061,[],False,False,False,False,False
138107,2024-11-27 15:30:41.762726400,106,False,True,740,1061,[],False,False,False,False,False
138108,2024-11-27 15:30:41.853875200,106,False,True,740,1060,[],False,False,False,False,False


### Visualize Trajectories

In [9]:
# Extract list of trials from session
trials_list = event_data['trial_number'].unique().tolist()
print(trials_list)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106]


In [11]:
visualize_trial_trajectory(event_data, trial_number=[trials_list[17]], color_code="frame_number", target_frame=True, opacity=.5)


Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. Call result.infer_objects(copy=False) instead. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`



In [30]:
# Set video dimensions in pixels
scale = 2 # Reduction factor
dim_x = 894//scale
dim_y = 1952//scale

# Create GridMaze object
grid = GridMaze((dim_y,dim_x),(13,6), border=True)  # 9,4 for diet clickbait

In [31]:
# Blank white canvas to draw on
canvas = np.ones(grid.shape)
# Draw grid in black
grid_img = grid.draw_grid(canvas, color=(0,0,0)).astype(np.uint8)
# Convert grayscale image to BGR
grid_img = cv2.cvtColor(grid_img, cv2.COLOR_GRAY2RGB)

In [32]:
test_data = event_data.loc[event_data['trial_number'].isin([18])].copy()
test_data

Unnamed: 0,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click
24805,2024-11-27 14:53:23.217932800,18,False,False,523,1247,[28],False,False,False,False,False
24806,2024-11-27 14:53:23.237171200,18,False,False,530,1238,[28],False,False,False,False,False
24807,2024-11-27 14:53:23.257318400,18,False,False,537,1230,[28],False,False,False,False,False
24808,2024-11-27 14:53:23.280947200,18,False,False,545,1226,[28],False,False,False,False,False
24809,2024-11-27 14:53:23.298816000,18,False,False,556,1222,[28],False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...
25555,2024-11-27 14:53:38.029145600,18,False,False,246,684,[],True,False,False,False,False
25556,2024-11-27 14:53:38.049408000,18,False,False,244,679,[],True,False,False,False,False
25557,2024-11-27 14:53:38.069299200,18,False,False,242,673,[],True,False,False,False,False
25558,2024-11-27 14:53:38.088192000,18,False,False,238,662,[],True,False,False,False,False


In [None]:
# Still working on how to do individual-trial clips
test_data.index[0]

In [35]:
video_filename

'A:/Clickbait/1006/full7/11272024_1006_full7.avi'

In [41]:
single_trial = False
display = False
write = True
loop = False

if single_trial:
    test_data = event_data.loc[event_data['trial_number'].isin([20])].copy()
    ii = test_data.index[0]

else:
    test_data = event_data.copy()
    ii = 0  # Start from 0

# Load video
video = cv2.VideoCapture(video_filename)
video.set(cv2.CAP_PROP_POS_FRAMES, ii)

# Display video with event overlay
while True:
    # Load video frame
    ret, frame = video.read()
    frame = cv2.resize(frame, (frame.shape[1]//scale, frame.shape[0]//scale)).astype(np.uint8)
    current_frame = int(video.get(cv2.CAP_PROP_POS_FRAMES))

    # Draw Grid
    frame = grid.draw_grid(frame, color=(0,0,0), opacity=.75).astype(np.uint8)

    # Draw target
    target = grid.get_target_cell(test_data['target_cell'][ii])
    frame = cv2.rectangle(frame, target[0], target[1], (0,0,0), -1)

    # Get mouse centroid
    pt_x = test_data['centroid_x'][ii]//scale
    pt_y = test_data['centroid_y'][ii]//scale

    # Get mouse cell and draw
    cell_i, cell_j = grid.get_mouse_cell(pt_x, pt_y)
    frame = grid.draw_cell(frame, cell_i, cell_j, (128,128,255), -1, opacity=.25)

    # Draw mouse centroid
    frame = cv2.circle(frame, (pt_x, pt_y), 5, (255,255,255), -1, cv2.LINE_AA)

    # Set state color for trial number
    if test_data['iti'][ii] == True:
        state_color = (255,128,128)
    elif test_data['reward_state'][ii] == True:
        state_color = (128,255,128)
    else:
        state_color = (128,128,255)

    # Rotate frame 90 degrees
    frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
    #frame = cv2.rotate(frame, cv2.ROTATE_180)
    
    ## Draw trial number in state color
    frame = cv2.putText(frame, str(test_data['trial_number'][ii]+1), (20,55), cv2.FONT_HERSHEY_SIMPLEX, 1.5, state_color, 2, cv2.LINE_AA)
    ## Draw video frame number
    #frame_count = f"Video frame: {current_frame}, DataFrame index: {ii}"
    #frame = cv2.putText(frame, frame_count, (20,60), cv2.FONT_HERSHEY_SIMPLEX, .5, (255,255,255), 1, cv2.LINE_AA)

    # Display frame
    if display:
        cv2.imshow("clickbait", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    if write:
        cv2.imwrite(f"output/frame_{ii}.png", frame)

    # Increase counter
    ii += 1

    # Loop counter if needed
    if loop:
        if ii >= test_data.index[-1]:
            ii = test_data.index[0]
            video.set(cv2.CAP_PROP_POS_FRAMES, ii) 
    else:
        if ii >= test_data.index[-1]:
            break

cv2.destroyAllWindows()

In [15]:
test_data

Unnamed: 0,timestamp,trial_number,poke_left,poke_right,centroid_x,centroid_y,target_cell,iti,reward_state,water_left,water_right,click
0,2024-11-13 11:56:10.782566400,0,False,False,545,126,[13],False,False,False,False,False
1,2024-11-13 11:56:10.795033600,0,False,False,545,126,[13],False,False,False,False,False
2,2024-11-13 11:56:10.811315200,0,False,False,545,126,[13],False,False,False,False,False
3,2024-11-13 11:56:10.833548800,0,False,False,545,126,[13],False,False,False,False,False
4,2024-11-13 11:56:10.851788800,0,False,False,545,126,[13],False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...
137995,2024-11-13 12:41:39.151731200,25,False,False,209,903,[],False,True,False,False,False
137996,2024-11-13 12:41:39.213094400,25,False,False,210,899,[],False,True,False,False,False
137997,2024-11-13 12:41:39.275878400,25,False,False,211,893,[],False,True,False,False,False
137998,2024-11-13 12:41:39.335628800,25,False,False,210,888,[],False,True,False,False,False


In [16]:
video_ts['timestamp'][4029]

Timestamp('2024-11-13 11:57:30.375846400')

In [17]:
grid.cells[1]

[[[0, 108], [111, 216]],
 [[111, 108], [222, 216]],
 [[222, 108], [333, 216]],
 [[333, 108], [444, 216]]]

In [18]:
from itertools import chain
flattened = list(chain(*grid.cells))

In [19]:
flattened = sum(grid.cells, [])

In [20]:
flattened[4][0]

[0, 108]