# Explore and Model Behaviour distributions

## 0. Scope
This performs a temporal analysis of the behaviour classification based on Categorical HMM.

### 0.1 Requirements
 * `SEGMENT_LIST`: DataFrame with Segment information
 * `BEHAVIOURS`: Annotation of Behaviours (for comparing against)
 * `MODELLING_DF`: Modelling Datasets as generated using `Build_Modelling_Dataset.ipynb`
 
### 0.2 Extent of Analysis
 * Models are trained in a Cross-Validation fashion using the folds

In [None]:
from mpctools.extensions import mplext, npext, utils, skext
from string import ascii_uppercase as sau
from IPython.display import display, HTML
from mpctools.parallel import ProgressBar
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib as mpl
import itertools as it
import seaborn as sns
import pandas as pd
import numpy as np
import time as tm
import joblib
import sys
import os

# Add the Project Directories to the path
sys.path.append('../../../../')

# Add own Tools
from Scripts.Constants import Const, CAGE_SHORTHAND
from Tools.Parsers import BORISParser

# Display Options
display(HTML("<style>.container { width:95% !important; }</style>"))
np.set_printoptions(precision=4, linewidth=150, suppress=True)

# Location Logic
MPC_WORK = os.uname().nodename == 'MPCWork'

In [None]:
# ======= Sources ======= #
RESULT_LOC = os.path.join(Const['Results.Scratch'], 'Modelling')

# ==== Visualisations ==== #
BEH_ORDER = BORISParser.BEHAVIOURS(True).values()
BEH_NAMES = BORISParser.BEHAVIOURS(True, True).values()

COLOUR_MAPS = {'Z': 'Purples', 'R': 'Reds', 'G': 'Greens', 'B': 'Blues'}

INDIVIDUAL = True

clrs = list(mpl.colormaps['tab10'].colors); clrs.insert(10, mpl.colormaps['Set1'].colors[5]); clrs.insert(11, mpl.colormaps['Dark2'].colors[5])
CMAP = mpl.colors.ListedColormap(clrs, 'Custom')

VIS_HEIGHTS = {3:3.9, 6:5.7, 7:6.3}

# ==== Experimental Conditions ==== #
SEGMENT_BTIS = 30*60
RUN_SEGMENTS = 5 # How many segments per run

RESTARTS = 50
MAX_ITER = 150

Z_SIZES = (2, 3, 4, 5, 6, 7, 9, 11, 13) # (2, 3, 4, 5, 6, 7, 8, 9) # 
Z_VIS = (3, 7,) # 2, 3, 4, 
Z_TEMP = 3
Z_PERM = 7
CG_PERM = 4181380

PERM_ORDER = {''.join(name): order for name, order in zip(it.permutations('RGB', 3), it.permutations((0, 1, 2), 3))}

N_JOBS = 8 if MPC_WORK else 63
RANDOM_STATE = 101

# ==== Execution Control ==== #
FIT_MODELS = False
OPTIMISE_Z = True
VISUALISE_PARAM = True
VISUALISE_TEMP = False
EXPLORE_PERMUTATIONS = False

In [None]:
results_path = os.path.join(RESULT_LOC, 'Analyse_CHMM'); utils.make_dir(results_path)
figures_path = os.path.join(RESULT_LOC, 'Figures'); utils.make_dir(figures_path)

# ==== Functions ==== #
def min_max(vec):
    return vec.min(), vec.max()

def reindex_run(rdf):
    first_seg = rdf.index.get_level_values('Segment').min()
    rdf[('Sensors', 'Time')] = (rdf.index.get_level_values(2) - first_seg) * SEGMENT_BTIS + rdf.index.get_level_values(3)
    rdf = rdf.set_index(('Sensors', 'Time'), append=True).droplevel((0, 1, 2, 3)).reorder_levels((1, 0)).rename_axis(('BTI', 'Mouse'))
    return rdf.unstack(-1).reindex(np.arange(RUN_SEGMENTS * SEGMENT_BTIS)).stack(-1, dropna=False)

# Define Function for running search over in parallel
#   Note that this uses the same random key (and hence generator) for all cages: this is ok since cages are independent, and we do not care about inter-cage similarity
#     - It also means that the same initialisation sequence is typically used for each latent parameter (which is to some extent also desirable)
def evaluate_percage(cid, X, fid, sZ):
    """
    @param cid: Cage ID (for referencing)
    @param X:  List of Arrays (each size [T, K, X]). Folds are defined such that each run is a validation fold once.
    @param fid: Fold-IDs (just the name of the fold, specifying the validation run)
    @param sZ:  Z dimension
    """
    # Set up default random generator and placeholders
    rng = np.random.default_rng(RANDOM_STATE)
    init_psi = [[np.ones(7)*.5 for _ in range(sZ)] for _ in range(3)]
    scores = {}
    # Iterate over Folds
    for k in range(len(X)):
        X_t = X.copy(); X_v = [X_t.pop(k),]
        chmm = skext.CategoricalHMM(sZ, (3, 7), init_psi=init_psi, tol=1e-5, max_iter=MAX_ITER, inits=RESTARTS, random_state=rng, n_jobs=0).fit(X_t)
        scores[fid[k]] = pd.DataFrame({
            'Train': {'Folds': len(X_t), 'Samples.All': np.sum([len(x) for x in X_t]), 'LL': chmm.logpdf(X_t)},
            'Validate': {'Folds': len(X_v), 'Samples.All': np.sum([len(x) for x in X_v]), 'LL': chmm.logpdf(X_v)},
        })
    # Now Train one on entire data and report its Scores
    chmm = skext.CategoricalHMM(sZ, (3, 7), init_psi=init_psi, tol=1e-5, max_iter=MAX_ITER, inits=RESTARTS, random_state=rng, n_jobs=0).fit(X)
    scores['All'] = pd.DataFrame({'All': {'Folds': len(X), 'Samples.All': np.sum([len(x) for x in X]), 'LL': chmm.logpdf(X)}})
    sys.stdout.write('>'); sys.stdout.flush()  # Also, show progress
    # Return Scores as DataFrame
    return (
        (cid, sZ), 
        pd.concat(scores, axis=1, names=['Fold', 'Dataset']).T, 
        {'Pi': chmm.Pi, 'Psi': chmm.Psi, 'Omega': chmm.Omega, 'LL.Evo': chmm.Evolution, 'LL.Fin': chmm.Stability}
    )

def show_omega(omega, ax):
    def _fmt_val(v):
        if v == 0:
            return ''
        elif v == 1:
            return '~1'
        else:
            return f'{v:.2f}'[1:]
    # - Note, that I only show off-diagonal elements (multiplied by 1000), with the reccurring probability on the side.
    # Prepare
    omega_stat = npext.markov_stationary(omega)
    omega_dwell = npext.markov_dwell(omega)
    y_labs=sau[:len(omega)]
    x_labs=[f'{a}\n'+f'{s:.2f}\n'[1:]+f'{n:.0f}' for a, s, n in zip(sau, omega_stat, omega_dwell)]
    omega_vis = pd.DataFrame(omega.round(2)).applymap(_fmt_val).to_numpy() #
    # Plot
    sns.heatmap(np.zeros_like(omega), annot=omega_vis, cmap=[(0.9, 0.9, 0.9)], fmt='s', ax=ax, cbar=False, annot_kws={"size": 19}, linewidths=2,)
    ax.set_xticks(np.arange(0.5, len(x_labs) + 0.5)); ax.set_xticklabels(x_labs, rotation=0, fontsize=19)
    ax.set_yticks(np.arange(0.5, len(y_labs) + 0.5)); ax.set_yticklabels(y_labs, rotation=0, va="center", fontsize=19)
    ax.set_xlabel('$Z^{[t+1]}$', fontsize=22); ax.set_ylabel('$Z^{[t]}$', fontsize=22, rotation=0, labelpad=22, va='center')
    ax.set_title(r'$\Omega$', fontsize=22)
        
def show_psi_k(psi_k, ax, full_names=True, prc=2):
    def _fmt_val(v):
        if v == 0:
            return ''
        elif v == 1:
            return '~1'
        else:
            return f'{v:.{prc}f}'[1:]
    psi_k_annot = pd.DataFrame(psi_k.round(prc)).applymap(_fmt_val).to_numpy()
    mplext.plot_matrix(psi_k, mode='hinton', show_val=psi_k_annot, x_labels=BEH_NAMES if full_names else [b[0] for b in BEH_NAMES], y_labels=None, fmt='s', fs=20, x_rot=90 if full_names else 0, buffer=0.6, ax=ax)
    ax.set_xlabel(f'$X_{k+1}$', fontsize=20, labelpad=4)
    if full_names: ax.xaxis.set_tick_params(pad=-2)
    ax.set_title(f'$\Psi_{k+1}$', fontsize=22)

In [None]:
if FIT_MODELS or VISUALISE_TEMP or EXPLORE_PERMUTATIONS:
    # Load the Main Data
    sys.stdout.write('Loading Behaviour Predictions ... '); sys.stdout.flush()
    data = pd.concat([pd.read_pickle(os.path.join(MODELLING_DF, f'{k}.df'), compression='bz2') for k in ('Train', 'Validate', 'Test')])
    data = data.drop([4188079, 4185530, 4154665], level=0)  # Remove Unused Cages
    print('Done')

--------

## 1. Fit Models

The idea is to search over architectures (latent-state dimensionality) for the MoC model, optimising using LL.

In [None]:
if FIT_MODELS:
    print('Generating Statistics for various sZ ... (will take some time)')
    # Create Location (just in case)
    utils.make_dir(results_path)
    # Create Runs
    print(f' + Generating folds per-Cage ... ', end='', flush=True); s = tm.time()
    X, fids = defaultdict(list), defaultdict(list)
    obs = defaultdict(dict)
    for cid, cdf in data.groupby(by='CageID'):
        non_nan = {}
        for rid, rdf in cdf['ALM.Prob'].groupby(by='Run'):
            first_seg = rdf.index.get_level_values('Segment').min()
            rdf['Time'] = (rdf.index.get_level_values('Segment') - first_seg) * SEGMENT_BTIS + rdf.index.get_level_values('BTI') # 
            rdf = rdf.set_index('Time', append=True).droplevel((0, 1, 2, 3)).unstack(0).reorder_levels((1, 0), axis=1)[['R', 'G', 'B']]
            non_nan[rid] = len(rdf.dropna(how='all'))
            rdf = rdf.reindex(np.arange(RUN_SEGMENTS * SEGMENT_BTIS))
            X[cid].append(rdf.to_numpy().reshape(RUN_SEGMENTS * SEGMENT_BTIS, 3, 7))
            fids[cid].append(rid)
        # Resolve Observables
        total = np.sum([v for v in non_nan.values()])
        for rid, rs in non_nan.items():
            obs[cid][rid] = {'Train': total - rs, 'Validate': rs}
        obs[cid] = pd.DataFrame(obs[cid]).stack().reorder_levels((1, 0))
    obs = pd.concat(obs).rename('Samples.Observable').rename_axis(('CageID', 'Fold', 'Dataset'))
    print(f' Done! [{utils.show_time(tm.time() - s)}]')
    # Run Experiments
    print(f' + Learning :{"         ".join("|"*int(1 + (len(X)*len(Z_SIZES))/10))}')
    print(f' + Learning : ', end='', flush=True); s = tm.time()
    results = joblib.Parallel(n_jobs=N_JOBS, prefer='processes')(joblib.delayed(evaluate_percage)(cid, X[cid], fids[cid], sZ) for sZ, cid in it.product(reversed(Z_SIZES), X.keys()))
    print(f'* Done! [{utils.show_time(tm.time() - s)}]')
    # Store
    stats = pd.concat({cz: cz_df for cz, cz_df, _ in results}, names=['CageID', '|Z|']).join(obs)
    stats.to_pickle(os.path.join(results_path, f'Scores.df'), compression='bz2')
    params = {cid: {sz: cz_param for (c_id, sz), _, cz_param in results if c_id == cid} for (cid, _), _, _ in results}
    joblib.dump(params, os.path.join(results_path, f'Params.jlib'), compress=True)

## 2. Optimise $|Z|$

### 2.1 Best Log-Likelihood
The idea is to compare log-likelihoods for various values of $|Z|$. Note, that I report normalised log-likelihood.

#### 2.1.1 Per-Cage

This is the version used in the thesis.

In [None]:
if OPTIMISE_Z and MPC_WORK:
    stats = pd.read_pickle(os.path.join(results_path, f'Scores.df'), compression='bz2')
    static = pd.read_pickle(os.path.join(RESULT_LOC, 'Optimise_MoC', f'Scores.MoC.df'), compression='bz2')[['LL']].join(stats['Folds'], how='right') # Get also static for comparison
    for cid, cstats in stats.groupby(by='CageID'):
        if INDIVIDUAL:
            fig, axs = plt.subplots(2, 1, figsize=[8.3, 10], tight_layout=True, sharex=True, gridspec_kw={'height_ratios':[2,1]}); ax_t, ax_m = axs
            # Do HMM alone
            nll_t = (cstats['LL'] / cstats['Samples.Observable']).drop('All', level=2).to_frame('NLL').reset_index()
            nll_t['|Z|'] = nll_t['|Z|'].map(lambda z: Z_SIZES.index(z)) + nll_t['Dataset'].map({'Train': -0.1, 'Validate': 0.1})
            sns.lineplot(nll_t, x='|Z|', y='NLL', hue='Dataset', errorbar=None, linestyle=':', palette='tab10', marker='P', ms=12, ax=ax_t)
            sns.scatterplot(nll_t, x='|Z|', y='NLL', hue='Dataset', s=80, ax=ax_t)
            elem = [*ax_t.get_legend_handles_labels()[0][2:], *ax_t.get_lines()[:2]]; elem = [elem[l] for l in [0, 2, 1, 3]]
            ax_t.legend(handles=elem, labels=['Folds (T)', 'Mean (T)', 'Folds (V)', 'Mean (V)'], fontsize=17, ncol=4, loc=(0, 1.02), markerscale=1.2, handlelength=1.4, handletextpad=0.3, columnspacing=0.8)
            ax_t.tick_params(labelsize=20) 
            ax_t.set_ylabel(r'$\widehat{\mathcal{L}}_{HMM}$', fontsize=20); ax_t.set_xlabel('|Z|', fontsize=20)
            ax_t.set_xticks(Z_SIZES); ax_t.set_xticklabels([f'|Z|={z}' for z in Z_SIZES], rotation=30, ha='right')
            # Now do HMM vs MoC
            nll_t = (cstats['LL'] / cstats['Samples.Observable']).drop('All', level=2).to_frame('NLL')
            nll_s = (static.loc[cid, 'LL'] / cstats.loc[cid, 'Samples.Observable']).drop('All', level=2).to_frame('NLL')
            join = (nll_t - nll_s).reset_index(); join['|Z|'] = join['|Z|'].map(lambda z: Z_SIZES.index(z)) + join['Dataset'].map({'Train': -0.1, 'Validate': 0.1})
            sns.scatterplot(join, x='|Z|', y='NLL', hue='Dataset', s=80, ax=ax_m, legend=None)
            ax_m.tick_params(labelsize=20); ax_m.set_ylabel(r'$\widehat{\mathcal{L}}_{HMM} - \widehat{\mathcal{L}}_{MoC}$', fontsize=20, labelpad=15); ax_m.set_xlabel('|Z|', fontsize=20)
            ax_m.set_xticks(np.arange(len(Z_SIZES))); ax_m.set_xticklabels([f'|Z|={z}' for z in Z_SIZES], rotation=30, ha='right', fontsize=20)
            plt.tight_layout(h_pad=0)
        else:
            # Create Fig/Axis
            fig, ax_t = plt.subplots(1, 1, figsize=[8, 8], tight_layout=True)
            # Do Temporal
            nll_t = (cstats['LL'] / cstats['Folds']).drop('All', level=2).to_frame('NLL').reset_index() # .groupby(by=['|Z|', 'Dataset']).sum()
            nll_t['|Z|'] += nll_t['Dataset'].map({'Train': -0.05, 'Validate': 0.05})
            sns.lineplot(nll_t, x='|Z|', y='NLL', hue='Dataset', errorbar=None, linestyle=':', palette='tab10', marker='P', ms=10, ax=ax_t)
            sns.scatterplot(nll_t, x='|Z|', y='NLL', hue='Dataset', s=40)
            ax_t.legend(handles=ax_t.get_lines(), labels=['Mean (T)', 'Mean (V)', 'Folds (T)', 'Folds (V)'], fontsize=14, title='HMM', title_fontsize=16, ncol=2, loc=(0, 1.02))
            ax_t.tick_params(labelsize=14) 
            # Now do Static (on alternative axis)
            ax_s = ax_t.twinx()
            nll_s = (static.loc[cid, 'LL']/static.loc[cid, 'Folds']).drop('All', level=2).groupby(by=['|Z|', 'Dataset']).mean().to_frame('NLL').reset_index()
            sns.lineplot(nll_s, x='|Z|', y='NLL', hue='Dataset', marker='^', ms=10, linestyle=':', palette='Dark2', ax=ax_s)
            ax_s.legend(handles=ax_s.get_lines(), labels=['Train', 'Validate'], fontsize=14, title='MoC', title_fontsize=16, ncol=2, loc=(0.55, 1.02))
            ax_s.tick_params(labelright=False, right=False); ax_s.set_ylabel(None)
            # Commonalities
            y_lim = (min(ax_t.get_ylim()[0], ax_s.get_ylim()[0]), max(ax_t.get_ylim()[1], ax_s.get_ylim()[1]))
            ax_t.set_ylabel('Normalised LL', fontsize=15); ax_t.set_xlabel('|Z|', fontsize=15)
            ax_t.set_ylim(*y_lim); ax_s.set_ylim(*y_lim)
            ax_t.set_xticks(Z_SIZES); ax_t.set_xticklabels([f'|Z|={z}' for z in Z_SIZES], rotation=30, ha='right')
        plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_chmm_ll_C{cid}.png'), bbox_inches='tight', dpi=150)
        plt.close()

#### 2.1.2 Overall

This is used in the paper to justify the $|Z|=7$.

In [None]:
if OPTIMISE_Z and MPC_WORK:
    stats = pd.read_pickle(os.path.join(results_path, f'Scores.df'), compression='bz2')
    fig, ax = plt.subplots(1, 1, figsize=[10, 3])
    nll_t = (stats['LL'] / stats['Samples.Observable']).drop('All', level=2).to_frame('NLL').reset_index()
    sns.lineplot(nll_t, x='|Z|', y='NLL', hue='Dataset', errorbar=None, linestyle=':', palette='tab10', marker='P', ms=14, lw=2.5, ax=ax)
    ax.tick_params(labelsize=20) 
    ax.set_ylabel(r'Average $\widehat{\mathcal{L}}$', fontsize=20); ax.set_xlabel('|Z|', fontsize=20)
    ax.set_xticks(Z_SIZES); ax.set_xticklabels([f'{z}' for z in Z_SIZES])
    ax.legend(fontsize=20, loc=4, ncol=2)
    plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_chmm_ll_CALL.png'), bbox_inches='tight', dpi=150)

### 2.2 Stability of Solution

(Distribution over random restarts)

In [None]:
if OPTIMISE_Z and MPC_WORK:
    # Compute
    params = joblib.load(os.path.join(results_path, f'Params.jlib'))
    stability = pd.DataFrame({(cid, sZ): zpar['LL.Fin']/zpar['LL.Fin'].mean() for cid, cpar in params.items() for sZ, zpar in cpar.items() }).T
    stability = stability.rename_axis(index=('CageID', '|Z|')).stack().rename('LL').reset_index((0, 1))
    # Display
    fig, ax = plt.subplots(1, 1, figsize=[22, 6], tight_layout=True)
    sns.boxplot(stability, x='CageID', y='LL', hue='|Z|', ax=ax)
    ax.legend(fontsize=14, title='|Z|', title_fontsize=15, ncol=3)
    ax.tick_params(labelsize=14); ax.set_xlabel('Cage ID', fontsize=15); ax.set_ylabel('LL relative to Mean (across folds)', fontsize=15)

## 3. Visualise Parameters

In [None]:
if VISUALISE_PARAM and MPC_WORK:
    params = joblib.load(os.path.join(results_path, f'Params.jlib'))
    for cid, cdict in params.items():
        for z in Z_VIS:
            omega, psi = cdict[z]['Omega'], cdict[z]['Psi']
            fig, axs = plt.subplots(1, 4, figsize=[15+z/2, VIS_HEIGHTS[z]], tight_layout=True, gridspec_kw={'width_ratios':[z,7,7,7]})
            show_omega(omega, axs[0])
            for k, ax in enumerate(axs[1:]):
                show_psi_k(psi[k, :, :], ax)
            plt.tight_layout(w_pad=0)
            plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_chmm_params_C{cid}_Z={z}.png'), bbox_inches='tight', dpi=150)
            plt.close()

## 4. Visualise Temporal

In [None]:
if VISUALISE_TEMP and MPC_WORK:
    params = joblib.load(os.path.join(results_path, f'Params.jlib'))
    progress = ProgressBar(len(data.index.unique('Run'))).reset('Plotting:')
    for cid, cdict in params.items():
        c_par = cdict[Z_TEMP]
        chmm = skext.CategoricalHMM(c_par['Pi'], c_par['Psi'], c_par['Omega'])
        for rid, rdf in data.loc[cid].groupby(by='Run'):
            first_seg = rdf.index.get_level_values('Segment').min()
            rdf[('Sensors', 'Time')] = (rdf.index.get_level_values('Segment') - first_seg) * SEGMENT_BTIS + rdf.index.get_level_values('BTI')
            rdf = rdf.set_index(('Sensors', 'Time'), append=True).droplevel(('Run', 'Segment', 'BTI')).reorder_levels((1, 0)).rename_axis(('BTI', 'Mouse'))
            rdf = rdf.unstack(-1).reindex(np.arange(RUN_SEGMENTS * SEGMENT_BTIS))
            # Extract Data
            r_X = rdf['ALM.Prob'][BEH_ORDER].reorder_levels((1, 0), 1)[['R', 'G', 'B']]
            r_Z = chmm.predict_proba([r_X.to_numpy().reshape(SEGMENT_BTIS * RUN_SEGMENTS, 3, 7).astype('float', order='C')])[0].T
            r_L = rdf['Sensors']['Light']['R'].astype(float).to_numpy()[np.newaxis, :]
            # Set up Figure
            fig, axs = plt.subplots(5, 1, figsize=[25, 5+Z_TEMP if Z_TEMP < 5 else 4+Z_TEMP], tight_layout=True, sharex=True, gridspec_kw={'height_ratios':[0.8, Z_TEMP if Z_TEMP < 5 else Z_TEMP-2, 5, 5, 5]})
            #  i) Plot Light Status
            cmap = mpl.colormaps['gray']; cmap.set_bad('tan')
            axs[0].imshow(r_L, cmap=cmap, aspect='auto'); axs[0].set_yticks([0]); axs[0].set_yticklabels(['Light'], fontsize=22)
            # ii) Plot the sequence of Latent States
            cmap = mpl.colormaps[COLOUR_MAPS['Z']]; cmap.set_bad('gray')
            axs[1].imshow(r_Z, cmap=cmap, aspect='auto', interpolation='none')
            axs[1].set_yticks(np.arange(len(r_Z))); axs[1].set_yticklabels(sau[:len(r_Z)], fontsize=22)
            axs[1].set_ylabel('Z', fontsize=25, rotation=0, va='center', labelpad=80)
            # iii) Plot each of the three mouse behaviours.
            for k, m in enumerate('RGB'):
                cmap = mpl.colormaps[COLOUR_MAPS[m]]; cmap.set_bad('gray')
                m_X = r_X[m].to_numpy().T
                axs[2+k].imshow(m_X, cmap=cmap, aspect='auto', interpolation='none')
                axs[2+k].set_yticks(np.arange(7)); axs[2+k].set_yticklabels(BEH_NAMES, fontsize=22)
                axs[2+k].set_ylabel(f'$X_{k+1}$', fontsize=25, va='center', rotation=0, labelpad=20)
            # iv) Commonalities
            time_ticks = np.arange(0, RUN_SEGMENTS * SEGMENT_BTIS+1, 600)
            axs[-1].set_xticks(time_ticks); axs[-1].set_xticklabels(utils.time_list(time_ticks*1000., fmt='%H:%M'), fontsize=22, ha='center')
            axs[-1].set_xlabel('Time (Hrs:Mins)', fontsize=18)
            # Save Figure
            plt.tight_layout(h_pad=0)
            plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_chmm_temporal_R{rid}_C{cid}_Z{Z_TEMP}.png'), bbox_inches='tight', dpi=200)
            plt.close()
            progress.update()

## 5. Permutation Analysis

Explore how Permutations of the cages effects the log-likelihood.

 1. For each cage model:
    1. For each Othr Cage (including itself)
        1. Permute over Mice and find LL
        
### 5.1 Generate Data

In [None]:
# First Generate the Data
if EXPLORE_PERMUTATIONS:
    # First Generate
    params = joblib.load(os.path.join(results_path, f'Params.jlib'))
    lls_hmm = defaultdict(lambda : defaultdict(dict))
    lls_bl = {}
    progress = ProgressBar(len(params) ** 2 * len(PERM_ORDER) + len(params)*2, prec=2).reset('Generating')
    for x_id, x_data in data.groupby(level=0):
        # 1) Prepare Data
        X = []
        for rid, rdf in x_data['ALM.Prob'].groupby('Run'):
            first_seg = rdf.index.get_level_values('Segment').min()
            rdf['Time'] = (rdf.index.get_level_values('Segment') - first_seg) * SEGMENT_BTIS + rdf.index.get_level_values('BTI')
            rdf = rdf.set_index('Time', append=True).droplevel(('CageID', 'Run', 'Segment', 'BTI')).reorder_levels((1, 0)).rename_axis(('BTI', 'Mouse'))
            rdf = rdf.unstack(-1).reindex(np.arange(RUN_SEGMENTS * SEGMENT_BTIS))
            X.append(rdf[BEH_ORDER].reorder_levels((1, 0), 1)[['R', 'G', 'B']].to_numpy().reshape(RUN_SEGMENTS * SEGMENT_BTIS, 3, 7))
        progress.update()
        # 2) Generate Permutation Data
        # Iterate over models
        for mdl_id, mdl_par in params.items():
            pars = mdl_par[Z_PERM]; chmm = skext.CategoricalHMM(pars['Pi'], pars['Psi'], pars['Omega'])
            # Iterate over Ordering
            for name, order in PERM_ORDER.items():
                lls_hmm[mdl_id][x_id][name] = chmm.logpdf([x[:, order, :] for x in X], norm=True)
                progress.update()
        # 3) Generate Baseline Data
        # Fit Model: Just Value Counts over all data
        x_anon = np.vstack([x.reshape(-1, 7) for x in X])
        x_anon = x_anon[np.isfinite(x_anon).all(axis=1), :]
        bl_mdl = np.log(npext.sum_to_one(x_anon.sum(axis=0)))
        # Now Compute Log-Likelihood: just multiply and sum
        lls_bl[x_id] = (x_anon @ bl_mdl).mean()
        progress.update()

    # Stores
    lls_bl = pd.Series(lls_bl, name='Baseline')
    lls_bl.to_pickle(os.path.join(results_path, f'LL.BL.Same.df'), compression='bz2')
    lls_hmm = pd.concat({m_id: pd.DataFrame(clls) for m_id, clls in lls_hmm.items()}, axis=0).unstack(-1)
    lls_hmm = lls_hmm.rename_axis(index=('Model',), columns=('CageID', 'Permutation'))
    lls_hmm.to_pickle(os.path.join(results_path, f'LL.CHMM.Z{Z_PERM}.df'), compression='bz2')

### 5.2 Permutations within Cage

This is to show the within-cage individuality of the mice.

In [None]:
if EXPLORE_PERMUTATIONS:
    # Now Visualise
    lls_hmm = pd.read_pickle(os.path.join(results_path, f'LL.CHMM.Z{Z_PERM}.df'), compression='bz2')
    lls_bl = pd.read_pickle(os.path.join(results_path, f'LL.BL.Same.df'), compression='bz2').rename(CAGE_SHORTHAND)
    # Compute per-cage
    ll_pc = pd.concat({CAGE_SHORTHAND[cid]: lls_hmm.loc[cid, cid] for cid in lls_hmm.index}, axis=1)
    ll_pc = ll_pc - lls_bl  # Subtract Baseline
    ll_pc = ((ll_pc.loc['RGB'] - ll_pc)/ll_pc.loc['RGB']).T * 100
    clr_rng = (0, ll_pc.to_numpy().max())
    # Create as Figure
    fig, axs = plt.subplots(1, 3, figsize=[7.5, 7], tight_layout=True, sharey=False, gridspec_kw={'width_ratios': [6, 1, 1]})
    # First Overall
    ax = axs[0]
    mplext.plot_matrix(ll_pc.to_numpy(), mode='heatmap', min_max=clr_rng, x_labels=ll_pc.columns, y_labels=ll_pc.index, show_val=True, fs=17, cax=False, ax=ax, hm_args={'cmap': 'Blues'})
    ax.set_xlabel('Permutation', fontsize=18); ax.set_ylabel('Cage ID', fontsize=18)
    # Now Mean
    ax = axs[1]
    mplext.plot_matrix(ll_pc.mean(axis=1).to_numpy()[:, np.newaxis], mode='heatmap', min_max=clr_rng, x_labels=[], y_labels=[], show_val=True, fs=17, cax=False, ax=ax, hm_args={'cmap': 'Blues'})
    ax.set_xlabel('Mean', fontsize=18, labelpad=20)
    # Finally StD
    ax = axs[2]
    mplext.plot_matrix(ll_pc.std(axis=1).to_numpy()[:, np.newaxis], mode='heatmap', min_max=clr_rng, x_labels=[], y_labels=[], show_val=True, fs=17, cax=False, ax=ax, hm_args={'cmap': 'Blues'})
    ax.set_xlabel('StDev', fontsize=18, labelpad=20)
    # Save
    plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_anomaly_per_cage_Z{Z_PERM}.png'), bbox_inches='tight', dpi=150)
    # Print
    print(f'Overall Stats: Mean={ll_pc.to_numpy().mean():.1f}')

### 5.3  [OBSOLETE] 

#### 5.3.2 Per-Run Log-Likelihood

In [None]:
if EXPLORE_PERMUTATIONS:
    # Build Data
    X, r = [], []
    for rid, rdf in data.loc[CG_PERM, 'ALM.Prob'].groupby('Run'):
        first_seg = rdf.index.get_level_values('Segment').min()
        rdf['Time'] = (rdf.index.get_level_values('Segment') - first_seg) * SEGMENT_BTIS + rdf.index.get_level_values('BTI')
        rdf = rdf.set_index('Time', append=True).droplevel(('Run', 'Segment', 'BTI')).reorder_levels((1, 0)).rename_axis(('BTI', 'Mouse'))
        rdf = rdf.unstack(-1).reindex(np.arange(RUN_SEGMENTS * SEGMENT_BTIS))
        X.append(rdf[BEH_ORDER].reorder_levels((1, 0), 1)[['R', 'G', 'B']].to_numpy().reshape(RUN_SEGMENTS * SEGMENT_BTIS, 3, 7))
        r.append(rid)
    
    # Evaluate per-Run
    params = joblib.load(os.path.join(results_path, f'Params.jlib'))
    lls_best = pd.read_pickle(os.path.join(results_path, f'LL.CHMM.Z{Z_PERM}.df'), compression='bz2')[CG_PERM].idxmax(axis=1).map(PERM_ORDER)
    lls = {}
    progress = ProgressBar(len(params), prec=2).reset('Generating')
    for mdl_id, mdl_par in params.items():
        pars = mdl_par[Z_PERM]; chmm = skext.CategoricalHMM(pars['Pi'], pars['Psi'], pars['Omega'])
        lls[mdl_id] = pd.Series(chmm.logpdf([x[:, lls_best[mdl_id], :] for x in X], per_run=True, norm=True), index=r, name='Run')
        progress.update()
    
    # Store
    lls = pd.concat(lls, axis=1).T.rename_axis(index='Model', columns='Run')
    lls.to_pickle(os.path.join(results_path, f'LL_PRUN.CHMM.C{CG_PERM}.Z{Z_PERM}.df'), compression='bz2')

In [None]:
if EXPLORE_PERMUTATIONS:
    lls = pd.read_pickle(os.path.join(results_path, f'LL_PRUN.CHMM.C{CG_PERM}.Z{Z_PERM}.df'), compression='bz2')
    lls = lls.rename(index=CAGE_SHORTHAND)
    # Create Figure
    fig, axs = plt.subplots(2, 1, figsize=(5.5, 7), sharex='col', tight_layout=True, gridspec_kw={'height_ratios': [12, 1]})
    # Per-Run Plot
    ax = axs[0]
    mplext.plot_matrix(lls.to_numpy(), mode='heatmap', y_labels=lls.index, show_val=True, fs=17, fmt='.2f', cax=False, ax=ax, hm_args={'cmap': 'Blues_r'})
    ax.set_ylabel('Model', fontsize=17)
    # Mean across Models (per-Run)
    ax = axs[1]
    mplext.plot_matrix(lls.mean(axis=0).to_numpy()[np.newaxis, :], mode='heatmap', x_labels=lls.columns, y_labels=['Mean'], show_val=True, fs=17, fmt='.2f', cax=False, ax=ax, hm_args={'cmap': 'Blues_r'})
    ax.set_xlabel('Run', fontsize=17);
    # Common and Save
    plt.tight_layout(h_pad=0.3)
    plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_anomaly_per_run_C{CG_PERM}_Z{Z_PERM}.png'), bbox_inches='tight', dpi=150)