In [None]:
import jax
import jax.numpy as jnp
import jax.random as jxr
import matplotlib.pyplot as plt
import seaborn as sns

import sys
sys.path.append('../')
import utils
from models import wGPLDS, WeightSpaceGaussianProcess

In [None]:
#! Dummy data, change to yours

B_batches, T, N_neurons, M_conditions = 100, 40, 89, 2
Y = jnp.empty((B_batches, T, N_neurons))
U = jnp.empty((B_batches, T, M_conditions))

# Condition spaces
t_range = jnp.linspace(0, 0.1, 10) # To be taken from U, make sure consistent accros batches
coherencies = jnp.linspace(-1, 1, 10) # To be taken from U, make sure consistent accros batches

# Partition data into training and test sets
partition = 0.8
n_trials = Y.shape[0]
n_train = int(n_trials * partition)
n_test = n_trials - n_train

# Randomly shuffle trials and pick training and test sets
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
perm = jax.random.permutation(subkey, n_trials)

Y = Y[perm]
U = U[perm]
Y_train = Y[:n_train]
U_train = U[:n_train]
Y_test = Y[n_train:]
U_test = U[n_train:]

In [3]:
# Define model

# Hyperparams
latent_dim = 5
sigma = 1.0 # fix to scale of data. I usually do Y = (Y - Y.mean()) / Y.std() 
kappa = 0.2 # In that range is a good start

# Weight space GP priors
# Pad with 3 lengthscales each side of the range to make conditions not-periodic
t_period = (t_range.max() - t_range.min()) + 6 * kappa 
c_period = (coherencies.max() - coherencies.min()) + 6 * kappa
torus_basis_funcs = utils.T2_basis(5, sigma, kappa, t_period, c_period)
# torus_basis_funcs = utils.T1_basis(5, _sigma, _kappa, c_period) # If using one condition

n_neurons = Y.shape[-1]
A_prior = WeightSpaceGaussianProcess(torus_basis_funcs, D1=latent_dim, D2=latent_dim)
b_prior = WeightSpaceGaussianProcess(torus_basis_funcs, D1=latent_dim, D2=1)
m0_prior = WeightSpaceGaussianProcess(torus_basis_funcs, D1=latent_dim, D2=1)

# Instanciate model
model = wGPLDS(
    wgps={'A': A_prior, 'b': b_prior, 'C': None, 'm0': m0_prior,}, 
    state_dim=latent_dim, 
    emission_dim=n_neurons,
    )

# # Plot prior samples
# fig, ax = plt.subplots(figsize=[4,3])
# for i in coherence_indices[0][:3]:
#     ax.plot(A_prior.sample(jxr.PRNGKey(0), U[i])[:,0,0], c='tab:blue');
#     ax.plot(A_prior.sample(jxr.PRNGKey(0), U[i])[:,0,1], c='tab:orange');
# ax.set_title('A prior samples');

In [None]:
from models import ParamswGPLDS

num_timesteps = Y.shape[1]

# Initialize
seed = 0
A_key, b_key, C_key, m0_key = jxr.split(jxr.PRNGKey(seed), 4)
initial_params = ParamswGPLDS(
    dynamics_gp_weights = A_prior.sample_weights(A_key),
    Q = jnp.eye(latent_dim),
    R = jnp.eye(n_neurons),
    m0 = jnp.zeros(latent_dim),
    S0 = jnp.eye(latent_dim),
    Cs = jnp.tile(jxr.normal(C_key, (n_neurons, latent_dim)), (num_timesteps, 1, 1)),
    emissions_gp_weights = None,
    bs = jnp.zeros((num_timesteps-1, latent_dim)),
    bias_gp_weights = b_prior.sample_weights(b_key),
    m0_gp_weights = m0_prior.sample_weights(m0_key),
)

# Fit model
params, log_probs = model.fit_em(initial_params, emissions=Y_train, conditions=U_train, num_iters=50)

# Evaluate
test_ll = model.marginal_log_lik(params, emissions=Y_test, conditions=U_test)
print(f"Test log likelihood: {test_ll}")

# Show results
fig, ax = plt.subplots(figsize=[4,3])
ax.plot(log_probs)
ax.set_ylabel('log prob')
ax.set_xlabel('Epochs');

In [None]:
# Plot reconstruction

def reconstruct_y(b):
    '''Reconstucts the firing rate of the model for a given trial'''
    _, _, (smoothed_means, _, _) = model.smoother(params, Y[b], U[b])
    _reconstructed_ys = jnp.einsum('tnl,tl->tn', params.Cs, smoothed_means)
    return _reconstructed_ys

reconstructed_ys = jax.vmap(reconstruct_y)(jnp.arange(len(Y)))
fig, ax = plt.subplots();
for i in range(n_neurons):
    ax.plot(t_range, Y.mean(0)[:,i], c='k', alpha=0.1)
    ax.plot(t_range, reconstructed_ys.mean(0)[:,i], c='tab:orange', alpha=0.1)

ax.plot(t_range, Y.mean(0).mean(1), c='k', label='Data')
ax.plot(t_range, reconstructed_ys.mean(0).mean(1), c='tab:orange', label='Model')
ax.set_ylabel('Firing rate')
ax.set_xlabel(f'Time from {event} (s)')
ax.set_title('Reconstruction per neuron, averaged over trials')
ax.axvline(x=0, c='tab:gray', zorder=-1);
ax.legend();