In [3]:
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt

import torch

#from itertools import product
from pynwb import NWBHDF5IO
from nlb_tools.nwb_interface import NWBDataset

import warnings
warnings.filterwarnings("ignore")

In [None]:
# 'dandi download' downloads the data in this folder:
datapath = 'data/NWB/000070/sub-Jenkins/'
dataset = NWBDataset(datapath, split_heldout=False)

# Extract neural data and lagged hand velocity.
#binsize = 5 #ms
#dataset.resample(binsize)

trial_info = dataset.trial_info

Dropping Position_Cursor due to timestamp mismatch.
Dropping Position_Eye due to timestamp mismatch.
Dropping Position_Hand due to timestamp mismatch.
Spikes found outside of observed interval.
Dropping Position_Cursor due to timestamp mismatch.
Dropping Position_Eye due to timestamp mismatch.
Dropping Position_Hand due to timestamp mismatch.
Spikes found outside of observed interval.
Dropping Position_Cursor due to timestamp mismatch.
Dropping Position_Eye due to timestamp mismatch.
Dropping Position_Hand due to timestamp mismatch.
Spikes found outside of observed interval.


In [2]:
# Combining the number of columns in the 'spikes' field with those in the 'heldout_spikes' field gives the total number of neurons.

n_null_trials = trial_info.isnull().sum()['success']
n_neurons = dataset.data.spikes.values.shape[1] + dataset.data.heldout_spikes.values.shape[1]

print(f'number of neurons: {n_neurons}')
print(f'total number of trials: {len(trial_info)}')
print(f'number of null trials: {n_null_trials}')

NameError: name 'trial_info' is not defined

In [4]:
trial_info

NameError: name 'trial_info' is not defined

In [None]:
for column in trial_info.columns:
    print(column)

In [None]:
# Note that there is a number of trials at the begining of the session(s) that has no metadata at all.
# The NWBDataset function: make_trial_data() drops all of those.

print(f"Total num of trials: {len(trial_info)}\n")
print("Count of null values in each column:\n")

trial_info.isnull().sum()

In [None]:
# Number of all time bins in the entire dataset.
# That's right, each row is the measures in one time bin.
# Also, the original sampling rate is 100Hz (10ms bin size).
# Keep in mind that not all the trials are the same length.

len(dataset.data)

In [None]:
# The make_trial_data() returns a DataFrame containing trialized data. It has the same fields as the continuous `dataset.data` DataFrame,
# But adds `trial_id`, `trial_time`, and `align_time`. Till here, each row is still a time bin.
# Note: Later we do the cropping and alignment arount move_onset by ourselves.

trial_data = dataset.make_trial_data()
len(trial_data)

In [None]:
trial_data

In [None]:
trial_info[trial_info['success'] == False]

In [None]:
trial_info['active_target'].value_counts()

In [None]:
np.isnan(trial_data['spikes'].to_numpy().flatten()).any()

In [None]:
trial_lens = []

for trial_id, trial in trial_data.groupby('trial_id'):
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]
    
    trial_lens.append((trial.trial_time.values[-1] / np.timedelta64(1, 's')) - (trial.trial_time.values[0] / np.timedelta64(1, 's')))
    
plt.hist(trial_lens, bins='auto', density=False, alpha=0.7, edgecolor='gray')

plt.title('trial length\n')
plt.xlabel('time interval of trial length (sec)')
plt.savefig('output_figs/tl.png')
plt.show

In [None]:
trials = [trial[1] for trial in trial_data.groupby('trial_id')]
inter_trial_intervals = []
    
for i, trial in enumerate(trials):
    trial_id = i + n_null_trials - 1
    trial_id_next = i + 1 + n_null_trials - 1
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]
    trial_id_trial_info_next = trial_info[trial_info['trial_id'] == trial_id_next]
    
    inter_trial_intervals.append((trial_id_trial_info_next['start_time'].iloc[0] / np.timedelta64(1, 's')) - (trial_id_trial_info['start_time'].iloc[0] / np.timedelta64(1, 's')))
    
plt.hist(inter_trial_intervals, bins='auto', density=False, alpha=0.7, edgecolor='gray')

plt.title('inter-trial time interval\n')
plt.xlabel('inter-trial time interval (sec)')
plt.savefig('output_figs/iit.png')
plt.show()

In [None]:
len(trials)

In [None]:
time_utill_move = []

for i, _ in enumerate(trials):
    trial_id = i + n_null_trials
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]
    
    time_utill_move.append((trial_id_trial_info['move_onset_time'].iloc[0] / np.timedelta64(1, 's')) - (trial_id_trial_info['start_time'].iloc[0] / np.timedelta64(1, 's')))
    
plt.hist(time_utill_move, bins='auto', density=False, alpha=0.7, edgecolor='gray')
plt.title('time untill movement\n')
plt.xlabel('time interval before move_onset (sec)')
plt.savefig('output_figs/tibm.png')
plt.show()

In [None]:
time_after_move = []

for i, _ in enumerate(trials):
    trial_id = i + n_null_trials
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]
    
    time_after_move.append((trial_id_trial_info['end_time'].iloc[0] / np.timedelta64(1, 's')) - (trial_id_trial_info['move_onset_time'].iloc[0] / np.timedelta64(1, 's')))
    
plt.hist(time_after_move, bins='auto', density=False, alpha=0.7, edgecolor='gray')
plt.title('time after movement\n')
plt.xlabel('time interval after move_onset (sec)')
plt.savefig('output_figs/tiam.png')
plt.show()

### Forming conditions

In [None]:
def get_simple_cond(angle):
        
    if 350 <= angle < 360 or 0 <= angle < 38:
        return 0
    elif 38 <= angle < 125:
        return 1
    elif 125 <= angle < 175:
        return 2
    elif 175 <= angle < 212:
        return 3
    elif 212 <= angle < 232:
        return 4
    elif 232 <= angle < 280:
        return 5
    elif 280 <= angle <= 329:
        return 6
    elif 329 <= angle <= 350:
        return 7
    else:
        raise ValueError("Angle out of range")

def n_unigue_conds(trial_conds):
    # Convert each list to a frozenset and use a set to track unique frozensets
    unique_conds = set(frozenset(cond) for cond in trial_conds)
    return len(unique_conds)

In [None]:
## Plot trial-averaged reaches

# Find unique conditions
conds = trial_info.set_index(['trial_type', 'trial_version']).index.unique().tolist()
conds = [cond for cond in conds if not any(math.isnan(x) for x in cond)]

# Initialize plot
fig = plt.figure(figsize=(6, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])

# Loop over conditions and compute average trajectory
for cond in conds:
    # Find trials in condition
    mask = np.all(trial_info[['trial_type', 'trial_version']] == cond, axis=1)
    # Extract trial data
    trial_d = dataset.make_trial_data(ignored_trials=(~mask))
    # Average hand position across trials
    traj = trial_d.groupby('align_time')[[('hand_pos', 'x'), ('hand_pos', 'y')]].mean().to_numpy()
    # Determine reach angle for color
    active_target = trial_info[mask].target_pos.iloc[0][int(dataset.trial_info[mask].active_target.iloc[0])]
    reach_angle = np.arctan2(*active_target[::-1])
    # Plot reach
    ax.plot([0, traj[:, 0][-1]], [0, traj[:, 1][-1]], linewidth=0.7, color=plt.cm.hsv(reach_angle / (2*np.pi) + 0.5))

angle_radians = np.radians(350)
x_end = 140 * np.cos(angle_radians)
y_end = 140 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=0.7, color='black')
'''
angle_radians = np.radians(16)
x_end = 140 * np.cos(angle_radians)
y_end = 140 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=0.7, color='black')
'''
angle_radians = np.radians(38)
x_end = 80 * np.cos(angle_radians)
y_end = 80 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=0.7, color='black')

angle_radians = np.radians(125)
x_end = 50 * np.cos(angle_radians)
y_end = 50 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=0.7, color='black')

angle_radians = np.radians(175)
x_end = 140 * np.cos(angle_radians)
y_end = 140 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=0.7, color='black')
'''
angle_radians = np.radians(193)
x_end = 140 * np.cos(angle_radians)
y_end = 140 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=0.7, color='black')
'''
angle_radians = np.radians(212)
x_end = 140 * np.cos(angle_radians)
y_end = 140 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=0.7, color='black')

angle_radians = np.radians(232)
x_end = 140 * np.cos(angle_radians)
y_end = 140 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=0.7, color='black')

angle_radians = np.radians(280)
x_end = 100 * np.cos(angle_radians)
y_end = 100 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=0.7, color='black')

angle_radians = np.radians(329)
x_end = 140 * np.cos(angle_radians)
y_end = 140 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=0.7, color='black')

angle_radians = np.radians(0)
x_end = 140 * np.cos(angle_radians)
y_end = 140 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=1.7, color='navy')

angle_radians = np.radians(90)
x_end = 50 * np.cos(angle_radians)
y_end = 50 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=1.7, color='navy')

angle_radians = np.radians(180)
x_end = 140 * np.cos(angle_radians)
y_end = 140 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=1.7, color='navy')

angle_radians = np.radians(270)
x_end = 100 * np.cos(angle_radians)
y_end = 100 * np.sin(angle_radians)
ax.plot([0, x_end], [0, y_end], linewidth=1.7, color='navy')

plt.axis('off')
plt.show()

In [None]:
## Plot trial-averaged reaches

# Find unique conditions
maze_conds = trial_info.set_index(['trial_type', 'trial_version']).index.unique().tolist()
maze_conds = [cond for cond in maze_conds if not any(math.isnan(x) for x in cond)]

orig_conds = {}
simp_conds = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[]}

fig = plt.figure(figsize=(6, 6))
fig.suptitle('Aligned trials (same length)')
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])

# Loop over conditions and compute average trajectory
for cond_idx, cond in enumerate(maze_conds):
    # Find trials in condition
    mask = np.all(dataset.trial_info[['trial_type', 'trial_version']] == cond, axis=1)
    trial_d = dataset.make_trial_data(align_field='move_onset_time', align_range=(-240, 660), ignored_trials=(~mask))
    traj = trial_d.groupby('align_time')[[('hand_pos', 'x'), ('hand_pos', 'y')]].mean().to_numpy()
    # Determine reach angle for color
    reach_angle = np.arctan2(*trial_info[mask].target_pos.iloc[0][int(trial_info[mask].active_target.iloc[0])][::-1])
    # Plot reach
    ax.plot(traj[:, 0], traj[:, 1], linewidth=0.7, color=plt.cm.hsv(reach_angle / (2*np.pi) + 0.5))
    
    orig_conds[cond_idx] = trial_d.trial_id.drop_duplicates().values
    simp_conds[get_simple_cond(math.degrees(reach_angle) + 360 / 2)].append(trial_d.trial_id.drop_duplicates().values)

simp_conds = {key: np.concatenate(value) for key, value in simp_conds.items()}
    
plt.axis('off')
plt.show()

In [None]:
conds = []
for trial_id, trial in trial_data.groupby('trial_id'):
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]
    
    for cond, trial_ids in orig_conds.items():
        if trial_id in trial_ids:
            conds.append(cond)
            break
            
maze_conds = torch.tensor(maze_conds)
conds = torch.tensor(conds)

print(maze_conds.shape)
print(conds.shape)

In [None]:
_, cond_counts = torch.unique(conds, return_counts=True)
cond_counts

## Forming trials and label vectors

In [None]:
#[col for col in trial_data.columns if any(_ in col for _ in ['vel', 'pos', 'force', 'acc', 'target'])]
label_cols = [col for col in trial_data.columns if any(_ in col for _ in ['x', 'y'])]

In [None]:
label_cols

In [None]:
# Align the trials arount the move_onset bin with offsets before and after that bin.
y = []
labels = []
target_pos = []
active_target = []
conds = []

# We want total trial length of 900ms, which is 90 time bins.
bins_before_move = 48
bins_after_move = 132

trial_length = bins_before_move + bins_after_move
n_trials = trial_data.shape[0] // trial_length

for trial_id, trial in trial_data.groupby('trial_id'):
    trial_id_info = trial_info[trial_info['trial_id'] == trial_id]
            
    # Get the untill movement in ms.
    move_time = (trial_id_info['move_onset_time'].iloc[0] / np.timedelta64(1, 'ms')) - (trial_id_info['start_time'].iloc[0] / np.timedelta64(1, 'ms'))
    # Get the number of bins until movement.
    move_bin = int(move_time // binsize)

    y_heldin_t = torch.tensor(trial.spikes.values)
    y_heldout_t = torch.tensor(trial.heldout_spikes.values)
    
    # Crop the trials arount the move_onset bin with offsets before and after that bin.
    y_t = torch.concat(
        [y_heldin_t[move_bin-bins_before_move:move_bin+bins_after_move, :], y_heldout_t[move_bin-bins_before_move:move_bin+bins_after_move, :]], dim=-1
    )
    
    y.append(y_t.reshape(1, trial_length, n_neurons))
    labels.append(torch.tensor(trial.cursor_pos.values[move_bin-bins_before_move:move_bin+bins_after_move, :]).reshape(1, trial_length, 2))
    
    target_pos.append(trial_id_info.target_pos.values[0])
    active_target.append(int(trial_id_info.active_target.values[0]))
    
    for cond, trial_ids in orig_conds.items():
        if trial_id in trial_ids:
            conds.append(cond)
    
y = torch.concat(y, dim=0)
labels = torch.concat(labels, dim=0)
conds = torch.tensor(conds)
active_target = torch.tensor(active_target)

print(y.shape)
print(labels.shape)
print(conds.shape)
print(len(target_pos))
print(active_target.shape)

### Forming events occurance time bins

In [None]:
for i, _ in enumerate(trials):
    trial_id = i + n_null_trials
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]
    
    target_on = (((trial_id_trial_info['target_on_time'].iloc[0] / np.timedelta64(1, 'ms')) - (trial_id_trial_info['start_time'].iloc[0] / np.timedelta64(1, 'ms'))) // binsize)
    gocue = (((trial_id_trial_info['go_cue_time'].iloc[0] / np.timedelta64(1, 'ms')) - (trial_id_trial_info['start_time'].iloc[0] / np.timedelta64(1, 'ms'))) // binsize)
    move_onset = (((trial_id_trial_info['move_onset_time'].iloc[0] / np.timedelta64(1, 'ms')) - (trial_id_trial_info['start_time'].iloc[0] / np.timedelta64(1, 'ms'))) // binsize)

In [None]:
target_bins = []
gocue_bins = []
move_bins = []
event_bins = []

for i, _ in enumerate(trials):
    trial_id = i + n_null_trials
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]
    
    # target : go
    delay = (((trial_id_trial_info['go_cue_time'].iloc[0] / np.timedelta64(1, 'ms')) - (trial_id_trial_info['target_on_time'].iloc[0] / np.timedelta64(1, 'ms'))) // binsize)
    # go : move
    prep = (((trial_id_trial_info['move_onset_time'].iloc[0] / np.timedelta64(1, 'ms')) - (trial_id_trial_info['go_cue_time'].iloc[0] / np.timedelta64(1, 'ms'))) // binsize)
    
    target_on = (((trial_id_trial_info['target_on_time'].iloc[0] / np.timedelta64(1, 'ms')) - (trial_id_trial_info['start_time'].iloc[0] / np.timedelta64(1, 'ms'))) // binsize)
    gocue = (((trial_id_trial_info['go_cue_time'].iloc[0] / np.timedelta64(1, 'ms')) - (trial_id_trial_info['start_time'].iloc[0] / np.timedelta64(1, 'ms'))) // binsize)
    move_onset = (((trial_id_trial_info['move_onset_time'].iloc[0] / np.timedelta64(1, 'ms')) - (trial_id_trial_info['start_time'].iloc[0] / np.timedelta64(1, 'ms'))) // binsize)
    
    target_bins.append(target_on - move_onset + bins_before_move)
    gocue_bins.append(gocue - move_onset + bins_before_move)
    move_bins.append(bins_before_move)

event_bins.append(torch.tensor(target_bins))
event_bins.append(torch.tensor(gocue_bins))
event_bins.append(torch.tensor(move_bins))
event_bins = torch.stack(event_bins)
event_bins.shape

In [None]:
event_bins = event_bins.permute(1, 0)

event_bins[event_bins < 0] = float('nan')
event_bins[event_bins > bins_before_move + bins_after_move] = float('nan')
event_bins.shape

In [None]:
event_bins

### Save data splits

In [None]:
import torch

!mkdir data
save_root_path = 'data/'

train_data, valid_data, test_data = {}, {}, {}
n_trials, seq_len, n_neurons = y.shape
n_valid_trials = 574

# obs: observations
train_data['y_obs'] = torch.Tensor(y[:-n_valid_trials])
valid_data['y_obs'] = torch.Tensor(y[-n_valid_trials:-n_valid_trials // 2])
test_data['y_obs'] = torch.Tensor(y[-n_valid_trials // 2:])

# 'n_bins_enc': Number of time bins used later by in modeling for enconding (default full trial).
# 'n_bins_obs': originaly observed trial length (after alignment)
# Same for 'n_neurons_obs' and 'n_neurons_enc'.
train_data['n_bins_obs'] = valid_data['n_bins_obs'] = test_data['n_bins_obs'] = seq_len
train_data['n_bins_enc'] = valid_data['n_bins_enc'] = test_data['n_bins_enc'] = seq_len
train_data['n_neurons_obs'] = valid_data['n_neurons_obs'] = test_data['n_neurons_obs'] = n_neurons
train_data['n_neurons_enc'] = valid_data['n_neurons_enc'] = test_data['n_neurons_enc'] = n_neurons

# Save a 1D array for event bins for each data split, for each trial, for each event.
# Note: the o here in event_bins[0] is the session-animal group.
for event_id, event in enumerate(['targrt_on_bin', 'go_cue_bin', 'move_onset_bin']):
    train_data[event] = torch.Tensor(np.array(event_bins[:-n_valid_trials, event_id]))
    valid_data[event] = torch.Tensor(np.array(event_bins[-n_valid_trials:-n_valid_trials // 2, event_id]))
    test_data[event] = torch.Tensor(np.array(event_bins[-n_valid_trials // 2:, event_id]))

train_data['hand_vel'] = torch.Tensor(np.array(labels[:-n_valid_trials, :, :]))
valid_data['hand_vel'] = torch.Tensor(np.array(labels[-n_valid_trials:-n_valid_trials // 2, :, :]))
test_data['hand_vel'] = torch.Tensor(np.array(labels[-n_valid_trials // 2:, :, :]))

train_data['conds'] = torch.Tensor(np.array(conds[:-n_valid_trials]))
valid_data['conds'] = torch.Tensor(np.array(conds[-n_valid_trials:-n_valid_trials // 2]))
test_data['conds'] = torch.Tensor(np.array(conds[-n_valid_trials // 2:]))

train_data['target_pos'] = target_pos[:-n_valid_trials]
valid_data['target_pos'] = target_pos[-n_valid_trials:-n_valid_trials // 2]
test_data['target_pos'] = target_pos[-n_valid_trials // 2:]

train_data['active_target'] = torch.Tensor(np.array(active_target[:-n_valid_trials]))
valid_data['active_target'] = torch.Tensor(np.array(active_target[-n_valid_trials:-n_valid_trials // 2]))
test_data['active_target'] = torch.Tensor(np.array(active_target[-n_valid_trials // 2:]))

'''
for label_id, label in enumerate(label_cols):
    train_data[f'{label[0]}_{label[1]}'] = torch.Tensor(np.array(labels[:-n_valid_trials, :, :]))
    valid_data[f'{label[0]}_{label[1]}'] = torch.Tensor(np.array(labels[-n_valid_trials:-n_valid_trials // 2, :, :]))
    test_data[f'{label[0]}_{label[1]}'] = torch.Tensor(np.array(labels[-n_valid_trials // 2:, :, :]))
'''
torch.save(train_data, save_root_path + f'data_train_{binsize}ms.pt')
torch.save(valid_data, save_root_path + f'data_valid_{binsize}ms.pt')
torch.save(test_data, save_root_path + f'data_test_{binsize}ms.pt')

print('Data splits (train/valid/test) saved into the "data" folder.')

## Load data splits

In [None]:
data_path = 'data/data_{split}_{bin_sz_ms}ms.pt'
train_data = torch.load(data_path.format(split='train', bin_sz_ms=cfg.bin_sz_ms))
val_data = torch.load(data_path.format(split='valid', bin_sz_ms=cfg.bin_sz_ms))
test_data = torch.load(data_path.format(split='test', bin_sz_ms=cfg.bin_sz_ms))

y_train_obs = train_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_valid_obs = val_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_test_obs = test_data['y_obs'].type(torch.float32).to(cfg.data_device)

vel_train = train_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_valid = val_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_test = test_data['velocity'].type(torch.float32).to(cfg.data_device)

y_train_dataset = torch.utils.data.TensorDataset(y_train_obs, vel_train)
y_val_dataset = torch.utils.data.TensorDataset(y_valid_obs, vel_valid)
y_test_dataset = torch.utils.data.TensorDataset(y_test_obs, vel_test)

train_dataloader = torch.utils.data.DataLoader(y_train_dataset, batch_size=cfg.batch_sz, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(y_val_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(y_test_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)

# Data dimensions
n_train_trials, n_time_bins, n_neurons_obs = y_train_obs.shape
n_valid_trials = y_valid_obs.shape[0]
n_test_trials = y_test_obs.shape[0]
n_time_bins_enc = train_data['n_time_bins_enc']

batch_sz_train = list(y_train_obs.shape)[:-1]
batch_sz_valid = list(y_valid_obs.shape)[:-1]
batch_sz_test = list(y_test_obs.shape)[:-1]

print("# training trials: {0}".format(n_train_trials))
print("# validation trials: {0}".format(n_valid_trials))
print("# testing trials: {0}".format(n_test_trials))
print("# neurons: {0}".format(n_neurons_obs))
print("# time bins: {0}".format(n_time_bins))
print("# time bins used for forcasting: {0}".format(cfg.n_bins_bhv))
print("# predicted time bins: {0}".format(n_time_bins - cfg.n_bins_bhv))

### Reach variability

In [None]:
conds_ids, cond_counts = torch.unique(conds, return_counts=True)
cond_counts

In [None]:
top_cs, top_ids = torch.topk(cond_counts, 5)
top_ids

In [None]:
trial_info['num_targets']

In [None]:
# Trial-averged 108 reach conditions

fig = plt.figure(figsize=(6, 6))
fig.suptitle('Aligned trials (same length)')
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])

# Loop over conditions and compute average trajectory
for cond_idx, cond in enumerate([cond for cond in trial_info.set_index(['trial_type', 'trial_version']).index.unique().tolist() if not any(math.isnan(x) for x in cond)]):
    # Find trials in condition
    mask = np.all(dataset.trial_info[['trial_type', 'trial_version']] == cond, axis=1)
    trial_d = dataset.make_trial_data(align_field='move_onset_time', align_range=(-240, 660), ignored_trials=(~mask))
    traj = trial_d.groupby('align_time')[[('hand_pos', 'x'), ('hand_pos', 'y')]].mean().to_numpy()
    # Determine reach angle for color
    reach_angle = np.arctan2(*trial_info[mask].target_pos.iloc[0][int(trial_info[mask].active_target.iloc[0])][::-1])
    # Plot reach
    ax.plot(traj[:, 0], traj[:, 1], linewidth=0.7, color=plt.cm.hsv(reach_angle / (2*np.pi) + 0.5))
    
plt.axis('off')
plt.savefig('trial_averged_reaches_108')
plt.show()

In [None]:
y_data = y
y_vel = labels
y_conds = conds
y_t_pos = target_pos
act_t = active_target

In [None]:
def get_cond_trials(y_vel, y_conds):
    psth = np.zeros((len(conds_ids), y_vel.shape[1], y_vel.shape[2]))

    for cond in conds_ids:
        mask = y_conds == cond
        psth[cond, :, :] = y_vel[mask, :, :].mean(axis=0)
    
    return psth

def calc_var_to_mean_ratio(psth):
    v_m_ratio =  np.sum((psth.var(axis=0)) / (psth.mean(axis=0)), axis=0)
    
    return(np.nan_to_num(v_m_ratio, nan=0))

In [None]:
psth = get_cond_trials(y_vel, y_conds)

In [None]:
psth.shape

In [None]:
vmr = calc_var_to_mean_ratio(psth)

In [None]:
vmr

In [None]:
psth[0, 0, :]

In [None]:
# Compute the variance for each tensor along the specified dimension
var = [t.var(axis=0).mean().item() for t in psth[0, :, :]]

In [None]:
var

In [None]:
np.array(var).shape

In [None]:
psth[0, :, :].shape

In [None]:


# Sort the tensors based on the calculated variances
sorted_tensors = [tensor for _, tensor in sorted(zip(variances, tensors), key=lambda x: x[0])]

# Output the sorted tensors
for i, tensor in enumerate(sorted_tensors):
    print(f"Tensor {i+1}:\n{tensor}\n")

In [None]:
psth[0, :, :]