In [23]:
import math
import h5py
import torch
import numpy as np
from vjf.model import VJF
import matplotlib.pyplot as plt
from einops import rearrange

from config import get_cfg_defaults
# pip install git+https://github.com/catniplab/vjf.git


In [24]:
# # loading matrix
# n_latents = 2
# n_neurons = 150
# n_time_bins = 1000
#
# n_trials = 5
# bin_size_ms = 5
# time_delta = bin_size_ms * 1e-3
#
# b = torch.randn(n_neurons)
# C = torch.randn(n_latents, n_neurons)
#
# # latent states
# t = torch.arange(0, n_time_bins)*time_delta  # time point to be evaluated
# X = torch.column_stack((torch.sin(t), torch.cos(t)))  # latent trajectory
# X = X + 0.1 * torch.randn_like(X)
#
# # observations
# Y = X @ C + b
# Y = Y + 0.1 * torch.randn_like(Y)

In [28]:
# load and prepare data

cfg = get_cfg_defaults()
data = h5py.File('data/poisson_obs.h5')


Y = np.array(data['Y'])
X = np.array(data['X'])
C = np.array(data['C'])
b = np.array(data['bias'])

# setup parameters
n_trials = 1
bin_size_ms = 5
time_delta = bin_size_ms * 1e-3

n_latents = X.shape[2]
n_neurons = Y.shape[2]
n_time_bins = Y.shape[1]

Y = Y[0]
X = X[0]

In [None]:
# Setup and fit VJF
n_rbf = 64  # number of radial basis functions for dynamical system
hidden_sizes = [32]  # size of hidden layers of recognition model
# likelihood = 'gaussian'  # gaussian or poisson
likelihood = 'poisson'  # gaussian or poisson

model = VJF.make_model(n_neurons, n_latents, udim=0, n_rbf=n_rbf, hidden_sizes=hidden_sizes, likelihood=likelihood)

In [None]:
m, logvar, _ = model.fit(Y, max_iter=100)  # fit and return list of state posterior tuples (mean, log variance)
m = m.detach().numpy().squeeze()

In [None]:
# regress to account for invariance
X_hat = m #rearrange(m, 'batch time lat ->  (batch time) lat')
S = np.linalg.pinv(X_hat) @ X.reshape(n_trials * n_time_bins, n_latents)
X_hat_tilde = X_hat @ S
X_hat_tilde = X_hat_tilde.reshape(n_trials, n_time_bins, n_latents)

# Plot
fig, axs = plt.subplots(2, 1, sharex='all')
print(f'X_hat shape: {X_hat.shape}')
axs[0].plot(X_hat_tilde[0, :, 0])
axs[0].plot(X[0, :, 0])
axs[1].plot(X_hat_tilde[0, :, 1])
axs[1].plot(X[0, :, 1])
plt.show()

In [None]:
# plot the true and inferred latents of trial 0
trial_num = 1
fig, [ax0, ax1] = plt.subplots(1, 2, figsize=(10,5))
_= ax0.plot(X[trial_num,:,0], X[trial_num,:,1])
_= ax1.plot(X_hat_tilde[trial_num,:,0], X_hat_tilde[trial_num,:,1])