# Apply AR-HMM to rest of dataset

In [1]:
from jax.config import config
config.update("jax_enable_x64", True)

import h5py
import joblib
import numpy as np
import jax.numpy as jnp
from pathlib import Path
from tqdm.auto import tqdm
from toolz import partial, valmap, partition_all
from jax_moseq.utils import batch, convert_data_precision, unbatch
from jax_moseq.models.arhmm.gibbs import resample_discrete_stateseqs

In [2]:
version = 7
folder = Path(f'/n/groups/datta/win/longtogeny/data/ontogeny/version_{version:02d}')

## Load training PCs, compute whitening parameters

In [3]:
def apply_whitening(data, L, mu):
    return np.linalg.solve(L, (data - mu).T).T


def get_whitening_params(data_dict):
    non_nan = lambda x: x[~np.isnan(np.reshape(x, (x.shape[0], -1))).any(1)]
    meancov = lambda x: (x.mean(0), np.cov(x, rowvar=False, bias=1))
    contig = partial(np.require, dtype=np.float64, requirements="C")

    mu, Sigma = meancov(np.concatenate(list(map(non_nan, data_dict.values()))))
    L = np.linalg.cholesky(Sigma)

    return mu, L

In [4]:
pca_path = folder / '_pca/pca_scores.h5'

with h5py.File(pca_path, 'r') as h5f:
    pc_data = {k: h5f['scores'][k][:, :10] for k in h5f['scores']}

In [5]:
mu, L = get_whitening_params(pc_data)

## Load new data in batches, apply whitening and MoSeq model

In [6]:
model = joblib.load(folder / 'model_params.p')

In [7]:
all_pcs_path = folder / 'all_data_pca/pca_scores.h5'

In [8]:
list(model)

['seed', 'states', 'params', 'hypparams']

In [9]:
batch_size = 80

syllables_file = all_pcs_path.with_name("syllables.h5")
# if the syllables file already exists, don't replace already computed data
mode = "a" if syllables_file.exists() else "w"
try:
    h5f = h5py.File(syllables_file, 'r')
    h5f.close()
except OSError:
    mode = "w"

with h5py.File(all_pcs_path, "r") as h5f, h5py.File(syllables_file, mode) as out_h5:
    if mode == "w":
        seq = partition_all(batch_size, h5f["scores"])
    else:
        seq = partition_all(batch_size, filter(lambda u: u not in out_h5.keys(), h5f["scores"]))
    for uuids in tqdm(list(seq)):
        pc_data = {uuid: h5f["scores"][uuid][:, :10] for uuid in uuids}
        pc_data = valmap(partial(apply_whitening, L=L, mu=mu), pc_data)

        data = {}
        data["x"], data["mask"], lbls = batch(pc_data)
        data["mask"] = jnp.where(jnp.isnan(data["x"]).any(-1), 0, data["mask"])
        data["x"] = jnp.where(jnp.isnan(data["x"]), 0, data["x"])
        data = convert_data_precision(data)
        data["mask"] = data["mask"].astype("int")

        z = resample_discrete_stateseqs(
            **data, **model, **model["params"], **model["hypparams"], robust=True
        )

        z = unbatch(z, lbls)
        for k, v in z.items():
            out_h5.create_dataset(k, data=v)

  0%|          | 0/1 [00:00<?, ?it/s]

## Compute AR likelihoods

In [10]:
import jax
from functools import partial
from jax_moseq.utils.autoregression import robust_ar_log_likelihood, get_nlags

In [11]:
def compute_likelihood(x, mask, Ab, Q, **kwargs):
    nlags = get_nlags(Ab)
    log_likelihoods = jax.lax.map(
        partial(robust_ar_log_likelihood, x),
        (
            Ab,
            Q,
            kwargs["nu"],
            jnp.tile(mask[None, ..., nlags:], (len(Ab), *(1,) * len(mask.shape))),
        ),
    )
    return log_likelihoods

In [12]:
batch_size = 80

syllables_file = all_pcs_path.with_name("ar_log_likelihoods.h5")
# if the syllables file already exists, don't replace already computed data
mode = "a" if syllables_file.exists() else "w"
try:
    h5f = h5py.File(syllables_file, 'r')
    h5f.close()
except OSError:
    mode = "w"

with h5py.File(all_pcs_path, "r") as h5f, h5py.File(syllables_file, mode) as out_h5:
    if mode == "w":
        seq = partition_all(batch_size, h5f["scores"])
    else:
        seq = partition_all(batch_size, filter(lambda u: u not in out_h5.keys(), h5f["scores"]))
    for uuids in tqdm(list(seq)):
        pc_data = {uuid: h5f["scores"][uuid][:, :10] for uuid in uuids}
        pc_data = valmap(partial(apply_whitening, L=L, mu=mu), pc_data)

        data = {}
        data["x"], data["mask"], lbls = batch(pc_data)
        data['mask'] = jnp.where(jnp.isnan(data['x']).any(-1), 0, data['mask'])
        data['x'] = jnp.where(jnp.isnan(data['x']), 0, data['x'])
        data = convert_data_precision(data)
        data['mask'] = data['mask'].astype('int')

        likes = compute_likelihood(**data, **model, **model['params'], **model['hypparams'])
        likes = jnp.moveaxis(likes, 0, -1)
        likes = unbatch(likes, lbls)

        for k, v in likes.items():
            out_h5.create_dataset(k, data=v.astype('float32'), compression='lzf')

  0%|          | 0/46 [00:00<?, ?it/s]