# Inferring Latent Neural States

Let's analyze some neural data using popular dimensionality reduction methods.
We will use the folloiwng methods with progressively better modeling assumptions.
- PCA (Principal Components Analysis)
  - Gaussian observation
  - Independent identical gaussian noise per neuron
- GPFA (Gaussian Process Factor Analysis)
  - Gaussian observation
  - Unequal magnitude of noise per neuron
  - Smoothness assumption on the latent trajectory
- vLGP (varational latent Gaussian Process)
  - Poisson observation
  - Unequal magnitude of noise per neuron
  - Smoothness assumption on the latent trajectory

## Load Monkey delayed-reaching task data

In [None]:
import h5py
import numpy as np
import pickle
import matplotlib.pyplot as plt
from einops import rearrange
import scipy.ndimage

baseDir = 'mc_maze/data/'
trial_info_save_path = baseDir + 'info_per_trial_{}.pkl'

with (open(trial_info_save_path.format("train"), "rb")) as openfile:
    trial_info_train = pickle.load(openfile)
    
with (open(trial_info_save_path.format("val"), "rb")) as openfile:
    trial_info_val = pickle.load(openfile)
    
m5 = h5py.File(baseDir + 'monkey.hdf5', 'r')

In [None]:
nTrial = m5['pos-train'].shape[0]
nT = m5['pos-train'].shape[1]
nNeuron = m5['spk-train'].shape[2]
dt = 0.005  # 5 ms bin
T = dt * nT

In [None]:
y = m5['spk-train'][()]

In [None]:
kTrial = 100
raster = []
for kNeuron in range(nNeuron):
    raster.append(np.nonzero(y[kTrial,:,kNeuron])[0]/nT*T)
plt.eventplot(raster, lw=0.5, color='k', label='spikes')
plt.xlim(0, T); plt.xlabel('time'); plt.title('raster plot'); plt.ylabel('neurons');

In [None]:
for i in range(100,120):
    plt.plot(m5['pos-train'][i,:,0], m5['pos-train'][i,:,1])
    
plt.xlabel('X hand position'); plt.ylabel('Y hand position'); plt.grid(); plt.title('center out reaching trajectory')

## PCA

In order to perform PCA, we first concatenate the the trials such that the data is of the form (trial x time) x neurons. We then smooth the data with a gaussian kernel.

In [None]:
# smoothing data with a gaussian kernel
data_stacked = rearrange(y, 'trial time neurons -> (trial neurons) time')
data_smooth = scipy.ndimage.gaussian_filter1d(input = data_stacked, sigma=0.050/dt, axis=1)
data_smooth = rearrange(data_smooth, '(trial neurons) time -> (trial time) neurons', trial=nTrial, neurons=nNeuron)

In [None]:
data_centered = data_smooth - np.mean(data_smooth, axis=0)

In [None]:
tidx = slice(nT,2*nT)
fig, ax = plt.subplots(1, 1, figsize =(10, 5))
tr = np.arange(0, T, dt)
ax.plot(tr, data_centered[tidx, 0:10]);
ax.set_xlabel("time");

In [None]:
# PCA using SVD
u, s, vh = np.linalg.svd(data_centered, full_matrices=False)
u.shape, s.shape

In [None]:
norm_sv = s**2/np.sum(s**2)
top2sv = np.sum(norm_sv[:2])
print("Total observations explained by the first two principal components: {0:.2f}%".format(top2sv*100))

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 3))

axs[0].plot(norm_sv * 100, 'o-')
axs[1].plot(norm_sv.cumsum() * 100, 'o-')
axs[1].set_ylim([0, 100])
axs[2].plot(20*np.log10(norm_sv), 'o-')

[(axs[k].grid(), axs[k].set_title(f''), axs[k].set_xlabel("PC (ordered)")) for k in range(3)]
axs[0].set_ylabel("Variance explained per PC ($\%$)"); 
axs[1].set_ylabel("Cumulative variance explained ($\%$)");
axs[2].set_ylabel("Variance explained (dB)"); 
fig.suptitle("What's the dimensionality? Inspecting variance explained by each PC defined dim");

In [None]:
# visualizing top two PCs
top2u = u[:, :2]
X_hat_PCA = rearrange(top2u, '(trial time) pcs -> trial time pcs', trial=nTrial)

In [None]:
for k in range(10):
    plt.plot(X_hat_PCA[k,:,0],  X_hat_PCA[k,:,1])
    plt.plot(X_hat_PCA[k,-1,0], X_hat_PCA[k,-1,1], 'o')
    
plt.xlabel('PC1'); plt.ylabel('PC2'); plt.grid(); plt.title('2D slice')

## GPFA

We are using the implementation included in the Elephant package:
https://elephant.readthedocs.io/en/latest/reference/gpfa.html

 - Yu, B. M., Cunningham, J. P., Santhanam, G., Ryu, S. I., Shenoy, K. V., & Sahani, M. (2009). Gaussian-process factor analysis for low-dimensional single-trial analysis of neural population activity. Journal of Neurophysiology, 102(1), 614–635.

In [None]:
from elephant.gpfa import GPFA
import neo
import quantities as pq

In [None]:
# ---- Convert to neo.SpikeTrains ---- #
def array_to_spiketrains(array, bin_size):
    """Convert B x T x N spiking array to list of list of SpikeTrains"""
    stList = []

    for trial in range(array.shape[0]):
        trialList = []
        for channel in range(array.shape[2]):
            times = np.nonzero(array[trial, :, channel])[0]
            counts = array[trial, times, channel].astype(int)
            times = np.repeat(times, counts)
            st = neo.SpikeTrain(times*bin_size, t_stop=array.shape[1]*bin_size)
            trialList.append(st)
        stList.append(trialList)
    return stList

Y_st_train = array_to_spiketrains(y, dt*pq.s)

In [None]:
# ---- Run GPFA ---- #
nLatents = 3
gpfa = GPFA(bin_size=(dt * pq.s), x_dim=nLatents)
gpfa_val_result = gpfa.fit_transform(Y_st_train)

length_scales = gpfa.params_estimated['gamma']

In [None]:
X_hat_GPFA = rearrange(np.stack(gpfa_val_result, 0), 'trials lat time -> trials time lat')

In [None]:
for k in range(10):
    plt.plot(X_hat_GPFA[k,:,1], X_hat_GPFA[k,:,2])

## vLGP

 - Zhao, Y., & Park, I. M. (2017). Variational Latent Gaussian Process for Recovering Single-Trial Dynamics from Population Spike Trains. Neural Computation, 29(5), 1293–1316. arXiv.
 - Nam, H. (2015). Poisson Extension of Gaussian Process Factor Analysis for Modelling Spiking Neural Populations (J. Macke (ed.)). Eberhard-Karls-Universität Tübingen.

In [None]:
from vlgpax.kernel import RBF
from vlgpax import Session, vi

In [None]:
session = Session(dt)

# Session is the top level container of data. Two arguments, binsize and unit of time, are required at construction.
for i, yy in enumerate(y):
    session.add_trial(i + 1, y = yy)  # Add trials to the session.

# Build the model
kernel = RBF(scale = 1., lengthscale = 25 * dt)  # RBF kernel

In [None]:
random_seed = 20221011
np.random.seed(random_seed)
session, params = vi.fit(session, n_factors=nLatents, kernel=kernel, seed=random_seed, max_iter=50)

In [None]:
X_hat_VLGP = rearrange(session.z, '(trials time) lat -> trials time lat', time=nT)

In [None]:
plt.subplots(1,3,figsize=(12,4))

plt.subplot(1,3,1)
for k in range(10):
    plt.plot(X_hat_PCA[k,:,0],  X_hat_PCA[k,:,1])
    plt.plot(X_hat_PCA[k,-1,0], X_hat_PCA[k,-1,1], 'o')
plt.xticks([]); plt.yticks([]); plt.gca().axis('equal')
plt.title('PCA')
    
plt.subplot(1,3,2)
for k in range(10):
    plt.plot(X_hat_GPFA[k,:,0],  X_hat_GPFA[k,:,1])
    plt.plot(X_hat_GPFA[k,-1,0], X_hat_GPFA[k,-1,1], 'o')
plt.xticks([]); plt.yticks([]); plt.gca().axis('equal')
plt.title("GPFA");
        
plt.subplot(1,3,3)
for k in range(10):
    plt.plot(X_hat_VLGP[k,:,0],  X_hat_VLGP[k,:,1])
    plt.plot(X_hat_VLGP[k,-1,0], X_hat_VLGP[k,-1,1], 'o')
plt.xticks([]); plt.yticks([]); plt.gca().axis('equal')
plt.title("vLGP");

In [None]:
m5.close() # closing the data file