## Switching Linear Dynamical Systems fMRI Demo

In [None]:
import autograd.numpy as np
import autograd.numpy.random as npr
npr.seed(0)

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

%matplotlib inline

import seaborn as sns

sns.set_style("white")
sns.set_context("talk")

color_names = ["windows blue",
               "red",
               "amber",
               "faded green",
               "dusty purple",
               "orange",
               "clay",
               "pink",
               "greyish",
               "mint",
               "cyan",
               "steel blue",
               "forest green",
               "pastel purple",
               "salmon",
               "dark brown"]

colors = sns.xkcd_palette(color_names)
cmap = ListedColormap(colors)

import ssm
from ssm.util import random_rotation, find_permutation
from ssm.plots import plot_dynamics_2d

save_figures = False

import scipy.io

In [None]:
# Import fMRI data
mat = scipy.io.loadmat('data/logan_tmsPredict_aug2019.mat')
data = mat['logan_timeSeries_roi25']

In [None]:
data.shape

In [None]:
# Set the parameters of the SLDS
time_bins = data.shape[0]    # number of time bins
n_disc_states = 5       # number of discrete states
latent_dim =  18       # number of latent dimensions
emissions_dim = data.shape[1]      # number of observed dimensions
n_scans = data.shape[2]

cmap_limited = ListedColormap(colors[0:n_disc_states])

In [None]:
# Find peaks and troughs of the data

maximum = [np.max(data[:,:,i]) for i in range(n_scans)]
minimum = [np.min(data[:,:,i]) for i in range(n_scans)]

In [None]:
# Plot max and min peaks and troughs 

fig, axs = plt.subplots(1,2,figsize=(10,5))
axs[0].hist(maximum)
axs[0].set_title('Maximum')
axs[1].hist(minimum)
axs[1].set_title('Minimum')
plt.suptitle('fMRI Amplitude Distribution')
plt.show()

**Fit using Laplace-EM**

In [None]:
# Fit SLDS Models to Each Time Series

states=[]
As = []
bs = []
covs = []
elbos = []


for i in range(n_scans):
    print("Fitting SLDS with Laplace-EM")

    # Create the model and initialize its parameters
    slds = ssm.SLDS(emissions_dim, n_disc_states, latent_dim, emissions="gaussian_orthog")

    # Fit the model using Laplace-EM with a structured variational posterior
    q_lem_elbos, q_lem = slds.fit(data[:,:,i], method="laplace_em",
                                   variational_posterior="structured_meanfield",
                                   num_iters=3, alpha=0.0)

    # Get the posterior mean of the continuous states
    q_lem_x = q_lem.mean_continuous_states[0]

    # Find most likely states
    q_lem_z = slds.most_likely_states(q_lem_x, data[:,:,i])

    # Smooth the data under the variational posterior
    q_lem_y = slds.smooth(q_lem_x, data[:,:,i])
    

    
    As.append(slds.dynamics.As)
    bs.append(slds.dynamics.bs)
    covs.append(slds.dynamics.Sigmas)
    states.append(q_lem_z)
    elbos.append(q_lem_elbos)

In [None]:
# Plot the ELBOs

for i in range(n_scans):
    q_lem_elbos = elbos[i]
    plt.plot(q_lem_elbos, label="Laplace-EM: Structured Mean-Field Posterior")
    plt.xlabel("Iteration")
    plt.ylabel("ELBO")
    plt.legend(bbox_to_anchor=(1.0,1.0))
    plt.title("Convergence for learning an SLDS")
    plt.show()

In [None]:
# Find Eigenvalues and Eigenvevtors of the matrices

eig = [np.linalg.eig(As[i]) for i in range(n_scans)]
e_vals = [eig[i][0] for i in range(n_scans)]
e_vects = [eig[i][1] for i in range(n_scans)]

In [None]:
# Plot Eigenvalues

r_cutoff = 0.5

x = np.real(np.asarray(e_vals).flatten())
y = np.imag(np.asarray(e_vals).flatten())

unit_circle = plt.Circle((0,0), radius=1, color=colors[1], fill=False)
inner_circle = plt.Circle((0,0), radius=r_cutoff, color=colors[2], fill=False)


fig, ax = plt.subplots(figsize=(46,16))
ax.scatter(x, y, s=1, color=colors[0])

ax.axhline(y=0, color = 'k', linewidth=0.5)
ax.axvline(x=0, color = 'k', linewidth=0.5)

ax.add_patch(unit_circle)
ax.add_patch(inner_circle)

ax.set_xlabel('Real')
ax.set_ylabel('Imaginary')
ax.set_aspect('equal')

ax.plot()

plt.suptitle('Eigenvalues')

In [None]:
# Find the intrisic dimensionality of the dynamics

e_vals_magnitudes = np.abs(np.asarray(e_vals).flatten()) # find magnitude of e.vals
n_sig_evals = np.sum(e_vals_magnitudes > r_cutoff) # find number of e.vals with mag > r_cutoff
intrisic_dim = n_sig_evals / (n_scans * n_disc_states) # find intrinsic dimensionality of dynamics
print('intrisic dimensionality =',intrisic_dim)

# Visualize Inferred Latent States

In [None]:
# Plot inferred states for Each Time Series 

for i in range(n_scans):
    plt.figure(figsize=(12,2))
    plt.imshow(states[i][None,:], aspect='auto', cmap=cmap_limited)
    plt.title('fMRI Inferred States')
    plt.xlabel('Frames')
    ax = plt.gca()
    ax.set_yticks([])
    #plt.savefig('scan_%i' % (i))
    plt.show()

# Inference on unseen data
After learning a model from data, a common use-case is to compute the distribution over latent states given some new observations. For example, in the case of a simple LDS, we could use the Kalman Smoother to estimate the latent state trajectory given a set of observations. 

In the case of an SLDS (or Recurrent SLDS), the posterior over latent states can't be computed exactly. Instead, we need to live with a variational approximation to the true posterior. SSM allows us to compute this approximation using the `SLDS.approximate_posterior()` method. 

In the below example, we generate some new data from the true model. We then use the `approximate_posterior()` function to estimate the continuous and discrete states. 

In [None]:
# Use data that was not used for fitting
# Or cheat and do it anyway for practice ;)
validation = data[:,:,-1]

# Compute the approximate posterior over latent and continuous
# states for the new data under the current model parameters.
elbos, posterior = slds.approximate_posterior(validation,
                                              method="laplace_em",
                                              variational_posterior="structured_meanfield",
                                              num_iters=3)

# Verify that the ELBO increases during fitting. We don't expect a substantial increase:
# we are updating the estimate of the latent states but we are not changing model params.
plt.plot(elbos)
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.show()

**Estimating Latent States**  
  
`posterior` is now an `ssm.variational.SLDSStructuredMeanFieldVariationalPosterior` object. Using this object, we can estimate the continuous and discrete states just like we did after calling the fit function.

In the below cell, we get the estimated continuous states as follows:
```python
posterior_x = posterior.mean_continuous_states[0]
```
This line uses the `mean_continuous_states` property of the posterior object, which returns a list, where each entry of the list corresponds to a single trial of data. Since we have only passed in a single trial the list will have length 1, and we take the first entry.

We then permute the discrete and continuous states to best match the ground truth. This is for aesthetic purposes when plotting. The following lines compute the best permutation which match the predicted states (`most_likely`) to the ground truth discrete states (`data_z`). We then permute the states of the SLDS accordingly:
```python

most_likely = slds.most_likely_states(posterior_x, data)
perm = find_permutation(data_z, most_likely)
slds.permute(perm)
z_est = slds.most_likely_states(posterior_x, data)

```

In [None]:
x_est = posterior.mean_continuous_states[0]
z_est = slds.most_likely_states(x_est, validation)

In [None]:
title_str = ["$x_1$", "$x_2$"]
fig, axs = plt.subplots(2,1, figsize=(14,4))
for (d, ax) in enumerate(axs):
    ax.plot(x_est[:,d] + 4 * d, '-', color=colors[2], label="Laplace-EM" if d==0 else None)
    ax.set_yticks([])
    ax.set_title(title_str[d], loc="left", y=0.5, x=-0.03)
axs[0].set_xticks([])
axs[0].legend(loc="upper right")

plt.suptitle("Estimated Continuous States", va="bottom")
plt.tight_layout()

In [None]:
model_z, model_x, model_y = slds.sample(time_bins)

In [None]:
title_str = ["$x_1$", "$x_2$"]
fig, axs = plt.subplots(2,1, figsize=(14,4))
for (d, ax) in enumerate(axs):
    ax.plot(model_x[:,d] + 4 * d, '-', color=colors[0], label="'Naively Generated'" if d==0 else None)
    ax.plot(x_est[:,d] + 4 * d, '-', color=colors[2], label="Estimated" if d==0 else None)
    ax.set_yticks([])
    ax.set_title(title_str[d], loc="left", y=0.5, x=-0.03)
axs[0].set_xticks([])
axs[0].legend(loc="upper right")

plt.suptitle(" 'Naively Generated' and Estimated Continuous States", va="bottom")
plt.tight_layout()

Generate by:
    
$x_{n+1, k} = A_kx_n + b + w$

In [None]:
# Predict n_gen continuous states by switching models simultaenously with estimated states

A = slds.dynamics.As
b = slds.dynamics.bs
cov = slds.dynamics.Sigmas
n_gen = 100

mse = np.zeros(n_gen)
mae = np.zeros(n_gen)

for j in range(n_gen):
    x = [x_est[0]]
    
    for i in range(time_bins-1):
        k = z_est[i]
        w = np.random.multivariate_normal(np.zeros(latent_dim), cov[k])
        x_i = A[k]@x[-1] + b[k] + w
        x.append(x_i)
    
    x_gen = np.vstack(x)
    mse[j] = np.mean((x_est - x_gen)**2)
    mae[j] = np.mean(np.abs(x_est - x_gen))

In [None]:
# Find and plot errors from each of the runs as compared to the estimated continuous states 

fig, axs = plt.subplots(1, 2, figsize=(8,4))
axs[0].hist(mse)
axs[0].set_title('Mean Squared Error')
axs[1].hist(mae)
axs[1].set_title('Mean Absolute Error')

plt.suptitle("Error Distributions")
plt.tight_layout()

In [None]:
title_str = ["$x_{%i}$" %i for i in range(latent_dim)]
fig, axs = plt.subplots(latent_dim,1, figsize=(14,2*latent_dim))
for (d, ax) in enumerate(axs):
    ax.plot(x_gen[:,d] + 4 * d, '-', color=colors[0], label="Generated" if d==0 else None)
    ax.plot(x_est[:,d] + 4 * d, '-', color=colors[2], label="Estimated" if d==0 else None)
    ax.set_yticks([])
    ax.set_title(title_str[d], loc="left", y=0.5, x=-0.03)
axs[0].set_xticks([])
axs[0].legend(loc="upper right")

plt.suptitle("Generated and Estimated Continuous States", va="bottom")
plt.tight_layout()

# Creating Training, Validation and Test Data

In [None]:
training_fraction = .6
validation_fraction = .2
training_index = int(training_fraction * n_scans)
validation_index = int(training_index + validation_fraction * n_scans)

training_data = np.swapaxes(np.hstack(data[:,:,:training_index]),0,1)
validation_data = np.swapaxes(np.hstack(data[:,:,training_index:validation_index]),0,1)
test_data = np.swapaxes(np.hstack(data[:,:,validation_index:]),0,1)

In [None]:
training_data.shape

In [None]:
# Fit SLDS Models to The Training Time Series

print("Fitting SLDS with Laplace-EM")

# Create the model and initialize its parameters
slds = ssm.SLDS(emissions_dim, n_disc_states, latent_dim, emissions="gaussian_orthog")

# Fit the model using Laplace-EM with a structured variational posterior
q_lem_elbos, q_lem = slds.fit(training_data, method="laplace_em",
                               variational_posterior="structured_meanfield",
                               num_iters=3, alpha=0.0)

# Get the posterior mean of the continuous states
q_lem_x = q_lem.mean_continuous_states[0]

# Find most likely states
q_lem_z = slds.most_likely_states(q_lem_x, training_data)

# Smooth the data under the variational posterior
q_lem_y = slds.smooth(q_lem_x, training_data)

In [None]:
# Plot the ELBOs
plt.plot(q_lem_elbos, label="Laplace-EM: Structured Mean-Field Posterior")
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.legend(bbox_to_anchor=(1.0,1.0))
plt.title("Convergence for learning an SLDS")
plt.show()

In [None]:
plt.figure(figsize=(12,2))
plt.imshow(q_lem_z[None,:], aspect='auto', cmap=cmap_limited)
plt.title('fMRI Inferred States')
plt.xlabel('Frames')
ax = plt.gca()
ax.set_yticks([])
plt.savefig('scan_%i' % (i))
plt.show()

In [None]:
# Compute the approximate posterior over latent and continuous
# states for the new data under the current model parameters.
elbos, posterior = slds.approximate_posterior(validation_data,
                                              method="laplace_em",
                                              variational_posterior="structured_meanfield",
                                              num_iters=3)

# Verify that the ELBO increases during fitting. We don't expect a substantial increase:
# we are updating the estimate of the latent states but we are not changing model params.
plt.plot(elbos)
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.show()

In [None]:
x_est = posterior.mean_continuous_states[0]
z_est = slds.most_likely_states(x_est, validation_data)

In [None]:
def mse(ts1, ts2):
    return(np.mean((ts1 - ts2)**2))

def mae(ts1, ts2):
    return(np.mean(np.abs(ts1 - ts2)))
    

In [None]:
# generate noise only

A = slds.dynamics.As
b = slds.dynamics.bs
cov = slds.dynamics.Sigmas
n_gen = 1
n_val_frames = validation_data.shape[0]

MSE = np.zeros(n_gen)
MAE = np.zeros(n_gen)

for j in range(n_gen):
    x = [x_est[0]]
    
    for i in range(n_val_frames-1):
        k = z_est[i]
        w = np.random.multivariate_normal(np.zeros(latent_dim), cov[k])
        x_i = w
        x.append(x_i)
    
    x_gen = np.vstack(x)
    MSE[j] = mse(x_est, x_gen)
    MAE[j] = mae(x_est, x_gen)

In [None]:
print('Mean Squared Error:' ,MSE[0])

In [None]:
cum_err_n = [mse(x_est[:t], x_gen[:t]) for t in range(1,100)]
cum_err_n_prime = np.gradient(cum_err_n)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16,4))
axs[0].plot(cum_err_n)
axs[1].plot(cum_err_n_prime)
axs[0].set_xlabel('Time Step')
axs[1].set_xlabel('Time Step')
axs[0].set_ylabel('MSE')
axs[1].set_ylabel('d/dt MSE')

plt.suptitle('Noise Only')

In [None]:
title_str = ["$x_{%i}$" %i for i in range(latent_dim)]
fig, axs = plt.subplots(latent_dim,1, figsize=(14*30,2*latent_dim))
for (d, ax) in enumerate(axs):
    ax.plot(x_gen[:,d] + 4 * d, '-', color=colors[0], label="Generated" if d==0 else None)
    ax.plot(x_est[:,d] + 4 * d, '-', color=colors[2], label="Estimated" if d==0 else None)
    ax.set_yticks([])
    ax.set_title(title_str[d], loc="left", y=0.5, x=-0.03)
axs[0].set_xticks([])
axs[0].legend(loc="upper right")

plt.suptitle("Generated and Estimated Continuous States", va="bottom")
plt.tight_layout()

In [None]:
# Model + Noise

A = slds.dynamics.As
b = slds.dynamics.bs
cov = slds.dynamics.Sigmas
n_gen = 1
n_val_frames = validation_data.shape[0]

MSE = np.zeros(n_gen)
MAE = np.zeros(n_gen)

for j in range(n_gen):
    x = [x_est[0]]
    
    for i in range(n_val_frames-1):
        k = z_est[i]
        w = np.random.multivariate_normal(np.zeros(latent_dim), cov[k])
        x_i = A[k]@x[-1] + b[k] + w
        x.append(x_i)
    
    x_gen = np.vstack(x)
    MSE[j] = mse(x_est, x_gen)
    MAE[j] = mae(x_est, x_gen)

In [None]:
print('Mean Squared Error:' ,MSE[0])

In [None]:
cum_err_mn = [mse(x_est[:t], x_gen[:t]) for t in range(1,100)]
cum_err_mn_prime = np.gradient(cum_err_mn)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16,4))
axs[0].plot(cum_err)
axs[1].plot(cum_err_prime)
axs[0].set_xlabel('Time Step')
axs[1].set_xlabel('Time Step')
axs[0].set_ylabel('MSE')
axs[1].set_ylabel('d/dt MSE')

plt.suptitle('Model + Noise')

In [None]:
title_str = ["$x_{%i}$" %i for i in range(latent_dim)]
fig, axs = plt.subplots(latent_dim,1, figsize=(14*30,2*latent_dim))
for (d, ax) in enumerate(axs):
    ax.plot(x_gen[:,d] + 4 * d, '-', color=colors[0], label="Generated" if d==0 else None)
    ax.plot(x_est[:,d] + 4 * d, '-', color=colors[2], label="Estimated" if d==0 else None)
    ax.set_yticks([])
    ax.set_title(title_str[d], loc="left", y=0.5, x=-0.03)
axs[0].set_xticks([])
axs[0].legend(loc="upper right")

plt.suptitle("Generated and Estimated Continuous States", va="bottom")
plt.tight_layout()

In [None]:
# Model only

A = slds.dynamics.As
b = slds.dynamics.bs
cov = slds.dynamics.Sigmas
n_gen = 1
n_val_frames = validation_data.shape[0]

MSE = np.zeros(n_gen)
MAE = np.zeros(n_gen)

for j in range(n_gen):
    x = [x_est[0]]
    
    for i in range(n_val_frames-1):
        k = z_est[i]
        x_i = A[k]@x[-1] + b[k]
        x.append(x_i)
    
    x_gen = np.vstack(x)
    MSE[j] = mse(x_est, x_gen)
    MAE[j] = mae(x_est, x_gen)

In [None]:
print('Mean Squared Error:', MSE[0])

In [None]:
cum_err_m = [mse(x_est[:t], x_gen[:t]) for t in range(1,100)]
cum_err_m_prime = np.gradient(cum_err_m)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16,4))
axs[0].plot(cum_err_m)
axs[1].plot(cum_err_m_prime)
axs[0].set_xlabel('Time Step')
axs[1].set_xlabel('Time Step')
axs[0].set_ylabel('MSE')
axs[1].set_ylabel('d/dt MSE')

plt.suptitle('Model Only')

In [None]:
title_str = ["$x_{%i}$" %i for i in range(latent_dim)]
fig, axs = plt.subplots(latent_dim,1, figsize=(14*30,2*latent_dim))
for (d, ax) in enumerate(axs):
    ax.plot(x_gen[:,d] + 4 * d, '-', color=colors[0], label="Generated" if d==0 else None)
    ax.plot(x_est[:,d] + 4 * d, '-', color=colors[2], label="Estimated" if d==0 else None)
    ax.set_yticks([])
    ax.set_title(title_str[d], loc="left", y=0.5, x=-0.03)
axs[0].set_xticks([])
axs[0].legend(loc="upper right")

plt.suptitle("Generated and Estimated Continuous States", va="bottom")
plt.tight_layout()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(20,6))
axs[0].plot(cum_err_n, label='Noise Only')
axs[0].plot(cum_err_mn, label='Model + Noise')
axs[0].plot(cum_err_m, label='Model Only')
axs[1].plot(cum_err_n_prime, label='Noise Only')
axs[1].plot(cum_err_mn_prime, label='Model + Noise')
axs[1].plot(cum_err_m_prime, label='Model Only')
axs[0].legend()
axs[1].legend()
axs[0].set_xlabel('Time Step')
axs[1].set_xlabel('Time Step')
axs[0].set_ylabel('MSE')
axs[1].set_ylabel('d/dt MSE')

plt.suptitle('MSE Comparison')