In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch

from adaptnn.model_fitting import MCJN05DataModel


torch.set_default_device('cuda:0')

In [2]:
model = MCJN05DataModel(dataset_params={"dtype" : torch.float32, "train_long_contrast_levels" : (1,3)},
                        net_params={"layer_time_lengths" : (40,12), "layer_channels" : (16,8)})

Building multi-layer temporal convolutional model for 9 neurons and full-field stimuli.
Adding full-rank convolutional layer of size 40 and 8 channels.
Adding nonlinearity: Softplus.
Adding full-rank convolutional layer of size 1 and 8 channels.
Adding nonlinearity: Softplus.
Adding full-connected linear layer: 8 to 9.
Adding output nonlinearity: Softplus.
Model initialized.


In [3]:
model.train(4000, print_every=50)

epoch 50, loss 14.252159600956782, step size 0.0001
epoch 100, loss 11.623149816385055, step size 0.0001
epoch 150, loss 10.623230267629108, step size 0.0001
epoch 200, loss 9.713234041510166, step size 0.0001
epoch 250, loss 8.999063212087266, step size 0.0001
epoch 300, loss 8.530660571564134, step size 0.0001
epoch 350, loss 8.276841308618321, step size 0.0001
epoch 400, loss 8.186914389870969, step size 0.0001
epoch 450, loss 8.148620045867176, step size 0.0001
epoch 500, loss 8.121884845598746, step size 0.0001
epoch 550, loss 8.101445698893357, step size 0.0001
epoch 600, loss 8.090498383072424, step size 0.0001
epoch 650, loss 8.07585142774582, step size 0.0001
epoch 700, loss 8.064733773268928, step size 0.0001
epoch 750, loss 8.049437077129776, step size 0.0001
epoch 800, loss 8.033686130090338, step size 0.0001
epoch 850, loss 8.019568220046924, step size 0.0001
epoch 900, loss 8.00846486659848, step size 0.0001
epoch 950, loss 7.999636632816127, step size 0.0001


In [None]:
X1,Y1_0 = model.predict_rpt(1)
X2,Y2_0 = model.predict_rpt(2)
X3,Y3_0 = model.predict_rpt(3)
with torch.no_grad():
    Y1 = Y1_0.mean(dim=0).cpu().numpy()
    Y2 = Y2_0.mean(dim=0).cpu().numpy()
    Y3 = Y3_0.mean(dim=0).cpu().numpy()

    X1 = X1.cpu().numpy()
    X2 = X2.cpu().numpy()
    X3 = X3.cpu().numpy()

In [None]:


NC = 3
NR = 9
plt.figure(figsize=(NC*4, NR*3))

for cc in range(9):
    plt.subplot(NR,NC, cc*NC + 1)
    plt.plot(Y1[cc,:],color='black')
    plt.plot(X1[cc,:])

    plt.subplot(NR,NC, cc*NC + 2)
    plt.plot(Y2[cc,:],color='black')
    plt.plot(X2[cc,:])

    plt.subplot(NR,NC, cc*NC + 3)
    plt.plot(Y3[cc,:],color='black')
    plt.plot(X3[cc,:])

In [None]:
X = model.dataset.X_full[-1,...]
Y = model.dataset.Y_full[-1,...]

In [None]:
P = 50
T = 10000
t_0 = 1000
Ys = np.zeros((T,9))
Xs = np.ones((T,P+1))
for ii in range(T):
    Xs[ii,:P] = X[(ii-P+1+t_0):(ii+1+t_0)]
    Ys[ii,:] = Y[:,ii+t_0]

In [None]:
b,*_ = np.linalg.lstsq(Xs,Ys,rcond=None)

In [None]:
plt.plot(b[:-1,:]);

In [None]:
Y1.shape