In [3]:
import numpy as np
import pandas as pd
import seaborn as sns
import math
import yaml
import hydra

import matplotlib.pyplot as plt
from matplotlib import cm

import torch
import pytorch_lightning as lightning

from itertools import product
from nlb_tools.nwb_interface import NWBDataset

from hydra import compose, initialize
from xfads import plot_utils
from xfads.smoothers.lightning_trainers import LightningMonkeyReaching
from xfads.ssm_modules.prebuilt_models import create_xfads_poisson_log_link

In [4]:
torch.cuda.empty_cache()

In [5]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path="", job_name="monkey_reaching")

cfg = compose(config_name="config")

In [6]:
lightning.seed_everything(cfg.seed, workers=True)
torch.set_default_dtype(torch.float32)

Seed set to 1236


In [7]:
def get_int_to_verbose_map():
    options = [('S', 'L'), ('E', 'H'), ('1', '2', '3', '4', '5')]
    combinations = list(product(*options))
    mapping = {i: combination for i, combination in enumerate(combinations)}

    return mapping


def get_verbose_to_int_map():
    options = [('S', 'L'), ('E', 'H'), ('1', '2', '3', '4', '5')]
    combinations = list(product(*options))
    reverse_mapping = {combination: i for i, combination in enumerate(combinations)}

    return reverse_mapping

In [13]:
"""Downloading the data"""

datapath = 'data/000128/sub-Jenkins/'
dataset = NWBDataset(datapath)
save_root_path = 'data/'

  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)


In [14]:
# Extract neural data and lagged hand velocity
binsize = cfg.bin_sz_ms
n_neurons = cfg.n_neurons
dataset.resample(binsize)

start = -450
end = 450
trial_length = (end - start) // binsize

verbose_to_int_map = get_verbose_to_int_map()
int_to_verbose_map = get_int_to_verbose_map()

# Extract neural data
trial_info = dataset.trial_info  # .dropna()
trial_info['color'] = None
trial_info['position_id'] = None
trial_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(start, end))
n_trials = trial_data.shape[0] // trial_length

print('done')
print(trial_data.columns)

Shortened 573 trials to prevent overlap.
NaNs found in `self.data`. Dropping 18.19% of points to remove NaNs from `trial_data`.


done
MultiIndex([('align_time',   ''),
            ('clock_time',   ''),
            ('cursor_pos',  'x'),
            ('cursor_pos',  'y'),
            (   'eye_pos',  'x'),
            (   'eye_pos',  'y'),
            (  'hand_pos',  'x'),
            (  'hand_pos',  'y'),
            (  'hand_vel',  'x'),
            (  'hand_vel',  'y'),
            ...
            (    'spikes', 2861),
            (    'spikes', 2862),
            (    'spikes', 2871),
            (    'spikes', 2881),
            (    'spikes', 2911),
            (    'spikes', 2931),
            (    'spikes', 2951),
            (    'spikes', 2961),
            (  'trial_id',   ''),
            ('trial_time',   '')],
           length=195)


In [25]:
trial_info

Unnamed: 0,trial_id,start_time,end_time,move_onset_time,split,trial_type,trial_version,maze_id,success,target_on_time,go_cue_time,rt,delay,num_targets,target_pos,num_barriers,barrier_pos,active_target,color,position_id
0,0,0 days 00:00:00,0 days 00:00:00.700000,0 days 00:00:00.250000,test,,,,,NaT,NaT,,,,,,,,,
1,1,0 days 00:00:00.800000,0 days 00:00:01.500000,0 days 00:00:01.050000,test,,,,,NaT,NaT,,,,,,,,,
2,2,0 days 00:00:01.600000,0 days 00:00:02.300000,0 days 00:00:01.850000,test,,,,,NaT,NaT,,,,,,,,,
3,3,0 days 00:00:02.400000,0 days 00:00:03.100000,0 days 00:00:02.650000,test,,,,,NaT,NaT,,,,,,,,,
4,4,0 days 00:00:03.200000,0 days 00:00:03.900000,0 days 00:00:03.450000,test,,,,,NaT,NaT,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2864,2864,0 days 02:03:15.800000,0 days 02:03:18.996000,0 days 02:03:17.785000,train,34.0,1.0,91.0,True,0 days 02:03:16.562000,0 days 02:03:17.477000,308.0,915.0,1.0,"[[116, -77]]",7.0,"[[66, -43, 30, 9], [-66, 1, 11, 70], [-35, 50,...",0.0,,
2865,2865,0 days 02:03:19.100000,0 days 02:03:21.936000,0 days 02:03:20.841000,train,15.0,1.0,75.0,True,0 days 02:03:19.917000,0 days 02:03:20.465000,376.0,548.0,1.0,"[[133, -81]]",9.0,"[[-33, 47, 37, 6], [-77, 48, 61, 11], [-64, -2...",0.0,,
2866,2866,0 days 02:03:22,0 days 02:03:24.966000,0 days 02:03:23.914000,train,23.0,0.0,67.0,True,0 days 02:03:22.665000,0 days 02:03:23.596000,318.0,931.0,1.0,"[[94, -86]]",0.0,[],0.0,,
2867,2867,0 days 02:03:25,0 days 02:03:28.401000,0 days 02:03:26.816000,val,25.0,2.0,84.0,True,0 days 02:03:25.831000,0 days 02:03:25.863000,953.0,32.0,3.0,"[[-111, -82], [-108, 81], [118, 72]]",8.0,"[[69, 31, 14, 99], [69, 54, 5, 101], [-62, -48...",2.0,,


In [33]:
trial_data.columns

MultiIndex([('align_time',   ''),
            ('clock_time',   ''),
            ('cursor_pos',  'x'),
            ('cursor_pos',  'y'),
            (   'eye_pos',  'x'),
            (   'eye_pos',  'y'),
            (  'hand_pos',  'x'),
            (  'hand_pos',  'y'),
            (  'hand_vel',  'x'),
            (  'hand_vel',  'y'),
            ...
            (    'spikes', 2861),
            (    'spikes', 2862),
            (    'spikes', 2871),
            (    'spikes', 2881),
            (    'spikes', 2911),
            (    'spikes', 2931),
            (    'spikes', 2951),
            (    'spikes', 2961),
            (  'trial_id',   ''),
            ('trial_time',   '')],
           length=195)

In [35]:
trial_data

Unnamed: 0_level_0,align_time,clock_time,cursor_pos,cursor_pos,eye_pos,eye_pos,hand_pos,hand_pos,hand_vel,hand_vel,...,spikes,spikes,spikes,spikes,spikes,spikes,spikes,spikes,trial_id,trial_time
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,x,y,x,y,x,y,x,y,...,2861,2862,2871,2881,2911,2931,2951,2961,Unnamed: 20_level_1,Unnamed: 21_level_1
22955,-1 days +23:59:59.540000,0 days 00:07:40.660000,-1.463915,-7.492066,-18.382911,-0.377749,-1.424627,-42.511253,3.055079,-0.315920,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,574,0 days 00:00:01.460000
22956,-1 days +23:59:59.560000,0 days 00:07:40.680000,-1.409679,-7.503377,-18.282735,-0.799903,-1.400660,-42.510206,-5.051846,-0.871977,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,574,0 days 00:00:01.480000
22957,-1 days +23:59:59.580000,0 days 00:07:40.700000,-1.459685,-7.517751,-17.032929,-0.337737,-1.534191,-42.588972,-2.760242,-5.309706,...,0.0,0.0,0.0,0.0,0.0,1.0,1.0,2.0,574,0 days 00:00:01.500000
22958,-1 days +23:59:59.600000,0 days 00:07:40.720000,-1.621291,-7.682808,-18.543133,-0.939321,-1.494304,-42.601084,2.597506,4.587658,...,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,574,0 days 00:00:01.520000
22959,-1 days +23:59:59.620000,0 days 00:07:40.740000,-1.477806,-7.588030,-18.233750,-0.126550,-1.505363,-42.536439,-2.003645,-1.305288,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,574,0 days 00:00:01.540000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
126225,0 days 00:00:00.340000,0 days 02:03:30.740000,-118.430515,-77.428451,-115.449735,-124.429779,-117.619761,-112.831567,83.182184,-82.032653,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2868,0 days 00:00:02.240000
126226,0 days 00:00:00.360000,0 days 02:03:30.760000,-116.444063,-79.152946,-116.129629,-124.352170,-116.450647,-114.055396,27.777021,-44.439482,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,2868,0 days 00:00:02.260000
126227,0 days 00:00:00.380000,0 days 02:03:30.780000,-115.842393,-79.840834,-115.741469,-124.183166,-116.174928,-114.680376,13.272618,-17.231327,...,0.0,0.0,0.0,0.0,0.0,4.0,1.0,0.0,2868,0 days 00:00:02.280000
126228,0 days 00:00:00.400000,0 days 02:03:30.800000,-115.891162,-80.008894,-114.388179,-125.555668,-115.712684,-114.782808,29.807791,3.919408,...,0.0,1.0,0.0,1.0,0.0,2.0,2.0,0.0,2868,0 days 00:00:02.300000


In [24]:
trial_data['trial_id']

45910      574
45911      574
45912      574
45913      574
45914      574
          ... 
252455    2868
252456    2868
252457    2868
252458    2868
252459    2868
Name: trial_id, Length: 206550, dtype: int64

In [51]:
torch.tensor(trial.spikes.values).shape

torch.Size([90, 137])

In [57]:
trial.hand_vel.values.shape

(90, 2)

In [59]:
y_t.shape

torch.Size([90, 182])

In [63]:
y_t.reshape(1, trial_length, n_neurons).shape

torch.Size([1, 90, 182])

In [65]:
y = []
target = []
#tp = []
#ts = []
#task_id = []

count = 0

for trial_id, trial in trial_data.groupby('trial_id'):
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]
    #is_outlier_t = trial_id_trial_info['is_outlier'].iloc[0]
    #tp_t = torch.tensor(trial_id_trial_info['tp'].iloc[0])
    #ts_t = torch.tensor(trial_id_trial_info['ts'].iloc[0])
    #is_short_t = trial_id_trial_info['is_short'].iloc[0]
    #is_eye_t = trial_id_trial_info['is_eye'].iloc[0]
    
    """
    if is_outlier_t or tp_t < 0:
        continue

    if is_short_t:
        task_str_1 = 'S'

        if ts_t == 480:
            task_str_3 = '1'
        elif ts_t == 560:
            task_str_3 = '2'
        elif ts_t == 640:
            task_str_3 = '3'
        elif ts_t == 720:
            task_str_3 = '4'
        elif ts_t == 800:
            task_str_3 = '5'
    else:
        task_str_1 = 'L'

        if ts_t == 800:
            task_str_3 = '1'
        elif ts_t == 900:
            task_str_3 = '2'
        elif ts_t == 1000:
            task_str_3 = '3'
        elif ts_t == 1100:
            task_str_3 = '4'
        elif ts_t == 1200:
            task_str_3 = '5'

    if is_eye_t:
        task_str_2 = 'E'
    else:
        task_str_2 = 'H'
    """

    y_heldin_t = torch.tensor(trial.spikes.values)
    y_heldout_t = torch.tensor(trial.heldout_spikes.values)
    y_t = torch.concat([y_heldin_t, y_heldout_t], dim=-1)
    y.append(y_t.reshape(1, trial_length, n_neurons))

    target.append(torch.tensor(trial.hand_vel.values).reshape(1, trial_length, 2))
    
    """
    task_id_key = (task_str_1, task_str_2, task_str_3)
    task_id_int = verbose_to_int_map[task_id_key]
    task_id.append(torch.tensor(task_id_int).unsqueeze(-1))

    tp.append(torch.tensor(tp_t).unsqueeze(-1))
    ts.append(torch.tensor(ts_t).unsqueeze(-1))

    if is_outlier_t:
        count += 1
    """
    
y = torch.concat(y, dim=0)
target = torch.concat(target, dim=0)

"""

task_id = torch.concat(task_id, dim=0)

subset_ex = 10
subset_ex_loc = torch.where(task_id == subset_ex)[0]

y_subset = y[subset_ex_loc]
y_psth = y_subset.mean(dim=0)

ts = torch.stack(ts, dim=0)
tp = torch.stack(tp, dim=0)
"""

'\n\ntask_id = torch.concat(task_id, dim=0)\n\nsubset_ex = 10\nsubset_ex_loc = torch.where(task_id == subset_ex)[0]\n\ny_subset = y[subset_ex_loc]\ny_psth = y_subset.mean(dim=0)\n\nts = torch.stack(ts, dim=0)\ntp = torch.stack(tp, dim=0)\n'

In [67]:
y.shape

torch.Size([2295, 90, 182])

In [69]:
target.shape

torch.Size([2295, 90, 2])

In [80]:
n_valid_trials = 574
y[:-n_valid_trials].shape

torch.Size([1721, 90, 182])

In [92]:
y[-n_valid_trials:].shape

torch.Size([574, 90, 182])

In [115]:
target[-n_valid_trials:][:n_valid_trials//2].shape

torch.Size([287, 90, 2])

In [117]:
target[-n_valid_trials:][n_valid_trials//2:].shape

torch.Size([287, 90, 2])

In [113]:
target[:-n_valid_trials].shape

torch.Size([1721, 90, 2])

In [123]:
y[-n_valid_trials:-n_valid_trials // 2] == y[-n_valid_trials // 2:]

tensor([[[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True, False,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True]],

        [[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True, False,  True],
         [ True,  True,  True,  ...,  True,  True,  True]],

        [[ True,  True,  True,  ...,  True, False,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ...,  True,  True,  True],
         [

In [None]:
train_data, valid_data, test_data = {}, {}, {}
seq_len = data_dict['train_encod_data'].shape[1]
n_valid_trials = 574

train_data['y_obs'] = y[:-n_valid_trials]
train_data['velocity'] = target[:-n_valid_trials]
train_data['n_neurons_enc'] = y.shape[-1]
train_data['n_neurons_obs'] = y.shape[-1]
train_data['n_time_bins_enc'] = seq_len

valid_data['y_obs'] = y[-n_valid_trials:-n_valid_trials // 2]
valid_data['velocity'] = target[-n_valid_trials:-n_valid_trials // 2]
valid_data['n_neurons_enc'] = y.shape[-1]
train_data['n_neurons_obs'] = y.shape[-1]
valid_data['n_time_bins_enc'] = seq_len

test_data['y_obs'] = y[-n_valid_trials // 2:]
test_data['velocity'] = target[-n_valid_trials // 2:]
test_data['n_neurons_enc'] = y.shape[-1]
test_data['n_neurons_obs'] = y.shape[-1]
test_data['n_time_bins_enc'] = seq_len

In [None]:
torch.save(train_data, save_root_path + f'data_train_{binsize}ms.pt')
torch.save(valid_data, save_root_path + f'data_valid_{binsize}ms.pt')
torch.save(test_data, save_root_path + f'data_test_{binsize}ms.pt')

In [95]:
""" Loading the data"""

data_path = 'data/data_{split}_{bin_sz_ms}ms.pt'
train_data = torch.load(data_path.format(split='train', bin_sz_ms=cfg.bin_sz_ms))
val_data = torch.load(data_path.format(split='valid', bin_sz_ms=cfg.bin_sz_ms))
test_data = torch.load(data_path.format(split='test', bin_sz_ms=cfg.bin_sz_ms))

y_train_obs = train_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_valid_obs = val_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_test_obs = test_data['y_obs'].type(torch.float32).to(cfg.data_device)

vel_train = train_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_valid = val_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_test = test_data['velocity'].type(torch.float32).to(cfg.data_device)

y_train_dataset = torch.utils.data.TensorDataset(y_train_obs, vel_train)
y_val_dataset = torch.utils.data.TensorDataset(y_valid_obs, vel_valid)
y_test_dataset = torch.utils.data.TensorDataset(y_test_obs, vel_test)

train_dataloader = torch.utils.data.DataLoader(y_train_dataset, batch_size=cfg.batch_sz, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(y_val_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(y_test_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)

# Data dimensions
n_train_trials, n_time_bins, n_neurons_obs = y_train_obs.shape
n_valid_trials = y_valid_obs.shape[0]
n_test_trials = y_test_obs.shape[0]
n_time_bins_enc = train_data['n_time_bins_enc']

n_bins_bhv = 10  # at t=n_bins_bhv start forecast
stim_onset = 12  # stimulus onset

trial_list=[1, 287//4, 28//4 + 287//2, 286]

print("# training trials: {0}".format(n_train_trials))
print("# validation trials: {0}".format(n_valid_trials))
print("# testing trials: {0}".format(n_test_trials))
print("# neurons: {0}".format(n_neurons_obs))
print("# time bins: {0}".format(n_time_bins))
print("# time bins used for forcasting: {0}".format(n_bins_bhv))
print("# predicted time bins: {0}".format(n_time_bins_enc))

KeyError: 'velocity'

In [97]:
train_data.keys()

dict_keys(['y_obs', 'task_id', 'ts', 'tp', 'n_neurons_enc', 'n_neurons_obs', 'n_time_bins_enc'])

In [33]:
dataset.bin_width

10

In [35]:
data_dict['train_behavior']

NameError: name 'data_dict' is not defined