In [1]:
import numpy as np
import pandas as pd
import cv2
from sldata import SessionData

### Load Bonsai Data (Now with SLEAP) 

In [2]:
mouse_id = "7004"
session_id = "m4"
experiment = "clickbait-motivate"

print("=== Testing SessionData Class ===")
print(f"Loading data for {mouse_id}_{session_id}...")

data = SessionData(
    mouse_id=mouse_id,
    session_id=session_id,
    experiment=experiment,
    min_spikes=50,
    verbose=True)

=== Testing SessionData Class ===
Loading data for 7004_m4...
Successfully loaded spike_times from S:\clickbait-motivate\kilosorted\7004\m4\spike_times.npy
Successfully loaded spike_templates from S:\clickbait-motivate\kilosorted\7004\m4\spike_templates.npy
Successfully loaded templates from S:\clickbait-motivate\kilosorted\7004\m4\templates.npy
Successfully loaded sniff from S:\clickbait-motivate\preprocessed\7004\m4\sniff.npy
Loaded data: ['spike_times', 'spike_templates', 'templates', 'sniff']
Successfully loaded events from S:\clickbait-motivate\bonsai\7004\m4\events.csv
Video properties: 888x1968, 30.0 FPS, 78923 frames
Filtering and decimating sniff signal from 30kHz to 1kHz...
Found 20125 peaks in sniff signal
Peak times range: 16.0 - 2649261.0 ms
Loaded 33 clusters for 7004_m4
Loaded sniff data: 2649560 samples
Found 20125 sniff events


In [3]:
data.events[-5:]

Unnamed: 0,trial_number,timestamp,poke_left,poke_right,bonsai_centroid_x,bonsai_centroid_y,target_cell,iti,water_left,water_right,...,instance.score,nose.x,nose.y,nose.score,centroid.x,centroid.y,centroid.score,tailbase.x,tailbase.y,tailbase.score
78904,38,2025-06-25 15:58:13.646528000,False,False,484,1657,40.0,False,False,False,...,1.588075,357.057348,1744.105432,0.0,504.62442,1684.380005,0.888498,596.663086,1596.383545,0.621944
78905,38,2025-06-25 15:58:13.681881600,False,False,484,1658,40.0,False,False,False,...,1.598764,355.95792,1743.258335,0.0,504.500336,1684.322998,0.885732,596.537231,1596.363525,0.630794
78906,38,2025-06-25 15:58:13.720460800,False,False,483,1658,40.0,False,False,False,...,1.63421,354.858493,1742.411237,0.0,504.331146,1684.206909,0.87617,596.137268,1592.485107,0.667485
78907,38,2025-06-25 15:58:13.748121600,False,False,483,1658,40.0,False,False,False,...,1.601967,353.759065,1741.564139,0.0,500.595245,1688.119995,0.871896,593.237183,1592.397339,0.642041
78908,38,2025-06-25 15:58:13.779187200,False,False,481,1660,40.0,False,False,False,...,1.79543,352.659637,1740.717041,0.330317,504.169464,1688.426025,0.883555,596.473022,1592.727173,0.581558


### Visualize SLEAP Tracking

In [8]:
# Visualization options
scale = 0.5  # Scaling factor for preview animation
save = False # Toggle preview
video = True  # Toggle video background
bonsai_centroid = False  # Centroid from realtime tracking

# SLEAP Nodes
nose = data.events[['nose.x', 'nose.y']].to_numpy().astype(np.int16)
cent = data.events[['centroid.x', 'centroid.y']].to_numpy().astype(np.int16)
base = data.events[['tailbase.x', 'tailbase.y']].to_numpy().astype(np.int16)
bon_cent = data.events[['bonsai_centroid_x', 'bonsai_centroid_y']].to_numpy().astype(np.int16)


for ii in range(0, data.total_frames):
    if video:
        data.video.set(cv2.CAP_PROP_POS_FRAMES, ii + data.video_offset)  # Get current video frame
        ret, canvas = data.video.read()
    else:
        canvas = np.zeros((data.height,data.width,3), dtype=np.uint8)  # Clear canvas
    
    # Draw Bonsai centroid
    if bonsai_centroid:
        cv2.circle(img=canvas, center=bon_cent[ii], radius=12, color=(128,128,128), thickness=-1, lineType=cv2.LINE_AA)

    # Draw SLEAP nodes
    cv2.circle(img=canvas, center=nose[ii], radius=5, color=(255,0,0), thickness=-1, lineType=cv2.LINE_AA)  # Nose
    cv2.circle(img=canvas, center=cent[ii], radius=5, color=(0,255,0), thickness=-1, lineType=cv2.LINE_AA)  # Centroid
    cv2.circle(img=canvas, center=base[ii], radius=5, color=(0,0,255), thickness=-1, lineType=cv2.LINE_AA)  # Tail base

    canvas = cv2.resize(canvas, dsize=(int(data.width*scale), int(data.height*scale)))
    canvas = cv2.rotate(canvas, cv2.ROTATE_90_CLOCKWISE) 

    if save:
        cv2.imwrite(f"S:/track-test-video/{mouse_id}_{session_id}_tracking_{ii}.png", canvas)
    else:
        cv2.imshow('SLEAP Test', canvas)
        # Wait for key press with timeout
        key = cv2.waitKey(16) & 0xFF
        # Break on 'q' key or ESC key
        if key == ord('q') or key == 27:
            break

cv2.waitKey(0)
cv2.destroyAllWindows()