# Format Model Outputs for Evaluation

This script:
 * Visualises the progression of validation-set Accuracy/Score as per Thesis Report
 * Formats the outputs of the model for evaluation (in new data format) (***N.B.***: This can only be used to format for Tuning and not End2End!)

In [None]:
# General Libraries
from mpctools.extensions import utils, npext, mplext
from IPython.display import display, HTML
from scipy.special import softmax
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import json
import sys
import os

# Add the Project Directories to the path
sys.path.append('../../../../')

# Add specific project tools
from Scripts.Constants import Const

# Finally Display Options
display(HTML("<style>.container { width:95% !important; }</style>"))
pd.set_option('display.max_columns', 50)

In [None]:
# === Data === #
HIDDEN = 0

# === Paths === #
BASE_DATA = '/media/veracrypt4/Q1/Snippets/Curated/Behaviour'
BASE_RESULTS = os.path.join(Const['Results.Scratch'], 'Behaviour')
GT_DATA = os.path.join(BASE_DATA, 'Common', 'AVA.Data.df')

TRAINING_EVOLUTION = 'Training_Evolution'

# === Execution Control === #
# ------ Evolution ------ #
VISUALISE_EVOLUTION_LFB = {
    # CONFIG LFB.A
    # This is a combination of previous runs for display in the Report
    'R=5e-4 V=11:8 Ag=[-CR]': 'LFB/Config_A/LFB_C11_S08_L5e-4.json',                  # Formerly Cfg 3
    'R=5e-4 V=13:4 Ag=[-CR]': 'LFB/Config_A/LFB_C13_S04_L5e-4.json',                  # Formerly Cfg 3
    'R=5e-4 V=4:16 Ag=[ECR]': 'LFB/Config_A/train_lfb_50_16_L5e-4_W10_RJ_DCE.json',   # Formerly Cfg 2
    'R=5e-4 V=4:16 Ag=[-CR]': 'LFB/Config_A/train_lfb_50_16_L5e-4_W10_RJ.json',       # Formerly Cfg 1
    'R=5e-4 V=4:16 Ag=[-C-]': 'LFB/Config_A/train_lfb_50_16_L5e-4_W10_J.json',        # Formerly Cfg 1
    'R=5e-4 V=4:16 Ag=[---]': 'LFB/Config_A/train_lfb_50_16_L5e-4_W10_NA.json',       # Formerly Cfg 1
    
}

VISUALISE_EVOLUTION_STLT_CACNF = {
    'BBox     R=5e-6 L=36:3 V=--:-'      : 'STLT/train_stlt_36_3_L5e-6.Fixed.log',       # Formerly STLT/Config_5
    'BBox+VIS R=1e-7 L=36:3 V=12:1'      : 'STLT/train_cacnf_12+1_1e-7_Raw.log',         # Formerly CACNF/Config_3
    'BBox+VIS R=1e-7 L=36:3 V=12:1 {DCE}': 'STLT/train_cacnf_12+1_1e-7_DCE.log',         # Formerly CACNF/Config_4
    'BBox+VIS R=1e-7 L=36:3 V=12:2'      : 'STLT/train_cacnf_36+3_12s2_1e-7.log',        # Formerly CACNF/Config_5
    'BBox+VIS R=1e-7 L=36:3 V=24:3'      : 'STLT/train_cacnf_36+3_24s3_1e-7.log',        # Formerly CACNF/Config_5
    'BBox+VIS R=1e-7 L=36:3 V=12:2 {Mse}': 'STLT/train_cacnf_bbx_v12s2_128_10_1e-7.log', # Formerly CACNF/Config_6
}

# ------ Format ------ #
FORMAT_LFB = [
    # For Fixed (Tuning) Data
    ('Features/Raw/Fixed.Train.csv', 'Features/Formatted/LFB.Fixed.Train.df'),  # New SOTA
    ('Features/Raw/Fixed.Validate.csv', 'Features/Formatted/LFB.Fixed.Validate.df'),
    ('Features/Raw/Fixed.Test.csv', 'Features/Formatted/LFB.Fixed.Test.df'),
    # For Folds (Tuning) Data
    *[(f'Features/Raw/Folds.{f}.csv', f'Features/Formatted/LFB.Folds.{f:02d}.df') for f in range(1, 15)], # 
]

In [None]:
def split_sample_info(df, bti_col=3, mse_col=4):
    # Get Index and split
    idx = df[0].str.split('_', expand=True)
    
    # Update each part
    df['CageID'] = idx[0].astype(int)
    df['Segment'] = idx[1].astype(int)
    df['Snippet'] = idx[2].astype(int)
    _cols = ['CageID', 'Segment', 'Snippet']
    
    if bti_col is not None:
        df['BTI'] = idx[bti_col].astype(int)
        _cols.append('BTI')
        
    if mse_col is not None:
        df['Mouse'] = idx[mse_col]
        _cols.append('Mouse')
    return df.drop(columns=[0]).set_index(_cols)

## 1. Visualise Evolution

I visualise the evolution of various models.

### 1.1 Evolution of LFB Models

In [None]:
if len(VISUALISE_EVOLUTION_LFB) > 0:
    fig, ax = plt.subplots(1, 1, figsize=[18, 5], tight_layout=True)
    for schedule, log_file in VISUALISE_EVOLUTION_LFB.items():
        try:
            with open(os.path.join(BASE_RESULTS, TRAINING_EVOLUTION, log_file), 'r') as fin:
                epochs = {e['epoch']: e['mAP@0.5IOU'] for e in [json.loads(line) for line in fin][1:] if e['mode'] == 'val'}
            epochs = pd.Series(epochs, name='mAP')
            ax.plot(epochs.index, epochs.values, 'o-', label=schedule)
            print(f'Best Performance for {schedule: <15} -> {epochs.max():.03f} (@{epochs.idxmax(): 3d})')
        except FileNotFoundError as fnfe:
            print(f'Warning: Could not find JSON for {schedule}')
    plt.legend(fontsize=23, ncol=2, prop={'family': 'monospace', 'size': 22}, handlelength=1.8, handletextpad=0.5, borderaxespad=0.2, columnspacing=1.5); plt.xticks(fontsize=23); plt.yticks(fontsize=23)
    plt.xlabel('Epochs', fontsize=23); plt.ylabel('mAP @ IoU=0.5', fontsize=23)
    plt.xlim([0, 51]); plt.ylim([0.21, 0.51])
#     plt.title('Validation-Set Performance', fontsize=20)
    plt.savefig(os.path.join(BASE_RESULTS, 'Figures', 'fig_beh_lfb_evolution.png'), bbox_inches='tight', dpi=150)

### 1.2 Evolution of STLT Models

In [None]:
if len(VISUALISE_EVOLUTION_STLT_CACNF) > 0:
    fig, axs = plt.subplots(1, 1, figsize=[18, 5], tight_layout=True, sharey=True)
    for schedule, log_file in VISUALISE_EVOLUTION_STLT_CACNF.items():
        # Extract Data
        accuracies = []
        _stlt = 'VIS' not in schedule
        with open(os.path.join(BASE_RESULTS, TRAINING_EVOLUTION, log_file), 'r') as fin:
            for line in filter(lambda l: 'INFO:root:' in l, fin):
                if _stlt:
                    if 'top1/stlt' in line:
                        accuracies.append(float(line.split()[1]))
                else:
                    if 'top1/caf' in line:
                        accuracies.append(float(line.split()[1]))
        # Plot
        accuracies = np.asarray(accuracies[:min(len(accuracies), 50)])/100
        axs.plot(np.arange(1, len(accuracies)+1), accuracies, 'o-', label=schedule)
        print(f'Best Performance for {schedule: <25} -> {max(accuracies):.02f} (@{np.argmax(accuracies)+1: 3d})')
    axs.tick_params(axis='both', which='major', labelsize=23)
    axs.set_xlabel('Epochs', fontsize=23)
    axs.legend(fontsize=23, ncol=2, prop={'family': 'monospace', 'size': 17});
    axs.set_ylabel('Accuracy', fontsize=23)
#     axs.set_title('Performance on Validation Set', fontsize=23)
    plt.xlim([0, 51])
    plt.savefig(os.path.join(BASE_RESULTS, 'Figures', 'fig_beh_stlt_evolution.png'), bbox_inches='tight', dpi=150)

## 2. Prepare Outputs for Evaluation

This prepares the output in a consistent format for evaluation. 

This has been stripped down to operate only on the LFB Models since the dataset changed: the STLT/CACNF should use the formerly generated ones.
For these, to get the identity, we need to join with the base detections.

In [None]:
# Load Ground-Truth Data
#   -> Note that I do Filtering here, since this is meant to only be used for the LFB Computations.
gts = pd.read_pickle(GT_DATA, compression='bz2')
gts = gts[(gts['GT.Behaviour'] != HIDDEN) & (gts['GT.Source'] == 'A')]
gts = gts[[0, 1, 2, 3]].round(3).set_index([0, 1, 2, 3], append=True).reset_index('Mouse')

# Iterate over Samples
for raw_in, clean_out in FORMAT_LFB:
    # Load the CSV
    csv = pd.read_csv(os.path.join(BASE_RESULTS, raw_in), header=None)
    csv = split_sample_info(csv, None, None).set_index([1, 2, 3, 4, 5, 6], append=True).unstack(-1)
    csv = csv.droplevel(0, axis=1).rename_axis(columns='').rename_axis(index={1: 'BTI', 2: 0, 3: 1, 4: 2, 5: 3})
    # Join together to store
    #   Since we do a left join on the CSV, then this will ignore `Hidden` mice for which we do not have predictions if need be.
    preds = csv.join(gts).set_index('Mouse', append=True).reset_index([0, 1, 2, 3], drop=True)
    preds.to_pickle(os.path.join(BASE_RESULTS, clean_out), compression='bz2')