In [None]:
import scipy.io as sio
import autograd.numpy as np
import autograd.numpy.random as npr
npr.seed(0)

import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
sns.set_style("dark")
sns.set_context("talk")

color_names = ["windows blue",
               "red",
               "amber",
               "faded green",
               "dusty purple",
               "orange",
               "clay",
               "pink",
               "greyish",
               "mint",
               "light cyan",
               "steel blue",
               "forest green",
               "pastel purple",
               "salmon",
               "dark brown"]

colors = sns.xkcd_palette(color_names)

import ssm
from ssm.util import random_rotation, find_permutation

In [None]:
def format_data(input):
    datas = []
    for i in range(input.shape[0]):
        input_cur = np.squeeze(input[i,:,:])

        # To use the Poisson observation model, we must also
        # convert our arrays to be integer types.
        datas.append(np.asarray(input_cur, dtype=int))
    return datas

mat = sio.loadmat('rawdata.mat', squeeze_me=True, variable_names={'spdata','inpdata'})
data_sp = format_data(mat['spdata'])
data_inp = format_data(mat['inpdata'])

binsize = 0.01
n_trials = mat['spdata'].shape[0]
n_timebins = mat['spdata'].shape[1]
n_neurons = mat['spdata'].shape[2]

inp_dim = mat['inpdata'].shape[2]
state_dim = 10
disc_dim = 3

In [None]:
fit_lds = ssm.LDS(n_neurons, state_dim, 
                   M = inp_dim,
                   transitions="sticky",
                   emissions="poisson_orthog", 
                   emission_kwargs=dict(link="softplus", bin_size=binsize))

# fit_lds = ssm.SLDS(n_neurons, disc_dim, state_dim, 
#                    M = inp_dim,
#                    transitions="sticky",
#                    emissions="poisson_orthog", 
#                    emission_kwargs=dict(link="softplus", bin_size=binsize))

fit_lds.initialize(data_sp)

elbos, q = fit_lds.fit(data_sp, inputs = data_inp, 
                       method="laplace_em",
                       variational_posterior="structured_meanfield",
                       initialize=False,
                       num_iters=10)

In [None]:
# Plot the ELBOs
plt.plot(elbos[0:10], label="Laplace-EM")
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.legend()

In [None]:
plt.figure()
plt.imshow(fit_lds.emissions.Cs[0,:,:],aspect='auto')
plt.colorbar()

In [None]:
tr = 10

# Get the posterior mean of the continuous states
q_x = q.mean_continuous_states[tr]

# Smooth the data under the variational posterior
yhat = fit_lds.smooth(q_x, data_sp[tr], input=data_inp[tr])
zhat = fit_lds.most_likely_states(q_x, data_sp[tr], input=data_inp[tr])

plt.figure(figsize=(8,4))
for d in range(state_dim):
    plt.plot(yhat[:,d], '-', color=colors[d], label="Estimated States" if d==0 else None)
plt.ylabel("$x$")
plt.xlabel("time")
# plt.legend(loc='upper right')
plt.title("Estimated States")
plt.show()

plt.figure(figsize=(8,4))
plt.plot(zhat, '-', color=colors[d], label="Estimated States" if d==0 else None)
plt.ylabel("$x$")
plt.xlabel("time")
# plt.legend(loc='upper right')
plt.title("Estimated States")
plt.show()


plt.figure(figsize=(8,4))
plt.plot(yhat[0,:], yhat[1,:])
plt.ylabel("$x0$")
plt.xlabel("$x1$")
# plt.legend(loc='upper right')
plt.title("Estimated States")
plt.show()


In [None]:
z_sim, x_sim, y_sim = fit_lds.sample(n_timebins, data_inp[tr])

In [None]:
plt.figure()
plt.imshow(y_sim.T,aspect='auto',vmax = 1)
plt.colorbar()

In [None]:
plt.figure()
plt.imshow(data_sp[tr].T, aspect='auto',vmax =1)
plt.colorbar()

In [None]:
cnum = 45
plt.figure()
plt.plot(y_sim[:,cnum])
plt.plot(data_sp[tr][:,cnum])