In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import math
import yaml
import hydra

import matplotlib.pyplot as plt
from matplotlib import cm

import torch
import pytorch_lightning as lightning

from itertools import product
from nlb_tools.nwb_interface import NWBDataset

from hydra import compose, initialize
from xfads import plot_utils
from xfads.smoothers.lightning_trainers import LightningMonkeyReaching
from xfads.ssm_modules.prebuilt_models import create_xfads_poisson_log_link

In [2]:
torch.cuda.empty_cache()

In [3]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path="", job_name="monkey_reaching")

cfg = compose(config_name="config")

In [4]:
lightning.seed_everything(cfg.seed, workers=True)
torch.set_default_dtype(torch.float32)

Seed set to 1236


In [5]:
"""Downloading the data"""

datapath = 'data/000128/sub-Jenkins/'
dataset = NWBDataset(datapath)
save_root_path = 'data/'

  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)


In [22]:
# Extract neural data and lagged hand velocity
binsize = 10 # ms
n_neurons = 182
dataset.resample(binsize)

# We want total trial length of 900ms ...
start = -450
end = 450
#... which is 90 time bins
trial_length = (end - start) // binsize

# Extract neural data
trial_info = dataset.trial_info  # .dropna()
# Trials aligned around the movement_onset time bin
trial_data = dataset.make_trial_data(
    align_field='move_onset_time', align_range=(start, end), ignored_trials=None)
n_trials = trial_data.shape[0] // trial_length

Dataset already at 10 ms resolution, skipping resampling...
Shortened 573 trials to prevent overlap.
NaNs found in `self.data`. Dropping 13.16% of points to remove NaNs from `trial_data`.


In [23]:
y = []
target = []

for trial_id, trial in trial_data.groupby('trial_id'):
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]

    y_heldin_t = torch.tensor(trial.spikes.values)
    y_heldout_t = torch.tensor(trial.heldout_spikes.values)
    y_t = torch.concat([y_heldin_t, y_heldout_t], dim=-1)
    y.append(y_t.reshape(1, trial_length, n_neurons))

    target.append(torch.tensor(trial.hand_vel.values).reshape(1, trial_length, 2))
    
y = torch.concat(y, dim=0)
target = torch.concat(target, dim=0)

RuntimeError: shape '[1, 132, 182]' is invalid for input of size 2002

In [25]:
y_t.shape

torch.Size([11, 182])

In [9]:
trial_length

90

In [11]:
y.shape

torch.Size([2295, 90, 182])

In [221]:
train_data, valid_data, test_data = {}, {}, {}
seq_len = y.shape[1]
n_neurons = y.shape[-1]
n_valid_trials = 574

train_data['y_obs'] = y[:-n_valid_trials]
train_data['velocity'] = target[:-n_valid_trials]
train_data['n_neurons_enc'] = y.shape[-1]
train_data['n_neurons_obs'] = y.shape[-1]
train_data['n_time_bins_enc'] = seq_len

valid_data['y_obs'] = y[-n_valid_trials:-n_valid_trials // 2]

valid_data['velocity'] = target[-n_valid_trials:-n_valid_trials // 2]
valid_data['n_neurons_enc'] = n_neurons
train_data['n_neurons_obs'] = n_neurons
valid_data['n_time_bins_enc'] = seq_len

test_data['y_obs'] = y[-n_valid_trials // 2:]
test_data['velocity'] = target[-n_valid_trials // 2:]
test_data['n_neurons_enc'] = n_neurons
test_data['n_neurons_obs'] = n_neurons
test_data['n_time_bins_enc'] = seq_len

In [208]:
torch.save(train_data, save_root_path + f'data_train_{binsize}ms.pt')
torch.save(valid_data, save_root_path + f'data_valid_{binsize}ms.pt')
torch.save(test_data, save_root_path + f'data_test_{binsize}ms.pt')

In [209]:
data_path = 'data/data_{split}_{bin_sz_ms}ms.pt'
train_data = torch.load(data_path.format(split='train', bin_sz_ms=cfg.bin_sz_ms))
val_data = torch.load(data_path.format(split='valid', bin_sz_ms=cfg.bin_sz_ms))
test_data = torch.load(data_path.format(split='test', bin_sz_ms=cfg.bin_sz_ms))

y_train_obs = train_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_valid_obs = val_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_test_obs = test_data['y_obs'].type(torch.float32).to(cfg.data_device)

vel_train = train_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_valid = val_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_test = test_data['velocity'].type(torch.float32).to(cfg.data_device)

y_train_dataset = torch.utils.data.TensorDataset(y_train_obs, vel_train)
y_val_dataset = torch.utils.data.TensorDataset(y_valid_obs, vel_valid)
y_test_dataset = torch.utils.data.TensorDataset(y_test_obs, vel_test)

train_dataloader = torch.utils.data.DataLoader(y_train_dataset, batch_size=cfg.batch_sz, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(y_val_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(y_test_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)

# Data dimensions
n_train_trials, n_time_bins, n_neurons_obs = y_train_obs.shape
n_valid_trials = y_valid_obs.shape[0]
n_test_trials = y_test_obs.shape[0]
n_time_bins_enc = train_data['n_time_bins_enc']
batch_sz_test = list(y_test_obs.shape)[:-1]

print("# training trials: {0}".format(n_train_trials))
print("# validation trials: {0}".format(n_valid_trials))
print("# testing trials: {0}".format(n_test_trials))
print("# neurons: {0}".format(n_neurons_obs))
print("# time bins: {0}".format(n_time_bins))
print("# time bins used for forcasting: {0}".format(cfg.n_bins_bhv))
print("# predicted time bins: {0}".format(n_time_bins - cfg.n_bins_bhv))

# training trials: 1721
# validation trials: 287
# testing trials: 287
# neurons: 182
# time bins: 60
# time bins used for forcasting: 20
# predicted time bins: 40


In [None]:
y_20 = torch.load('data/data_train_20ms.pt')['y_obs'].type(torch.float32).to(cfg.data_device)
y_10 = torch.load('data/data_train_10ms.pt')['y_obs'].type(torch.float32).to(cfg.data_device)

In [None]:
y_20 == y_10

In [None]:
y_20[0]

In [None]:
y_10[0]

In [None]:
plt.imshow(y_20[0].T, aspect=0.4)

In [None]:
plt.imshow(y_10[0].T, aspect=0.8)

In [None]:
torch.mean(y_20[0], axis=0).shape

In [None]:
torch.mean(y_10[0], axis=0).shape

In [None]:
torch.mean(torch.sum(y_20, axis=0), axis=0)

In [None]:
torch.mean(torch.sum(y_10, axis=0), axis=0)

In [None]:
torch.sum(y_20[0], axis=0) == torch.sum(y_10[0], axis=0)