In [1]:
import numpy as np
import pandas as pd
import sys
import matplotlib.pyplot as plt
from tools import * 
from utils import *

# Choose dataset and mouse

In [2]:
dataset = 'M030_2024_04_12_09_40'
mouse_id = 'M030'


# Load events data 

In [3]:
dir_to_events = '/data/mouse_data/processed/'+mouse_id+'/'+dataset+'/'+dataset+'_behav/'

df_events = pd.read_csv(dir_to_events + dataset+'_events.csv')

 Subtract before camera trigger so that time 0 equals the first camera trigger :) 


In [4]:
df_events['time_shifted'] = df_events['time'] - df_events[df_events['name']=='before_camera_trigger'].time.values

In [5]:
df_events.head()

Unnamed: 0.1,Unnamed: 0,type,name,time,duration,value,time_shifted
0,0,info,Experiment name,,,run_task,
1,1,info,Task name,,,earthquake-long,
2,2,info,Task file hash,,,397468066,
3,3,info,Setup ID,,,COM9,
4,4,info,Subject ID,,,M030_2024_04_12_09_40,


# Load behaviour (joint angles and keypoints)

In [6]:
dir_to_key_points = '/data/mouse_data/processed/'+mouse_id+'/'+dataset+'/'+dataset+'_pose_estimation/'
df_pose = pd.read_csv(dir_to_key_points + dataset+'_3dpts_angles.csv')

In [7]:
df_pose.head()

Unnamed: 0,shoulder_center_x,shoulder_center_y,shoulder_center_z,left_shoulder_x,left_shoulder_y,left_shoulder_z,left_paw_x,left_paw_y,left_paw_z,right_shoulder_x,...,right_wrist_error,right_wrist_score,right_wrist_ncams,fnum,left_elbow_angle,right_elbow_angle,left_knee_angle,right_knee_angle,left_ankle_angle,right_ankle_angle
0,22.568039,-12.787642,176.544237,14.649775,-5.457195,178.72409,16.995655,8.254706,178.622392,27.003074,...,0,1,6,1,129.184099,117.118031,88.995706,81.432577,117.187683,67.129205
1,22.725028,-13.038947,176.558044,14.945477,-5.462571,178.348523,17.321954,8.374876,175.367738,26.900694,...,0,1,6,2,139.103293,111.405901,92.631714,75.380036,128.459479,60.49034
2,22.791554,-13.152582,176.561081,15.169879,-5.306657,177.722464,17.778597,8.563668,171.919859,26.751527,...,0,1,6,3,141.715123,108.16515,85.320908,69.078621,128.520922,56.628642
3,22.814629,-13.194355,176.544332,15.30902,-5.111557,177.01062,17.920335,8.888127,168.541629,26.64667,...,0,1,6,4,143.962208,108.133504,67.598765,67.333393,108.557077,58.424544
4,22.899409,-13.193358,176.530568,15.495248,-4.947729,176.506788,17.719539,9.013606,165.271706,26.741539,...,0,1,6,5,145.88697,107.088191,57.654866,63.841905,83.688857,62.82321


In [8]:
len(df_pose)

360014

# Load and process spike data

In [9]:
probe_nb = '0'
break_rec = 198
str_motor_spikes, m1_spikes = load_spike_data(mouse_id, dataset,probe_nb,break_rec)



In [10]:
probe_nb = '1'
break_rec = 170
str_sensor_spikes, s1_spikes = load_spike_data(mouse_id, dataset,probe_nb,break_rec)



In [11]:
str_sensor_spikes.shape

(153, 361811)

# Create dataframe

Create one trial (row) for each solenoid

In [12]:
# Define bin size and window size
bin_size = 10  # 10 ms bins
window_size = 1000  # 1 second before and after event
bins_before_after = window_size // bin_size  # Number of bins before and after event

# Dictionary to store the spike data arrays
spike_data_dict = {
    'm1_spikes': m1_spikes,
    'str_motor_spikes': str_motor_spikes,
    's1_spikes': s1_spikes,
    'str_sensor_spikes': str_sensor_spikes,
}

list_of_angles = [col for col in df_pose.columns if 'angle' in col]

angles_data_dict = {}

for column in list_of_angles:
    angles_data_dict[column] = df_pose[column].values




In [89]:
def extract_spike_trial_data(spike_data, start_bin,end_bin):
    # Ensure the indices are within bounds
    if start_bin >= 0 and end_bin <= spike_data.shape[1]:
        trial_data = spike_data[:,start_bin:end_bin]
        return trial_data
    else:
        return None

In [90]:
def extract_keypoints_trial_data(spike_data, start_bin,end_bin):
    # Ensure the indices are within bounds
    if start_bin >= 0 and end_bin <= spike_data.shape[0]:
        
        trial_data = spike_data[start_bin:end_bin]
        return trial_data.flatten()
    else:
        return None

In [128]:
# Create a dictionary to store the trial data for each spike array and variable
trial_data_dict = {key: [] for key in list(spike_data_dict.keys())+ list(angles_data_dict.keys())}

sol_direction = df_events[df_events['name']=='Sol_direction'].time_shifted.values

##### Make the first trial the whole quiet period at the beginning 

event_time = df_events[df_events['name']=='quiet_period_end'].time_shifted.values
event_bin = event_time // bin_size  # Convert event time to bin index
start_bin = 0
end_bin = int(event_bin)

for key, spike_data in spike_data_dict.items():

    trial_data = extract_spike_trial_data(spike_data, start_bin, end_bin)
    trial_data_dict[key].append(trial_data)

for key, keypoint_data in angles_data_dict.items():

    trial_data = extract_keypoints_trial_data(keypoint_data, start_bin, end_bin)
    trial_data_dict[key].append(trial_data)


# create one trial for each perturbation 

for event_time in sol_direction:
    event_time = event_time // 10
    event_bin = event_time // bin_size  # Convert event time to bin index
    start_bin = int(event_bin - bins_before_after)
    end_bin = int(event_bin + bins_before_after)

    for key, spike_data in spike_data_dict.items():
        if isinstance(spike_data, np.ndarray):
            trial_data = extract_spike_trial_data(spike_data, start_bin, end_bin)
        else:
            trial_data = spike_data[event_time]
        
        if trial_data is not None:
            trial_data_dict[key].append(trial_data)
    
    for key, keypoint_data in angles_data_dict.items():
        if isinstance(keypoint_data, np.ndarray):
            trial_data = extract_keypoints_trial_data(keypoint_data, start_bin, end_bin)
        else:
            trial_data = keypoint_data[event_time]
        
        if trial_data is not None:
            trial_data_dict[key].append(trial_data)
    
##### Make the last trial the whole quiet period at the beginning 

event_time = df_events[df_events['name']=='quiet_period_start'].time_shifted.values
event_bin = event_time // bin_size  # Convert event time to bin index
start_bin = int(event_bin)
end_bin = int(df_events[df_events['name']=='session_timer'].time_shifted.values // bin_size )

for key, spike_data in spike_data_dict.items():

    trial_data = extract_spike_trial_data(spike_data, start_bin, end_bin)
    trial_data_dict[key].append(trial_data)

for key, keypoint_data in angles_data_dict.items():

    trial_data = extract_keypoints_trial_data(keypoint_data, start_bin, end_bin)
    trial_data_dict[key].append(trial_data)




# Create a DataFrame where each row is a trial, with separate columns for each spike array and variable

df_trials = pd.DataFrame({'trial_id':range(len(sol_direction) + 2),'event_time':[np.nan] + list(sol_direction) + [np.nan]})

df_trials['sol_direction'] = [np.nan] + list(df_events[df_events['name']=='Sol_direction'].value.values) + [np.nan]

df_trials['perturbation_bin'] = [np.nan] + list(np.ones(len(sol_direction))*100) + [np.nan]

df_trials['trialType']= 'perturbation'

df_trials.loc[[0,len(df_trials)-1],'trialType'] = 'spontaneous' # first and last trials are quiet period

for key in trial_data_dict.keys():
    df_trials[key] = trial_data_dict[key]




  end_bin = int(event_bin)
  start_bin = int(event_bin)
  end_bin = int(df_events[df_events['name']=='session_timer'].time_shifted.values // bin_size )


In [129]:
df_trials.head()

Unnamed: 0,trial_id,event_time,sol_direction,perturbation_bin,trialType,m1_spikes,str_motor_spikes,s1_spikes,str_sensor_spikes,left_elbow_angle,right_elbow_angle,left_knee_angle,right_knee_angle,left_ankle_angle,right_ankle_angle
0,0,,,,spontaneous,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[129.18409900512486, 139.10329336738377, 141.7...","[117.11803139586924, 111.4059006364921, 108.16...","[88.99570582203008, 92.63171364274764, 85.3209...","[81.43257675883319, 75.38003632027481, 69.0786...","[117.18768277113504, 128.45947891963243, 128.5...","[67.12920499283922, 60.49033956452664, 56.6286..."
1,1,305009.0,5.0,100.0,perturbation,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[117.06575623027888, 113.78532894902098, 112.6...","[125.33736724957572, 118.22442978812428, 112.8...","[45.7071365696036, 46.30436048830133, 46.37119...","[80.4188553189747, 78.03497515412027, 74.58574...","[39.44832713281205, 39.2727000182535, 38.30628...","[77.39167059645153, 74.17487792977438, 69.3768..."
2,2,310703.0,4.0,100.0,perturbation,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[95.00810768555134, 118.169696143769, 116.7976...","[146.9310595922003, 161.939151994095, 166.7651...","[59.50232593947056, 57.996965747462895, 55.452...","[69.16711458185428, 70.07436241028357, 68.1145...","[39.10847001306168, 37.05807651298933, 46.0287...","[30.380371873105343, 40.75181260865588, 48.924..."
3,3,316669.0,2.0,100.0,perturbation,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[97.211065826272, 99.92305917912942, 101.23160...","[129.42472903076177, 140.9474363941409, 163.61...","[66.96711667774626, 77.55387549080794, 89.4424...","[41.5958233165213, 36.45963979049096, 52.48881...","[81.95405844575761, 92.21193377245424, 107.373...","[95.02164962297178, 66.25845508160552, 52.4953..."
4,4,324523.0,5.0,100.0,perturbation,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[135.5664808018621, 128.84889328833623, 119.67...","[148.79163293003285, 141.38604429051617, 132.7...","[48.93661819872716, 48.679780207549456, 50.625...","[53.723365916853865, 51.36813911587471, 64.248...","[62.70530910171245, 65.57326397691517, 69.6818...","[55.79656059450058, 50.19269436680059, 57.7548..."


In [100]:
np.unique(df_events['name'])

array(['Experiment name ', 'Setup ID', 'Sol_direction', 'Start date',
       'Subject ID', 'Task file hash', 'Task name',
       'before_camera_trigger', 'motion', 'quiet_period_end',
       'quiet_period_start', 'session_end', 'session_middle',
       'session_start', 'session_timer', 'trial_off', 'trial_on'],
      dtype=object)

In [102]:
df_events[df_events['name']=='quiet_period_end'].time_shifted.values

array([299996.])