# Evaluation of STAE Experiments

This notebook connects to MLflow, downloads all experiment runs and creates visualizations.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict
from pathlib import Path
from pprint import pprint

import dagshub
import matplotlib.pyplot as plt
import pandas as pd
from mlflow.client import MlflowClient

from model.fonts import set_cmu_typewriter_font
from model.latex import get_caption, pivot_table_to_latex, pivotize_drivers
from model.mlflow import download_all_runs, download_predictions, load_predictions
from model.plot import plot_results

In [None]:
REPO_NAME = 'driver-stae'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)  # type: ignore

pd.set_option('display.max_columns', None)

client = MlflowClient()

plt.style.use('seaborn-v0_8-whitegrid')
font = set_cmu_typewriter_font()
plt.rc('font', size=34)

## Get all experiment runs

In [None]:
experiments = client.search_experiments()
pprint([experiment.name for experiment in experiments])

In [None]:
runs_df = download_all_runs(client=client, experiments=experiments)

# Display the DataFrame
print(f'Total runs: {len(runs_df)}')
runs_df.head(10)

In [None]:
source_type_map = {
    'masks': 'Mask',
    'depth': 'MDE',
    'images': 'RGB',
    'rgb': 'RGB',
    'rgbd': 'RGBD',
    'rgbdm': 'RGBDM',
}
source_type_color_map = {
    'Mask': 'tab:orange',
    'MDE': 'tab:blue',
    'RGB': 'tab:green',
    'RGBD': 'tab:red',
    'RGBDM': 'tab:purple',
}
source_type_linestyle_map = {
    'MDE': '-',
    'Mask': '--',
    'RGB': '-.',
    'RGBD': ':',
    'RGBDM': '-',
}
driver_name_mapping = {'dans': 1, 'geordi': 2, 'jakub': 3, 'michal': 4, 'poli': 5}

OUTPUT_DIR = Path('outputs')
OUTPUT_DIR.mkdir(exist_ok=True)

## Filtering

In [None]:
integer_columns = ['image_size', 'batch_size', 'early_stopping']
runs_df[integer_columns] = runs_df[integer_columns].astype(int)

In [None]:
df_selection = runs_df[
    (runs_df['image_size'] == 64)
    & (runs_df['tag.Commit ID'] == '36865282077c25364392dcb25cfea2e3e09b1edb')
    & (runs_df['tag.Dataset'] != 'dmd')
]

In [None]:
idx = df_selection.groupby(['driver', 'source_type', 'image_size'])[
    'metric.roc_auc'
].idxmax()
df_mrl_stae = df_selection.loc[idx]
df_mrl_stae[
    [
        'driver',
        'source_type',
        'metric.roc_auc',
        'metric.pr_auc',
        'early_stopping',
        'patience',
        'best_metric',
    ]
]

In [None]:
df_mrl_stae_pivot = pivotize_drivers(
    df_mrl_stae,
    source_type_map=source_type_map,
    driver_name_mapping=driver_name_mapping,
)
df_mrl_stae_pivot.head()

In [None]:
pivot_table_to_latex(
    df_mrl_stae_pivot,
    path=OUTPUT_DIR / 'stae_mrl_pivot.tex',
    caption=get_caption('STAE', 'MRL'),
    label='tab:stae-mrl-pivot',
)

## Download predictions

In [None]:
df_mrl_stae = download_predictions(client=client, df=df_mrl_stae)
data_mrl_stae = load_predictions(df_mrl_stae, source_type_map=source_type_map)

## Visualizations

In [None]:
plt.rc('font', size=17)

plot_kwargs = dict(
    source_type_color_map=source_type_color_map,
    source_type_linestyle_map=source_type_linestyle_map,
    driver_name_mapping=driver_name_mapping,
    fig_height_multiplier=5,
    fig_width_multiplier=3.6,
    n_rows=1,
    linewidth=2,
    legend_outside=True,
)

In [None]:
plot_results('roc', data_mrl_stae, save_path='outputs/roc_auc.pdf', **plot_kwargs)

In [None]:
plot_results('pr', data_mrl_stae, save_path='outputs/pr_auc.pdf', **plot_kwargs)

## Recalculate MSE and MAE metrics

In [None]:
import copy

from model.eval import compute_best_roc_auc

redata = defaultdict(dict)
data = copy.deepcopy(data_mrl_stae)

# iqr = (0.00, 0.95)
iqr = (0.00, 1.00)

drivers = list(data.keys())
source_types = list(data[list(data.keys())[0]].keys())
pprint(source_types)
pprint(drivers)

for driver in drivers:
    for source_type in source_types:
        x = copy.deepcopy(data[driver][source_type])
        res = compute_best_roc_auc(
            x['y_true'],
            x['errors'],
            iqr=iqr,
            metric='mae',
        )
        redata[driver][source_type] = x
        redata[driver][source_type].update(res)
        y = redata[driver][source_type]

In [None]:
plot_results('roc', redata, save_path='outputs/roc_auc_mae.pdf', **plot_kwargs)

In [None]:
plot_results('pr', redata, save_path='outputs/pr_auc_mae.pdf', **plot_kwargs)

In [None]:
for driver in drivers:
    for source_type in source_types:
        x = copy.deepcopy(data[driver][source_type])
        res = compute_best_roc_auc(
            x['y_true'],
            x['errors'],
            iqr=iqr,
            metric='mse',
        )
        redata[driver][source_type] = x
        redata[driver][source_type].update(res)
        y = redata[driver][source_type]

In [None]:
plot_results('roc', redata, save_path='outputs/roc_auc_mse.pdf', **plot_kwargs)

In [None]:
plot_results('pr', redata, save_path='outputs/pr_auc_mse.pdf', **plot_kwargs)