In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

import sys
sys.path.append("../..")


import neuroprob as mdl
from neuroprob import utils

import utils_func

dev = utils.pytorch.get_device()

PyTorch version: 1.7.0
Using device: cuda:0


### Data

In [2]:
# loading data
dataset = 'hc3'
session_id = 'ec014.468'

data = np.load('./checkpoint/{}_{}.npz'.format(dataset, session_id))
spktrain = data['spktrain']
x_t = data['x_t']
y_t = data['y_t']
hd_t = data['hd_t']
theta_t = data['theta_t']
eeg_t = data['eeg_t']
sample_bin = data['sample_bin']

units = spktrain.shape[0]

In [5]:
c_x_t = utils_func.class_x_t(x_t)
dir_t = utils_func.L_R_run(c_x_t)
ind_L_R = np.where(dir_t == -1)
ind_R_L = np.where(dir_t == 1)
ind_stat = np.where(dir_t == 0)

# incompleted runs:  1


In [None]:
n_spikes_L_R = np.sum(spktrain[:, ind_L_R], axis=(1, 2), dtype=int)
n_spikes_R_L = np.sum(spktrain[:, ind_R_L], axis=(1, 2), dtype=int)
n_spikes_stat = np.sum(spktrain[:, ind_stat], axis=(1, 2), dtype=int)
n_spikes = np.vstack((n_spikes_L_R, n_spikes_stat, n_spikes_R_L))

### Decoder

In [None]:
hist_len = 1

rp_t = np.stack((rhd_t,))
x_dims = rp_t.shape[0]

cov_hist = tools.lagged_input(torch.tensor(rp_t).float(), hist_len=hist_len, hist_stride=1, time_stride=1)
act_hist = tools.lagged_input(torch.tensor(rc_t).float(), hist_len=hist_len, hist_stride=1, time_stride=1) # dim, time, hist


batch_size = 1000

# inputs
dec_input = act_hist.permute(1, 0, 2).float() # time, neurons, hist
enc_input = cov_hist.permute(1, 0, 2).float()

# targets
cov_batched = torch.split(torch.tensor(rp_t[:, hist_len-1:].T).float(), batch_size)


if hist_len > 1:
    dec_input_batched = torch.split(dec_input[:-hist_len+1, ...], batch_size)
    enc_input_batched = torch.split(enc_input[:-hist_len+1, ...], batch_size)
    act_batched = torch.split(torch.tensor(rc_t[:, :-hist_len+1].T).float(), batch_size)
else:
    dec_input_batched = torch.split(dec_input, batch_size)
    enc_input_batched = torch.split(enc_input, batch_size)
    act_batched = torch.split(torch.tensor(rc_t.T).float(), batch_size)

In [None]:
x_dist = [distributions.Tn_Normal]
muMLP = tools.MLP([50, 20, 20, 30], neurons, lat_dims, nonlin=tools.Siren(), out=None)
sigmaMLP = tools.MLP([50, 20, 20, 30], neurons, lat_dims, nonlin=tools.Siren(), out=nn.Softplus())
sigmaMLP.net[-2].weight.data.fill_(0.)
sigmaMLP.net[-2].bias.data.fill_(0.)
kern = torch.ones((neurons, hist_len))/hist_len

rate_dec = utils.decoding.rate_decoder(muMLP, sigmaMLP, kern)
rate_dec.to(dev)

In [None]:
# training
optimizer = torch.optim.Adam(list(rate_dec.parameters()), lr=1e-4, weight_decay=0.0)
sch = optim.lr_scheduler.MultiplicativeLR(optimizer, lambda e: 0.9)
print("Number of parameters: {}".format(sum(p.numel() for p in rate_dec.parameters())))

loss = utils.decoding.fit_decoder(rate_dec, x_dist, dec_input_batched, cov_batched, resamples, optimizer, sch, 100, 
                              dev, 10000)

plt.plot(loss)
plt.show()