# Generative perspective on State Space Models with spike train observations

In this notebook, we will play around with parameters of a state space model and generate various spike trains.
These population activity patterns will be fun to look at. (hopefully)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# set a default figure size for matplotlib figures (using 1 + golden ratio)
plt.rcParams['figure.figsize'] = (10 * np.array([1, 1 / (1 + (np.sqrt(5)+ 1)/2)])).tolist()
# set a default font size for matplotlib figures
plt.rcParams['font.size'] = 12

## Neural Manifolds and Causality

Neurons cause the coordinated activity pattern that we experimentally observe. However, with the current experimental technology, we do not yet have enough data to recover the full spiking neural network faithfully.

Fortunately, neural recordings have a lot of spatial structure that restricts activity to a low-dimensional manifold.
Moreover, neural recordings have a lot of temporal structure to define an effective "flow" on the manifold.

Therefore, to analyze the neural data, we could focus on the effective collective behavior of the population reflected in neural recordings.
We assume we have access to the low-dimensional population state via the partially observed neurons.
Think of the neurons as noisy measurement devices that are coupled with the thinig we really want to measure, the neural population state.

This means, we can define an observation model: $p(y(t) \mid x(t))$ where $y(t)$ is the neural data and $x(t)$ is the latent (i.e. hidden or unobserved) neural state. $x(t)$ lives on the coordinate system that parametrizes the neural manifold or its embedding space.

This may may look acausal, but it makes sense in the statistical sense. Using this model does not assume that the neurons are not fundamentally generating the latent state dynamics, it's merely a methodological necessity.

For educational purposes, it is useful to generate spike trains as if this acausal model is true.
It allows to gain intuition about what the statistical model expects the observed data to be if the assumptions hold up.

## A simple 1-D latent (neural) trajectory

For illustration, we will use a sinusoid as the 1-D latent trajectory.
$$ x(t) = sin(2\pi f\cdot t) $$
In this example, $x(t)$ represents the (instantaneous) state of the neural population of interest.

In [None]:
# simulate a simple latent trajectory
nT = 1000  # number of time points (samples) in the simulated trajectory
T = 10 # duration of the trajectory in seconds
frq = 0.3 # frequency of the sinusoid in Hz
tr = np.linspace(0, T, nT) # time range
dt = tr[1] - tr[0] # time step in seconds
x = np.sin(2 * np.pi * frq * tr)[:, np.newaxis]  # generate a sinusoid over time, shape [nT, 1]
fig = plt.figure(); plt.plot(tr, x); plt.title('1-D latent trajectory'); plt.xlabel('time (s)');

## One Poisson neuron driven by the latent process

We will generate spike trains from a *Poisson neuron*, which is just an inhomogeneous Poisson process with a time varying firing rate function $\lambda(t)$.
The spike count $y(t)$ in a small time bin of size $\Delta$ is distributed as the Poisson distribution:
$$ y(t) \sim \text{Poisson}(\Delta\lambda(t)) $$

Importantly, the firing rate will be a function of $x(t)$, but not of past $x$ nor past $y$.
$$ \lambda(t) = g(x(t)) $$
The only constraint for $g(\cdot)$ is that the resulting firing rate has to be non-negative.
A mathematically convenient function is the exponential function and relates to the spike-response model (Gerstner & Kistler 2002).

$$ \lambda(t) = \exp(a x(t) + b) = \exp(b)\exp(a x(t)) $$

- Gerstner, W., & Kistler, W. M. (2002). Spiking Neuron Models. Cambridge University Press. https://doi.org/10.1017/cbo9780511815706

In [None]:
# plot the log-linear Poisson neuron's tuning curve
# y-axis is the firing rate, x-axis is the signal to be encoded
a = 5
b = -3
lam = np.exp(a * x + b)

plt.figure(figsize=(6, 4))
plt.plot(x, lam)
plt.xlabel('x (latent process)')
plt.ylabel('Firing rate λ(x) in Hz')
plt.title('Log-linear Poisson neuron tuning curve')
plt.grid(True)
plt.show()

In [None]:
y = np.random.poisson(lam*dt)

plt.figure(figsize=(10, 2))
plt.plot(tr, lam, label='firing rate');
plt.eventplot(np.nonzero(y)[0]/nT*T, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.yticks([]); plt.legend();

## A population of Poisson neurons driven by a common 1-D latent process

We can have more than one neuron that's driven by the same latent process.
This way, we an have more observation dimensions than the latent state space dimension.
Let's given them a random amount of "drive".

In [None]:
nNeuron = 200
C = 2 * np.random.randn(nNeuron, 1)
b = -2.0 + np.random.rand(1, nNeuron)
lam = np.exp(x @ C.T + b)
y = np.random.poisson(lam*dt)

We can make spike raster plot. But since we know the amount of drive (each value in $C$), we can sort the neurons accordingly as well.

In [None]:
cidx = np.argsort(C)

raster = []
rasterSorted = []
for k in range(nNeuron):
    raster.append(np.nonzero(y[k,:])[0]/nT*T)
    rasterSorted.append(np.nonzero(y[cidx[k],:])[0]/nT*T)

plt.subplots(1,2, figsize=(10, 4))
plt.subplot(1,2,1)
plt.eventplot(raster, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.yticks([]); plt.title('raster plot'); plt.ylabel('neurons');
plt.subplot(1,2,2)
plt.eventplot(rasterSorted, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.yticks([]); plt.title('raster plot (again)'); plt.ylabel('sorted neurons');

# Signal-to-Noise Ratio of population spike trains

We will use the `neurofisherSNR` package to estimate the upper bound of the signal-to-noise ratio to see how much information the spike trains have about the latent variable per time bin.

In [None]:
#!git clone https://github.com/catniplab/neurofisherSNR.git
#!pip install ./neurofisherSNR

from neurofisherSNR.snr import SNR_bound_instantaneous
from neurofisherSNR.utils import power_to_dB, power_from_dB

In [None]:
SNRdb = SNR_bound_instantaneous(x, C.T, b)
print(f"{SNRdb:.2f} dB")  # note that decibel is a logarithmic unit

## 2D latent space example

Here we will build a 2D manifold with two independent processes.
The first latent dimension will be same as above, but we will add $x_2(t)$ as a sawtooth function:
$$ x_2(t) = t \,\, \text{mod} \, 1 $$

In [None]:
x2 = 1.5 * ((tr % 1) - 0.5)[:, np.newaxis]
X = np.hstack([x, x2]) # (time) x (latent dim), python likes to have the time dimension first
dLatent = X.shape[1]

In [None]:
plt.subplots(2,1,figsize=(10,4))
plt.subplot(2,1,1);
plt.plot(tr, x ); plt.ylabel('first latent dim'); plt.xlabel('time')
plt.subplot(2,1,2);
plt.plot(tr, x2); plt.ylabel('second latent dim'); plt.xlabel('time')

Now that we have more than 1 latent dimension, we face a choice of how neurons relate to each of the latent dimes.

### Random projection observation

Random projection assumes that each neuron is deriven by all the latent dimensions with a random amount.
Under this assumption, the neural manifold is likely oblique to the axes, i.e., the neuron will be modulated by changes in any direction in the latent state space.
Theoretical analysis of [Gao & Ganguli 2015] assumes random projections and showed that not many neurons need to be sampled (observed) to recover the manifold structure.
In addition, the so-called *mixed-selectivity* [Tye et al. 2024] appears as a result.

- Gao, P., & Ganguli, S. (2015). On Simplicity and Complexity in the Brave New World of Large-Scale Neuroscience. Current Opinion in Neurobiology, 32, 148–155.
- Tye, K. M., Miller, E. K., Taschbach, F. H., Benna, M. K., Rigotti, M., & Fusi, S. (2024). Mixed selectivity: Cellular computations for complexity. Neuron, 112(14), 2289–2303. https://doi.org/10.1016/j.neuron.2024.04.017

In [None]:
C = 0.8 * np.random.randn(nNeuron, dLatent) # random projection
b = 0.1 * np.random.randn(nNeuron) + np.log(5)
lam = np.exp(X @ C.T + b)
y = np.random.poisson(lam*dt)

SNRdb = SNR_bound_instantaneous(X, C.T, b)
print(f"{SNRdb:.2f} dB")

In [None]:
cidx1 = np.lexsort((C[:,0], C[:,1]), axis=0)
cidx2 = np.lexsort((C[:,1], C[:,0]), axis=0)

In [None]:
raster = []; rasterSorted1 = []; rasterSorted2 = []
for k in range(nNeuron):
    raster.append(np.nonzero(y[:,k])[0]/nT*T)
    rasterSorted1.append(np.nonzero(y[:, cidx1[k]])[0]/nT*T)
    rasterSorted2.append(np.nonzero(y[:, cidx2[k]])[0]/nT*T)

plt.subplots(1,3, figsize=(10, 3))
plt.subplot(1,3,1)
plt.eventplot(raster, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.yticks([]); plt.title('raster plot'); plt.ylabel('neurons');
plt.subplot(1,3,2)
plt.eventplot(rasterSorted1, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.yticks([]); plt.ylabel('sorted neurons');
plt.subplot(1,3,3)
plt.eventplot(rasterSorted2, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.yticks([]); plt.ylabel('sorted neurons');

### Axis aligned observation

Biologists have long loved neurons that are tuned specifically for a particular feature but not modulated by others.
In our context, the neurons will be either driven by the first dimension or the second dimension of the latent process.
Recent paper argues that this is optimal [Whittington et al. 2022].

 - Whittington, J. C. R., Dorrell, W., Ganguli, S., & Behrens, T. E. J. (2022). Disentangling with Biological Constraints: A Theory of Functional Cell Types. In arXiv [q-bio.NC]. arXiv. http://arxiv.org/abs/2210.01768

In [None]:
bidx = np.random.rand(nNeuron) < 0.5
C[bidx, 0] = 0
C[~bidx, 1] = 0
b[bidx] += 1.5 # boost the firing rate a bit for the 2nd latent dim
lam = np.exp(X @ C.T + b)
y = np.random.poisson(lam*dt)

# for independent populations with independent latents, the SNR is the sum of the SNRs of the two populations
SNRdb1 = SNR_bound_instantaneous(X[:,[0]], C[:,[0]].T, b)  # keep the singular rows or columns
SNRdb2 = SNR_bound_instantaneous(X[:,[1]], C[:,[1]].T, b)
print(f"{SNRdb1:.2f} dB, {SNRdb2:.2f} dB")
SNRdb = SNR_bound_instantaneous(X, C.T, b)
print(f"{SNRdb:.2f} dB = ({SNRdb1:.2f} dB + {SNRdb2:.2f} dB)/2 = {power_to_dB((power_from_dB(SNRdb1)+power_from_dB(SNRdb2))/2):.2f} dB")

In [None]:
cidx = np.lexsort((C[:,1], C[:,0]), axis=0)

raster = []
rasterSorted = []
for k in range(nNeuron):
    raster.append(np.nonzero(y[:,k])[0]/nT*T)
    rasterSorted.append(np.nonzero(y[:,cidx[k]])[0]/nT*T)

plt.subplots(1,2, figsize=(10, 4))
plt.subplot(1,2,1)
plt.eventplot(raster, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.yticks([]); plt.title('raster plot'); plt.ylabel('neurons');
plt.subplot(1,2,2)
plt.eventplot(rasterSorted, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.yticks([]); plt.title('raster plot (again)'); plt.ylabel('sorted neurons');

## What's next?

Now we understand better the generative process of the model. But what we are interested is the opposite direction, that is, how do we infer the model parameters given just the observations (neural data)? This is the statistical inference problem of interest.