## SLDS Example Notebook
### Overview
This notebook illustrates how to use `jax_moseq.models.slds` to fit a sticky hierarchical Dirichlet process switching linear dynamical system (henceforth simply an "SLDS") to time series data via Gibbs sampling. Like the ARHMM model implemented by this package, this Bayesian nonparametric variant of the SLDS was originally described by Fox et al. (2008). Our version extends the original formulation by explicitly modeling the level of uncertainty for each observation, which allows us to downweight the influence of outliers when resampling the continuous latent states. For illustration purposes, we will simply fit the model to noisy observations derived from the depth principal components used in the ARHMM notebook.

### Model
#### Intuition
The SLDS is a very natural extension of the ARHMM (see the ARHMM notebook for the needed background) in which the AR dynamics occur in a continuous latent space from which our observations are i.i.d. noisy emissions, rather than in the observation space directly. This allows the model to learn smoothly evolving low-dimensional latent dynamics and therefore creates an additional barrier between the inferred state sequences/AR parameters and the noise in the observations.

#### Formalism
The SLDS includes all the variables in the ARHMM (though it construes the continuous trajectories $X$ as a set of latent states to be inferred rather than as the observed variables) and the following additions (where $o$ denotes the observation dimensionality):

- The continuous observations: $Y = \{ y_t \in \mathbb{R}^{o} \}_{t=1}^{T}$
- The noise scales: $S = \{ s_{t} \in \mathbb{R}_{+}^{o} \}_{t=1}^{T}$
- The emmission parameters: $C \in \mathbb{R}^{o \times d}, d \in \mathbb{R}^{o}$
- The unscaled noise: $\sigma^2 \in \mathbb{R}_{+}^{o}$.

Correspondingly, the generative model for the SLDS contains the following additions:

- $y_t \sim \mathcal{N}(C x_{t} + d, S_t)$
- $\sigma_i^2 \sim \chi^{-2}(\nu_{\sigma}, \sigma_0^2)$
- $s_{t, i} \sim \chi^{-2} (\nu_s, s_0)$

Above the time $t$ noise covariance matrix $S_t = \text{diag}(s_t \odot \sigma^2)$ is a diagonal matrix where $[S_t]_{ii} = s_{t, i} \sigma_i^2$. Finally, we have the following hyperparameters:

- The number of chi-squared degrees of freedom for the unscaled noise: $\nu_{\sigma} \in \mathbb{Z}_+$
- The inverse chi-squared scaling factor for the unscaled noise: $\sigma_0^2$
- The number of chi-squared degrees of freedom for the noise scales: $\nu_s \in \mathbb{Z}_+$
- The inverse chi-squared scaling factor for the noise scales: $s_0$.


#### Fitting
We fit the SLDS via Gibbs sampling. Note that our current implementation does not resample the emission parameters $C, d$, which we fix in advance via PCA.

### References
[1] Fox, E., Sudderth, E., Jordan, M., & Willsky, A. (2008). Nonparametric Bayesian learning of switching linear dynamical systems. Advances in neural information processing systems, 21.

### Code
Before running this notebook, be sure to install `jax_moseq` and its associated dependencies. This notebook also requires `tqdm` and `matplotlib`. Also note that while a GPU is not required, it certainly doesn't hurt.

In [None]:
from jax_moseq.models import slds

import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np

from tqdm.auto import trange
import matplotlib.pyplot as plt

#### Helper Methods

In [None]:
def print_dict(d, depth=0, ind='  '):
    if type(d) != dict:
        try:
            item = d.shape
        except AttributeError:
            item = d
        print(f'{ind * depth}{item}')
        return
    
    for k, v in d.items():
        print(f'{ind * depth}{k}')
        print_dict(v, depth + 1, ind)

In [None]:
def plot_ll(key, ll_history):
    plt.title(f'Log Likelihood of {key}')
    plt.xlabel('Iteration')
    plt.ylabel('Log Likelihood')
    plt.plot(ll_history)
    plt.show()

#### Loading the Data
The data is stored in a dictionary with two entries: 
- `'Y'` - a jax array of shape `(num_sessions, num_timesteps, obs_dim)` containing the continuous observations to which the model will be fit. In this case, these data are artificially derived from the mouse depth PCs (see Wiltschko et al. 2015).
- `'mask'` - a jax array of shape `(num_sessions, num_timesteps)` indicating which data points are valid (which is useful in the event that data for each session differs in length).

In [None]:
x_path = 'example_data.npy'
x = jax.device_put(np.load(x_path))

latent_dim = x.shape[-1]
obs_dim = 24
projection_matrix = jr.normal(jr.PRNGKey(0), (obs_dim, latent_dim))

Y = jnp.einsum('...d,od->...o', x, projection_matrix)
del x

data = {'Y': Y,
        'mask': jnp.ones((Y.shape[:2]))}

#### Setting the Hyperparameters

In [None]:
num_states = 100
nlags = 3

# Note: we dub the dimensionality of the continuous
# trajectories `latent_dim` despite the fact that they're
# to harmonize the lingo across the ARHMM and SLDS.

# TODO: identify a good set of hyperparameters for the dataset

trans_hypparams = {
    'gamma': 1e3, 
    'alpha': 5.7, 
    'kappa': 2e5,
    'num_states': num_states}

ar_hypparams = {
    'S_0_scale': 10,
    'K_0_scale': 0.1,
    'latent_dim': latent_dim,
    'num_states': num_states,
    'nlags': nlags}

obs_hypparams = {
    'nu_sigma': 1e5,
    'sigmasq_0': 10,
    'nu_s': 5,
    's_0': 1
}

#### Fitting the Model

In [None]:
model = slds.init_model(data,
                        trans_hypparams=trans_hypparams,
                        ar_hypparams=ar_hypparams,
                        obs_hypparams=obs_hypparams,
                        verbose=True)

print()
print_dict(model)

In [None]:
ar_iters = 50    # number of training iterations
total_iters = 75

ll_keys = ['z', 'x', 's', 'Y']
ll_history = {key: [] for key in ll_keys}

for i in trange(ar_iters):
    # Perform Gibbs resampling
    model = slds.resample_model(data, **model, ar_only=True)
    
    # Compute the likelihood of the data and
    # resampled states given the resampled params
    ll = slds.model_likelihood(data, **model)
    for key in ll_keys:
        ll_history[key].append(ll[key].item())
        
for i in trange(ar_iters, total_iters):
    # Perform Gibbs resampling
    model = slds.resample_model(data, **model)
    
    # Compute the likelihood of the data and
    # resampled states given the resampled params
    ll = slds.model_likelihood(data, **model)
    for key in ll_keys:
        ll_history[key].append(ll[key].item())

In [None]:
for k, v in ll_history.items():
    plot_ll(k, v)