This notebook fits a hierarchical GLM-HMM. The model includes observed discrete states $z_t \in \{1,...,N \}$, unobserved hidden states $w_t \in \{1,...,M\}$, and external inputs $u_t \in \mathbb{R}^U$. The external inputs modify the transition probabilities between the observed states according to a GLM, where the weights of the GLM depend on the hidden states. Formally, the states as generated by the following model:

$$ w_t \mid w_{t-1} \sim \text{Categorical}(\pi_{w_{t-1}}) $$

$$ z_t \mid z_{t-1}, w_t, u_t \sim \text{Categorical}(\exp(W_{w_t} u_t)+P_{z_{t-1}}) $$

where $\pi$ is the transition matrix between hidden states, $\mathbf{W}_{w_t} \in \mathbb{R}^{N \times U}$ is the weight matrix for the GLM, and $P \in \mathbb{R}^{N \times N}$ is the baseline transition matrix between observed states. The parameters are generated by the following prior distributions:

$$ \pi \sim \text{sticky-HDP}(\gamma, \alpha, \kappa) $$

$$ W_{i,j} \sim \text{Normal}(0, \sigma_W^2)$$

$$ P_{i,j} \sim \text{Normal}(0, \sigma_P^2)$$

In [3]:
import jax
import jax.numpy as jnp
import jax.random as jr

from jax_moseq.utils.distributions import sample_hmm_stateseq
from jax_moseq.utils.transitions import resample_hdp_transitions
from jax_moseq.utils import pad_along_axis
from functools import partial
na = jnp.newaxis


'''
init = blackjax.mcmc.hmc.init
step = blackjax.mcmc.hmc.kernel() 



'''

def glm_log_likelihood(z, u, P, W):
    """
    Computes the log-likelihood of each`z` transition under the GLM.

    Parameters
    ----------
    z : jnp.ndarray of shape (num_seqs, T)
        Observed discrete states.
    u : jnp.ndarray of shape (num_seqs, T, U)
        Observed external inputs.
    P : jnp.ndarray of shape (N, N)
        Baseline transition matrix for the observed states.
    W : jnp.ndarray of shape (N, U)
        GLM weights.

    Returns
    -------
    log_likelihoods : jnp.ndarray of shape (num_seqs, T)
        Log-likelihoods under the GLM.
    """
    baseline = pad_along_axis(P[z[:,1:]], (1,0), axis=1, value=0)
    bias = (W[na,na,:,:] @ u[:,:,:,na]).squeeze(-1)
    normalizer = jax.nn.logsumexp(baseline + bias, axis=-1)
    log_likelihoods = (baseline + bias)[z] - normalizer
    return log_likelihoods
    

#@jax.jit
def resample_discrete_stateseqs(seed, z, u, mask, pi, P, W, **kwargs):
    """
    Resamples the discrete state sequence ``w``.

    Parameters
    ----------
    seed : jr.PRNGKey
        JAX random seed.
    z : jnp.ndarray of shape (num_seqs, T)
        Observed discrete states.
    u : jnp.ndarray of shape (num_seqs, T, U)
        Observed external inputs.
    mask : jnp.ndarray of shape (num_seqs, T)
        Mask for observed data.
    pi : jnp.ndarray of shape (M, M)
        Hidden state transition matrix.
    P : jnp.ndarray of shape (N, N)
        Baseline transition matrix for the observed states.
    W : jnp.ndarray of shape (M, N, U)
        GLM weights for each hidden state.

    Returns
    ------
    w : jax_array of shape (num_seqs, T)
        Discrete state sequences.
    """
    num_seqs = mask.shape[0]
    log_likelihoods = jax.lax.map(partial(glm_log_likelihood, z, u, P), W)
    _, w = jax.vmap(sample_hmm_stateseq, in_axes=(0,na,0,0))(
        jr.split(seed, num_seqs),
        pi,
        jnp.moveaxis(log_likelihoods,0,-1),
        mask.astype(float))
    return w


def resample_model(data, seed, states, params, hypparams, states_only=False, **kwargs):
    """
    Resample the GLM-HMM model.

    Parameters
    ----------
    data : dict
        Data dictionary containing
        - `z` : jnp.ndarray of shape (num_seqs, T)
            Observed discrete states.
        - `u` : jnp.ndarray of shape (num_seqs, T, U)
            Observed external inputs.
        - `mask` : jnp.ndarray of shape (num_seqs, T)
            Mask for observed data.

    seed : jr.PRNGKey
        JAX random seed.

    states : dict
        State dictionary containing
        - `w` : jnp.ndarray of shape (num_seqs, T)
            Hidden states.

    params : dict
        Parameter dictionary containing
        - `pi` : jnp.ndarray of shape (M, M)
            Hidden state transition matrix.
        - `betas` : jnp.ndarray of shape (M,)
            Global concentration weights for the HDP prior over hidden state transitions.
        - `P` : jnp.ndarray of shape (N, N)
            Baseline transition matrix for the observed states.
        - `W` : jnp.ndarray of shape (M, N, U)
            GLM weights for each hidden state.
            
    hypparams : dict
        Hyperparameter dictionary (see `init_hyperparams`)

    states_only : bool, default=False
        Only resample states if True.

    Returns
    ------
    model : dict
        Dictionary containing the hyperparameters and
        updated seed, states, and parameters of the model.
    """
    seed = jr.split(seed)[1]

    if not states_only: 
        params['betas'], params['pi'] = resample_hdp_transitions(
            seed, **data, **states, **params,
            **hypparams['trans_hypparams'])

        params['W'], params['P']= resample_glm_params(
            seed, **data, **states, **params, 
            **hypparams['glm_hypparams'])

    states['w'] = resample_discrete_stateseqs(
        seed, **data, **states, **params)

    return {'seed': seed,
            'states': states, 
            'params': params, 
            'hypparams': hypparams}