In [1]:
import torch
import torchvision

from models import LFADS
from utils import read_data, load_parameters, save_parameters, batchify_random_sample

np = torch._np
import matplotlib.pyplot as plt
import yaml
import os

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'; print('Using device: %s'%device)

Using device: cuda


In [3]:
author = 'lyprince'
seed = 100
if os.path.exists('./synth_data/lorenz_100'):
    data_dict = read_data('./synth_data/lorenz_100')
else:
    from synthetic_data import generate_lorenz_data
    data_dict = generate_lorenz_data(N_cells=30, N_inits=65, N_trials=20, N_steps=200, N_stepsinbin=2, dt_lorenz=0.015, dt_spike = 1./20, base_firing_rate= 1.0, save=True)

# For spike data
train_data = torch.Tensor(data_dict['train_spikes']).to(device)
valid_data = torch.Tensor(data_dict['valid_spikes']).to(device)

train_truth = {'rates'  : data_dict['train_rates'],
               'latent' : data_dict['train_latent']}

valid_truth = {'rates'  : data_dict['valid_rates'],
               'latent' : data_dict['valid_latent']}

train_ds      = torch.utils.data.TensorDataset(train_data)
valid_ds      = torch.utils.data.TensorDataset(valid_data)

num_trials, num_steps, num_cells = train_data.shape;
print(train_data.shape);
print('Number of datapoints = %s'%train_data.numel())

  "just not the slower interior point methods we compared to in the papers.")
  result = result[core]


Saving variable with name:  valid_latent
Saving variable with name:  dt
Saving variable with name:  train_latent
Saving variable with name:  valid_oasis
Saving variable with name:  train_calcium
Saving variable with name:  valid_spikes
Saving variable with name:  train_data
Saving variable with name:  valid_data
Saving variable with name:  train_oasis
Saving variable with name:  conversion_factor
Saving variable with name:  train_truth
Saving variable with name:  valid_rates
Saving variable with name:  train_fluor
Saving variable with name:  valid_truth
Saving variable with name:  valid_fluor
Saving variable with name:  train_rates
Saving variable with name:  valid_calcium
Saving variable with name:  train_spikes
Saving variable with name:  loading_weights
torch.Size([1040, 100, 30])
Number of datapoints = 3120000


In [4]:
hyperparams = load_parameters('./parameters/parameters_lorenz_spikes.yaml')
hyperparams['run_name'] += '_demo'
save_parameters(hyperparams, path=None)

hyperparams

{'betas': (0.9, 0.99),
 'clip_val': 5.0,
 'dataset_name': 'lorenz',
 'datatype': 'spikes',
 'epsilon': 0.1,
 'factors_dim': 3,
 'g0_encoder_dim': 64,
 'g0_prior_kappa': 0.1,
 'g0_prior_var_max': 0.1,
 'g0_prior_var_min': 0.1,
 'g_dim': 64,
 'keep_prob': 0.95,
 'kernel_dim': 20,
 'kl_weight_min': 0.0,
 'kl_weight_schedule_dur': 1600,
 'kl_weight_schedule_start': 0,
 'l2_con_scale': 0,
 'l2_gen_scale': 250,
 'l2_weight_min': 0.0,
 'l2_weight_schedule_dur': 1600,
 'l2_weight_schedule_start': 0.0,
 'learning_rate': 0.01,
 'learning_rate_decay': 0.95,
 'learning_rate_min': 1e-05,
 'max_norm': 200,
 'norm_factors': True,
 'run_name': 'poisson_demo',
 'scheduler_cooldown': 6,
 'scheduler_on': True,
 'scheduler_patience': 6,
 'u_dim': 0,
 'u_prior_kappa': 0.1,
 'use_weight_schedule_fn': True}

In [5]:
model = LFADS(inputs_dim = num_cells, T = num_steps, dt = float(data_dict['dt']), device=device,
              model_hyperparams=hyperparams).to(device)

Random seed: 7968


In [None]:
model.fit(train_ds, valid_ds, train_truth=train_truth, valid_truth=valid_truth,
          max_epochs=2000, batch_size=65, use_tensorboard=True, health_check=False)

Beginning training...
Epoch:    1, Step:    16, training loss: 1289.844, validation loss: 1478.769
Epoch:    2, Step:    32, training loss: 1273.732, validation loss: 1387.469
Epoch:    3, Step:    48, training loss: 1263.496, validation loss: 1283.039
Epoch:    4, Step:    64, training loss: 1257.180, validation loss: 1270.510
Epoch:    5, Step:    80, training loss: 1253.778, validation loss: 1269.829
Epoch:    6, Step:    96, training loss: 1250.417, validation loss: 1272.616
Epoch:    7, Step:   112, training loss: 1247.252, validation loss: 1283.232
Epoch:    8, Step:   128, training loss: 1243.691, validation loss: 1274.111
Epoch:    9, Step:   144, training loss: 1237.998, validation loss: 1270.269
Epoch:   10, Step:   160, training loss: 1236.176, validation loss: 1266.048
Epoch:   11, Step:   176, training loss: 1231.137, validation loss: 1259.896
Epoch:   12, Step:   192, training loss: 1227.174, validation loss: 1260.505
Epoch:   13, Step:   208, training loss: 1221.722, val