# Code to train ARHMM

- find optimal kappa
- use kappa to fit full model

In [1]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.5
from jax.config import config
config.update("jax_enable_x64", True)

import h5py
import joblib
import numpy as np
from functools import partial
from collections import OrderedDict
from pathlib import Path
from toolz import valmap, valfilter
from tqdm.auto import tqdm

import jax
import jax.numpy as jnp
from jax_moseq.models import arhmm
from jax_moseq.utils import batch, convert_data_precision

env: XLA_PYTHON_CLIENT_MEM_FRACTION=.5


In [2]:
def whiten_all(data_dict, center=True):
    """
    Whiten the PC Scores (with Cholesky decomposition) using all the data to compute the covariance matrix.

    Args:
    data_dict (OrderedDict): Training dataset
    center (bool): Indicates whether to center data by subtracting the mean PC score.

    Returns:
    data_dict (OrderedDict): Whitened training data dictionary
    """

    non_nan = lambda x: x[~np.isnan(np.reshape(x, (x.shape[0], -1))).any(1)]
    meancov = lambda x: (x.mean(0), np.cov(x, rowvar=False, bias=1))
    contig = partial(np.require, dtype=np.float64, requirements='C')

    mu, Sigma = meancov(np.concatenate(list(map(non_nan, data_dict.values()))))
    L = np.linalg.cholesky(Sigma)

    offset = 0. if center else mu
    apply_whitening = lambda x:  np.linalg.solve(L, (x-mu).T).T + offset

    return OrderedDict((k, contig(apply_whitening(v))) for k, v in data_dict.items())


def concatenate_stateseqs(stateseqs, mask=None):
    """
    Concatenate state sequences, optionally applying a mask.
    Parameters
    ----------
    stateseqs: dict or ndarray, shape (..., t)
        Dictionary mapping names to 1d arrays, or a single
        multi-dimensional array representing a batch of state sequences
        where the last dim indexes time
    mask: ndarray, shape (..., >=t), default=None
        Binary indicator for which elements of ``stateseqs`` are valid,
        e.g. when state sequences of different lengths have been padded.
        If ``mask`` contains more time-points than ``stateseqs``, the
        initial extra time-points will be ignored.
    Returns
    -------
    stateseqs_flat: ndarray
        1d array containing all state sequences 
    """
    if isinstance(stateseqs, dict):
        stateseq_flat = np.hstack(list(stateseqs.values()))
    elif mask is not None:
        stateseq_flat = stateseqs[mask[:, -stateseqs.shape[1]:] > 0]
    else:
        stateseq_flat = stateseqs.flatten()
    return stateseq_flat


def get_durations(stateseqs, mask=None):
    """
    Get durations for a batch of state sequences. For a more detailed 
    description of the function parameters, see 
    :py:func:`keypoint_moseq.util.concatenate_stateseqs`
    Parameters
    ----------
    stateseqs: dict or ndarray of shape (..., t)
    mask: ndarray of shape (..., >=t), default=None
    Returns
    -------
    durations: 1d array
        The duration of each each state (across all state sequences)
    Examples
    --------
    >>> stateseqs = {
        'name1': np.array([1, 1, 2, 2, 2, 3]),
        'name2': np.array([0, 0, 0, 1])
    }
    >>> get_durations(stateseqs)
    array([2, 3, 1, 3, 1])
    """
    stateseq_flat = concatenate_stateseqs(stateseqs, mask=mask).astype(int)
    stateseq_padded = np.hstack([[-1], stateseq_flat, [-1]])
    changepoints = np.diff(stateseq_padded).nonzero()[0]
    return changepoints[1:]-changepoints[:-1]

In [3]:
folder = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_03')

In [4]:
pca_path = folder / '_pca/pca_scores.h5'

with h5py.File(pca_path, 'r') as h5f:
    pc_data = {k: h5f['scores'][k][:, :10] for k in h5f['scores']}

In [5]:
pc_data = whiten_all(pc_data)

In [6]:
nan_threshold = 300  # frames
pc_data = valfilter(lambda v: np.isnan(v).any(1).sum() < nan_threshold, pc_data)

In [7]:
total_frames = sum(map(len, pc_data.values()))
total_frames

2752092

In [8]:
max_frames = 1.5e6
frames_per_session = int(max_frames // len(pc_data))

In [9]:
pc_data = valmap(lambda v: v[:frames_per_session], pc_data)

In [10]:
num_states = 100
nlags = 3

In [11]:
data = {}
data["x"], data["mask"], lbls = batch(pc_data)

non_nans = ~jnp.isnan(data['x']).any(-1)
mask = [jnp.roll(non_nans, shift) for shift in range(nlags + 1)]
mask = jnp.all(jnp.array(mask), axis=0)

data['mask'] = jnp.where(mask, data['mask'], 0)
del mask
del non_nans

data['x'] = jnp.where(jnp.isnan(data['x']), 0, data['x'])
data = convert_data_precision(data)
data['mask'] = data['mask'].astype('int')

In [12]:
kappas = np.logspace(5, 8, 9)

In [13]:
latent_dim = obs_dim = data['x'].shape[-1]
ar_hypparams = {
    'S_0_scale': .01,
    'K_0_scale': 10,
    'num_states': num_states,
    'nlags': nlags,
    'latent_dim': latent_dim
}

In [14]:
num_iters = 35

durations = {}
agg_durations = {}
for k in tqdm(kappas, desc='kappa scan'):
    ll_keys = ['z', 'x']
    ll_history = {key: [] for key in ll_keys}

    trans_hypparams = {
        'gamma': 1e3,
        'alpha': 5.7,
        'kappa': k,
        'num_states': num_states
    }

    model = arhmm.init_model(
        data,
        ar_hypparams=ar_hypparams,
        trans_hypparams=trans_hypparams,
        robust=True,
        verbose=True
    )

    pbar = tqdm(range(num_iters))

    _durs = []

    for i in pbar:
        # Perform Gibbs resampling
        model = arhmm.resample_model(data, **model)

        durs = get_durations(model['states']['z'], data['mask'])

        # Compute the likelihood of the data and
        # resampled states given the resampled params
        ll = arhmm.model_likelihood(data, **model)
        for key in ll_keys:
            ll_history[key].append(ll[key].item())
        pbar.set_description(f"LL: {ll['x']:0.2e} -- Dur {np.mean(durs) / 30:0.2f}s {np.median(durs) / 30:0.2f}s")
        _durs.append(np.mean(durs))
    durations[k] = (np.mean(durs) / 30, np.median(durs) / 30)
    agg_durations[k] = _durs

kappa scan:   0%|          | 0/9 [00:00<?, ?it/s]

ARHMM: Initializing hyperparameters
ARHMM: Initializing parameters
ARHMM: Initializing states


  0%|          | 0/35 [00:00<?, ?it/s]

ARHMM: Initializing hyperparameters
ARHMM: Initializing parameters
ARHMM: Initializing states


  0%|          | 0/35 [00:00<?, ?it/s]

ARHMM: Initializing hyperparameters
ARHMM: Initializing parameters
ARHMM: Initializing states


  0%|          | 0/35 [00:00<?, ?it/s]

ARHMM: Initializing hyperparameters
ARHMM: Initializing parameters
ARHMM: Initializing states


  0%|          | 0/35 [00:00<?, ?it/s]

ARHMM: Initializing hyperparameters
ARHMM: Initializing parameters
ARHMM: Initializing states


  0%|          | 0/35 [00:00<?, ?it/s]

ARHMM: Initializing hyperparameters
ARHMM: Initializing parameters
ARHMM: Initializing states


  0%|          | 0/35 [00:00<?, ?it/s]

ARHMM: Initializing hyperparameters
ARHMM: Initializing parameters
ARHMM: Initializing states


  0%|          | 0/35 [00:00<?, ?it/s]

ARHMM: Initializing hyperparameters
ARHMM: Initializing parameters
ARHMM: Initializing states


  0%|          | 0/35 [00:00<?, ?it/s]

ARHMM: Initializing hyperparameters
ARHMM: Initializing parameters
ARHMM: Initializing states


  0%|          | 0/35 [00:00<?, ?it/s]

In [15]:
best_duration = valmap(lambda v: np.abs(v[0] - 0.6), durations)
best_kappa = min(best_duration, key=best_duration.get)

In [16]:
best_duration

{100000.0: 0.3500198470277825,
 237137.37056616554: 0.29347060062621044,
 562341.3251903491: 0.23430025594292642,
 1333521.432163324: 0.17717050327239897,
 3162277.6601683795: 0.11916604110125173,
 7498942.093324558: 0.06275610420519873,
 17782794.100389227: 0.009200703198290916,
 42169650.342858225: 0.03680959178817855,
 100000000.0: 0.08816591139699637}

In [17]:
best_kappa

17782794.100389227

In [18]:
num_iters = 1000
ll_keys = ['z', 'x']
ll_history = {key: [] for key in ll_keys}

trans_hypparams = {
    'gamma': 1e3,
    'alpha': 5.7,
    'kappa': best_kappa,
    'num_states': num_states
}

model = arhmm.init_model(
    data,
    ar_hypparams=ar_hypparams,
    trans_hypparams=trans_hypparams,
    robust=True,
    verbose=True
)

pbar = tqdm(range(num_iters))

for i in pbar:
    # Perform Gibbs resampling
    model = arhmm.resample_model(data, **model)
    durs = get_durations(model['states']['z'], data['mask'])

    # Compute the likelihood of the data and
    # resampled states given the resampled params
    ll = arhmm.model_likelihood(data, **model)
    for key in ll_keys:
        ll_history[key].append(ll[key].item())
    pbar.set_description(f"LL: {ll['x']:0.2e} -- Dur {np.mean(durs) / 30:0.2f}s {np.median(durs) / 30:0.2f}s")

ARHMM: Initializing hyperparameters
ARHMM: Initializing parameters
ARHMM: Initializing states


  0%|          | 0/1000 [00:00<?, ?it/s]

In [19]:
joblib.dump(model, folder / 'model_params.p')

['/n/groups/datta/win/longtogeny/data/ontogeny/version_03/model_params.p']