## Switching Linear Dynamical Systems fMRI Demo

In [None]:
import os
import pickle
import copy

import autograd.numpy as np
import autograd.numpy.random as npr
npr.seed(12345)

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.colors import ListedColormap
%matplotlib inline

import seaborn as sns
color_names = ["windows blue", "red", "amber", "faded green"]
colors = sns.xkcd_palette(color_names)
sns.set_style("white")
sns.set_context("talk")

cmap = ListedColormap(colors)

import ssm
from ssm.util import random_rotation, find_permutation

import scipy.io
import scipy.stats

# Helper functions for plotting results
def plot_trajectory(z, x, ax=None, ls="-"):
    zcps = np.concatenate(([0], np.where(np.diff(z))[0] + 1, [z.size]))
    if ax is None:
        fig = plt.figure(figsize=(4, 4))
        ax = fig.gca()
    for start, stop in zip(zcps[:-1], zcps[1:]):
        ax.plot(x[start:stop + 1, 0],
                x[start:stop + 1, 1],
                lw=1, ls=ls,
                color=colors[z[start] % len(colors)],
                alpha=1.0)
    return ax

def plot_observations(z, y, ax=None, ls="-", lw=1):

    zcps = np.concatenate(([0], np.where(np.diff(z))[0] + 1, [z.size]))
    if ax is None:
        fig = plt.figure(figsize=(4, 4))
        ax = fig.gca()
    T, N = y.shape
    t = np.arange(T)
    for n in range(N):
        for start, stop in zip(zcps[:-1], zcps[1:]):
            ax.plot(t[start:stop + 1], y[start:stop + 1, n],
                    lw=lw, ls=ls,
                    color=colors[z[start] % len(colors)],
                    alpha=1.0)
    return ax


def plot_most_likely_dynamics(model,
    xlim=(-4, 4), ylim=(-3, 3), nxpts=20, nypts=20,
    alpha=0.8, ax=None, figsize=(3, 3)):
    
    K = model.K
    assert model.D == 2
    x = np.linspace(*xlim, nxpts)
    y = np.linspace(*ylim, nypts)
    X, Y = np.meshgrid(x, y)
    xy = np.column_stack((X.ravel(), Y.ravel()))

    # Get the probability of each state at each xy location
    z = np.argmax(xy.dot(model.transitions.Rs.T) + model.transitions.r, axis=1)

    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)

    for k, (A, b) in enumerate(zip(model.dynamics.As, model.dynamics.bs)):
        dxydt_m = xy.dot(A.T) + b - xy

        zk = z == k
        if zk.sum(0) > 0:
            ax.quiver(xy[zk, 0], xy[zk, 1],
                      dxydt_m[zk, 0], dxydt_m[zk, 1],
                      color=colors[k % len(colors)], alpha=alpha)

    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')

    plt.tight_layout()

    return ax

In [None]:
# Import fMRI data
mat = scipy.io.loadmat('data/logan_tmsPredict_aug2019.mat')
data = mat['logan_timeSeries_roi25']

In [None]:
data.shape

In [None]:
y = data[:,:,0]

In [None]:
# Global parameters
T = data.shape[0]
K = 5
D_obs = data.shape[1]
D_latent = 2
n_scans = data.shape[2]

In [None]:
# Fit an rSLDS with its default initialization, using Laplace-EM with a structured variational posterior

rslds = ssm.SLDS(D_obs, K, D_latent,
             transitions="recurrent_only",
             dynamics="diagonal_gaussian",
             emissions="gaussian_orthog",
             single_subspace=True)
rslds.initialize(y)
q_elbos_lem, q_lem = rslds.fit(y, method="laplace_em",
                               variational_posterior="structured_meanfield",
                               initialize=False, num_iters=3, alpha=0.0)
xhat_lem = q_lem.mean_continuous_states[0]
zhat_lem = rslds.most_likely_states(xhat_lem, y)

# store rslds
rslds_lem = copy.deepcopy(rslds)

In [None]:
# Plot the ELBOs

q_elbos_lem
plt.plot(q_elbos_lem, label="Laplace-EM: Structured Variational Posterior")
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.legend(bbox_to_anchor=(1.0,1.0))
plt.title("Convergence for learning an SLDS")
plt.show()

In [None]:
plt.figure()
ax = plt.subplot()
plot_trajectory(zhat_lem, xhat_lem, ax=ax)
plt.title("Inferred, Laplace-EM")
plt.tight_layout()

In [None]:
plt.figure(figsize=(6,4))
ax = plt.subplot(111)
lim = abs(xhat_lem).max(axis=0) + 1
plot_most_likely_dynamics(rslds_lem, xlim=(-lim[0], lim[0]), ylim=(-lim[1], lim[1]), ax=ax)
plt.title("Inferred Dynamics, Laplace-EM")

In [None]:
plt.figure(figsize=(12,2))
plt.imshow(zhat_lem[None,:], aspect='auto', cmap=cmap)
plt.title('fMRI Inferred States')
plt.xlabel('Frames')
ax = plt.gca()
ax.set_yticks([])
plt.savefig('scan_%i' % (i))
plt.show()

In [None]:
A = rslds.dynamics.As
b = rslds.dynamics.bs
cov = rslds.dynamics.Sigmas
n_gen = 1
n_val_frames = y.shape[0]

mse = np.zeros(n_gen)
mae = np.zeros(n_gen)

for j in range(n_gen):
    x = [xhat_lem[0]]
    
    for i in range(n_val_frames-1):
        k = zhat_lem[i]
        w = np.random.multivariate_normal(np.zeros(D_latent), cov[k])
        x_i = A[k]@x[-1] + b[k] + w
        x.append(x_i)
    
    x_gen = np.vstack(x)
    mse[j] = np.mean((xhat_lem - x_gen)**2)
    mae[j] = np.mean(np.abs(xhat_lem - x_gen))

In [None]:
print(mse)

In [None]:
title_str = ["$x_{%i}$" %i for i in range(D_latent)]
fig, axs = plt.subplots(D_latent, 1, figsize=(14,2*D_latent))
for (d, ax) in enumerate(axs):
    ax.plot(x_gen[:,d] + 4 * d, '-', color=colors[0], label="Generated" if d==0 else None)
    ax.plot(xhat_lem[:,d] + 4 * d, '-', color=colors[2], label="Estimated" if d==0 else None)
    ax.set_yticks([])
    ax.set_title(title_str[d], loc="left", y=0.5, x=-0.03)
axs[0].set_xticks([])
axs[0].legend(loc="upper right")

plt.suptitle("Generated and Estimated Continuous States", va="bottom")
plt.tight_layout()

In [None]:
A = rslds.dynamics.As
b = rslds.dynamics.bs
cov = rslds.dynamics.Sigmas
n_gen = 1
n_val_frames = y.shape[0]

mse = np.zeros(n_gen)
mae = np.zeros(n_gen)

for j in range(n_gen):
    x = [xhat_lem[0]]
    
    for i in range(n_val_frames-1):
        k = zhat_lem[i]
        x_i = A[k]@x[-1] + b[k]
        x.append(x_i)
    
    x_gen = np.vstack(x)
    mse[j] = np.mean((xhat_lem - x_gen)**2)
    mae[j] = np.mean(np.abs(xhat_lem - x_gen))

In [None]:
print(mse)

In [None]:
title_str = ["$x_{%i}$" %i for i in range(D_latent)]
fig, axs = plt.subplots(D_latent, 1, figsize=(14,2*D_latent))
for (d, ax) in enumerate(axs):
    ax.plot(x_gen[:,d] + 4 * d, '-', color=colors[0], label="Generated" if d==0 else None)
    ax.plot(xhat_lem[:,d] + 4 * d, '-', color=colors[2], label="Estimated" if d==0 else None)
    ax.set_yticks([])
    ax.set_title(title_str[d], loc="left", y=0.5, x=-0.03)
axs[0].set_xticks([])
axs[0].legend(loc="upper right")

plt.suptitle("Generated and Estimated Continuous States", va="bottom")
plt.tight_layout()

In [None]:
# Global parameters
T = data.shape[0]
K = 5
D_obs = data.shape[1]
D_latent = 24
n_scans = data.shape[2]

In [None]:
# Fit an rSLDS with its default initialization, using Laplace-EM with a structured variational posterior
states=[]
As = []
bs = []
covs = []
elbos = []

for i in range(n_scans):
    rslds = ssm.SLDS(D_obs, K, D_latent,
                 transitions="recurrent_only",
                 dynamics="diagonal_gaussian",
                 emissions="gaussian_orthog",
                 single_subspace=True)
    rslds.initialize(data[:,:,i])
    q_elbos_lem, q_lem = rslds.fit(data[:,:,i], method="laplace_em",
                                   variational_posterior="structured_meanfield",
                                   initialize=False, num_iters=3, alpha=0.0)
    xhat_lem = q_lem.mean_continuous_states[0]
    zhat_lem = rslds.most_likely_states(xhat_lem, data[:,:,i])

    # store rslds
    rslds_lem = copy.deepcopy(rslds)
    
    As.append(rslds.dynamics.As)
    bs.append(rslds.dynamics.bs)
    covs.append(rslds.dynamics.Sigmas)
    states.append(zhat_lem)
    elbos.append(q_elbos_lem)

In [None]:
# Plot the ELBOs

q_elbos_lem
plt.plot(q_elbos_lem, label="Laplace-EM: Structured Variational Posterior")
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.legend(bbox_to_anchor=(1.0,1.0))
plt.title("Convergence for learning an SLDS")
plt.show()

In [None]:
# Find Eigenvalues and Eigenvevtors of the matrices

eig = [np.linalg.eig(As[i]) for i in range(n_scans)]
e_vals = [eig[i][0] for i in range(n_scans)]
e_vects = [eig[i][1] for i in range(n_scans)]

In [None]:
# Plot Eigenvalues

r_cutoff = 0.5

x = np.real(np.asarray(e_vals).flatten())
y = np.imag(np.asarray(e_vals).flatten())

unit_circle = plt.Circle((0,0), radius=1, color=colors[1], fill=False)
inner_circle = plt.Circle((0,0), radius=r_cutoff, color=colors[2], fill=False)


fig, ax = plt.subplots(figsize=(46,16))
ax.scatter(x, y, s=1, color=colors[0])

ax.axhline(y=0, color = 'k', linewidth=0.5)
ax.axvline(x=0, color = 'k', linewidth=0.5)

ax.add_patch(unit_circle)
ax.add_patch(inner_circle)

ax.set_xlabel('Real')
ax.set_ylabel('Imaginary')
ax.set_aspect('equal')

ax.plot()

plt.suptitle('Eigenvalues')

In [None]:
# Find the intrisic dimensionality of the dynamics

e_vals_magnitudes = np.abs(np.asarray(e_vals).flatten()) # find magnitude of e.vals
n_sig_evals = np.sum(e_vals_magnitudes > r_cutoff) # find number of e.vals with mag > r_cutoff
intrinsic_dim = n_sig_evals / (n_scans * n_disc_states) # find intrinsic dimensionality of dynamics
print('intrinsic dimensionality =',intrinsic_dim)

In [None]:
int(intrinsic_dim)

In [None]:
# Global parameters
T = data.shape[0]
K = 5
D_obs = data.shape[1]
D_latent = 17 #int(intrinsic_dim)
n_scans = data.shape[2]
# D_latent = 18 #Typical output

In [None]:
training_fraction = .6
validation_fraction = .2
training_index = int(training_fraction * n_scans)
validation_index = int(training_index + validation_fraction * n_scans)

training_data = np.swapaxes(np.hstack(data[:,:,:training_index]),0,1)
validation_data = np.swapaxes(np.hstack(data[:,:,training_index:validation_index]),0,1)
test_data = np.swapaxes(np.hstack(data[:,:,validation_index:]),0,1)

In [None]:
training_data.shape

In [None]:
# Fit an rSLDS with its default initialization, using Laplace-EM with a structured variational posterior

rslds = ssm.SLDS(D_obs, K, D_latent,
             transitions="recurrent_only",
             dynamics="diagonal_gaussian",
             emissions="gaussian_orthog",
             single_subspace=True)
rslds.initialize(training_data)
q_elbos_lem, q_lem = rslds.fit(training_data, method="laplace_em",
                               variational_posterior="structured_meanfield",
                               initialize=False, num_iters=3, alpha=0.0)
xhat_lem = q_lem.mean_continuous_states[0]
zhat_lem = rslds.most_likely_states(xhat_lem, training_data)

# store rslds
rslds_lem = copy.deepcopy(rslds)

In [None]:
# Plot the ELBOs

q_elbos_lem
plt.plot(q_elbos_lem, label="Laplace-EM: Structured Variational Posterior")
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.legend(bbox_to_anchor=(1.0,1.0))
plt.title("Convergence for learning an SLDS")
plt.show()

In [None]:
# Compute the approximate posterior over latent and continuous
# states for the new data under the current model parameters.
elbos, posterior = rslds.approximate_posterior(validation_data,
                                              method="laplace_em",
                                              variational_posterior="structured_meanfield",
                                              num_iters=3)

# Verify that the ELBO increases during fitting. We don't expect a substantial increase:
# we are updating the estimate of the latent states but we are not changing model params.
plt.plot(elbos)
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.show()

In [None]:
x_est = posterior.mean_continuous_states[0]
z_est = rslds.most_likely_states(x_est, validation_data)

In [None]:
def mse(ts1, ts2):
    return(np.mean((ts1 - ts2)**2))

def mae(ts1, ts2):
    return(np.mean(np.abs(ts1 - ts2)))
    

In [None]:
# Model + Noise

A = rslds.dynamics.As
b = rslds.dynamics.bs
cov = rslds.dynamics.Sigmas
n_gen = 1
n_val_frames = validation_data.shape[0]

MSE = np.zeros(n_gen)
MAE = np.zeros(n_gen)

for j in range(n_gen):
    x = [x_est[0]]
    
    for i in range(n_val_frames-1):
        k = z_est[i]
        w = np.random.multivariate_normal(np.zeros(D_latent), cov[k])
        x_i = A[k]@x[-1] + b[k] + w
        x.append(x_i)
    
    x_gen = np.vstack(x)
    MSE[j] = mse(x_est, x_gen)
    MAE[j] = mae(x_est, x_gen)

In [None]:
print('Mean Squared Error:' ,MSE[0])

In [None]:
title_str = ["$x_{%i}$" %i for i in range(D_latent)]
fig, axs = plt.subplots(D_latent,1, figsize=(14*30,2*D_latent))
for (d, ax) in enumerate(axs):
    ax.plot(x_gen[:,d] + 4 * d, '-', color=colors[0], label="Generated" if d==0 else None)
    ax.plot(x_est[:,d] + 4 * d, '-', color=colors[2], label="Estimated" if d==0 else None)
    ax.set_yticks([])
    ax.set_title(title_str[d], loc="left", y=0.5, x=-0.03)
axs[0].set_xticks([])
axs[0].legend(loc="upper right")

plt.suptitle("Generated and Estimated Continuous States", va="bottom")
plt.tight_layout()

In [None]:
cum_err_mn = [mse(x_est[:t], x_gen[:t]) for t in range(1,100)]
cum_err_mn_prime = np.gradient(cum_err_mn)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16,4))
axs[0].plot(cum_err_mn)
axs[1].plot(cum_err_mn_prime)
axs[0].set_xlabel('Time Step')
axs[1].set_xlabel('Time Step')
axs[0].set_ylabel('MSE')
axs[1].set_ylabel('d/dt MSE')

plt.suptitle('Model + Noise')

In [None]:
# Model Only

A = rslds.dynamics.As
b = rslds.dynamics.bs
cov = rslds.dynamics.Sigmas
n_gen = 1
n_val_frames = validation_data.shape[0]

MSE = np.zeros(n_gen)
MAE = np.zeros(n_gen)

for j in range(n_gen):
    x = [x_est[0]]
    
    for i in range(n_val_frames-1):
        k = z_est[i]
        x_i = A[k]@x[-1] + b[k]
        x.append(x_i)
    
    x_gen = np.vstack(x)
    MSE[j] = mse(x_est, x_gen)
    MAE[j] = mae(x_est, x_gen)

In [None]:
correlations = [np.correlate(x_gen[i], x_est[i])/np.sqrt(np.mean((x_gen[i]-x_est[i])**2)) for i in range(D_latent)]

In [None]:
correlations

In [None]:
plt.imshow(np.corrcoef(x_est.T, x_gen.T), cmap='coolwarm')
plt.colorbar()

In [None]:
print('Mean Squared Error:' ,MSE[0])

In [None]:
title_str = ["$x_{%i}$" %i for i in range(D_latent)]
fig, axs = plt.subplots(D_latent,1, figsize=(14*30,2*D_latent))
for (d, ax) in enumerate(axs):
    ax.plot(x_gen[:,d] + 4 * d, '-', color=colors[0], label="Generated" if d==0 else None)
    ax.plot(x_est[:,d] + 4 * d, '-', color=colors[2], label="Estimated" if d==0 else None)
    ax.set_yticks([])
    ax.set_title(title_str[d], loc="left", y=0.5, x=-0.03)
axs[0].set_xticks([])
axs[0].legend(loc="upper right")

plt.suptitle("Generated and Estimated Continuous States", va="bottom")
plt.tight_layout()

In [None]:
cum_err_m = [mse(x_est[:t], x_gen[:t]) for t in range(1,100)]
cum_err_m_prime = np.gradient(cum_err_m)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16,4))
axs[0].plot(cum_err_m)
axs[1].plot(cum_err_m_prime)
axs[0].set_xlabel('Time Step')
axs[1].set_xlabel('Time Step')
axs[0].set_ylabel('MSE')
axs[1].set_ylabel('d/dt MSE')

plt.suptitle('Model Only')

In [None]:
# Noise Only

A = rslds.dynamics.As
b = rslds.dynamics.bs
cov = rslds.dynamics.Sigmas
n_gen = 1
n_val_frames = validation_data.shape[0]

MSE = np.zeros(n_gen)
MAE = np.zeros(n_gen)

for j in range(n_gen):
    x = [x_est[0]]
    
    for i in range(n_val_frames-1):
        k = z_est[i]
        w = np.random.multivariate_normal(np.zeros(D_latent), cov[k])
        x_i = w
        x.append(x_i)
    
    x_gen = np.vstack(x)
    MSE[j] = mse(x_est, x_gen)
    MAE[j] = mae(x_est, x_gen)

In [None]:
print('Mean Squared Error:' ,MSE[0])

In [None]:
cum_err_n = [mse(x_est[:t], x_gen[:t]) for t in range(1,100)]
cum_err_n_prime = np.gradient(cum_err_n)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16,4))
axs[0].plot(cum_err_n)
axs[1].plot(cum_err_n_prime)
axs[0].set_xlabel('Time Step')
axs[1].set_xlabel('Time Step')
axs[0].set_ylabel('MSE')
axs[1].set_ylabel('d/dt MSE')

plt.suptitle('Noise Only')

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(20,6))
axs[0].plot(cum_err_n, label='Noise Only')
axs[0].plot(cum_err_mn, label='Model + Noise')
axs[0].plot(cum_err_m, label='Model Only')
axs[1].plot(cum_err_n_prime, label='Noise Only')
axs[1].plot(cum_err_mn_prime, label='Model + Noise')
axs[1].plot(cum_err_m_prime, label='Model Only')
axs[0].legend()
axs[1].legend()
axs[0].set_xlabel('Time Step')
axs[1].set_xlabel('Time Step')
axs[0].set_ylabel('MSE')
axs[1].set_ylabel('d/dt MSE')

plt.suptitle('MSE Comparison')

In [None]:
plt.figure(figsize=(12*30,2))
plt.imshow(z_est[None,:], aspect='auto', cmap=cmap)
plt.title('fMRI Inferred States')
plt.xlabel('Frames')
ax = plt.gca()
ax.set_yticks([])
plt.savefig('scan_%i' % (i))
plt.show()

In [None]:
difference = [z_est[i] == z_est[i+1] for i in range(len(z_est)-1)]

In [None]:
runs = []
counter = 0
for d in difference:
    if d == True:
        counter += 1
    else:
        runs.append(counter)
        counter = 0
        
n = len(runs)
runs = [runs[i] + 1 for i in range(n)]

In [None]:
 see plt.hist(runs)
plt.title('Run Length (time step)')

In [None]:
m = np.max(runs)
weights = [sum([runs[i] == j for i in range(n)]) for j in range(1,m+1)]
mass = np.asarray(range(1,m+1)) * np.asarray(weights)

In [None]:
plt.scatter(range(1,m+1), mass)
plt.title('Time Spent in States by Length')