# Explore and Model Behaviour distributions

## 0. Scope
This analysis the global CHMM model with permutation.

### 0.1 Requirements
 * `MODELLING_DF`: Modelling Datasets as generated using `Build_Modelling_Dataset.ipynb`
 * `PAIRWISE_DF` : Results from the pairwise modelling (as generated using Model_CHMM)
 * `PARAMS_DF`   : Model training parameters for single-cage data.

In [None]:
from mpctools.extensions import mplext, npext, utils, skext, pdext
from string import ascii_uppercase as sau
from mpctools.parallel import ProgressBar
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import matplotlib as mpl
import itertools as it
import seaborn as sns
import pandas as pd
import numpy as np
import time as tm
import joblib
import sys
import os

# Add the Project Directories to the path
sys.path.append('../../../../')

# Add own Tools
from Scripts.Constants import Const, CAGE_SHORTHAND
from Tools.Parsers import BORISParser

# Location Logic
MPC_WORK = os.uname().nodename == 'MPCWork'

# Display Options
if MPC_WORK:
    display(HTML("<style>.container { width:95% !important; }</style>"))
    np.set_printoptions(precision=4, linewidth=150, suppress=True)

In [None]:
# ======= Sources ======= #
MODELLING_DF = '/media/veracrypt4/Q1/Modelling'

RESULT_LOC = os.path.join(Const['Results.Scratch'], 'Modelling')
PAIRWISE_DF = os.path.join(RESULT_LOC, 'Analyse_CHMM', 'LL.CHMM.Z{:d}.df')
PARAMS_DF = os.path.join(RESULT_LOC, 'Analyse_CHMM', 'Params.jlib')

FIGURES = os.path.join(RESULT_LOC, 'Figures'); utils.make_dir(FIGURES)
RESULTS = os.path.join(RESULT_LOC, 'Analyse_Global'); utils.make_dir(RESULTS)

# ==== Visualisations/Setups ==== #
BEH_ORDER = BORISParser.BEHAVIOURS(True).values()
BEH_NAMES = BORISParser.BEHAVIOURS(True, True).values()
PERM_ORDER = {''.join(name): order for name, order in zip(it.permutations('RGB', 3), it.permutations((0, 1, 2), 3))}

clrs = list(mpl.colormaps['tab10'].colors); clrs.insert(10, mpl.colormaps['Set1'].colors[5]); clrs.insert(11, mpl.colormaps['Dark2'].colors[5])
CAGE_CMAP = mpl.colors.ListedColormap(clrs, 'Custom')
MSE_CMAP = mpl.colors.ListedColormap(['red', 'green', 'blue'], 'RGB'); MSE_CMAP.set_bad('gainsboro')

COLOUR_MAPS = {'Z': 'Purples', 'R': 'Reds', 'G': 'Greens', 'B': 'Blues'}

VIS_HEIGHTS = {3:3.9, 6:5.7, 7:6.3}

FOR_PAPER = False

# ==== Experimental Conditions ==== #
SEGMENT_BTIS = 30*60
RUN_SEGMENTS = 5 # How many segments per run

MAX_ITER = 150

MODEL_INIT = 'A'
Z_GLOBAL = 7
MODEL_GLOBAL = None 

N_JOBS = 8 if MPC_WORK else 60
RANDOM_STATE = 101
CONVERGE_TOL = 1e-5

# ==== Execution Control ==== #
ANALYSE_BEST_MODEL = True
FIT_OUTLIER_MODEL = False
ANALYSE_OUTLIER_FITS = False
FIT_GLOBAL_MODEL = False
ANALYSE_GLOBAL_MODEL = True
VISUALISE_TEMPORAL = False

In [None]:
def reindex_run(rdf):
    first_seg = rdf.index.get_level_values('Segment').min()
    rdf['Time'] = (rdf.index.get_level_values('Segment') - first_seg) * SEGMENT_BTIS + rdf.index.get_level_values('BTI')
    rdf = rdf.set_index('Time', append=True).droplevel(('CageID', 'Run', 'Segment', 'BTI')).reorder_levels((1, 0)).rename_axis(('BTI', 'Mouse'))
    rdf = rdf.unstack(-1).reindex(np.arange(RUN_SEGMENTS * SEGMENT_BTIS))
    return rdf[BEH_ORDER].reorder_levels((1, 0), 1)[['R', 'G', 'B']]

def optimise_perm(X_c, mdl):
    """Find the best permutation for a cage given the current model"""
    lls = {}
    for perm, order in PERM_ORDER.items():
        lls[perm] = mdl.logpdf([x[:, order, :] for x in X_c], norm=False)
    best = pd.Series(lls).idxmax()
    return best, PERM_ORDER[best], lls[best]

def check_converged(lls):
    """Convergencence Check"""
    if len(lls) < 2:
        return False
    elif lls[-1] < lls[-2]:
        warnings.warn("Drop in Log-Likelihood Observed! Results are probably wrong.")
        return False
    else:
        return abs((lls[-1] - lls[-2]) / lls[-2]) < CONVERGE_TOL
        
def train_model(p_init, _X):
    """Convenience wrapper for training a global model"""
    # Initialise model
    mdl = skext.CategoricalHMM(sZ=p_init['Pi'], sKX=p_init['Psi'], omega=p_init['Omega'], max_iter=1)
    # Start the Optimisation
    ll, perms = [], {}
    while not check_converged(ll):
        # Find best permutation for each cage (and store)
        perm_pcage = {cid: optimise_perm(X_c, mdl) for cid, X_c in _X.items()}
        utils.extend_dict(perms, {k: v[0] for k, v in perm_pcage.items()})
        # Organise data according to this permutation
        W_tr = [x[:, p_c[1], :] for _, (x_c, p_c) in utils.dzip(_X, perm_pcage) for x in x_c]
        # Fit Model
        mdl.fit_partial(W_tr)
        # Append Likelihood evolution
        ll.append(mdl.Evolution[-1])
    # Resolve best model
    perm_final = {cid: optimise_perm(X_c, mdl) for cid, X_c in _X.items()}
    # Return 
    return mdl.Pi, mdl.Psi, mdl.Omega, utils.extend_dict(perms, {k: v[0] for k, v in perm_final.items()}), np.sum([l for _, (_, _, l) in perm_final.items()]), ll

def check_outlier(v_id, m_id, p_init, X_tr, X_val):
    # Run Training
    pi, psi, omega, perm, t_ll, ll_evo = train_model(p_init, X_tr)
    # Evaluate on Validation Cage & Format
    mdl = skext.CategoricalHMM(sZ=pi, sKX=psi, omega=omega)
    v_perm, v_order, v_ll = optimise_perm(X_val, mdl)
    perm_final = {m_id: p[-1] for m_id, p in perm.items()}; perm_final[v_id] = v_perm
    # Also, show progress
    sys.stdout.write('>'); sys.stdout.flush()
    # Return
    return (
        (v_id, m_id),                                                  # Key
        pd.Series({'Train': t_ll, 'Eval': v_ll}, name=(v_id, m_id)),   # Log-Likelihood for Comparison
        {'Pi': pi, 'Psi': psi, 'Omega': omega, 'Perm': perm_final},    # Model Parameters 
        {'Perm': perm, 'LL': ll_evo}                                   # Evolutions
    )

def train_global_model(m_id, p_init, X):
    # Run Training
    pi, psi, omega, perm, t_ll, ll_evo = train_model(p_init, X)
    # Format Permutations (and show progress)
    perm_final = {m_id: p[-1] for m_id, p in perm.items()}
    sys.stdout.write('>'); sys.stdout.flush()
    # Return
    return (
        m_id,
        pd.Series({'LL': t_ll}, name=m_id),
        {'Pi': pi, 'Psi': psi, 'Omega': omega, 'Perm': perm_final},    # Model Parameters
        {'Perm': perm, 'LL': ll_evo}                                   # Evolutions
    )

def show_omega(omega, ax, stats=True):
    def _fmt_val(v):
        if v == 0:
            return ''
        elif v == 1:
            return '~1'
        else:
            return f'{v:.2f}'[1:]
    # - Note, that I only show off-diagonal elements (multiplied by 1000), with the reccurring probability on the side.
    # Prepare
    if stats:
        omega_stat = npext.markov_stationary(omega)
        omega_dwell = npext.markov_dwell(omega)
        x_labs=[f'{a}\n'+f'{s:.2f}\n'[1:]+f'{n:.0f}' for a, s, n in zip(sau, omega_stat, omega_dwell)]
    else:
        x_labs=sau[:len(omega)]
    y_labs=sau[:len(omega)]
    omega_vis = pd.DataFrame(omega.round(2)).applymap(_fmt_val).to_numpy() #
    # Plot
    sns.heatmap(np.zeros_like(omega), annot=omega_vis, cmap=[(0.9, 0.9, 0.9)], fmt='s', ax=ax, cbar=False, annot_kws={"size": 19}, linewidths=2,)
    ax.set_xticks(np.arange(0.5, len(x_labs) + 0.5)); ax.set_xticklabels(x_labs, rotation=0, fontsize=19)
    ax.set_yticks(np.arange(0.5, len(y_labs) + 0.5)); ax.set_yticklabels(y_labs, rotation=0, va="center", fontsize=19)
    ax.set_xlabel('$Z^{[t+1]}$', fontsize=22); ax.set_ylabel('$Z^{[t]}$', fontsize=22, rotation=0, labelpad=22, va='center')
    ax.set_title(r'$\Omega$', fontsize=22)
    if not stats:
        ax.set_aspect('equal')
        
def show_psi_k(psi_k, ax, full_names=True, prc=2, axis_labels=False):
    def _fmt_val(v):
        if v == 0:
            return ''
        elif v == 1:
            return '~1'
        else:
            return f'{v:.{prc}f}'[1:]
    psi_k_annot = pd.DataFrame(psi_k.round(prc)).applymap(_fmt_val).to_numpy() if prc > 0 else False
    mplext.plot_matrix(psi_k, mode='hinton', show_val=psi_k_annot, x_labels=BEH_NAMES if full_names else [b[0] for b in BEH_NAMES], y_labels=sau[:psi_k.shape[0]] if axis_labels else None , fmt='s', fs=20, x_rot=90 if full_names else 0, buffer=0.6, ax=ax)
    ax.set_xlabel(f'$X_{k+1}$', fontsize=20, labelpad=4)
    if axis_labels:
        ax.set_ylabel(f'$Z$', fontsize=22, rotation=0, va="center", labelpad=10)
    if full_names: ax.xaxis.set_tick_params(pad=-2)
    ax.set_title(f'$\Psi_{k+1}$', fontsize=22)
        
def summarise_omega(omega, ax):
    omega_stat = npext.markov_stationary(omega); omega_stat = pd.Series(omega_stat.round(2)).apply(lambda x: '~0' if x == 0 else f'{x:.2f}'[1:]).to_numpy()
    omega_dwell = npext.markov_dwell(omega).round(0).astype(int).astype(str)
    omega = np.vstack([omega_stat, omega_dwell]).T
    sns.heatmap(np.zeros(omega.shape), annot=omega, cmap=[(0.9, 0.9, 0.9)], fmt='s', ax=ax, cbar=False, annot_kws={"size": 19}, linewidths=2,)
    ax.set_xticks([0.5, 1.5]); ax.set_xticklabels(['SS', 'DT'], rotation=0, fontsize=19)
    ax.set_yticks([])

In [None]:
#  1. Load the Main Data
if FIT_OUTLIER_MODEL or FIT_GLOBAL_MODEL or VISUALISE_TEMPORAL or ANALYSE_BEST_MODEL:
    sys.stdout.write('Loading Behaviour Predictions ... '); sys.stdout.flush()
    data = pd.concat([pd.read_pickle(os.path.join(MODELLING_DF, f'{k}.df'), compression='bz2') for k in ('Train', 'Validate', 'Test')])
    #  2. Train also the baseline model (and store)
    bl_global = np.log(npext.sum_to_one(data['ALM.Prob'][BEH_ORDER].sum().to_numpy()))
    pd.Series(bl_global, index=BEH_ORDER).to_pickle(os.path.join(RESULTS, 'Baseline.Global.df'), compression='bz2')
    lls_bl = data['ALM.Prob'][BEH_ORDER].dropna(how='any').groupby('CageID').apply(lambda cdf: (cdf.to_numpy() @ bl_global).sum())
    #  3. Get also the lengths (for upscaling)
    x_len = data['ALM.Prob'][BEH_ORDER].dropna(how='any').groupby('CageID').size()
    print('Done')
    
if ANALYSE_BEST_MODEL or ANALYSE_GLOBAL_MODEL:
    # Finally Compute Log-Likelihood
    ll_hmm = pd.read_pickle(PAIRWISE_DF.format(Z_GLOBAL), compression='bz2')

## 1. Analyse Feasibility of Global Model

### 1.1 Matrix of Models for Cages

This has a two-fold purpose:
 * Show that all cages are quite similar, and hence can be grouped together.
 * Pick the model to investigate for permutation peakedness.

In [None]:
if ANALYSE_BEST_MODEL and MPC_WORK:
    # Compute Best and then subtract baseline
    ll_best = (ll_hmm.groupby(by='CageID', axis=1).max().rename(index=CAGE_SHORTHAND) * x_len - lls_bl)
    # Create Figure
    fig, axs = plt.subplots(2, 2, figsize=(12, 7), sharex='col', sharey='row', tight_layout=True, gridspec_kw={'height_ratios': [12, 1], 'width_ratios': [12, 1]})
    # First overall
    ax = axs[0, 0]
    mplext.plot_matrix((ll_best / x_len).to_numpy(), mode='heatmap', y_labels=ll_best.index, show_val=True, fs=17, fmt='.3f', cax=False, ax=ax, hm_args={'cmap': 'Blues'})
    ax.set_ylabel('Model', fontsize=17)
    # Mean Across Cages (per-Model)
    ax = axs[0, 1]
    mplext.plot_matrix((ll_best.sum(axis=1) / x_len.sum()).to_numpy()[:, np.newaxis], mode='heatmap', y_labels=ll_best.index, show_val=True, fs=17, fmt='.3f', cax=False, ax=ax, hm_args={'cmap': 'Blues'})
    ax.set_title('Mean', fontsize=17)
    # Mean Across Models (per-Cage)
    ax = axs[1, 0]
    mplext.plot_matrix((ll_best.mean(axis=0) / x_len).to_numpy()[np.newaxis, :], mode='heatmap', x_labels=ll_best.index, y_labels=['Mean'], show_val=True, fs=17, fmt='.3f', cax=False, ax=ax, hm_args={'cmap': 'Blues'})
    ax.set_xlabel('Cage ID', fontsize=17);
    # Turn off other axis
    axs[1, 1].axis('off')
    # Common & Save
    plt.tight_layout(h_pad=0.3, w_pad=0.3)
    plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_mdl_x_cage_Z{Z_GLOBAL}.png'), bbox_inches='tight', dpi=150)

### 1.2 Peakedness of Permutations

This is the posterior over which permutation matrix to use. We use bayes'rule:
$$P(q|X) = \frac{\xi_q P\left(Q_q^{-1}X\right)}{\sum_{q'}\xi_{q'}P\left(Q_{q'}^{-1}X\right)}$$

Now note:
 * We assume a uniform prior $P(q)=\xi_q$ hence its contribution can be ignored.
 * We cannot work with normalised versions of the $P\left(\tilde{X}\right)$ since the normalisation is a power in probability space, and hence must multiply it out. However, we cannot simply sum the probabilities (underflow) and hence, will use the log-sum-exp trick.

In [None]:
if ANALYSE_BEST_MODEL and MPC_WORK:
    # Compute Probabilities
    lls_perm = ll_hmm.loc[MODEL_INIT].unstack(0) * x_len
    lls_perm = pd.DataFrame(npext.sum_to_one(np.exp(lls_perm - lls_perm.max()).to_numpy(), axis=0), index=lls_perm.index, columns=lls_perm.columns).T
    probs_W = lls_perm.rename(index=CAGE_SHORTHAND).sort_index(axis=1, ascending=False)
    # Visualise
    fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
    mplext.plot_matrix(probs_W.to_numpy(), mode='heatmap', x_labels=probs_W.columns, y_labels=probs_W.index, show_val=True, fs=17, fmt='.2f', cax=False, ax=ax, hm_args={'cmap': 'Blues'})
    ax.set_xlabel('Permutation', fontsize=18); ax.set_ylabel('Cage', fontsize=18)
    # Save Figure
    plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_peakedness_C{MODEL_INIT}_Z{Z_GLOBAL}.png'), bbox_inches='tight', dpi=150)

## 2. Fit Outlier Model (in cross-validation)

The premise here is to use leave-one(cage)-out cross-validation in training and then evaluating on the other cage.
In short, the procedure is as follows:

 * For each 'Test' Cage (the one we are going to class as outlier or not):
     1. Train Model on Other Cages:
         * Iterate over all Possible Training Cages as Initial Models:
             1. Find best permutation for each cage to the current parameters
             2. Run one step of EM using this permutation.
             3. Evaluate Observable Likelihood.
             4. Repeat from A. until Likelihood does not change
     2. Evaluate likelihood on Test Cage.

We can also envision replacing step A with multiple restarts, one from each cage in the training set, and then keeping the best model (or possibly computing multiple points for the test cage)

In [None]:
if FIT_OUTLIER_MODEL:
    print('Generating Statistics for outlier detection ... (will take some time)')
    # Extract Data
    print(f' + Extracting : ', end='', flush=True)
    X_all = {cid: [reindex_run(rdf).to_numpy().reshape(RUN_SEGMENTS * SEGMENT_BTIS, 3, 7) for rid, rdf in cdf.groupby('Run')] for cid, cdf in data['ALM.Prob'].groupby('CageID')}
    # Get also Model Parameters
    P_all = {cid: utils.subdict(cpar[Z_GLOBAL], ('Pi', 'Psi', 'Omega')) for cid, cpar in joblib.load(PARAMS_DF).items()}
    # Create Data Splits
    parallel_split = {}
    for v_id in P_all.keys():
        # Split Data
        t_ids = set(P_all.keys()).difference({v_id})
        parallel_split[v_id] = {
            'X_trn': utils.subdict(X_all, t_ids),
            'X_val': X_all[v_id],
            'P_init': utils.subdict(P_all, t_ids)
        }
    print(f' Done!')
    # Run Cross-Validation
    print(f' + Running CV :{"         ".join("|"*int(1 + (len(P_all)*(len(P_all)-1))/10))}')
    print(f' + Running CV : ', end='', flush=True); s = tm.time()    
    # Iterate over initialisation point (try each cage)
    results = joblib.Parallel(n_jobs=N_JOBS, prefer='processes')(joblib.delayed(check_outlier)(v_id, m_id, m_par, v_setup['X_trn'], v_setup['X_val']) for v_id, v_setup in parallel_split.items() for m_id, m_par in v_setup['P_init'].items())
    print(f'* Done! [{utils.show_time(tm.time() - s)}]')
    # Format the Data & store
    print(f' + Formatting : ', end='', flush=True)
    scores = pd.concat([res[1] for res in results], axis=1).T.rename_axis(index=('Test Cage', 'Init Model'))
    scores = scores.join(x_len.rename_axis('Test Cage').rename('Samples'))
    scores.to_pickle(os.path.join(RESULTS, f'Scores.Outlier.Z{Z_GLOBAL}.df'), compression='bz2')
    params = {v_id: {m_id: vm_param for (vid, m_id), _, vm_param, _ in results if v_id == vid} for (v_id, _), _, _, _ in results}
    joblib.dump(params, os.path.join(RESULTS, f'Params.Outlier.Z{Z_GLOBAL}.df'), compress=True)
    evos = {v_id: {m_id: vm_evos for (vid, m_id), _, _, vm_evos in results if v_id == vid} for (v_id, _), _, _, _ in results}
    joblib.dump(evos, os.path.join(RESULTS, f'Evolutions.Outlier.Z{Z_GLOBAL}.df'), compress=True)
    print(f' Done!\n-------------------')

## 3. Analyse Outliers

In [None]:
if ANALYSE_OUTLIER_FITS and MPC_WORK:
    # Load Data
    scores = pd.read_pickle(os.path.join(RESULTS, f'Scores.Outlier.Z{Z_GLOBAL}.df'), compression='bz2')
    evos = joblib.load(os.path.join(RESULTS, f'Evolutions.Outlier.Z{Z_GLOBAL}.df'))

### 3.1 Analyse Log-Likelihoods

(Always, relative to Baseline)

#### 3.1.1 Training Statistics

In [None]:
if ANALYSE_OUTLIER_FITS and MPC_WORK:
    # Compute
    #  This needs some 'twerking' to be able to get the training-set baseline and length: the trick is to subtract the validation cage (LL or length) from the total over all cages
    train_samples = scores['Samples'].groupby(by='Test Cage').first().sum() - scores['Samples']
    train_loglike = lls_bl.sum() - lls_bl.rename_axis('Test Cage')
    train_ll = ((scores['Train'] - train_loglike) / train_samples).unstack(-1).rename(index=CAGE_SHORTHAND, columns=CAGE_SHORTHAND)
    train_ll_best = train_ll.max(axis=1)
    train_ll_mean = train_ll.mean(axis=1)
    # Show as figure
    fig, axs = plt.subplots(1, 3, figsize=[14, 6], tight_layout=True, sharey=False, gridspec_kw={'width_ratios': [12, 1, 1]})
    cmap = mpl.colormaps['Blues']; cmap.set_bad('gray')
    # 1) Overall Matrix
    ax = axs[0]
    mplext.plot_matrix(train_ll.to_numpy(), mode='heatmap', x_labels=train_ll.columns, y_labels=train_ll.index, show_val=True, fmt='.3f', fs=17, cax=False, ax=ax, hm_args={'cmap': cmap})
    ax.set_xlabel('Initial Model', fontsize=18); ax.set_ylabel('Test Cage', fontsize=18)
    # 2) Best
    ax = axs[1]
    mplext.plot_matrix(train_ll_best.to_numpy()[:, np.newaxis], mode='heatmap', x_labels=[], y_labels=[], show_val=True, fmt='.3f', fs=17, cax=False, ax=ax, hm_args={'cmap': cmap})
    ax.set_xlabel('Best', fontsize=18, labelpad=20)
    # 3) Mean
    ax = axs[2]
    mplext.plot_matrix(train_ll_mean.to_numpy()[:, np.newaxis], mode='heatmap', x_labels=[], y_labels=[], show_val=True, fmt='.3f', fs=17, cax=False, ax=ax, hm_args={'cmap': cmap})
    ax.set_xlabel('Mean', fontsize=18, labelpad=20)
    # Save Figure
    plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_outlier_traincages_Z{Z_GLOBAL}.png'), bbox_inches='tight', dpi=150)

#### 3.1.2 Validation Statistics (including Difference from Training)

In [None]:
if ANALYSE_OUTLIER_FITS and MPC_WORK:
    # Compute
    val_ll = ((scores['Eval'] - lls_bl.rename_axis('Test Cage')) / scores['Samples']).unstack(-1).rename(index=CAGE_SHORTHAND, columns=CAGE_SHORTHAND)
    val_ll_best = val_ll.max(axis=1)
    val_ll_mean = val_ll.mean(axis=1)
    diff_ll = (train_ll - val_ll) / train_ll * 100
    diff_ll_best = diff_ll.to_numpy()[np.arange(len(diff_ll)), np.nanargmin(np.abs(diff_ll.to_numpy()), axis=1)]
    diff_ll_mean = diff_ll.mean(axis=1)
    # Show as figure
    fig, axs = plt.subplots(1, 5, figsize=[16, 5.5], tight_layout=True, sharey=False, gridspec_kw={'width_ratios': [12, 1, 1, 1, 1]})
    cmap = mpl.colormaps['Blues']; cmap.set_bad('gray')
    # 1) Overall Matrix
    ax = axs[0]
    mplext.plot_matrix(val_ll.to_numpy(), mode='heatmap', x_labels=val_ll.columns, y_labels=val_ll.index, show_val=True, fmt='.3f', fs=17, cax=False, ax=ax, hm_args={'cmap': cmap})
    ax.set_xlabel('Initial Model', fontsize=18, labelpad=10); ax.set_ylabel('Test Cage', fontsize=18)
    # 2) Best
    ax = axs[1]
    mplext.plot_matrix(val_ll_best.to_numpy()[:, np.newaxis], mode='heatmap', x_labels=['Best'], y_labels=[], show_val=True, fmt='.3f', fs=17, cax=False, ax=ax, hm_args={'cmap': cmap})
    # 3) Mean
    ax = axs[2]
    mplext.plot_matrix(val_ll_mean.to_numpy()[:, np.newaxis], mode='heatmap', x_labels=['Mean'], y_labels=[], show_val=True, fmt='.3f', fs=17, cax=False, ax=ax, hm_args={'cmap': cmap})
    # 4) Best Drop
    ax = axs[3]
    mplext.plot_matrix(diff_ll_best[:, np.newaxis], mode='heatmap', x_labels=['Best'], y_labels=[], show_val=True, fmt='.2f', fs=17, cax=False, ax=ax, hm_args={'cmap': 'bwr'})
    # 5) Best Mean
    ax = axs[4]
    mplext.plot_matrix(diff_ll_mean.to_numpy()[:, np.newaxis], mode='heatmap', x_labels=['Mean'], y_labels=[], show_val=True, fmt='.2f', fs=17, cax=False, ax=ax, hm_args={'cmap': 'bwr'})
    # 6) Global Stuff
    axs[1].set_xlabel('Statistics', fontsize=18, ha='left', labelpad=10); axs[3].set_xlabel(r'$\hat{\mathcal{L}}_{train} - \hat{\mathcal{L}}_{valid}$', fontsize=18, ha='left', x=0.3)
    # Save Figure
    plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_outlier_testcage_Z{Z_GLOBAL}.png'), bbox_inches='tight', dpi=150)

### 3.2 Analyse Evolutions

In [None]:
if ANALYSE_OUTLIER_FITS and MPC_WORK:
    progress = ProgressBar(len(evos), prec=2).reset('Plotting')
    for v_id, t_evo in evos.items():
        # Compute
        ll_evos = {CAGE_SHORTHAND[m_id]: pd.Series(m_evo['LL'], name='LL')/1e6 for m_id, m_evo in t_evo.items()}
        ll_evos = pd.concat(ll_evos).rename_axis(('Model Cage', 'Iter')).reset_index()
        perm_evos = {CAGE_SHORTHAND[m_id]: pd.Series(m_evo['Perm'][m_id], name='Perm') for m_id, m_evo in t_evo.items()}
        perm_evos = pd.DataFrame(perm_evos).stack().str.split('', expand=True)[[1, 2, 3]]
        perm_evos = perm_evos.applymap(lambda x: {'R': 0, 'G': 1, 'B': 2}[x]).unstack(-1).reorder_levels((1, 0), 1).T.sort_index()
        # Plot
        fig, axs = plt.subplots(12, 1, figsize=[10, 10], tight_layout=True, sharey=False, sharex=False, gridspec_kw={'height_ratios': [12, *([1]*11)]})
        # 1) LL Evolution
        ax = axs[0]
        sns.lineplot(data=ll_evos, x='Iter', y='LL', hue='Model Cage', lw=3, palette=CAGE_CMAP.colors, ax=ax)
        ax.tick_params(labelsize=25); ax.set_ylabel(r'Log-Likelihood ($\times 10^{6}$)', fontsize=25); ax.set_xticks([]); ax.set_xlabel(None)
        ax.legend(fontsize=22, ncol=4, title='Cage ID', title_fontsize=25, borderaxespad=0.2, handlelength=1.8, handletextpad=0.5, columnspacing=1.5)
        # 2) Now Plot the Individual Permutation Evolutions
        for ax, (m_id, m_perm) in zip(axs[1:], perm_evos.groupby(level=0)):
            mplext.plot_matrix(m_perm.to_numpy(), mode='heatmap', show_val=False, y_labels=[], x_labels=[], cax=False, ax=ax, hm_args={'cmap': MSE_CMAP})
            ax.set_yticks([1.5]); ax.set_yticklabels([m_id], fontsize=22, rotation=0, va='center')
        # 3) Global
        axs[6].set_ylabel('Permutation', fontsize=25, labelpad=70)
        axs[-1].tick_params(labelsize=25); axs[-1].set_xlabel('Iteration', fontsize=25)
        axs[-1].set_xticks(np.arange(0, ll_evos.groupby('Model Cage').size().max(), 2 if len(ll_evos['Iter'].unique()) < 15 else 3))
        axs[-1].set_xticklabels(np.arange(0, ll_evos.groupby('Model Cage').size().max(), 2 if len(ll_evos['Iter'].unique()) < 15 else 3))
        plt.tight_layout(h_pad=0)
        # Save and Update
        plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_outlier_evos_V{v_id}_Z{Z_GLOBAL}.png'), bbox_inches='tight', dpi=150); plt.close()
        progress.update()

## 4. Fit Global Model

This just fits the model once on selected Cages (or all if required)

### 4.1 Fit Models

In [None]:
if FIT_GLOBAL_MODEL:
    print('Generating Statistics for Global Model ... (will take some time)')
    # Extract Data
    print(f' + Extracting : ', end='', flush=True)
    selection = data.loc[MODEL_GLOBAL, 'ALM.Prob'] if MODEL_GLOBAL is not None else data.loc[:, 'ALM.Prob']
    X_all = {cid: [reindex_run(rdf).to_numpy().reshape(RUN_SEGMENTS * SEGMENT_BTIS, 3, 7) for rid, rdf in cdf.groupby('Run')] for cid, cdf in selection.groupby('CageID')}
    # Get also Model Parameters
    P_all = {cid: utils.subdict(cpar[Z_GLOBAL], ('Pi', 'Psi', 'Omega')) for cid, cpar in joblib.load(PARAMS_DF).items()}
    if MODEL_GLOBAL is not None:
        P_all = utils.subdict(P_all, MODEL_GLOBAL)
    print(f' Done!')
    # Run Cross-Validation
    print(f' + Running CV :{"         ".join("|"*int(1 + len(P_all)/10))}')
    print(f' + Running CV : ', end='', flush=True); s = tm.time()    
    # Iterate over initialisation point (try each cage)
    results = joblib.Parallel(n_jobs=N_JOBS, prefer='processes')(joblib.delayed(train_global_model)(m_id, m_par, X_all) for m_id, m_par in P_all.items())
    print(f'* Done! [{utils.show_time(tm.time() - s)}]')
    # Format the Data & store
    print(f' + Formatting : ', end='', flush=True)
    scores = pd.concat([res[1] for res in results], axis=1).T.rename_axis(index=('Init Model'))
    scores['Samples'] = (x_len.loc[MODEL_GLOBAL] if MODEL_GLOBAL is not None else x_len).sum(); scores['NLL'] = scores['LL']/scores['Samples']
    scores.to_pickle(os.path.join(RESULTS, f'Scores.{"Inlier" if MODEL_GLOBAL is not None else "Global"}.Z{Z_GLOBAL}.df'), compression='bz2')
    params = {m_id: m_param for m_id, _, m_param, _ in results}
    joblib.dump(params, os.path.join(RESULTS, f'Params.{"Inlier" if MODEL_GLOBAL is not None else "Global"}.Z{Z_GLOBAL}.jlib'), compress=True)
    evos = {m_id: m_evos for m_id, _, _, m_evos in results}
    joblib.dump(evos, os.path.join(RESULTS, f'Evolutions.{"Inlier" if MODEL_GLOBAL is not None else "Global"}.Z{Z_GLOBAL}.jlib'), compress=True)
    print(f' Done!\n-------------------')

### 4.2 Generate some Statistics

In [None]:
if FIT_GLOBAL_MODEL:
    # B) Compute the Score on the Validation Data
    #  1. Prepare Data
    best_mdl = scores['NLL'].idxmax()
    X_vld = {cid: [reindex_run(rdf).to_numpy().reshape(RUN_SEGMENTS * SEGMENT_BTIS, 3, 7) for rid, rdf in cdf.groupby('Run')] for cid, cdf in data['ALM.Prob'].groupby('CageID')}
    #  2. Prepare Model
    g_mdl = params[best_mdl]
    g_mdl = skext.CategoricalHMM(sZ=g_mdl['Pi'], sKX=g_mdl['Psi'], omega=g_mdl['Omega'])
    #  3. Run Evaluation (generating also per-cage readings)
    a_eval = {}
    pc = ll_hmm.groupby(by='CageID', axis=1).max()
    for v_id, v_X in X_vld.items():
        a_eval[v_id] = {
            r'$\mathcal{L}_{\text{GM}}$': optimise_perm(v_X, g_mdl)[2] / x_len[v_id],
            r'$\mathcal{L}_{\text{PC}}$': pc.loc[v_id, v_id],
            r'$\mathcal{L}_{\text{BL}}$': lls_bl[v_id]/x_len[v_id],
            'S': x_len[v_id]
        }
    #  4. Format and Store
    summary = pd.DataFrame(a_eval).T; summary['S'] = summary['S'].astype(int)
    joblib.dump(params[best_mdl], os.path.join(RESULTS, f'Params_Best.Global.Z{Z_GLOBAL}.C{best_mdl}.jlib'), compress=True)
    summary.to_pickle(os.path.join(RESULTS, f'Scores_Best.Global.Z{Z_GLOBAL}.C{best_mdl}.df'), compression='bz2')

## 5. Analyse Global Model

In [None]:
if ANALYSE_GLOBAL_MODEL:
    # A) Load Data
    scores = pd.read_pickle(os.path.join(RESULTS, f'Scores.{"Inlier" if MODEL_GLOBAL is not None else "Global"}.Z{Z_GLOBAL}.df'), compression='bz2')
    params = joblib.load(os.path.join(RESULTS, f'Params.{"Inlier" if MODEL_GLOBAL is not None else "Global"}.Z{Z_GLOBAL}.jlib'))
    evos = joblib.load(os.path.join(RESULTS, f'Evolutions.{"Inlier" if MODEL_GLOBAL is not None else "Global"}.Z{Z_GLOBAL}.jlib'))
    
    # B) Statistics on Training Data
    #  1. Best Model (including initialiser)
    best_mdl, best_nll = scores['NLL'].idxmax(), scores['NLL'].max()
    #  2. Mean/Std
    mean_nll, std_nll = scores['NLL'].mean(), scores['NLL'].std()
    #  3. Format and visualise
    t_eval = pd.Series({'Best': best_nll, 'Mean': mean_nll, 'St.Dev': std_nll})
    display(t_eval)
    display(f'Best Model = {best_mdl} [{CAGE_SHORTHAND[best_mdl]}]')

### 5.1 Quality of Fit

In [None]:
if ANALYSE_GLOBAL_MODEL and MPC_WORK:
    # A) Load and Display
    summary = pd.read_pickle(os.path.join(RESULTS, f'Scores_Best.Global.Z{Z_GLOBAL}.C{best_mdl}.df'), compression='bz2')
    display(summary.round(3).rename(CAGE_SHORTHAND).T)
    # B) Compute RDL (and display)
    l_gm = summary[r'$\mathcal{L}_{\text{GM}}$'] - summary[r'$\mathcal{L}_{\text{BL}}$']
    l_pc = summary[r'$\mathcal{L}_{\text{PC}}$'] - summary[r'$\mathcal{L}_{\text{BL}}$']
    rdl = ((l_gm - l_pc) * 100 /l_pc).abs().rename('RDL')
    l_rdl = pd.concat([summary[r'$\mathcal{L}_{\text{GM}}$'], rdl], axis=1).rename(CAGE_SHORTHAND).T
    print(l_rdl.style.format(precision=2).to_latex(hrules=True))
    print(f'Mean RDL: {rdl.mean():.2f}')

### 5.2 Analyse Evolution

This is the evolution of the global model (for all initialisations)

In [None]:
if ANALYSE_GLOBAL_MODEL and MPC_WORK:
    # A. Compute
    ll_evos = pd.concat({CAGE_SHORTHAND[m_id]: pd.Series(m_evo['LL']) for m_id, m_evo in evos.items()}, names=['Init Model', 'Iter'])
    ll_evos = (ll_evos / scores.rename(CAGE_SHORTHAND)['Samples']).rename('LL').reset_index()
    perm_evos = pd.concat({m_id: pdext.diff(pd.DataFrame(m_evo['Perm']), 1, True).sum(axis=1) for m_id, m_evo in evos.items()}, axis=1)
    perm_evos = perm_evos.T.rename(CAGE_SHORTHAND)
    # B. Find some common stats
    n_cages = len(utils.default(MODEL_GLOBAL, perm_evos.index.unique(0)))
    n_iters = ll_evos.groupby('Init Model').size().max()
    max_dif = perm_evos.max().max()
    # C. Plot
    fig, axs = plt.subplots(1+n_cages, 1, figsize=[10, 10], tight_layout=True, sharey=False, sharex=False, gridspec_kw={'height_ratios': [12, *([1]*n_cages)]})
    #  1) LL Evolution
    ax = axs[0]
    sns.lineplot(data=ll_evos, x='Iter', y='LL', hue='Init Model', lw=4, palette=CAGE_CMAP.colors, ax=ax)
    ax.tick_params(labelsize=22); ax.set_ylabel(r'$\widehat{\mathcal{L}}$', fontsize=25, labelpad=10); ax.set_xticks([]); ax.set_xlabel(None); ax.set_xlim([-0.5, n_iters - 0.5])
    ax.legend(fontsize=22, ncol=4, title='Initialiser', title_fontsize=25, borderaxespad=0.2, handlelength=1.8, handletextpad=0.5, columnspacing=1.5, markerscale=2)
    #  2) Now Plot the Individual Permutation Evolutions
    for ax, (m_id, m_perm) in zip(axs[1:], perm_evos.groupby(level=0)):
        m_perm.iloc[0].plot.bar(ax=ax, fontsize=22, width=(0.5 if n_iters < 15 else 0.9)); ax.bar(pdext.idx_where(m_perm.iloc[0], pd.isna), max_dif+1, width=1, facecolor='white', hatch="xx")
        ax.set_ylim([0, max_dif + 0.5]); ax.set_xlim([-0.5, n_iters + 0.5]); ax.set_xticks([])
        ax.set_yticks([(max_dif+.5)/2]); ax.set_yticklabels([m_id], fontsize=22, rotation=0, va='center')
    #  3) Global
    axs[int(n_cages/2)].set_ylabel(f'Changes in $Q$ (max {max_dif:.0f})', fontsize=24, labelpad=55, va='bottom')
    axs[-1].tick_params(labelsize=22, rotation=0); axs[-1].set_xlabel('Iteration', fontsize=25)
    axs[-1].set_xticks(np.arange(0, n_iters+1, (2 if n_iters < 15 else 5)))
    axs[-1].set_xticklabels(np.arange(0, n_iters+1, 2 if n_iters < 15 else 5))
    plt.tight_layout(h_pad=0)
    # Save and Update
    plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_global_evos_Z{Z_GLOBAL}.png'), bbox_inches='tight', dpi=150)

### 5.3 Show Parameters

This is again for all models.

In [None]:
if ANALYSE_GLOBAL_MODEL and MPC_WORK:
    # Report Immobility for the Best Model
    prob_imm = params[best_mdl]['Psi'][:, :, 0] @ npext.markov_stationary(params[best_mdl]['Omega'])
    print(f'|Z|={Z_GLOBAL}: ' + ' '.join(f'IMM_{k+1}={prob_imm[k]:.2f}' for k in range(3)))
    for m_id, m_pars in params.items():
        omega, psi = m_pars['Omega'], m_pars['Psi']
        if FOR_PAPER:
            fig, axs = plt.subplots(2, 2, figsize=[VIS_HEIGHTS[Z_GLOBAL]*1.8, VIS_HEIGHTS[Z_GLOBAL]*1.8 - 1.5], tight_layout=True); axs=axs.ravel()
        else:
            fig, axs = plt.subplots(1, 4, figsize=[(15+Z_GLOBAL/2), VIS_HEIGHTS[Z_GLOBAL]], tight_layout=True, gridspec_kw={'width_ratios':[Z_GLOBAL,7,7,7]})
        show_omega(omega, axs[0], stats=not FOR_PAPER)
        for k, ax in enumerate(axs[1:]):
            show_psi_k(psi[k, :, :], ax, not FOR_PAPER, 1 if FOR_PAPER else -1, FOR_PAPER)
        plt.tight_layout(w_pad=0)
        plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_global_params_M{m_id}_Z{Z_GLOBAL}.png'), bbox_inches='tight', dpi=150)
        plt.close()

### 5.4 Temporal Progression

In [None]:
if VISUALISE_TEMPORAL and MPC_WORK:
    best_par = joblib.load(os.path.join(RESULTS, f'Params_Best.Global.Z{Z_GLOBAL}.C{best_mdl}.jlib'))
    chmm = skext.CategoricalHMM(best_par['Pi'], best_par['Psi'], best_par['Omega'])
    progress = ProgressBar(len(data.index.unique('Run'))).reset('Plotting:')
    for cid, cdf in data.groupby('CageID'):
        perm = best_par['Perm'][cid] # Get the order for this Cage
        for rid, rdf in cdf.droplevel(0, axis=0).groupby(by='Run'):
            # Re-Index the run
            first_seg = rdf.index.get_level_values('Segment').min()
            rdf[('Sensors', 'Time')] = (rdf.index.get_level_values('Segment') - first_seg) * SEGMENT_BTIS + rdf.index.get_level_values('BTI')
            rdf = rdf.set_index(('Sensors', 'Time'), append=True).droplevel(('Run', 'Segment', 'BTI')).reorder_levels((1, 0)).rename_axis(('BTI', 'Mouse'))
            rdf = rdf.unstack(-1).reindex(np.arange(RUN_SEGMENTS * SEGMENT_BTIS))
            # Extract Data
            r_X = rdf['ALM.Prob'][BEH_ORDER].reorder_levels((1, 0), 1)[[*perm]]
            r_Z = chmm.predict_proba([r_X.to_numpy().reshape(SEGMENT_BTIS * RUN_SEGMENTS, 3, 7).astype('float', order='C')])[0].T
            r_L = rdf['Sensors']['Light']['R'].astype(float).to_numpy()[np.newaxis, :]
            # Set up Figure
            fig, axs = plt.subplots(5, 1, figsize=[25, 4+Z_GLOBAL], tight_layout=True, sharex=True, gridspec_kw={'height_ratios':[0.8, Z_GLOBAL-2, 5, 5, 5]})
            #  i) Plot Light Status
            cmap = mpl.colormaps['gray']; cmap.set_bad('tan')
            axs[0].imshow(r_L, cmap=cmap, aspect='auto'); axs[0].set_yticks([0]); axs[0].set_yticklabels(['Light'], fontsize=22)
            # ii) Plot the sequence of Latent States
            cmap = mpl.colormaps[COLOUR_MAPS['Z']]; cmap.set_bad('gray')
            axs[1].imshow(r_Z, cmap=cmap, aspect='auto', interpolation='none')
            axs[1].set_yticks(np.arange(len(r_Z))); axs[1].set_yticklabels(sau[:len(r_Z)], fontsize=22)
            axs[1].set_ylabel('Z', fontsize=25, rotation=0, va='center', labelpad=80)
            # iii) Plot each of the three mouse behaviours.
            for k, m in enumerate(perm):
                cmap = mpl.colormaps[COLOUR_MAPS[m]]; cmap.set_bad('gray')
                m_X = r_X[m].to_numpy().T
                axs[2+k].imshow(m_X, cmap=cmap, aspect='auto', interpolation='none')
                axs[2+k].set_yticks(np.arange(7)); axs[2+k].set_yticklabels(BEH_NAMES, fontsize=22)
                axs[2+k].set_ylabel(f'$X_{k+1}$', fontsize=25, va='center', rotation=0, labelpad=20)
            # iv) Commonalities
            time_ticks = np.arange(0, RUN_SEGMENTS * SEGMENT_BTIS+1, 600)
            axs[-1].set_xticks(time_ticks); axs[-1].set_xticklabels(utils.time_list(time_ticks*1000., fmt='%H:%M'), fontsize=22, ha='center')
            axs[-1].set_xlabel('Time (Hrs:Mins)', fontsize=18)
            # Save Figure
            plt.tight_layout(h_pad=0)
            plt.savefig(os.path.join(RESULT_LOC, 'Figures', f'fig_model_global_temporal_R{rid}_C{cid}_Z{Z_GLOBAL}.png'), bbox_inches='tight', dpi=200)
            plt.close()
            progress.update()