# Explore and Model Behaviour distributions

## 0. Scope
This performs the analysis of the CHMM model (Omnibus and individual) on the Q2 Data.

### 0.1 Requirements
 * `MODELLING_Q2`: Modelling Datasets as generated using `Build_Modelling_Dataset.ipynb` for Q2
 * `PARAMS_NORMAL`: Model of Normality (best model)
 * `SCORES_NORMAL`: Scores of the model on the normal cages

In [None]:
from mpctools.extensions import utils, skext, npext, mplext
from string import ascii_uppercase as sau
from mpctools.parallel import ProgressBar
from IPython.display import display, HTML
from sklearn import metrics as skmetrics
import matplotlib.pyplot as plt
import scipy.stats as scstats
import matplotlib as mpl
import itertools as it
import seaborn as sns
import pandas as pd
import numpy as np
import time as tm
import joblib
import sys
import os

# Add the Project Directories to the path
sys.path.append('../../../../')

# Add own Tools
from Scripts.Constants import Const, CAGE_SHORTHAND
from Tools.Parsers import BORISParser

# Location Logic
MPC_WORK = os.uname().nodename == 'MPCWork'

# Display Options
if MPC_WORK:
    display(HTML("<style>.container { width:95% !important; }</style>"))
    np.set_printoptions(precision=4, linewidth=150, suppress=True)

In [None]:
# ======= Sources ======= #
MODELLING_Q2 = '/media/veracrypt4/Q2/Modelling/Outlier.df'
RESULT_LOC = os.path.join(Const['Results.Scratch'], 'Modelling')
FIGURES = os.path.join(RESULT_LOC, 'Figures'); utils.make_dir(FIGURES)
RESULTS = os.path.join(RESULT_LOC, 'Analyse_HoldOut'); utils.make_dir(RESULTS)

Z_NORMAL = 7
PARAMS_NORMAL = os.path.join(RESULT_LOC, 'Analyse_Global', 'Params_Best.Global.Z7.jlib'
SCORES_NORMAL = os.path.join(RESULT_LOC, 'Analyse_Global', 'Scores_Best.Global.Z7.df'
BASELINE_MDL = os.path.join(RESULT_LOC, 'Analyse_Global', 'Baseline.Global.df')

# ==== Visualisations/Setups ==== #
BEH_ORDER = BORISParser.BEHAVIOURS(True).values()
BEH_NAMES = BORISParser.BEHAVIOURS(True, True).values()
PERM_ORDER = {''.join(name): order for name, order in zip(it.permutations('RGB', 3), it.permutations((0, 1, 2), 3))}

# clrs = list(mpl.colormaps['tab10'].colors); clrs.insert(10, mpl.colormaps['Set1'].colors[5]); clrs.insert(11, mpl.colormaps['Dark2'].colors[5])
# CAGE_CMAP = mpl.colors.ListedColormap(clrs, 'Custom')
MSE_CMAP = mpl.colors.ListedColormap(['red', 'green', 'blue'], 'RGB'); MSE_CMAP.set_bad('gainsboro')
COLOUR_MAPS = {'Z': 'Purples', 'R': 'Reds', 'G': 'Greens', 'B': 'Blues'}

VIS_HEIGHTS = {3:3.9, 6:5.7, 7:6.3}

# ==== Experimental Conditions ==== #
SEGMENT_BTIS = 30*60
RUN_SEGMENTS = 5 # How many segments per run

MAX_ITER = 150
RESTARTS = 50

Z_SEARCH = (2, 3, 4, 5, 6, 7)
Z_PARAMS = (3, 6, 7)
Z_TEMP = 6

N_JOBS = 8 if MPC_WORK else 60
RANDOM_STATE = 101
CONVERGE_TOL = 1e-5

# ==== Execution Control ==== #
ANALYSE_NORMALITY = True
ANALYSE_GLOBAL_Q2 = True

FIT_MODELS = False

FOR_PAPER = True

In [None]:
def reindex_run(rdf):
    first_seg = rdf.index.get_level_values('Segment').min()
    rdf['Time'] = (rdf.index.get_level_values('Segment') - first_seg) * SEGMENT_BTIS + rdf.index.get_level_values('BTI')
    rdf = rdf.set_index('Time', append=True).droplevel(('CageID', 'Run', 'Segment', 'BTI')).reorder_levels((1, 0)).rename_axis(('BTI', 'Mouse'))
    rdf = rdf.unstack(-1).reindex(np.arange(RUN_SEGMENTS * SEGMENT_BTIS))
    return rdf[BEH_ORDER].reorder_levels((1, 0), 1)[['R', 'G', 'B']]

def optimise_perm(X_c, mdl, normalise=False):
    """Find the best permutation for a cage given the current model"""
    lls = {}
    for perm, order in PERM_ORDER.items():
        lls[perm] = mdl.logpdf([x[:, order, :] for x in X_c], norm=normalise)
    best = pd.Series(lls).idxmax()
    return best, PERM_ORDER[best], lls[best]

def check_converged(lls):
    """Convergencence Check"""
    if len(lls) < 2:
        return False
    elif lls[-1] < lls[-2]:
        warnings.warn("Drop in Log-Likelihood Observed! Results are probably wrong.")
        return False
    else:
        return abs((lls[-1] - lls[-2]) / lls[-2]) < CONVERGE_TOL

def train_from_cage(X, c_id, z):
    """Run Training for the entire data starting from a particular cage (as base model) and for a particular dimensionality (sZ)"""
    # A. Setup
    rng = np.random.default_rng(RANDOM_STATE)  # Default Random Number Generator
    init_psi = [[np.ones(7)*.5 for _ in range(z)] for _ in range(3)] # Initialiser
    ll, perms = [], {}
    # B. Train Single Cage Model
    mdl = skext.CategoricalHMM(z, (3, 7), init_psi=init_psi, tol=CONVERGE_TOL, max_iter=MAX_ITER, inits=RESTARTS, random_state=rng, n_jobs=0).fit(X[c_id])
    # C. Train Global Model
    # C.I Perform Loop on Permutations
    while not check_converged(ll):
        # C.I.1 Find Best Permutation
        perm_pcage = {cid: optimise_perm(X_c, mdl) for cid, X_c in X.items()}
        utils.extend_dict(perms, {k: v[0] for k, v in perm_pcage.items()})
        # C.I.2 Organise data according to this permutation
        W_tr = [x[:, p_c[1], :] for _, (x_c, p_c) in utils.dzip(X, perm_pcage) for x in x_c]
        # C.I.3 Fit Model on all data
        mdl.fit_partial(W_tr)
        ll.append(mdl.Evolution[-1])
    # C.II Run final Permutation optimisation
    perm_final = {cid: optimise_perm(X_c, mdl, False)[0] for cid, X_c in X.items()}
    utils.extend_dict(perms, {k: v[0] for k, v in perm_pcage.items()})
    ll_final = np.sum([l for _, (_, _, l) in perm_final.items()])
    # D. Report Progress and return
    sys.stdout.write('>'); sys.stdout.flush()
    return (
        (z, c_id),                                                              # Indexing
        pd.Series({'LL': ll_final}, name=(z, c_id)),                            # Log-Likelihood (un-normalised)
        {'Pi': mdl.Pi, 'Psi': mdl.Psi, 'Omega': mdl.Omega, 'Perm': perm_final}, # Model Parameters
        {'Perm': perms, 'LL': ll}                                               # Evolutions
    )
        
def show_omega(omega, ax, stats=True):
    def _fmt_val(v):
        if v == 0:
            return ''
        elif v == 1:
            return '~1'
        else:
            return f'{v:.2f}'[1:]
    # - Note, that I only show off-diagonal elements (multiplied by 1000), with the reccurring probability on the side.
    # Prepare
    if stats:
        omega_stat = npext.markov_stationary(omega)
        omega_dwell = npext.markov_dwell(omega)
        x_labs=[f'{a}\n'+f'{s:.2f}\n'[1:]+f'{n:.0f}' for a, s, n in zip(sau, omega_stat, omega_dwell)]
    else:
        x_labs=sau[:len(omega)]
    y_labs=sau[:len(omega)]
    omega_vis = pd.DataFrame(omega.round(2)).applymap(_fmt_val).to_numpy() #
    # Plot
    sns.heatmap(np.zeros_like(omega), annot=omega_vis, cmap=[(0.9, 0.9, 0.9)], fmt='s', ax=ax, cbar=False, annot_kws={"size": 19}, linewidths=2,)
    ax.set_xticks(np.arange(0.5, len(x_labs) + 0.5)); ax.set_xticklabels(x_labs, rotation=0, fontsize=19)
    ax.set_yticks(np.arange(0.5, len(y_labs) + 0.5)); ax.set_yticklabels(y_labs, rotation=0, va="center", fontsize=19)
    ax.set_xlabel('$Z^{[t+1]}$', fontsize=22); ax.set_ylabel('$Z^{[t]}$', fontsize=22, rotation=0, labelpad=22, va='center')
    ax.set_title(r'$\Omega$', fontsize=22)
    if not stats:
        ax.set_aspect('equal')
        
def show_psi_k(psi_k, ax, full_names=True, prc=2, axis_labels=False):
    def _fmt_val(v):
        if v == 0:
            return ''
        elif v == 1:
            return '~1'
        else:
            return f'{v:.{prc}f}'[1:]
    psi_k_annot = pd.DataFrame(psi_k.round(prc)).applymap(_fmt_val).to_numpy() if prc > 0 else False
    mplext.plot_matrix(psi_k, mode='hinton', show_val=psi_k_annot, x_labels=BEH_NAMES if full_names else [b[0] for b in BEH_NAMES], y_labels=sau[:psi_k.shape[0]] if axis_labels else None , fmt='s', fs=20, x_rot=90 if full_names else 0, buffer=0.6, ax=ax)
    ax.set_xlabel(f'$X_{k+1}$', fontsize=20, labelpad=4)
    if axis_labels:
        ax.set_ylabel(f'$Z$', fontsize=22, rotation=0, va="center", labelpad=10)
    if full_names: ax.xaxis.set_tick_params(pad=-2)
    ax.set_title(f'$\Psi_{k+1}$', fontsize=22)
        
def summarise_omega(omega, ax):
    omega_stat = npext.markov_stationary(omega); omega_stat = pd.Series(omega_stat.round(2)).apply(lambda x: '~0' if x == 0 else f'{x:.2f}'[1:]).to_numpy()
    omega_dwell = npext.markov_dwell(omega).round(0).astype(int).astype(str)
    omega = np.vstack([omega_stat, omega_dwell]).T
    sns.heatmap(np.zeros(omega.shape), annot=omega, cmap=[(0.9, 0.9, 0.9)], fmt='s', ax=ax, cbar=False, annot_kws={"size": 19}, linewidths=2,)
    ax.set_xticks([0.5, 1.5]); ax.set_xticklabels(['SS', 'DT'], rotation=0, fontsize=19)
    ax.set_yticks([])

In [None]:
#  Load the Main Data
sys.stdout.write('Loading Behaviour Predictions ... '); sys.stdout.flush()
data = pd.read_pickle(MODELLING_Q2, compression='bz2')
x_len = data['ALM.Prob'][BEH_ORDER].dropna(how='any').groupby('CageID').size().rename('Samples')  # Keep track of Sizes
print('Done')

## 1. Analyse Normal Model on Q2 Data Globally

The aim here is to see the likelihood scores on the new cages (after fitting permutations)

### 1.1 Generate Scores

In [None]:
if ANALYSE_NORMALITY and FIT_MODELS:
    # 1. Load Model (and scores)
    mdl = joblib.load(PARAMS_NORMAL); mdl = skext.CategoricalHMM(sZ=mdl['Pi'], sKX=mdl['Psi'], omega=mdl['Omega'])
    bl = pd.read_pickle(BASELINE_MDL, compression='bz2').to_numpy()
    # 2. Format Data
    X = {cid: [reindex_run(rdf).to_numpy().reshape(RUN_SEGMENTS * SEGMENT_BTIS, 3, 7) for rid, rdf in cdf.groupby('Run')] for cid, cdf in data['ALM.Prob'].groupby('CageID')}
    # 3. Run Evaluation
    ll_Q2 = pd.Series({c_id: optimise_perm(X_c, mdl, False)[2] for c_id, X_c in X.items()}, name='LL').rename_axis('CageID')
    # 4. Compute BL scores
    bl_Q2 = data['ALM.Prob'][BEH_ORDER].dropna(how='any').groupby('CageID').apply(lambda cdf: (cdf.to_numpy() @ bl).sum()).rename('BL')
    # 5. Format
    scores_Q2 = ll_Q2.to_frame().join(x_len).join(bl_Q2)
    scores_Q2.LL /= scores_Q2.Samples; scores_Q2.BL /= scores_Q2.Samples
    scores_Q2 = scores_Q2[['LL', 'BL', 'Samples']].rename(columns={'LL': r'$\mathcal{L}_{\text{GM}}$', 'BL': r'$\mathcal{L}_{\text{BL}}$', 'Samples': 'S'})
    # 6. Store
    scores_Q2.to_pickle(os.path.join(RESULTS, f'Scores.Normality.Z{Z_NORMAL}.df'), compression='bz2')

In [None]:
if ANALYSE_NORMALITY and MPC_WORK:
    scores_Q1 = pd.read_pickle(SCORES_NORMAL, compression='bz2')
    scores_Q2 = pd.read_pickle(os.path.join(RESULTS, f'Scores.Normality.Z{Z_NORMAL}.df'), compression='bz2')
    scores_all = pd.concat([scores_Q1[r'$\mathcal{L}_{\text{GM}}$'], scores_Q2[r'$\mathcal{L}_{\text{GM}}$']], axis=1, keys=['Adult', 'Young'])
    scores_all = scores_all.rename(CAGE_SHORTHAND).sort_index()
     # Compute First Key Statistics
    ll = scores_all
    stats = pd.Series({ds: scstats.shapiro(ll[ds].dropna())[1] for ds in ['Adult', 'Young']}).to_frame('SW-Test')
    stats = stats.join(ll.mean(axis=0).to_frame('Mean').join(ll.std(axis=0).to_frame('St. Dev.')))
    summary = pd.concat([ll.T, stats], axis=1, keys=['Per-Cage', 'Statistics'])

### 1.2 Visualise (Table)

In [None]:
if ANALYSE_NORMALITY and MPC_WORK:
    # Compute T-Tests
    t_up = scstats.ttest_ind(ll['Adult'].dropna(), ll['Young'].dropna(), equal_var=False, alternative='two-sided')
    dof_up = npext.welch_dof(ll["Adult"].dropna(), ll["Young"].dropna())
    pairs = ll.dropna(); t_long = scstats.ttest_rel(pairs['Adult'], pairs['Young'], alternative='two-sided')
    dof_long = npext.welch_dof(pairs["Adult"], pairs["Young"])
    # Display
    print(summary['Per-Cage'].style.format(precision=2, na_rep="").to_latex(hrules=True, multicol_align='c', multirow_align='m'))
    print(summary['Statistics'].unstack().to_frame('Per-Subset').T.style.format(precision=2, na_rep="").to_latex(hrules=True, multicol_align='c', multirow_align='m'))
    print(f'\multirow[m]{{2}}{{*}}{{T-Tests}} & \multicolumn{{6}}{{l}}{{Independent: t-Statistic={t_up[0]:.2f} (\\textsl{{p}}-value$={t_up[1]:.2e}$, DoF$\\approx {dof_up:.1f}$)}}\\\\')
    print(f'                            & \multicolumn{{6}}{{l}}{{Paired: t-Statistic={t_long[0]:.2f} (\\textsl{{p}}-value$={t_long[1]:.2e}$, DoF$\\approx {dof_long:.1f}$)}}\\\\')

### 1.3 Visualise (Figure)

In [None]:
if ANALYSE_NORMALITY and MPC_WORK:
    # Prepare Axis
    fig, ax_s = plt.subplots(1, 1, figsize=[17, 4.5 if FOR_PAPER else 5.5], tight_layout=True); ax_mtr = ax_s.twinx() # 4.8 for Paper
    # Compute
    nll_all = summary['Per-Cage'].stack().rename_axis(['Age Group', 'Cage ID']).to_frame('NLL').reset_index().sort_values('Cage ID')
    nll_sorted = nll_all.sort_values('NLL')
    recall = (nll_sorted['Age Group'] == 'Young').cumsum() / (nll_sorted['Age Group'] == 'Young').sum()
    precision = (nll_sorted['Age Group'] == 'Young').cumsum() / np.arange(1, len(nll_all)+1)
    accuracy = ((nll_sorted['Age Group'] == 'Young').cumsum() + (nll_sorted['Age Group'] == 'Adult').sum() - (nll_sorted['Age Group'] == 'Adult').cumsum())/len(nll_all)
    fpr = (nll_sorted['Age Group'] != 'Young').cumsum() / (nll_sorted['Age Group'] != 'Young').sum()
    # Plot First the Scatter
    sns.scatterplot(nll_all, y='Cage ID', x='NLL', hue='Age Group', s=200, ax=ax_s)
    ax_s.tick_params(labelsize=20); ax_s.set_ylabel('Cage ID', fontsize=20); ax_s.set_xlabel(r'$\widehat{\mathcal{L}}$', fontsize=24)
    ax_s.grid('on', axis='y', lw=4, alpha=0.5)
    # Now Plot the Precision/Recall
    l_acc = ax_mtr.step(nll_sorted['NLL'], accuracy, label='Accuracy', c='k', lw=2.5)
    if FOR_PAPER:
        h, l = ax_s.get_legend_handles_labels()
        ax_s.legend([*h, l_acc[0]], [*l, 'Accuracy'], fontsize=18, markerscale=2, loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0)
        ax_mtr.set_ylabel('Accuracy     ', fontsize=20, labelpad=30, va='bottom', ha='center')
    else:
        ax_s.legend(fontsize=18, markerscale=2, title='Age Group   ', title_fontsize=20, loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0)
        ax_mtr.step(nll_sorted['NLL'], recall, '--', label='Recall', c='olive', lw=2.5)
        ax_mtr.step(nll_sorted['NLL'], precision, ':', label='Precision', c='magenta', lw=4, ms=10)
        ax_mtr.legend(fontsize=18, title='Metric        ', title_fontsize=20, handlelength=1.7, handletextpad=0.5, loc='lower left', bbox_to_anchor=[1.05, 0], borderaxespad=0)
        ax_mtr.set_ylabel(' Accuracy', fontsize=20, labelpad=30, va='bottom')
    ax_mtr.set_ylim(0.0, 1.05); ax_s.set_xlim(-1.72, -1.09)
    ax_mtr.tick_params(labelsize=20)
    # Save figure
    plt.savefig(os.path.join(FIGURES, 'fig_model_anomaly_roc.png'), bbox_inches='tight', dpi=150)
    print(skmetrics.auc(fpr, recall))

## 2. Global Model on Held-Out Data

I will analyse a global model.

### 2.1 Generate Statistics.

We need to consider that we have to iterate over: 
 1. $|Z|$ - not sure which one is best for this data anymore
 2. Initialiser (i.e. which cage)
 
However, we will not use Folds in this case: i.e. we will work with the Training Log-Likelihood (the hope is that we will be able to see a specific kink due to the multiple cages).

In [None]:
if ANALYSE_GLOBAL_Q2 and FIT_MODELS:
    print('Training Global Model on Q2')
    # 1. Format Data
    print(f' + Extracting Data : ', end='', flush=True)
    X = {cid: [reindex_run(rdf).to_numpy().reshape(RUN_SEGMENTS * SEGMENT_BTIS, 3, 7) for rid, rdf in cdf.groupby('Run')] for cid, cdf in data['ALM.Prob'].groupby('CageID')}
    print(f' Done!')
    # 2. Train Models
    print(f' + Running Models :{"         ".join("|"*int(1 + len(Z_SEARCH) * len(X)/10))}')
    print(f' + Running Models : ', end='', flush=True); s = tm.time()    
    # Iterate over initialisation point (try each cage)
    results = joblib.Parallel(n_jobs=N_JOBS, prefer='processes')(joblib.delayed(train_from_cage)(X, cid, sZ) for sZ, cid in it.product(reversed(Z_SEARCH), X.keys()))
    print(f'* Done! [{utils.show_time(tm.time() - s)}]')
    # Format the Data & store
    print(f' + Formatting : ', end='', flush=True)
    scores = pd.concat([res[1] for res in results], axis=1).T.rename_axis(index=('|Z|', 'Init Model'))
    scores['Samples'] = x_len.sum(); scores['NLL'] = scores['LL']/scores['Samples']
    scores.to_pickle(os.path.join(RESULTS, f'Scores.HeldOut.Z_All.df'), compression='bz2')
    params = {sz: {cid: cz_param for (s_z, cid), _, cz_param, _ in results if s_z == sz} for (sz, _), _, _, _ in results}
    joblib.dump(params, os.path.join(RESULTS, f'Params.HeldOut.Z_All.jlib'), compress=True)
    evos = {sz: {cid: cz_evos for (s_z, cid), _, _, cz_evos in results if s_z == sz} for (sz, _), _, _, _ in results}
    joblib.dump(evos, os.path.join(RESULTS, f'Evolutions.HeldOut.Z_All.jlib'), compress=True)
    print(f' Done!\n-------------------')

In [None]:
if ANALYSE_GLOBAL_Q2:
    scores = pd.read_pickle(os.path.join(RESULTS, f'Scores.HeldOut.Z_All.df'), compression='bz2')
    params = joblib.load(os.path.join(RESULTS, f'Params.HeldOut.Z_All.jlib'))
    evos = joblib.load(os.path.join(RESULTS, f'Evolutions.HeldOut.Z_All.jlib'))

### 2.2 Choose $|Z|$

This is based on evaluating the best in each $|Z|$.

In [None]:
if ANALYSE_GLOBAL_Q2 and MPC_WORK:
    fig, ax = plt.subplots(1, 1, figsize=[18 if FOR_PAPER else 9.5, 5], tight_layout=True)
    nll = scores['NLL'].rename(CAGE_SHORTHAND, level=1).sort_index().reset_index()
    sns.lineplot(nll, x='|Z|', y='NLL', errorbar=None, linestyle=':', marker='P', ms=13, lw=2, ax=ax)
    nll['|Z|'] += nll['Init Model'].map({'D': -0.175, 'G': -0.125, 'H': -0.075, 'I': -0.025, 'K': 0.025, 'L': 0.075, 'M': 0.125, 'Y': 0.175})
    sns.scatterplot(nll, x='|Z|', y='NLL', hue='Init Model', s=100, ax=ax)
    ax.tick_params(labelsize=20); ax.set_xlabel('|Z|', fontsize=20); ax.set_ylabel(r'$\widehat{\mathcal{L}}$', fontsize=30, rotation=0, labelpad=20)
    elem = [*ax.get_legend_handles_labels()[0], *ax.get_lines()]
    lbls = [*ax.get_legend_handles_labels()[1], 'Mean']
    ax.legend(elem, lbls, fontsize=17, markerscale=1.5, ncol=9 if FOR_PAPER else 5, handlelength=1.8, handletextpad=0.2, columnspacing=0.8, borderaxespad=0.2, title='Initial Model', title_fontsize=17)
    plt.savefig(os.path.join(FIGURES, 'fig_model_Q2_ll_Z_All.png'), bbox_inches='tight', dpi=200)

### 2.3 Visualise Parameters

In [None]:
if ANALYSE_GLOBAL_Q2 and MPC_WORK:
    for z in Z_PARAMS:
        z_pars = params[z][scores.loc[z, 'LL'].idxmax()]
        omega, psi = z_pars['Omega'], z_pars['Psi']
        if FOR_PAPER:
            fig, axs = plt.subplots(2, 2, figsize=[VIS_HEIGHTS[z]*1.8, VIS_HEIGHTS[z]*1.8 - 1.5], tight_layout=True); axs=axs.ravel()
        else:
            fig, axs = plt.subplots(1, 4, figsize=[(15+z/2), VIS_HEIGHTS[z]], tight_layout=True, gridspec_kw={'width_ratios':[z,7,7,7]})
        show_omega(omega, axs[0], stats=not FOR_PAPER)
        for k, ax in enumerate(axs[1:]):
            show_psi_k(psi[k, :, :], ax, not FOR_PAPER, 1 if FOR_PAPER else -1, FOR_PAPER)
        plt.tight_layout(w_pad=0)
        plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_Q2_params_M{scores.loc[z, "LL"].idxmax()}_Z{z}.png'), bbox_inches='tight', dpi=150); #plt.close()    

### 2.4 Temporal Ethograms

In [None]:
if ANALYSE_GLOBAL_Q2 and MPC_WORK:
    best_par = params[Z_TEMP][scores.loc[Z_TEMP, 'LL'].idxmax()]
    chmm = skext.CategoricalHMM(best_par['Pi'], best_par['Psi'], best_par['Omega'])
    progress = ProgressBar(len(data.index.unique('Run'))).reset('Plotting:')
    for cid, cdf in data.groupby('CageID'):
        perm = best_par['Perm'][cid][0] # Get the order for this Cage
        for rid, rdf in cdf.droplevel(0, axis=0).groupby(by='Run'):
            # Re-Index the run
            first_seg = rdf.index.get_level_values('Segment').min()
            rdf[('Sensors', 'Time')] = (rdf.index.get_level_values('Segment') - first_seg) * SEGMENT_BTIS + rdf.index.get_level_values('BTI')
            rdf = rdf.set_index(('Sensors', 'Time'), append=True).droplevel(('Run', 'Segment', 'BTI')).reorder_levels((1, 0)).rename_axis(('BTI', 'Mouse'))
            rdf = rdf.unstack(-1).reindex(np.arange(RUN_SEGMENTS * SEGMENT_BTIS))
            # Extract Data
            r_X = rdf['ALM.Prob'][BEH_ORDER].reorder_levels((1, 0), 1)[[*perm]]
            r_Z = chmm.predict_proba([r_X.to_numpy().reshape(SEGMENT_BTIS * RUN_SEGMENTS, 3, 7).astype('float', order='C')])[0].T
            r_L = rdf['Sensors']['Light']['R'].astype(float).to_numpy()[np.newaxis, :]
            # Set up Figure
            fig, axs = plt.subplots(5, 1, figsize=[25, 4+Z_TEMP], tight_layout=True, sharex=True, gridspec_kw={'height_ratios':[0.8, Z_TEMP-2, 5, 5, 5]})
            #  i) Plot Light Status
            cmap = mpl.colormaps['gray']; cmap.set_bad('tan')
            axs[0].imshow(r_L, cmap=cmap, aspect='auto'); axs[0].set_yticks([0]); axs[0].set_yticklabels(['Light'], fontsize=22)
            # ii) Plot the sequence of Latent States
            cmap = mpl.colormaps[COLOUR_MAPS['Z']]; cmap.set_bad('gray')
            axs[1].imshow(r_Z, cmap=cmap, aspect='auto', interpolation='none')
            axs[1].set_yticks(np.arange(len(r_Z))); axs[1].set_yticklabels(sau[:len(r_Z)], fontsize=22)
            axs[1].set_ylabel('Z', fontsize=25, rotation=0, va='center', labelpad=80)
            # iii) Plot each of the three mouse behaviours.
            for k, m in enumerate(perm):
                cmap = mpl.colormaps[COLOUR_MAPS[m]]; cmap.set_bad('gray')
                m_X = r_X[m].to_numpy().T
                axs[2+k].imshow(m_X, cmap=cmap, aspect='auto', interpolation='none')
                axs[2+k].set_yticks(np.arange(7)); axs[2+k].set_yticklabels(BEH_NAMES, fontsize=22)
                axs[2+k].set_ylabel(f'$X_{k+1}$', fontsize=25, va='center', rotation=0, labelpad=20)
            # iv) Commonalities
            time_ticks = np.arange(0, RUN_SEGMENTS * SEGMENT_BTIS+1, 600)
            axs[-1].set_xticks(time_ticks); axs[-1].set_xticklabels(utils.time_list(time_ticks*1000., fmt='%H:%M'), fontsize=22, ha='center')
            axs[-1].set_xlabel('Time (Hrs:Mins)', fontsize=18)
            # Save Figure
            plt.tight_layout(h_pad=0)
            plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_q2_temporal_R{rid}_C{cid}_Z{Z_TEMP}.png'), bbox_inches='tight', dpi=200)
            plt.close()
            progress.update()