# Apply AR-HMM to rest of dataset

In [1]:
from jax.config import config
config.update("jax_enable_x64", True)

import h5py
import joblib
import numpy as np
import jax.numpy as jnp
from pathlib import Path
from tqdm.auto import tqdm
from toolz import partial, valmap, partition_all
from jax_moseq.utils import batch, convert_data_precision, unbatch
from jax_moseq.models.arhmm.gibbs import resample_discrete_stateseqs

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Load training PCs, compute whitening parameters

In [2]:
def apply_whitening(data, L, mu):
    return np.linalg.solve(L, (data - mu).T).T


def get_whitening_params(data_dict):
    non_nan = lambda x: x[~np.isnan(np.reshape(x, (x.shape[0], -1))).any(1)]
    meancov = lambda x: (x.mean(0), np.cov(x, rowvar=False, bias=1))
    contig = partial(np.require, dtype=np.float64, requirements="C")

    mu, Sigma = meancov(np.concatenate(list(map(non_nan, data_dict.values()))))
    L = np.linalg.cholesky(Sigma)

    return mu, L

In [3]:
pca_path = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_01/_pca/pca_scores.h5')

with h5py.File(pca_path, 'r') as h5f:
    pc_data = {k: h5f['scores'][k][:, :10] for k in h5f['scores']}

In [4]:
mu, L = get_whitening_params(pc_data)

## Load new data in batches, apply whitening and MoSeq model

In [5]:
model = joblib.load('/n/groups/datta/win/longtogeny/data/ontogeny/version_01/model_params.p')

In [6]:
all_pcs_path = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_01/all_data_pca/pca_scores.h5')

In [7]:
list(model)

['seed', 'states', 'params', 'hypparams']

In [16]:
batch_size = 25

with h5py.File(all_pcs_path, 'r') as h5f, h5py.File(all_pcs_path.with_name('syllables.h5'), 'w') as out_h5:
    for uuids in tqdm(list(partition_all(batch_size, h5f['scores']))):
        pc_data = {uuid: h5f["scores"][uuid][:, :10] for uuid in uuids}
        pc_data = valmap(partial(apply_whitening, L=L, mu=mu), pc_data)

        data = {}
        data["x"], data["mask"], lbls = batch(pc_data)
        data['mask'] = jnp.where(jnp.isnan(data['x']).any(-1), 0, data['mask'])
        data['x'] = jnp.where(jnp.isnan(data['x']), 0, data['x'])
        data = convert_data_precision(data)
        data['mask'] = data['mask'].astype('int')

        z = resample_discrete_stateseqs(**data, **model, **model['params'], **model['hypparams'], robust=True)

        z = unbatch(z, lbls)
        for k, v in z.items():
            out_h5.create_dataset(k, data=v)

  0%|          | 0/193 [00:00<?, ?it/s]

In [15]:
z[list(z)[0]]

array([35, 35, 35, ..., 92, 92, 92], dtype=int32)