In [11]:
import math
import h5py
import torch
import numpy as np
from vjf.model import VJF
import matplotlib.pyplot as plt

from config import get_cfg_defaults
# pip install git+https://github.com/catniplab/vjf.git

cfg = get_cfg_defaults()
data = h5py.File('data/poisson_obs.h5')

n_latents = 2
n_neurons = 150
n_time_bins = 1000

n_trials = 5
bin_size_ms = 5
time_delta = bin_size_ms * 1e-3

In [12]:
# loading matrix
b = torch.randn(n_neurons)
C = torch.randn(n_latents, n_neurons)

# latent states
t = torch.arange(0, n_time_bins, step=time_delta)  # time point to be evaluated
X = torch.column_stack((torch.sin(t), torch.cos(t)))  # latent trajectory
X = X + 0.1 * torch.randn_like(X)

# observations
Y = X @ C + b
Y = Y + 0.1 * torch.randn_like(Y)

In [13]:
# Setup and fit VJF
n_rbf = 5  # number of radial basis functions for dynamical system
hidden_sizes = [20]  # size of hidden layers of recognition model
likelihood = 'gaussian'  # gaussian or poisson
# likelihood = 'poisson'  # gaussian or poisson

model = VJF.make_model(n_neurons, n_latents, udim=0, n_rbf=n_rbf, hidden_sizes=hidden_sizes, likelihood=likelihood)

In [None]:
m, logvar, _ = model.fit(Y, max_iter=1)  # fit and return list of state posterior tuples (mean, log variance)
m = m.detach().numpy().squeeze()

  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
m.shape