In [None]:
import math

import matplotlib.pyplot as plt
import numpy as np
import torch

from vjf.model import VJF

import h5py
from einops import rearrange

def grid(n, lims):
    xedges = np.linspace(*lims, n)
    yedges = np.linspace(*lims, n)
    X, Y = np.meshgrid(xedges, yedges)
    grids = np.column_stack([X.reshape(-1), Y.reshape(-1)])
    return X, Y, grids

In [None]:
# Setup precision and random seeds
torch.set_default_dtype(torch.double)  # using double precision
np.random.seed(0)
torch.manual_seed(0)

In [None]:
# Generate data
T = 100.  # length
dt = 1e-2 * math.pi  # size of time step
xdim = 2  # state dimensionality
ydim = 20  # obsetvation dimensionality
udim = 0  # size of control input

C = torch.randn(xdim, ydim)  # loading matrix
d = torch.randn(ydim)  # bias

t = torch.arange(0, T, step=dt)  # time point to be evaluated
x = torch.column_stack((torch.sin(t), torch.cos(t)))  # latent trajectory
x = x + torch.randn_like(x) * 0.1  # add some noise

# observation
y = x @ C + d
y = y + torch.randn_like(y) * 0.1  # add some noise

# Plot latent trajectory
fig = plt.figure()
ax = fig.add_subplot(221)
plt.plot(x.numpy())
plt.title('True state')

In [None]:
# Setup and fit VJF 
n_rbf = 100  # number of radial basis functions for dynamical system
hidden_sizes = [20]  # size of hidden layers of recognition model
likelihood = 'gaussian'  # gaussian or poisson
# likelihood = 'poisson'  # gaussian or poisson

model = VJF.make_model(ydim, xdim, udim=udim, n_rbf=n_rbf, hidden_sizes=hidden_sizes, likelihood=likelihood)

m, logvar, _ = model.fit(y, max_iter=150)  # fit and return list of state posterior tuples (mean, log variance)

m = m.detach().numpy().squeeze()

In [None]:
# Plot
x = x.numpy()
X_hat = m 
S = np.linalg.pinv(X_hat) @ x.reshape(n_time_bins, 2)
X_hat_tilde = X_hat @ S
X_hat_tilde = X_hat_tilde.reshape(n_time_bins, 2)

fig, axs = plt.subplots(2,1, sharex='all')
axs[0].plot(x[:, 0])
axs[0].plot(X_hat_tilde[:, 0])

axs[1].plot(x[:, 1])
axs[1].plot(X_hat_tilde[:, 1])
plt.legend(["Data", "Fit"])
plt.show()

In [None]:
# Draw the inferred velocity field

ax = fig.add_subplot(223)
r = np.mean(np.abs(m).max())  # determine the limits of plot

Xm, Ym, XYm = grid(51, [-1.5*r, 1.5*r])
Um, Vm = model.transition.velocity(torch.tensor(XYm)).detach().numpy().T  # get velocity
Um = np.reshape(Um, Xm.shape)
Vm = np.reshape(Vm, Ym.shape)
plt.streamplot(Xm, Ym, Um, Vm)
plt.plot(*m.T, color='C1', alpha=0.5, zorder=5)
plt.title('Velocity field')

## Poisson

In [None]:
#generate Poisson observations: spikes
delta = 5e-3  # time bin size
C = np.random.normal(size=(n_neurons, n_latents))
max_c = .8
C = np.where(np.abs(C)>max_c, np.sign(C)*max_c, C)
C = 4*C
b = np.log(0.5 + 0.05 * np.random.rand(n_neurons))  
rates = np.exp(x@C.T + b)
r = delta * rates

Y_unchopped = np.random.poisson(r) 
Y = np.where(Y_unchopped>1, 1, Y_unchopped)

In [None]:
plt.plot(r[:,0])

In [None]:
cidx1 = np.lexsort((C[:,0], C[:,1]), axis=0)
cidx2 = np.lexsort((C[:,1], C[:,0]), axis=0)

raster = []
rasterSorted = []
for k in range(n_neurons):
    raster.append(np.nonzero(Y.T[k,:])[0]/nT*T)
    rasterSorted.append(np.nonzero(Y.T[cidx1[k],:])[0]/nT*T)

plt.subplots(1,2, figsize=(10, 4))
plt.subplot(1,2,1)
plt.eventplot(raster, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.yticks([]); plt.title('raster plot'); plt.ylabel('neurons');
plt.subplot(1,2,2)
plt.eventplot(rasterSorted, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.yticks([]); plt.title('raster plot (again)'); plt.ylabel('sorted neurons');

In [None]:
# Setup and fit VJF 
n_rbf = 100  # number of radial basis functions for dynamical system
hidden_sizes = [50]  # size of hidden layers of recognition model
# likelihood = 'gaussian'  # gaussian or poisson
likelihood = 'poisson'  # gaussian or poisson

model = VJF.make_model(n_neurons, n_latents, udim=udim, n_rbf=n_rbf, hidden_sizes=hidden_sizes, likelihood=likelihood)

m, logvar, _ = model.fit(Y, max_iter=150)  # fit and return list of state posterior tuples (mean, log variance)

m = m.detach().numpy().squeeze()

In [None]:
# Plot
X_hat = m 
S = np.linalg.pinv(X_hat) @ x.reshape(n_time_bins, 2)
X_hat_tilde = X_hat @ S
X_hat_tilde = X_hat_tilde.reshape(n_time_bins, 2)

fig, axs = plt.subplots(2,1, sharex='all')
axs[0].plot(x[:, 0])
axs[0].plot(X_hat_tilde[:, 0])

axs[1].plot(x[:, 1])
axs[1].plot(X_hat_tilde[:, 1])
plt.legend(["Data", "Fit"])
plt.show()

In [None]:
# Draw the inferred velocity field

ax = fig.add_subplot(223)
r = np.mean(np.abs(m).max())  # determine the limits of plot

Xm, Ym, XYm = grid(51, [-1.5*r, 1.5*r])
Um, Vm = model.transition.velocity(torch.tensor(XYm)).detach().numpy().T  # get velocity
Um = np.reshape(Um, Xm.shape)
Vm = np.reshape(Vm, Ym.shape)
plt.streamplot(Xm, Ym, Um, Vm)
plt.plot(*m.T, color='C1', alpha=0.5, zorder=5)
plt.title('Velocity field')