In [5]:
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt

from config import get_cfg_defaults
from SequentialVAE import NeuralVAE, SeqDataLoader

In [8]:
cfg = get_cfg_defaults()
data = h5py.File('data/poisson_obs.h5')
Y = torch.tensor(np.array(data['Y']), dtype=torch.float32)
X = torch.tensor(np.array(data['X']), dtype=torch.float32)
C = torch.tensor(np.array(data['C']), dtype=torch.float32)
b = torch.tensor(np.array(data['bias']), dtype=torch.float32)

n_epochs = 100
batch_size = 10
time_delta = 5e-3
n_latents = X.shape[2]
n_neurons = Y.shape[2]
n_time_bins = Y.shape[1]

In [7]:
vae = NeuralVAE(cfg, time_delta, n_neurons, n_latents, n_time_bins)
vae.manually_set_readout_params(C, b)

vae.decoder.C.bias.requires_grad_(False)
vae.decoder.C.weight.requires_grad_(False)
train_data_loader = SeqDataLoader((Y, X), batch_size)


opt = torch.optim.Adam(vae.parameters(), lr=1e-2)
for epoch in range(n_epochs):
    for batch_idx, (y, x) in enumerate(train_data_loader):
        loss, z, mu_t, log_var_t = vae(y, y, 1.0)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1e-1, norm_type=2)
        opt.step()
        opt.zero_grad()
        # print(batch_idx)

    print(epoch)
    if epoch % 5 == 0:
        plt.plot(z[:, 0, 0].detach().numpy())
        plt.plot(x[0, :, 0].detach().numpy())
        plt.show()

KeyboardInterrupt: 