# CAN Workshop- LFADS demo in pytorch

## Import necessary modules

In [1]:
%matplotlib inline

import torch
import torchvision
np = torch._np
import matplotlib.pyplot as plt

import os
import yaml

from lfads import LFADS_Net
from utils import read_data, load_parameters, save_parameters
import scipy.io
# plt.style.use('dark_background')
import shutil
path = './models'
if os.path.isdir(path): 
    shutil.rmtree(path)

In [2]:
# Select device to train LFADS on
device = 'cuda' if torch.cuda.is_available() else 'cpu'; print('Using device: %s'%device)

Using device: cuda


## Load or Generate Data 

In [None]:
all_time =12000
start_time = 2000
Time = 30
neuron_num = 71
datax = scipy.io.loadmat("spike71_k09_26000.mat")
y_data = datax["spike71_k09_26000"][:,start_time:all_time].T
traj = scipy.io.loadmat("trj71_k09_26000.mat")
traj = traj["trj71_k09_26000"][:,start_time:all_time].T
NRep=int((all_time-start_time)/Time)
output = np.zeros((NRep,Time,neuron_num))
print(y_data.shape)
for i in range(all_time-start_time):
    for j in range(neuron_num):
        if y_data[i][j] > 25:
            y_data[i][j] = 25
for i in range(NRep):
    output[i,:,:] = y_data[Time*i:Time*(i+1),:]
output = output.astype('float32')
output = torch.Tensor(output).to(device)
print(output.shape)
output_valid = output.detach().clone()

## View example Ground Truth Firing Rates

In [None]:
train_ds      = torch.utils.data.TensorDataset(output)
valid_ds      = torch.utils.data.TensorDataset(output)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(traj[:,0],traj[:,1])
plt.figure(figsize=(4, 4))
ax1 = plt.subplot(311)
ax1.plot(traj[:, 0], lw=4, color='k')

ax2 = plt.subplot(312, sharex=ax1)
ax2.plot(traj[:, 1], lw=4, color='k')    

# Z score output:
output=output.reshape(Time*NRep,neuron_num)
from scipy import stats
output = stats.zscore(output,axis=0)
output = output.reshape(NRep,Time,neuron_num)
print(output[0].shape)
plt.figure(figsize = (12,12))
plt.imshow(y_data.T, cmap=plt.cm.plasma,aspect='auto')
plt.xlabel('Time (s)')
plt.ylabel('Cell #')
plt.colorbar(orientation='horizontal', label='Firing Rate (Hz)')

In [None]:
for i in range(neuron_num):
    fig, ax = plt.subplots()
    mappable = ax.scatter(traj[:,0], traj[:,1], c=y_data[:,i] ,cmap='coolwarm',vmin=0,vmax=8,s=1)
    fig.colorbar(mappable)
    plt.title(i+1)
    plt.show()
#fig.savefig("img.png")

## LFADS Schema
<img src='lfads_schema.png' width=800 align=left>

## Load model hyperparameters 

In [6]:
hyperparams = load_parameters('./parameters.yaml')
save_parameters(hyperparams)
hyperparams

{'dataset_name': 'chaotic_rnn',
 'run_name': 'demo',
 'g_dim': 20,
 'u_dim': 20,
 'factors_dim': 2,
 'g0_encoder_dim': 20,
 'c_encoder_dim': 20,
 'controller_dim': 20,
 'g0_prior_kappa': 0.1,
 'u_prior_kappa': 0.1,
 'keep_prob': 0.95,
 'clip_val': 5.0,
 'max_norm': 200,
 'learning_rate': 0.015,
 'learning_rate_min': 1e-05,
 'learning_rate_decay': 0.95,
 'scheduler_on': True,
 'scheduler_patience': 6,
 'scheduler_cooldown': 6,
 'kl_weight_schedule_start': 0,
 'kl_weight_schedule_dur': 2000,
 'l2_weight_schedule_start': 0,
 'l2_weight_schedule_dur': 2000,
 'epsilon': 0.1,
 'betas': (0.9, 0.99),
 'l2_gen_scale': 2000,
 'l2_con_scale': 0}

## Instantiate LFADS model

In [7]:
model = LFADS_Net(inputs_dim = neuron_num, T = Time, dt = 1, device=device,
                 model_hyperparams=hyperparams).to(device)

Random seed: 6583


#### Pick up where you left off (if you have a recent save) 

In [8]:
# model.load_checkpoint('recent')
batch_size = 3
"""tt = np.empty((3000,2))
for k in range(10):
    t = model.infer_trj(output_valid[batch_size *k:batch_size *(k+1),:])
    #print(t[90].shape)
    for i in range(batch_size):
        for j in range(100):
            tt[k*100*i+j,:]=t[j][i].to('cpu').detach().numpy().copy()
plt.figure()
plt.plot(tt)"""

"tt = np.empty((3000,2))\nfor k in range(10):\n    t = model.infer_trj(output_valid[batch_size *k:batch_size *(k+1),:])\n    #print(t[90].shape)\n    for i in range(batch_size):\n        for j in range(100):\n            tt[k*100*i+j,:]=t[j][i].to('cpu').detach().numpy().copy()\nplt.figure()\nplt.plot(tt)"

## Fit model

Rule of thumb: You can usually see good fit after 200 epochs (~30 mins runtime on Thinkpad GPU, ~2.5 hours on CPU), but to see good inference of perturbation timings need to run for about 800 epochs (~2 hours on Thinkpad GPU).

In [None]:
model.fit(train_ds, valid_ds, max_epochs=800, batch_size=batch_size , use_tensorboard=False,
          train_truth=train_ds, valid_truth=valid_ds) #4270

Beginning training...
Epoch:    1, Step:   111, training loss: 2620.159
recon: 3079, kl:   193, dir:    0, klw:   0
Epoch:    2, Step:   222, training loss: 2464.798
recon: 2893, kl:   367, dir:    2, klw:   0
Epoch:    3, Step:   333, training loss: 2426.019
recon: 2612, kl:   544, dir:    1, klw:   0
Epoch:    4, Step:   444, training loss: 2397.630
recon: 2745, kl:   597, dir:    1, klw:   0
Epoch:    5, Step:   555, training loss: 2379.417
recon: 2119, kl:   703, dir:    0, klw:   0
Epoch:    6, Step:   666, training loss: 2328.338
recon: 2550, kl:   883, dir:    4, klw:   0
Epoch:    7, Step:   777, training loss: 2309.886
recon: 2091, kl:  1405, dir:    2, klw:   0
Epoch:    8, Step:   888, training loss: 2288.126
recon: 2718, kl:  1303, dir:    0, klw:   0
Epoch:    9, Step:   999, training loss: 2280.330
recon: 2662, kl:  1836, dir:    0, klw:   0
Epoch:   10, Step:  1110, training loss: 2266.379
recon: 2269, kl:  2012, dir:    5, klw:   0
Epoch:   11, Step:  1221, training los

## Load checkpoint with lowest validation error 

In [None]:
#model.load_checkpoint('best')

## Plot results summary 

In [None]:
print(output_valid.shape)
#model.plot_factors()
tt = np.zeros(((all_time-start_time),2))
sp = np.zeros(((all_time-start_time),neuron_num))
batch_size=1
gg = int(NRep/batch_size)
for k in range(gg):
    t = model.infer_factors(output_valid[batch_size *k:batch_size *(k+1),:])
    #spi = model.reconstruct(output_valid[batch_size *k:batch_size *(k+1),:])
    #print(spi.shape)
    for i in range(batch_size):
        for j in range(Time):
            tt[(i+batch_size*k)*Time+j,:] = t[i][j].to('cpu').detach().numpy().copy()
            #for l in range(neuron_num):
                #sp[(i+batch_size*k)*Time+j,l] = spi[j][l]
plt.figure()
plt.plot(tt)

In [None]:
for i in range(neuron_num):
    plt.figure()
    plt.plot(y_data[:,i], linewidth = 2)
    plt.plot(sp[:,i],'r',linewidth = 1)
    plt.title(i)
    plt.ylim(-2,20)

In [None]:
import funs
qz_mean_est = tt
#plt.plot(qz_mean_est[:,0])
qz_est_norm = qz_mean_est#np.stack(qz_mean_est)/np.linalg.norm(np.stack(qz_mean_est))
#plt.plot(qz_est_norm[:,0],qz_est_norm[:,1])
z_true_c = traj# - x_test.mean(axis=0)
z_true_norm = z_true_c#/np.linalg.norm(z_true_c)

R = funs.compute_optimal_rotation(np.stack(qz_est_norm), z_true_norm, scale=True)
qz_est_norm_R = np.stack(qz_est_norm).dot(R)

from scipy import signal
qz_est_norm_R[:,0] = signal.savgol_filter(qz_est_norm_R[:,0], 51, 5)
qz_est_norm_R[:,1] = signal.savgol_filter(qz_est_norm_R[:,1],51, 5)
st=0
en=4000
plt.figure(figsize=(14,8))
plt.subplot(211)
plt_post = plt.plot(qz_est_norm_R[st:en,0],'r', linewidth = 1, label = 'posterior mean')
plt_true = plt.plot(z_true_norm[st:en,0], 'k', linewidth = 1, label = '\"true\" mean')
plt.subplot(212)
plt_post = plt.plot(qz_est_norm_R[st:en,1],'r', linewidth = 1, label = 'posterior mean')
plt_true = plt.plot(z_true_norm[st:en,1], 'k', linewidth = 1, label = '\"true\" mean')

fig = plt.figure(figsize=(5,5))
plt.plot(z_true_norm[st:en,0], z_true_norm[st:en,1], lw=2, color = 'k')
plt.plot(qz_est_norm_R[st:en,0], qz_est_norm_R[st:en,1], lw=2, color = 'r')

In [None]:
#np.savetxt('lfads_k09_L.csv', qz_est_norm_R, delimiter=',')