In [31]:
import numpy as np
import pandas as pd
import seaborn as sns
import math
import yaml

import matplotlib.pyplot as plt
from matplotlib import cm

import torch
import pytorch_lightning as lightning

from itertools import product
from nlb_tools.nwb_interface import NWBDataset

from hydra import compose, initialize
from xfads import plot_utils
from xfads.smoothers.lightning_trainers import LightningMonkeyReaching
from xfads.ssm_modules.prebuilt_models import create_xfads_poisson_log_link

In [6]:
torch.cuda.empty_cache()

In [9]:
initialize(version_base=None, config_path="", job_name="monkey_reaching")

In [63]:
cfg = compose(config_name="config")

In [14]:
lightning.seed_everything(cfg.seed, workers=True)
torch.set_default_dtype(torch.float32)

Seed set to 1236


In [54]:
def get_int_to_verbose_map():
    options = [('S', 'L'), ('E', 'H'), ('1', '2', '3', '4', '5')]
    combinations = list(product(*options))
    mapping = {i: combination for i, combination in enumerate(combinations)}

    return mapping


def get_verbose_to_int_map():
    options = [('S', 'L'), ('E', 'H'), ('1', '2', '3', '4', '5')]
    combinations = list(product(*options))
    reverse_mapping = {combination: i for i, combination in enumerate(combinations)}

    return reverse_mapping

In [56]:
"""Downloading the data"""

datapath = 'data/000128/sub-Jenkins/'
dataset = NWBDataset(datapath)
save_root_path = 'data/'

In [67]:
# Extract neural data and lagged hand velocity
binsize = cfg.bin_sz_ms
n_neurons = cfg.n_neurons
dataset.resample(binsize)

start = -450
end = 459
trial_length = (end - start) // binsize

verbose_to_int_map = get_verbose_to_int_map()
int_to_verbose_map = get_int_to_verbose_map()

# Extract neural data
trial_info = dataset.trial_info  # .dropna()
trial_info['color'] = None
trial_info['position_id'] = None
trial_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(start, end))
n_trials = trial_data.shape[0] // trial_length

print('done')
print(trial_data.columns)

Dataset already at 10 ms resolution, skipping resampling...
Shortened 573 trials to prevent overlap.
NaNs found in `self.data`. Dropping 18.19% of points to remove NaNs from `trial_data`.


done
MultiIndex([('align_time',   ''),
            ('clock_time',   ''),
            ('cursor_pos',  'x'),
            ('cursor_pos',  'y'),
            (   'eye_pos',  'x'),
            (   'eye_pos',  'y'),
            (  'hand_pos',  'x'),
            (  'hand_pos',  'y'),
            (  'hand_vel',  'x'),
            (  'hand_vel',  'y'),
            ...
            (    'spikes', 2861),
            (    'spikes', 2862),
            (    'spikes', 2871),
            (    'spikes', 2881),
            (    'spikes', 2911),
            (    'spikes', 2931),
            (    'spikes', 2951),
            (    'spikes', 2961),
            (  'trial_id',   ''),
            ('trial_time',   '')],
           length=195)


In [69]:
trial_info

Unnamed: 0,trial_id,start_time,end_time,move_onset_time,split,trial_type,trial_version,maze_id,success,target_on_time,go_cue_time,rt,delay,num_targets,target_pos,num_barriers,barrier_pos,active_target,color,position_id
0,0,0 days 00:00:00,0 days 00:00:00.700000,0 days 00:00:00.250000,test,,,,,NaT,NaT,,,,,,,,,
1,1,0 days 00:00:00.800000,0 days 00:00:01.500000,0 days 00:00:01.050000,test,,,,,NaT,NaT,,,,,,,,,
2,2,0 days 00:00:01.600000,0 days 00:00:02.300000,0 days 00:00:01.850000,test,,,,,NaT,NaT,,,,,,,,,
3,3,0 days 00:00:02.400000,0 days 00:00:03.100000,0 days 00:00:02.650000,test,,,,,NaT,NaT,,,,,,,,,
4,4,0 days 00:00:03.200000,0 days 00:00:03.900000,0 days 00:00:03.450000,test,,,,,NaT,NaT,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2864,2864,0 days 02:03:15.800000,0 days 02:03:18.996000,0 days 02:03:17.785000,train,34.0,1.0,91.0,True,0 days 02:03:16.562000,0 days 02:03:17.477000,308.0,915.0,1.0,"[[116, -77]]",7.0,"[[66, -43, 30, 9], [-66, 1, 11, 70], [-35, 50,...",0.0,,
2865,2865,0 days 02:03:19.100000,0 days 02:03:21.936000,0 days 02:03:20.841000,train,15.0,1.0,75.0,True,0 days 02:03:19.917000,0 days 02:03:20.465000,376.0,548.0,1.0,"[[133, -81]]",9.0,"[[-33, 47, 37, 6], [-77, 48, 61, 11], [-64, -2...",0.0,,
2866,2866,0 days 02:03:22,0 days 02:03:24.966000,0 days 02:03:23.914000,train,23.0,0.0,67.0,True,0 days 02:03:22.665000,0 days 02:03:23.596000,318.0,931.0,1.0,"[[94, -86]]",0.0,[],0.0,,
2867,2867,0 days 02:03:25,0 days 02:03:28.401000,0 days 02:03:26.816000,val,25.0,2.0,84.0,True,0 days 02:03:25.831000,0 days 02:03:25.863000,953.0,32.0,3.0,"[[-111, -82], [-108, 81], [118, 72]]",8.0,"[[69, 31, 14, 99], [69, 54, 5, 101], [-62, -48...",2.0,,


In [85]:
y = []
tp = []
ts = []
task_id = []

count = 0

for trial_id, trial in trial_data.groupby('trial_id'):
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]
    #is_outlier_t = trial_id_trial_info['is_outlier'].iloc[0]
    #tp_t = torch.tensor(trial_id_trial_info['tp'].iloc[0])
    #ts_t = torch.tensor(trial_id_trial_info['ts'].iloc[0])
    #is_short_t = trial_id_trial_info['is_short'].iloc[0]
    #is_eye_t = trial_id_trial_info['is_eye'].iloc[0]
    
    """
    if is_outlier_t or tp_t < 0:
        continue

    if is_short_t:
        task_str_1 = 'S'

        if ts_t == 480:
            task_str_3 = '1'
        elif ts_t == 560:
            task_str_3 = '2'
        elif ts_t == 640:
            task_str_3 = '3'
        elif ts_t == 720:
            task_str_3 = '4'
        elif ts_t == 800:
            task_str_3 = '5'
    else:
        task_str_1 = 'L'

        if ts_t == 800:
            task_str_3 = '1'
        elif ts_t == 900:
            task_str_3 = '2'
        elif ts_t == 1000:
            task_str_3 = '3'
        elif ts_t == 1100:
            task_str_3 = '4'
        elif ts_t == 1200:
            task_str_3 = '5'

    if is_eye_t:
        task_str_2 = 'E'
    else:
        task_str_2 = 'H'
    """

    y_heldin_t = torch.tensor(trial.spikes.values)
    y_heldout_t = torch.tensor(trial.heldout_spikes.values)
    y_t = torch.concat([y_heldin_t, y_heldout_t], dim=-1)
    y.append(y_t.reshape(1, trial_length, n_neurons))
    
    """
    task_id_key = (task_str_1, task_str_2, task_str_3)
    task_id_int = verbose_to_int_map[task_id_key]
    task_id.append(torch.tensor(task_id_int).unsqueeze(-1))

    tp.append(torch.tensor(tp_t).unsqueeze(-1))
    ts.append(torch.tensor(ts_t).unsqueeze(-1))

    if is_outlier_t:
        count += 1
    """

y = torch.concat(y, dim=0)
"""
task_id = torch.concat(task_id, dim=0)

subset_ex = 10
subset_ex_loc = torch.where(task_id == subset_ex)[0]

y_subset = y[subset_ex_loc]
y_psth = y_subset.mean(dim=0)

ts = torch.stack(ts, dim=0)
tp = torch.stack(tp, dim=0)
"""

'\ntask_id = torch.concat(task_id, dim=0)\n\nsubset_ex = 10\nsubset_ex_loc = torch.where(task_id == subset_ex)[0]\n\ny_subset = y[subset_ex_loc]\ny_psth = y_subset.mean(dim=0)\n\nts = torch.stack(ts, dim=0)\ntp = torch.stack(tp, dim=0)\n'

In [111]:
y.shape

torch.Size([2295, 90, 182])

In [105]:
"""
with open('data/old_data/int_condition_map.yaml', 'w') as outfile:
    yaml.dump(int_to_verbose_map, outfile, default_flow_style=False)
"""

train_data, valid_data, test_data = {}, {}, {}
untrained_trials = 300
seq_len = trial_length

train_data['y_obs'] = y[:-untrained_trials]
train_data['velocity'] = torch.Tensor(data_dict['train_behavior'])
#train_data['task_id'] = task_id[:-untrained_trials]
#train_data['ts'] = ts[:-untrained_trials]
#train_data['tp'] = tp[:-untrained_trials]
train_data['n_neurons_enc'] = n_neurons
train_data['n_neurons_obs'] = n_neurons
train_data['n_time_bins_enc'] = seq_len

valid_data['y_obs'] = y[-untrained_trials:-untrained_trials // 2]
valid_data['velocity'] = torch.Tensor(data_dict['valid_behavior'])
#valid_data['task_id'] = task_id[-untrained_trials:-untrained_trials // 2]
#valid_data['ts'] = ts[-untrained_trials:-untrained_trials // 2]
#valid_data['tp'] = tp[-untrained_trials:-untrained_trials // 2]
valid_data['n_neurons_enc'] = n_neurons
valid_data['n_neurons_obs'] = n_neurons
valid_data['n_time_bins_enc'] = seq_len

test_data['y_obs'] = y[-untrained_trials // 2:]
test_data['velocity'] = torch.Tensor(data_dict['valid_behavior'])
#test_data['task_id'] = task_id[-untrained_trials // 2:]
#test_data['ts'] = ts[-untrained_trials // 2:]
#test_data['tp'] = tp[-untrained_trials // 2:]
test_data['n_neurons_enc'] = n_neurons
test_data['n_neurons_obs'] = n_neurons
test_data['n_time_bins_enc'] = seq_len

torch.save(train_data, save_root_path + f'data_train_{binsize}ms.pt')
torch.save(valid_data, save_root_path + f'data_valid_{binsize}ms.pt')
torch.save(test_data, save_root_path + f'data_test_{binsize}ms.pt')

NameError: name 'data_dict' is not defined

In [95]:
""" Loading the data"""

data_path = 'data/data_{split}_{bin_sz_ms}ms.pt'
train_data = torch.load(data_path.format(split='train', bin_sz_ms=cfg.bin_sz_ms))
val_data = torch.load(data_path.format(split='valid', bin_sz_ms=cfg.bin_sz_ms))
test_data = torch.load(data_path.format(split='test', bin_sz_ms=cfg.bin_sz_ms))

y_train_obs = train_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_valid_obs = val_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_test_obs = test_data['y_obs'].type(torch.float32).to(cfg.data_device)

vel_train = train_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_valid = val_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_test = test_data['velocity'].type(torch.float32).to(cfg.data_device)

y_train_dataset = torch.utils.data.TensorDataset(y_train_obs, vel_train)
y_val_dataset = torch.utils.data.TensorDataset(y_valid_obs, vel_valid)
y_test_dataset = torch.utils.data.TensorDataset(y_test_obs, vel_test)

train_dataloader = torch.utils.data.DataLoader(y_train_dataset, batch_size=cfg.batch_sz, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(y_val_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(y_test_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)

# Data dimensions
n_train_trials, n_time_bins, n_neurons_obs = y_train_obs.shape
n_valid_trials = y_valid_obs.shape[0]
n_test_trials = y_test_obs.shape[0]
n_time_bins_enc = train_data['n_time_bins_enc']

n_bins_bhv = 10  # at t=n_bins_bhv start forecast
stim_onset = 12  # stimulus onset

trial_list=[1, 287//4, 28//4 + 287//2, 286]

print("# training trials: {0}".format(n_train_trials))
print("# validation trials: {0}".format(n_valid_trials))
print("# testing trials: {0}".format(n_test_trials))
print("# neurons: {0}".format(n_neurons_obs))
print("# time bins: {0}".format(n_time_bins))
print("# time bins used for forcasting: {0}".format(n_bins_bhv))
print("# predicted time bins: {0}".format(n_time_bins_enc))

KeyError: 'velocity'

In [97]:
train_data.keys()

dict_keys(['y_obs', 'task_id', 'ts', 'tp', 'n_neurons_enc', 'n_neurons_obs', 'n_time_bins_enc'])

In [115]:
dir(dataset)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_make_midx',
 'add_continuous_data',
 'add_trialized_data',
 'bin_width',
 'calculate_onset',
 'data',
 'descriptions',
 'fpath',
 'load',
 'make_trial_data',
 'prefix',
 'resample',
 'smooth_spk',
 'trial_info']

In [137]:
for c in dataset.data.columns:
    print(c)

('cursor_pos', 'x')
('cursor_pos', 'y')
('eye_pos', 'x')
('eye_pos', 'y')
('hand_pos', 'x')
('hand_pos', 'y')
('hand_vel', 'x')
('hand_vel', 'y')
('heldout_spikes', 1021)
('heldout_spikes', 1022)
('heldout_spikes', 1041)
('heldout_spikes', 1051)
('heldout_spikes', 1062)
('heldout_spikes', 1091)
('heldout_spikes', 1092)
('heldout_spikes', 1102)
('heldout_spikes', 1122)
('heldout_spikes', 1123)
('heldout_spikes', 1211)
('heldout_spikes', 1233)
('heldout_spikes', 1251)
('heldout_spikes', 1401)
('heldout_spikes', 1411)
('heldout_spikes', 1431)
('heldout_spikes', 1471)
('heldout_spikes', 1511)
('heldout_spikes', 1561)
('heldout_spikes', 1751)
('heldout_spikes', 1791)
('heldout_spikes', 1812)
('heldout_spikes', 1841)
('heldout_spikes', 1902)
('heldout_spikes', 2121)
('heldout_spikes', 2132)
('heldout_spikes', 2151)
('heldout_spikes', 2191)
('heldout_spikes', 2231)
('heldout_spikes', 2251)
('heldout_spikes', 2272)
('heldout_spikes', 2281)
('heldout_spikes', 2301)
('heldout_spikes', 2351)
('he

In [139]:
trial_data

Unnamed: 0_level_0,align_time,clock_time,cursor_pos,cursor_pos,eye_pos,eye_pos,hand_pos,hand_pos,hand_vel,hand_vel,...,spikes,spikes,spikes,spikes,spikes,spikes,spikes,spikes,trial_id,trial_time
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,x,y,x,y,x,y,x,y,...,2861,2862,2871,2881,2911,2931,2951,2961,Unnamed: 20_level_1,Unnamed: 21_level_1
45910,-1 days +23:59:59.550000,0 days 00:07:40.660000,-1.427704,-7.503850,-18.150551,-0.201034,-1.433218,-42.508814,4.476467,-0.080540,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,574,0 days 00:00:01.460000
45911,-1 days +23:59:59.560000,0 days 00:07:40.670000,-1.562753,-7.492515,-18.717431,-0.671962,-1.379399,-42.508926,3.620276,-0.023929,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,574,0 days 00:00:01.470000
45912,-1 days +23:59:59.570000,0 days 00:07:40.680000,-1.401592,-7.506672,-18.552826,-0.827292,-1.394624,-42.513925,-6.734999,-1.251761,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,574,0 days 00:00:01.480000
45913,-1 days +23:59:59.580000,0 days 00:07:40.690000,-1.295584,-7.486479,-16.852135,-0.614574,-1.484415,-42.537580,-8.867741,-3.560745,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,574,0 days 00:00:01.490000
45914,-1 days +23:59:59.590000,0 days 00:07:40.700000,-1.516653,-7.531085,-17.646533,-0.284019,-1.537296,-42.583118,-1.131259,-5.115167,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,574,0 days 00:00:01.500000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
252455,0 days 00:00:00.400000,0 days 02:03:30.790000,-115.903462,-79.987314,-114.762998,-124.707191,-115.973210,-114.773909,21.487100,-4.834804,...,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,2868,0 days 00:00:02.290000
252456,0 days 00:00:00.410000,0 days 02:03:30.800000,-115.884467,-79.973531,-114.558286,-125.530125,-115.719873,-114.778053,28.939125,3.577302,...,0.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,2868,0 days 00:00:02.300000
252457,0 days 00:00:00.420000,0 days 02:03:30.810000,-115.601053,-79.912645,-113.967848,-125.804510,-115.410274,-114.710502,31.679641,9.450206,...,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,2868,0 days 00:00:02.310000
252458,0 days 00:00:00.430000,0 days 02:03:30.820000,-115.274500,-79.782599,-114.248345,-126.373757,-115.109901,-114.598085,27.761826,12.650392,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,2868,0 days 00:00:02.320000
