In [None]:
import matplotlib.pyplot as plt
from elephant.spike_train_generation import NonStationaryPoissonProcess
from neo.core import SpikeTrain, AnalogSignal
from quantities import s, ms, Hz
import numpy as np

# Plots

Let's first set up some helper functions for plotting dynamics

$\mathbf{\dot{z}} = f(\mathbf{z})$

In [None]:
def f(A, x):
    if len(x.shape) > 1:
        x = np.expand_dims(x, -1)
    dx_dt = np.dot(A, x)
    return np.squeeze(dx_dt)

def odeint(a, x0, dt, T):
    # Initialize variables
    t = np.arange(0, T, dt)
    x = np.zeros([t.shape[0], x0.shape[0]])
    x[0] = x0 # This is x at time t_0

    # Step through system and integrate in time
    for k in range(1, len(t)):
        # for each point in time, compute xdot from x[k-1]
        xdot = f(a, x[k-1])

        # Update x based on x[k-1] and xdot
        x[k] = x[k-1] +  xdot * dt

    return x, t


# Choose parameters
A = np.array([[-0.5, 1], [-1, 0]])    # parameter in f(x)
T = 10      # total Time duration
dt = 0.001  # timestep of our simulation
x0 = np.array([1, 1])     # initial condition of x at time 0

# Use Euler's method
x, t = odeint(A, x0, dt, T)

_, axs = plt.subplots(1, 2, figsize=(15, 5))

# Visualize
axs[0].plot(t, x[:, 0], label="z1")
axs[0].plot(t, x[:, 1], label="z2")
axs[0].set_ylabel("trajectory")
axs[0].set_xlabel("time/s")
axs[0].legend()

# Define the ranges of x1 and x2 values
x1_range = np.arange(-1.5, 1.5, 0.2)
x2_range = np.arange(-1.5, 1.5, 0.2)

# Create a meshgrid of x1 and x2 values
X1, X2 = np.meshgrid(x1_range, x2_range)

# Compute the linear dynamical system vector field at each point in the meshgrid
dx1_dt, dx2_dt = f(A, np.stack([X1, X2], axis=-1))

# Plot the vector field
axs[1].quiver(X1, X2, dx1_dt, dx2_dt, color='r', label="vector field")
axs[1].set_xlabel('z1')
axs[1].set_ylabel('z2')
axs[1].set_title('Vector Field')
axs[1].grid()
axs[1].plot(x[:,0], x[:,1], label="trajectory")
axs[1].legend()

plt.show()

# Ground-Truth: spike train generation from latent dynamics

We simulate a poisson process conditional on latent dynamics:


$\mathbf{\dot{z}} = f(\mathbf{z})$

$\lambda(t) = \mathrm{exp}(\gamma \, \mathbf{w_{rand}} \, \mathbf{z})$

$\mathbf{x} \sim Poisson(\cdot|\lambda(t))$

Where $gamma$ is some constant factor and $\mathbf{w_{rand}}$ is a fixed $\mathrm{dim}(\mathbf{x}) \times \mathrm{dim}(\mathbf{z})$ matrix with randomly generated entries.

In [None]:
sigmoid = lambda z: 1/(1 + np.exp(-z))

w_rand = np.random.random_sample([4, 2])
gamma = 3

_, axs = plt.subplots(4, 3, figsize=(20, 15))
axs[0, 0].plot(t, x[:, 0], label="z1")
axs[0, 0].plot(t, x[:, 1], label="z2")
axs[0, 0].set_ylabel("latent trajectory")
axs[0, 0].set_xlabel("time/s")
axs[0, 0].legend()

axs[1, 0].axis("off")
axs[2, 0].axis("off")
axs[3, 0].axis("off")

rate_signal = np.exp((gamma*w_rand@np.expand_dims(x, -1)).squeeze())

for i, r in enumerate(rate_signal.transpose()):
    rate = AnalogSignal(r, units=Hz, sampling_period=ms)
    axs[i, 1].plot(rate)
    axs[i, 1].set_ylabel("rate/Hz")
    axs[i, 1].set_xlabel("time/ms")
    # axs[i, 1].set_xticklabels([f"{i}" for i in range(10)])
    p = NonStationaryPoissonProcess(rate_signal=rate)
    train = p.generate_spiketrain()
    axs[i, 2].eventplot([train.magnitude[::4]], linelengths=0.75, color='black')
    axs[i, 2].set_xlabel("time/s")

plt.show()

# Inference

Now, we would like to infer (a distribution over) the latent dynamics -- in a scalable manner -- by only observing $\mathbf{x}$ (the spike times).

Here are some references:

[1] Anqi Wu, Nicholas A. Roy, Stephen Keeley, Jonathan W. Pillow, *Gaussian process based nonlinear latent structure discovery in multivariate spike train data*, https://proceedings.neurips.cc/paper/2017/hash/b3b4d2dbedc99fe843fd3dedb02f086f-Abstract.html

[2] Qi She, Anqi Wu, *Neural Dynamics Discovery via Gaussian Process Recurrent Neural Networks*, http://proceedings.mlr.press/v115/she20a.html

## TODOs

- (re-) implement the methods above and/or LFADS or something
- maybe we can come up with an alternative to [1] and [2] using either GP dynamics models and/or neural odes. The idea would be similar to [2] but 