# Evaluation of TAE Experiments

This notebook connects to MLflow, downloads all experiment runs and creates visualizations.

In [None]:
import json
import warnings
from collections import defaultdict
from pathlib import Path
from pprint import pprint

import dagshub
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mlflow.client import MlflowClient
from mlflow.entities import ViewType
from sklearn.metrics import auc, precision_recall_curve, roc_curve
from tqdm import tqdm

In [None]:
REPO_NAME = 'driver-tae'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)  # type: ignore

pd.set_option('display.max_columns', None)

client = MlflowClient()

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 20})

## Get all experiment runs

In [None]:
experiments = client.search_experiments(filter_string="name!='Default'")
pprint([experiment.name for experiment in experiments])

In [None]:
# Get all runs from the experiments
all_runs = []
for experiment in experiments:
    runs = client.search_runs(
        experiment_ids=[experiment.experiment_id],
        filter_string='',
        run_view_type=ViewType.ACTIVE_ONLY,
    )
    all_runs.extend(runs)

# Create a DataFrame from the runs
runs_df = pd.DataFrame(
    [
        {
            'run_id': r.info.run_id,
            'experiment_id': r.info.experiment_id,
            'experiment_name': client.get_experiment(r.info.experiment_id).name,
            'status': r.info.status,
            'start_time': pd.to_datetime(r.info.start_time, unit='ms'),
            'end_time': pd.to_datetime(r.info.end_time, unit='ms')
            if r.info.end_time
            else None,
            'artifact_uri': r.info.artifact_uri,
            **r.data.params,  # Add all parameters
            **{
                f'metric.{k}': v for k, v in r.data.metrics.items()
            },  # Add all metrics with "metric." prefix
        }
        for r in all_runs
    ]
)

# Display the DataFrame
print(f'Total runs: {len(runs_df)}')
runs_df.head()

In [None]:
runs_df.head()

## Filtering

In [None]:
integer_columns = ['image_size', 'latent_dim', 'batch_size', 'early_stopping']
runs_df[integer_columns] = runs_df[integer_columns].astype(int)

In [None]:
idx = runs_df.groupby(['driver', 'source_type', 'latent_dim', 'image_size'])[
    'metric.roc_auc'
].idxmax()
best_runs_df = runs_df.loc[idx]

In [None]:
df = best_runs_df[
    (best_runs_df['image_size'] == 64)
    & (best_runs_df['latent_dim'] == 128)
    & (best_runs_df['dataset'] != 'dmd')
]
df[
    [
        'driver',
        'source_type',
        'metric.roc_auc',
        'early_stopping',
        'patience',
        'min_epochs',
        'best_metric',
    ]
]

## Download predictions

In [None]:
df = df.assign(local_path=None)
local_root = Path.cwd() / 'outputs' / 'mlflow_artifacts'
artifact_dir = 'outputs/'

# Loop through each row in the dataframe
for index, row in tqdm(df.iterrows(), total=len(df)):
    run_id = row['run_id']
    # Download artifacts and store the path
    local_dir = local_root / str(run_id)
    local_dir.mkdir(parents=True, exist_ok=True)
    local_path = client.download_artifacts(
        run_id, artifact_dir + 'predictions.json', str(local_dir)
    )
    # Save the local path to the dataframe
    df.at[index, 'local_path'] = local_path

In [None]:
source_type_map = {
    'depth': 'Depth',
    'images': 'RGB',
    'masks': 'Mask',
    'rgbd': 'RGBD',
    'rgbdm': 'RGBDM',
}
source_type_color_map = {
    'Depth': 'tab:blue',
    'RGB': 'tab:orange',
    'Mask': 'tab:green',
    'RGBD': 'tab:red',
    'RGBDM': 'tab:purple',
}

In [None]:
# Load the predictions from the local paths
data = defaultdict(dict)
for index, row in df.iterrows():
    with open(row['local_path']) as f:
        results = json.load(f)
    data[row['driver']][source_type_map[row['source_type']]] = results

In [None]:
drivers = list(data.keys())
source_types = list(data[list(data.keys())[0]].keys())
pprint(source_types)
pprint(drivers)

## Visualizations

In [None]:
def plot_roc_auc_chart(
    data: dict[str, dict[str, dict[str, list[float | int] | float]]],
    cmap: str = 'rainbow',
    figsize: tuple[int, int] | None = None,
    plot_thresholds: bool = False,
    cbar_text: str = 'Threshold',
    save_path: str | Path | None = None,
    justification: int = 5,
) -> None:
    n_plots = len(data)

    # Calculate default figsize if not provided
    if figsize is None:
        figsize = (7 * n_plots, 7)

    # Create figure with gridspec to accommodate colorbar
    fig = plt.figure(figsize=figsize)
    gs = plt.GridSpec(1, n_plots + 1, width_ratios=[1] * n_plots + [0.05])  # type: ignore
    axes = [fig.add_subplot(gs[0, i]) for i in range(n_plots)]

    # Use fixed colormap range from 0 to 1
    norm = plt.Normalize(vmin=0, vmax=1)  # type: ignore

    for idx, (ax, (driver_name, driver_results)) in enumerate(zip(axes, data.items())):
        for source_type, results in driver_results.items():
            y_true: list[int] = results['y_true']  # type: ignore
            y_pred_proba: list[float] = results['y_proba']  # type: ignore
            title = driver_name.capitalize()

            # Calculate ROC curve
            fpr, tpr, thresholds = roc_curve(y_true[: len(y_pred_proba)], y_pred_proba)
            roc_auc = auc(fpr, tpr)

            # Plot ROC curve
            ax.plot(
                fpr,
                tpr,
                c=source_type_color_map[source_type],
                label=f'{source_type.ljust(justification)} AUC={roc_auc:.3f}',
                linewidth=2,
            )

            if plot_thresholds:
                scatter = ax.scatter(fpr, tpr, c=thresholds, cmap=cmap, norm=norm)

            # Random predictions curve
            ax.plot([0, 1], [0, 1], 'k--', alpha=0.5)

        # Set title and limits
        ax.set_title(title)
        ax.set_xlim([0, 1])  # type: ignore
        ax.set_ylim([0, 1])  # type: ignore
        ax.axis('square')

        # Handle axis labels and ticks
        if idx == 0:
            ax.set_ylabel('True positive rate')
        else:
            # Remove y-axis labels for all but the first plot
            ax.set_yticklabels([])

        # Add x-label to all plots
        ax.set_xlabel('False positive rate')
        ax.legend(loc='lower right', prop={'family': 'monospace'})

    # Add colorbar in the last column of gridspec
    if plot_thresholds:
        cbar_ax = fig.add_subplot(gs[0, -1])
        cbar = fig.colorbar(scatter, cax=cbar_ax)
        cbar.set_label(cbar_text)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
    plt.show()

In [None]:
plot_roc_auc_chart(data, save_path='outputs/roc_auc.pdf')

In [None]:
def plot_pr_auc_chart(
    data: dict[str, dict[str, dict[str, list[float | int] | float]]],
    cmap: str = 'rainbow',
    figsize: tuple[int, int] | None = None,
    plot_thresholds: bool = False,
    cbar_text: str = 'Threshold',
    save_path: str | Path | None = None,
    justification: int = 5,
) -> None:
    n_plots = len(data)

    # Calculate default figsize if not provided
    if figsize is None:
        figsize = (7 * n_plots, 7)

    # Create figure with gridspec to accommodate colorbar
    fig = plt.figure(figsize=figsize)
    gs = plt.GridSpec(  # type: ignore
        1,
        n_plots + 1,
        width_ratios=[1] * n_plots + ([0.05] if plot_thresholds else [0.0]),
    )
    axes = [fig.add_subplot(gs[0, i]) for i in range(n_plots)]

    # Use fixed colormap range from 0 to 1
    norm = plt.Normalize(vmin=0, vmax=1)  # type: ignore

    for idx, (ax, (driver_name, driver_results)) in enumerate(zip(axes, data.items())):
        for source_type, results in driver_results.items():
            y_true: list[int] = results['y_true']  # type: ignore
            y_pred_proba: list[float] = results['y_proba']  # type: ignore
            title = driver_name.capitalize()

            # Calculate ROC curve
            precision, recall, thresholds = precision_recall_curve(
                y_true[: len(y_pred_proba)], y_pred_proba
            )
            precision[precision > 1.0] = 1.0
            recall[recall > 1.0] = 1.0
            pr_auc = auc(recall, precision)

            # Compute F1 scores for each threshold (skip the first element).
            with warnings.catch_warnings():
                # Suppress warnings for division by zero
                warnings.simplefilter('ignore')
                f1_scores = (
                    2 * precision[1:] * recall[1:] / (precision[1:] + recall[1:])
                )
            # from sklearn.metrics import f1_score
            # f1_scores = []
            # for th in thresholds:
            #     y_pred = [1 if p >= th else 0 for p in y_pred_proba]
            #     f1 = f1_score(y_true, y_pred)
            #     f1_scores.append(f1)
            optimal_idx = np.array(f1_scores).argmax()
            optimal_threshold = thresholds[optimal_idx]  # noqa

            # Plot the precision-recall curve.
            ax.step(
                recall,
                precision,
                c=source_type_color_map[source_type],
                label=f'{source_type.ljust(justification)} AUC={pr_auc:.3f}',
                linewidth=2,
            )

            if plot_thresholds:
                scatter = plt.scatter(
                    recall[1:], precision[1:], c=thresholds, cmap=cmap, norm=norm
                )

        # Set title and limits
        ax.set_title(title)

        ax.set_xlim([0, 1.1])  # type: ignore
        ax.set_ylim([0, 1.1])  # type: ignore

        ax.axis('tight')

        ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
        ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])

        ax.grid(True)

        # Handle axis labels and ticks
        if idx == 0:
            ax.set_ylabel('Precision')
        else:
            # Remove y-axis labels for all but the first plot
            ax.set_yticklabels([])

        # Add x-label to all plots
        ax.set_xlabel('Recall')
        ax.legend(
            prop={'family': 'monospace'},
            loc='upper center',
            bbox_to_anchor=(0.5, -0.15),
        )

    # Add colorbar in the last column of gridspec
    if plot_thresholds:
        cbar_ax = fig.add_subplot(gs[0, -1])
        cbar = fig.colorbar(scatter, cax=cbar_ax)
        cbar.set_label(cbar_text)

    plt.tight_layout()
    plt.subplots_adjust(bottom=0.2)
    # plt.tight_layout(rect=[0, 0.05, 1, 1])  # Add padding at the bottom for the legends

    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
    plt.show()

In [None]:
plot_pr_auc_chart(data, save_path='outputs/pr_auc.pdf')