# Evaluation of STAE Experiments

This notebook connects to MLflow, downloads all experiment runs and creates visualizations.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
from collections import defaultdict
from pathlib import Path
from pprint import pprint

import dagshub
import matplotlib.pyplot as plt
import pandas as pd
from mlflow.client import MlflowClient
from mlflow.entities import ViewType
from tqdm import tqdm

from model.fonts import set_cmu_typewriter_font
from model.plot import plot_results

In [None]:
REPO_NAME = 'driver-stae'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)  # type: ignore

pd.set_option('display.max_columns', None)

client = MlflowClient()

plt.style.use('seaborn-v0_8-whitegrid')
font = set_cmu_typewriter_font()
plt.rc('font', size=34)

## Get all experiment runs

In [None]:
experiments = client.search_experiments()
pprint([experiment.name for experiment in experiments])

In [None]:
# Get all runs from the experiments
all_runs = []
for experiment in experiments:
    runs = client.search_runs(
        experiment_ids=[experiment.experiment_id],
        filter_string='',
        run_view_type=ViewType.ACTIVE_ONLY,
    )
    all_runs.extend(runs)

# Create a DataFrame from the runs
runs_df = pd.DataFrame(
    [
        {
            'run_id': r.info.run_id,
            'experiment_id': r.info.experiment_id,
            'experiment_name': client.get_experiment(r.info.experiment_id).name,
            'status': r.info.status,
            'start_time': pd.to_datetime(r.info.start_time, unit='ms'),
            'end_time': pd.to_datetime(r.info.end_time, unit='ms')
            if r.info.end_time
            else None,
            'artifact_uri': r.info.artifact_uri,
            **r.data.params,  # Add all parameters
            **{
                f'metric.{k}': v for k, v in r.data.metrics.items()
            },  # Add all metrics with "metric." prefix
        }
        for r in all_runs
    ]
)

# Display the DataFrame
print(f'Total runs: {len(runs_df)}')
runs_df.head()

In [None]:
runs_df.head()

## Filtering

In [None]:
integer_columns = ['image_size', 'batch_size', 'early_stopping']
runs_df[integer_columns] = runs_df[integer_columns].astype(int)

In [None]:
runs_df = runs_df[runs_df['image_size'] == 64]
runs_df['source_type'] = runs_df['source_type'].fillna('depth')

In [None]:
idx = runs_df.groupby(['driver', 'source_type', 'image_size'])[
    'metric.roc_auc'
].idxmax()
best_runs_df = runs_df.loc[idx]

In [None]:
df = best_runs_df[best_runs_df['image_size'] == 64]
df[
    [
        'driver',
        'source_type',
        'metric.roc_auc',
        'metric.pr_auc',
        'early_stopping',
        'patience',
        'best_metric',
    ]
]

## Download predictions

In [None]:
df = df.assign(local_path=None)
local_root = Path.cwd() / 'outputs' / 'mlflow_artifacts'
artifact_dir = 'outputs/'

# Loop through each row in the dataframe
for index, row in tqdm(df.iterrows(), total=len(df)):
    run_id = row['run_id']
    # Download artifacts and store the path
    local_dir = local_root / str(run_id)
    local_dir.mkdir(parents=True, exist_ok=True)
    local_path = client.download_artifacts(
        run_id, artifact_dir + 'predictions.json', str(local_dir)
    )
    # Save the local path to the dataframe
    df.at[index, 'local_path'] = local_path

In [None]:
source_type_map = {
    'masks': 'Mask',
    'depth': 'Depth',
    'images': 'RGB',
    'rgbd': 'RGBD',
    'rgbdm': 'RGBDM',
}
source_type_color_map = {
    'Mask': 'tab:orange',
    'Depth': 'tab:blue',
    'RGB': 'tab:green',
    'RGBD': 'tab:red',
    'RGBDM': 'tab:purple',
}
source_type_linestyle_map = {
    'Mask': '-',
    'Depth': '--',
    'RGB': '-.',
    'RGBD': ':',
    'RGBDM': '-',
}

In [None]:
# Load the predictions from the local paths
data = defaultdict(dict)
for index, row in df.iterrows():
    with open(row['local_path']) as f:
        results = json.load(f)
    data[row['driver']][source_type_map[row['source_type']]] = results

In [None]:
drivers = list(data.keys())
source_types = list(data[list(data.keys())[0]].keys())
pprint(source_types)
pprint(drivers)

## Visualizations

In [None]:
driver_name_mapping = {'dans': 1, 'geordi': 2, 'jakub': 3, 'michal': 4, 'poli': 5}

In [None]:
plot_kwargs = dict(
    source_type_color_map=source_type_color_map,
    source_type_linestyle_map=source_type_linestyle_map,
    driver_name_mapping=driver_name_mapping,
    linewidth=3,
    legend_outside=True,
)

In [None]:
plot_results('roc', data, save_path='outputs/roc_auc.pdf', **plot_kwargs)

In [None]:
plot_results('pr', data, save_path='outputs/pr_auc.pdf', **plot_kwargs)

## Recalculate metrics

In [None]:
import copy

from model.eval import compute_best_roc_auc

redata = defaultdict(dict)

for driver in drivers:
    for source_type in source_types:
        x = copy.deepcopy(data[driver][source_type])
        res = compute_best_roc_auc(x['y_true'], x['errors'], (0.00, 0.95))
        redata[driver][source_type] = x
        redata[driver][source_type].update(res)
        y = redata[driver][source_type]

In [None]:
plot_results('roc', redata, save_path='outputs/roc_auc_iqr.pdf', **plot_kwargs)

In [None]:
plot_results('pr', redata, save_path='outputs/pr_auc_iqr.pdf', **plot_kwargs)