# Evaluation of TAE Experiments

This notebook connects to MLflow, downloads all experiment runs and creates visualizations.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
from collections import defaultdict
from pathlib import Path
from pprint import pprint

import dagshub
import matplotlib.pyplot as plt
import pandas as pd
from mlflow.client import MlflowClient
from tqdm import tqdm

from model.fonts import set_cmu_typewriter_font
from model.mlflow import download_all_runs
from model.plot import plot_results

In [None]:
REPO_NAME = 'driver-tae'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)  # type: ignore

pd.set_option('display.max_columns', None)

client = MlflowClient()

plt.style.use('seaborn-v0_8-whitegrid')
font = set_cmu_typewriter_font()

## Get all experiment runs

In [None]:
experiments = client.search_experiments(filter_string="name!='Default'")
pprint([experiment.name for experiment in experiments])

In [None]:
runs_df = download_all_runs(client=client, experiments=experiments)

# Display the DataFrame
print(f'Total runs: {len(runs_df)}')
runs_df.head()

## Filtering

In [None]:
integer_columns = ['image_size', 'latent_dim', 'batch_size', 'early_stopping']
runs_df[integer_columns] = runs_df[integer_columns].astype(int)

In [None]:
runs_df = runs_df[runs_df['batch_size'] == 256]
runs_df = runs_df[runs_df['dataset'] == 'dmd']
runs_df = runs_df[runs_df['encoder_name'] != 'EfficientNetEncoder']

In [None]:
idx = runs_df.groupby(['driver', 'source_type', 'latent_dim', 'image_size'])[
    'metric.roc_auc'
].idxmax()
best_runs_df = runs_df.loc[idx]

In [None]:
df = best_runs_df[
    (best_runs_df['image_size'] == 64) & (best_runs_df['latent_dim'] == 128)
]
print(df.shape)
df[
    [
        'driver',
        'source_type',
        'metric.roc_auc',
        'early_stopping',
        'patience',
        'min_epochs',
        'best_metric',
        'encoder_name',
    ]
]

## Download predictions

In [None]:
df = df.assign(local_path=None)
local_root = Path.cwd() / 'outputs' / 'mlflow_artifacts'
artifact_dir = 'outputs/'

# Loop through each row in the dataframe
for index, row in tqdm(df.iterrows(), total=len(df)):
    run_id = row['run_id']
    # Download artifacts and store the path
    local_dir = local_root / str(run_id)
    local_dir.mkdir(parents=True, exist_ok=True)
    local_path = client.download_artifacts(
        run_id, artifact_dir + 'predictions.json', str(local_dir)
    )
    # Save the local path to the dataframe
    df.at[index, 'local_path'] = local_path

In [None]:
source_type_map = {
    'depth': 'MDE',
    'source_depth': 'Depth',
    'images': 'RGB',
    'rgb': 'RGB',
    'masks': 'Mask',
    'rgbd': 'RGBD',
    'rgbdm': 'RGBDM',
}
source_type_color_map = {
    'Depth': 'tab:blue',
    'RGB': 'tab:orange',
    'Mask': 'tab:green',
    'MDE': 'tab:red',
    'RGBD': 'tab:purple',
    'RGBDM': 'tab:pink',
}
source_type_linestyle_map = {
    'Mask': '-',
    'Depth': '--',
    'RGB': '-.',
    'RGBD': ':',
    'RGBDM': '-',
}

In [None]:
# Load the predictions from the local paths
data = defaultdict(dict)
for index, row in df.iterrows():
    with open(row['local_path']) as f:
        results = json.load(f)
    source_type = source_type_map[row['source_type']]
    if source_type == 'RGBDM':
        continue
    data[row['driver']][source_type] = results

In [None]:
drivers = list(data.keys())
source_types = list(data[list(data.keys())[1]].keys())
pprint(source_types)
pprint(drivers)

## Visualizations

In [None]:
driver_name_mapping = {
    'dans': 1,
    'geordi': 2,
    'jakub': 3,
    'michal': 4,
    'poli': 5,
    'all': 'all',
}

In [None]:
plt.rc('font', size=16)

plot_kwargs = dict(
    source_type_color_map=source_type_color_map,
    # source_type_linestyle_map=source_type_linestyle_map,
    # driver_name_mapping=driver_name_mapping,
    fig_height_multiplier=5,
    fig_width_multiplier=3.5,
    n_rows=1,
    linewidth=2,
    legend_outside=True,
)

In [None]:
plot_results('roc', data, save_path='outputs/roc_auc.pdf', **plot_kwargs)

In [None]:
plot_results('pr', data, save_path='outputs/pr_auc.pdf', **plot_kwargs)