In [31]:
import os
os.environ["OMP_NUM_THREADS"] = "8"  # export OMP_NUM_THREADS=4
os.environ["OPENBLAS_NUM_THREADS"] = "8"  # export OPENBLAS_NUM_THREADS=4
os.environ["MKL_NUM_THREADS"] = "8"  # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "8"  # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "8"  # export NUMEXPR_NUM_THREADS=6

import math
import torch
import torch.nn as nn
import xfads.utils as utils
import xfads.prob_utils as prob_utils
import pytorch_lightning as lightning
import matplotlib.pyplot as plt

from hydra import compose, initialize
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from xfads.ssm_modules.likelihoods import PoissonLikelihood
from xfads.ssm_modules.dynamics import DenseGaussianDynamics
from xfads.ssm_modules.dynamics import DenseGaussianInitialCondition
from xfads.ssm_modules.encoders import LocalEncoderLRMvn, BackwardEncoderLRMvn
from xfads.smoothers.lightning_trainers import LightningNonlinearSSM, LightningDMFCRSG
from xfads.smoothers.nonlinear_smoother import NonlinearFilter, LowRankNonlinearStateSpaceModel
# from dev.smoothers.nonlinear_smoother_causal_debug import NonlinearFilter, LowRankNonlinearStateSpaceModel

In [10]:
torch.cuda.empty_cache()
initialize(version_base=None, config_path="", job_name="dmfc_rsg")

In [162]:
cfg = compose(config_name="config")

In [21]:
# seeds = [1234, 1235, 1236]
# seeds = [1235, 1236]
seed = 1239
n_bins_bhv = 140

In [157]:
"""config"""

cfg.seed = seed

lightning.seed_everything(cfg.seed, workers=True)
torch.set_default_dtype(torch.float32)

Seed set to 1239


In [120]:
"""downloading the data"""

%run download_data.py

  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)


done
MultiIndex([(    'align_time',   ''),
            (    'clock_time',   ''),
            ('heldout_spikes', 1007),
            ('heldout_spikes', 1016),
            ('heldout_spikes', 1017),
            ('heldout_spikes', 1053),
            ('heldout_spikes', 1054),
            ('heldout_spikes', 1063),
            ('heldout_spikes', 1071),
            ('heldout_spikes', 1099),
            ('heldout_spikes', 1101),
            ('heldout_spikes', 1103),
            ('heldout_spikes', 1108),
            ('heldout_spikes', 2098),
            ('heldout_spikes', 3052),
            ('heldout_spikes', 3103),
            (        'margin',   ''),
            (        'spikes', 1001),
            (        'spikes', 1002),
            (        'spikes', 1003),
            (        'spikes', 1004),
            (        'spikes', 1012),
            (        'spikes', 1019),
            (        'spikes', 1024),
            (        'spikes', 1026),
            (        'spikes', 1028),
       

  tp.append(torch.tensor(tp_t).unsqueeze(-1))
  ts.append(torch.tensor(ts_t).unsqueeze(-1))


In [107]:
"""loading the data"""

data_path = 'data/data_{split}_{bin_sz_ms}ms.pt'
train_data = torch.load(data_path.format(split='train', bin_sz_ms=cfg.bin_sz_ms))
val_data = torch.load(data_path.format(split='valid', bin_sz_ms=cfg.bin_sz_ms))
test_data = torch.load(data_path.format(split='test', bin_sz_ms=cfg.bin_sz_ms))

y_valid_obs = val_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_train_obs = train_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_test_obs = test_data['y_obs'].type(torch.float32).to(cfg.data_device)
n_trials, n_time_bins, n_neurons_obs = y_train_obs.shape
n_time_bins_enc = train_data['n_time_bins_enc']

y_train_dataset = torch.utils.data.TensorDataset(y_train_obs, )
y_val_dataset = torch.utils.data.TensorDataset(y_valid_obs, )
y_test_dataset = torch.utils.data.TensorDataset(y_test_obs, )
train_dataloader = torch.utils.data.DataLoader(y_train_dataset, batch_size=cfg.batch_sz, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(y_val_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(y_test_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)

In [111]:
y_valid_obs.shape

torch.Size([150, 260, 54])

In [140]:
import yaml
import torch

from itertools import product
from nlb_tools.nwb_interface import NWBDataset


def get_int_to_verbose_map():
    options = [('S', 'L'), ('E', 'H'), ('1', '2', '3', '4', '5')]
    combinations = list(product(*options))
    mapping = {i: combination for i, combination in enumerate(combinations)}

    return mapping


def get_verbose_to_int_map():
    options = [('S', 'L'), ('E', 'H'), ('1', '2', '3', '4', '5')]
    combinations = list(product(*options))
    reverse_mapping = {combination: i for i, combination in enumerate(combinations)}

    return reverse_mapping

In [142]:
datapath = 'data/000130/sub-Haydn'
dataset = NWBDataset(datapath)
save_root_path = 'data/'

# Extract neural data and lagged hand velocity
binsize = 10
n_neurons = 54
dataset.resample(binsize)

start = -1300
end = 1300
trial_length = (end - start) // binsize

verbose_to_int_map = get_verbose_to_int_map()
int_to_verbose_map = get_int_to_verbose_map()

# Extract neural data
trial_info = dataset.trial_info  # .dropna()
trial_info['color'] = None
trial_info['position_id'] = None
trial_data = dataset.make_trial_data(align_field='set_time', align_range=(start, end))
n_trials = trial_data.shape[0] // trial_length

y = []
tp = []
ts = []
task_id = []

print('done')
print(trial_data.columns)
count = 0
for trial_id, trial in trial_data.groupby('trial_id'):
    trial_id_trial_info = trial_info[trial_info['trial_id'] == trial_id]
    is_outlier_t = trial_id_trial_info['is_outlier'].iloc[0]
    tp_t = torch.tensor(trial_id_trial_info['tp'].iloc[0])
    ts_t = torch.tensor(trial_id_trial_info['ts'].iloc[0])
    is_short_t = trial_id_trial_info['is_short'].iloc[0]
    is_eye_t = trial_id_trial_info['is_eye'].iloc[0]

    if is_outlier_t or tp_t < 0:
        continue

    if is_short_t:
        task_str_1 = 'S'

        if ts_t == 480:
            task_str_3 = '1'
        elif ts_t == 560:
            task_str_3 = '2'
        elif ts_t == 640:
            task_str_3 = '3'
        elif ts_t == 720:
            task_str_3 = '4'
        elif ts_t == 800:
            task_str_3 = '5'
    else:
        task_str_1 = 'L'

        if ts_t == 800:
            task_str_3 = '1'
        elif ts_t == 900:
            task_str_3 = '2'
        elif ts_t == 1000:
            task_str_3 = '3'
        elif ts_t == 1100:
            task_str_3 = '4'
        elif ts_t == 1200:
            task_str_3 = '5'

    if is_eye_t:
        task_str_2 = 'E'
    else:
        task_str_2 = 'H'

    y_heldin_t = torch.tensor(trial.spikes.values)
    y_heldout_t = torch.tensor(trial.heldout_spikes.values)
    y_t = torch.concat([y_heldin_t, y_heldout_t], dim=-1)
    y.append(y_t.reshape(1, trial_length, n_neurons))

    task_id_key = (task_str_1, task_str_2, task_str_3)
    task_id_int = verbose_to_int_map[task_id_key]
    task_id.append(torch.tensor(task_id_int).unsqueeze(-1))

    tp.append(torch.tensor(tp_t).unsqueeze(-1))
    ts.append(torch.tensor(ts_t).unsqueeze(-1))

    if is_outlier_t:
        count += 1

y = torch.concat(y, dim=0)
task_id = torch.concat(task_id, dim=0)

subset_ex = 10
subset_ex_loc = torch.where(task_id == subset_ex)[0]

y_subset = y[subset_ex_loc]
y_psth = y_subset.mean(dim=0)

ts = torch.stack(ts, dim=0)
tp = torch.stack(tp, dim=0)

with open('data/old_data/int_condition_map.yaml', 'w') as outfile:
    yaml.dump(int_to_verbose_map, outfile, default_flow_style=False)

train_data, valid_data, test_data = {}, {}, {}
untrained_trials = 300
seq_len = trial_length

train_data['y_obs'] = y[:-untrained_trials]
train_data['task_id'] = task_id[:-untrained_trials]
train_data['ts'] = ts[:-untrained_trials]
train_data['tp'] = tp[:-untrained_trials]
train_data['n_neurons_enc'] = n_neurons
train_data['n_neurons_obs'] = n_neurons
train_data['n_time_bins_enc'] = seq_len

valid_data['y_obs'] = y[-untrained_trials:-untrained_trials // 2]
valid_data['task_id'] = task_id[-untrained_trials:-untrained_trials // 2]
valid_data['ts'] = ts[-untrained_trials:-untrained_trials // 2]
valid_data['tp'] = tp[-untrained_trials:-untrained_trials // 2]
valid_data['n_neurons_enc'] = n_neurons
valid_data['n_neurons_obs'] = n_neurons
valid_data['n_time_bins_enc'] = seq_len

test_data['y_obs'] = y[-untrained_trials // 2:]
test_data['task_id'] = task_id[-untrained_trials // 2:]
test_data['ts'] = ts[-untrained_trials // 2:]
test_data['tp'] = tp[-untrained_trials // 2:]
test_data['n_neurons_enc'] = n_neurons
test_data['n_neurons_obs'] = n_neurons
test_data['n_time_bins_enc'] = seq_len

torch.save(train_data, save_root_path + f'data_train_{binsize}ms.pt')
torch.save(valid_data, save_root_path + f'data_valid_{binsize}ms.pt')
torch.save(test_data, save_root_path + f'data_test_{binsize}ms.pt')

  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)


done
MultiIndex([(    'align_time',   ''),
            (    'clock_time',   ''),
            ('heldout_spikes', 1007),
            ('heldout_spikes', 1016),
            ('heldout_spikes', 1017),
            ('heldout_spikes', 1053),
            ('heldout_spikes', 1054),
            ('heldout_spikes', 1063),
            ('heldout_spikes', 1071),
            ('heldout_spikes', 1099),
            ('heldout_spikes', 1101),
            ('heldout_spikes', 1103),
            ('heldout_spikes', 1108),
            ('heldout_spikes', 2098),
            ('heldout_spikes', 3052),
            ('heldout_spikes', 3103),
            (        'margin',   ''),
            (        'spikes', 1001),
            (        'spikes', 1002),
            (        'spikes', 1003),
            (        'spikes', 1004),
            (        'spikes', 1012),
            (        'spikes', 1019),
            (        'spikes', 1024),
            (        'spikes', 1026),
            (        'spikes', 1028),
       

  tp.append(torch.tensor(tp_t).unsqueeze(-1))
  ts.append(torch.tensor(ts_t).unsqueeze(-1))


In [143]:
trial_info

  return method()


Unnamed: 0,trial_id,start_time,end_time,go_time,split,fix_on_time,fix_time,target_on_time,ready_time,set_time,...,theta,ts,tp,fix_time_dur,target_time_dur,iti,reward_dur,is_outlier,color,position_id
0,0,0 days 00:00:00,0 days 00:00:01.700000,0 days 00:00:01.500000,test,NaT,NaT,NaT,NaT,NaT,...,,,,,,,,,,
1,1,0 days 00:00:01.800000,0 days 00:00:03.500000,0 days 00:00:03.300000,test,NaT,NaT,NaT,NaT,NaT,...,,,,,,,,,,
2,2,0 days 00:00:03.600000,0 days 00:00:05.300000,0 days 00:00:05.100000,test,NaT,NaT,NaT,NaT,NaT,...,,,,,,,,,,
3,3,0 days 00:00:05.400000,0 days 00:00:07.100000,0 days 00:00:06.900000,test,NaT,NaT,NaT,NaT,NaT,...,,,,,,,,,,
4,4,0 days 00:00:07.200000,0 days 00:00:08.900000,0 days 00:00:08.700000,test,NaT,NaT,NaT,NaT,NaT,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1438,1438,0 days 01:28:20.781500,0 days 01:28:22.815000,NaT,none,0 days 01:28:20.781500,0 days 01:28:20.832500,0 days 01:28:21.382500,0 days 01:28:22.249000,NaT,...,180.0,900.0,-64.421,545.444910,849.090111,500.0,0.000700,True,,
1439,1439,0 days 01:28:24.349000,0 days 01:28:27.131500,0 days 01:28:27.066000,train,0 days 01:28:24.349000,0 days 01:28:24.382500,0 days 01:28:24.982500,0 days 01:28:25.532500,0 days 01:28:26.332500,...,180.0,800.0,756.019,609.255693,537.199998,500.0,44.344417,False,,
1440,1440,0 days 01:28:27.649000,0 days 01:28:30.465000,0 days 01:28:30.378500,val,0 days 01:28:27.649000,0 days 01:28:27.682500,0 days 01:28:28.299500,0 days 01:28:28.616000,0 days 01:28:29.516000,...,180.0,900.0,884.761,625.105868,307.792749,500.0,62.098296,False,,
1441,1441,0 days 01:28:30.981500,0 days 01:28:34.465000,0 days 01:28:34.401500,val,0 days 01:28:30.981500,0 days 01:28:31.149500,0 days 01:28:31.666000,0 days 01:28:32.282500,0 days 01:28:33.382500,...,180.0,1100.0,1041.113,526.285915,610.483410,500.0,45.017636,False,,


In [146]:
"""loading the data"""

data_path = 'data/data_{split}_{bin_sz_ms}ms.pt'
train_data = torch.load(data_path.format(split='train', bin_sz_ms=cfg.bin_sz_ms))
val_data = torch.load(data_path.format(split='valid', bin_sz_ms=cfg.bin_sz_ms))
test_data = torch.load(data_path.format(split='test', bin_sz_ms=cfg.bin_sz_ms))

y_valid_obs = val_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_train_obs = train_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_test_obs = test_data['y_obs'].type(torch.float32).to(cfg.data_device)
n_trials, n_time_bins, n_neurons_obs = y_train_obs.shape
n_time_bins_enc = train_data['n_time_bins_enc']

y_train_dataset = torch.utils.data.TensorDataset(y_train_obs, )
y_val_dataset = torch.utils.data.TensorDataset(y_valid_obs, )
y_test_dataset = torch.utils.data.TensorDataset(y_test_obs, )
train_dataloader = torch.utils.data.DataLoader(y_train_dataset, batch_size=cfg.batch_sz, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(y_val_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(y_test_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)

In [169]:
"""likelihood pdf"""
H = utils.ReadoutLatentMask(cfg.n_latents, cfg.n_latents_read)
readout_fn = nn.Sequential(H, nn.Linear(cfg.n_latents_read, n_neurons_obs))
readout_fn[-1].bias.data = prob_utils.estimate_poisson_rate_bias(train_dataloader, cfg.bin_sz)
likelihood_pdf = PoissonLikelihood(readout_fn, n_neurons_obs, cfg.bin_sz, device=cfg.device)

"""dynamics module"""
Q_diag = 1. * torch.ones(cfg.n_latents, device=cfg.device)
dynamics_fn = utils.build_gru_dynamics_function(cfg.n_latents, cfg.n_hidden_dynamics, device=cfg.device)
dynamics_mod = DenseGaussianDynamics(dynamics_fn, cfg.n_latents, Q_diag, device=cfg.device)

"""initial condition"""
m_0 = torch.zeros(cfg.n_latents, device=cfg.device)
Q_0_diag = 1. * torch.ones(cfg.n_latents, device=cfg.device)
initial_condition_pdf = DenseGaussianInitialCondition(cfg.n_latents, m_0, Q_0_diag, device=cfg.device)

"""local/backward encoder"""
backward_encoder = BackwardEncoderLRMvn(cfg.n_latents, cfg.n_hidden_backward, cfg.n_latents,
                                        rank_local=cfg.rank_local, rank_backward=cfg.rank_backward,
                                        device=cfg.device)
local_encoder = LocalEncoderLRMvn(cfg.n_latents, n_neurons_obs, cfg.n_hidden_local, cfg.n_latents, rank=cfg.rank_local,
                                  device=cfg.device, dropout=cfg.p_local_dropout)
nl_filter = NonlinearFilter(dynamics_mod, initial_condition_pdf, device=cfg.device)

"""sequence vae"""
ssm = LowRankNonlinearStateSpaceModel(dynamics_mod, likelihood_pdf, initial_condition_pdf, backward_encoder,
                                      local_encoder, nl_filter, device=cfg.device)

"""lightning"""
#seq_vae = LightningDMFCRSG.load_from_checkpoint('ckpts/smoother/acausal/epoch=997_valid_loss=3288.73_valid_bps_enc=0.61_valid_bps_bhv=0.12.ckpt',
#                                                ssm=ssm, cfg=cfg, n_time_bins_enc=n_time_bins_enc, n_time_bins_bhv=n_bins_bhv, strict=False)
# seq_vae = LightningDMFCRSG(ssm, cfg, n_time_bins_enc, n_bins_bhv)
seq_vae = LightningDMFCRSG(ssm, cfg, n_time_bins_enc, n_bins_bhv)
csv_logger = CSVLogger('logs/smoother/acausal/', name=f'sd_{cfg.seed}_r_y_{cfg.rank_local}_r_b_{cfg.rank_backward}', version='smoother_acausal')
ckpt_callback = ModelCheckpoint(save_top_k=3, monitor='valid_bps_enc', mode='max', dirpath='ckpts/smoother/acausal/', save_last=True,
                                filename='{epoch:0}_{valid_loss:0.2f}_{valid_bps_enc:0.2f}_{valid_bps_bhv:0.2f}')

In [None]:
trainer = lightning.Trainer(max_epochs=cfg.n_epochs,
                            gradient_clip_val=1.0,
                            default_root_dir='lightning/',
                            callbacks=[ckpt_callback],
                            logger=csv_logger,
                            strategy='ddp_notebook',
                            accelerator='cpu',
                            )

trainer.fit(model=seq_vae, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
torch.save(ckpt_callback.best_model_path, 'ckpts/smoother/acausal/best_model_path.pt')
trainer.test(dataloaders=test_dataloader, ckpt_path='last')

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Seed set to 1239
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

/opt/anaconda3/envs/xfads/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/mahmoud/catnip/xfads/workshop/monkey_timing/ckpts/smoother/acausal exists and is not empty.

  | Name | Type                            | Params
---------------------------------------------------------
0 | ssm  | LowRankNonlinearStateSpaceModel | 478 K 
---------------------------------------------------------
478 K     Trainable params
0    

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/envs/xfads/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]current_epoch: 0
                                                                           

/opt/anaconda3/envs/xfads/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/opt/anaconda3/envs/xfads/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0:  33%|███▎      | 1/3 [04:30<09:00,  0.00it/s, v_num=usal]