In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

import h5py

from adaptnn.model_fitting import ArtificialModel
import tltorch
torch.set_default_device('cuda:0')

In [None]:
model = ArtificialModel(dataset_params = {"filter_spatial" : (15,15),
                                         "filter_time" : 10,
                                         "num_cells" : 4,
                                         "out_noise_std_train" : 0.1,
                                         "out_noise_std_test" : None,
                                         "filter_rank" : 2,
                                         "disjoint_segments" : True},
                        net_params = {"layer_time_lengths" : (10,1),
                                     "layer_rf_pixel_widths" : (5,5),
                                     "layer_channels" : (4,4,),
                                     "layer_spatio_temporal_rank" : 4,
                                     "layer_spatio_temporal_factorization_type" : ('spatial',),
                                     "out_normalization" : True,
                                     "layer_normalization" : True})
print(model.dataset.start_idx_X_train)
print(model.dataset.X_train.shape)

In [None]:
model.train(epochs=2000,print_every=10,penalty_params = {"en_lambda" : 0.0001}, optimizer_params = {"lr" : 1e-3})#, scheduler_params=None)

In [None]:
Y_fit, Y_true = model.predict()
Y_fit = Y_fit.cpu().numpy().squeeze()
Y_true = Y_true.cpu().numpy().squeeze()

In [None]:
NC = 4
NR = int(np.ceil(model.dataset.num_cells)/NC)

T = 100
plt.figure(figsize=(NC*3,NR*2))
for ii in range(model.dataset.num_cells):
    plt.subplot(NR,NC,ii+1)
    plt.plot(Y_true[ii,:T])
    plt.plot(Y_fit[ ii,:T])

In [None]:
with torch.no_grad():
    Y2 = model.model(model.dataset.X_train.unsqueeze(0).unsqueeze(0)).cpu().numpy().squeeze()
    Y1 = model.dataset.Y_train_0.cpu().numpy()[:,9:]

In [None]:
NC = 4
NR = int(np.ceil(model.dataset.num_cells)/NC)

T = 200
plt.figure(figsize=(NC*3,NR*2))
for ii in range(model.dataset.num_cells):
    plt.subplot(NR,NC,ii+1)
    plt.plot(Y1[ii,:T])
    plt.plot(Y2[ ii,:T])
    # plt.scatter(Y1[ii,:],Y1[ii,:])

In [None]:
mp = 15 * 15 * 40 * 8
pp = 8*3*3*40 + 6*8*8*3*3


print(f"max params {mp}")
print(f"paper params {pp}")
for rank in range(1,8):
    fp = (15*rank + 15*rank + 40 * rank + rank**3)*8
    print(f"factored params {fp} with rank {rank}")

# layer 1: spatial 15x15, temporal 40, channels 8, rank 6

In [None]:
time = 10
mp = 11 * 11 * time * 8
pp = 5*8*8*3*3


print(f"max params {mp}")
print(f"paper params {pp}")
for rank in range(1,8):
    fp = (11*rank + 11*rank + time * rank + rank**3)*8
    print(f"factored params {fp} with rank {rank}")

# layer 2: spatial 11x11, temporal 8, channels 8, rank 5