In [None]:
# Copyright (C) 2024 Mila - Institut québécois d'intelligence artificielle
# SPDX-License-Identifier: Apache-2.0

In [None]:
# This notebook permits to visualize the custom metrics AUPR and F1
# where precision is at the image level and recall at the ID level.

In [None]:
import glob
import os
import sys

import pandas as pd
import tqdm

sys.path.append("../")
from utils import compute_aupr_and_f1, plot_distribution_per_group

%matplotlib inline

In [None]:
# To adapt
root_directory = os.path.join(os.environ["HOME"])

In [None]:
cables = ["C01", "C02", "C03"]
metrics_lst = ["AUPR", "F1Score"]
metrics_dict = {cable: {k: [] for k in metrics_lst} for cable in cables}
for cable in cables:
    experiment_directory = os.path.join(root_directory, f"results/patchcore/hq/hq_kfold_unsupervised_{cable}")
    runs_directories = glob.glob(f"{experiment_directory}/*/")
    for run_directory in tqdm.tqdm(runs_directories):
        # Load ids level predictions
        pred_fname = os.path.join(run_directory, "test_image_predictions.csv")
        if not os.path.isfile(pred_fname):
            print(f"Broken run: {run_directory}")
            continue
        pred = pd.read_csv(pred_fname)
        # Load ids level predictions
        pred_ids_level_fname = os.path.join(run_directory, "test_identification_predictions.csv")
        pred_ids_level = pd.read_csv(pred_ids_level_fname)
        # Get image threshold
        normalization_stats = pd.read_csv(os.path.join(run_directory, "normalization_stats.csv"))
        threshold = round(normalization_stats["image_threshold"].values[0], 6)
        # Compute metrics
        aupr, f1_score = compute_aupr_and_f1(
            pred["target"],
            pred["anomaly_score"],
            pred_ids_level["target"],
            pred_ids_level["anomaly_score"],
            threshold,
            recall_level="ID",
            precision_level="image",
        )
        metrics_dict[cable]["AUPR"].append(aupr)
        metrics_dict[cable]["F1Score"].append(f1_score)

In [None]:
df = None
cables = sorted(metrics_dict.keys())
for cable in cables:
    if df is None:
        df = pd.DataFrame(metrics_dict[cable])
        df["cable"] = cable
    else:
        temp_df = pd.DataFrame(metrics_dict[cable])
        temp_df["cable"] = cable
        df = pd.concat([df, temp_df], ignore_index=True)

In [None]:
mapping_cable = {"C01": "Cable 1", "C02": "Cable 2", "C03": "Cable 3"}
df["cable"].replace(mapping_cable, inplace=True)

In [None]:
# Metrics test set (multiple folds)\n (Precision image level and Recall ID level)
# AUPR and F1 score with precision image level and recall ID level
# Possible options: "F1Score", "AUPR"
metric = "AUPR"
ylim = {"ymax": 1.0, "ymin": 0.85}
plot_distribution_per_group(
    df,
    "cable",
    [metric],
    "",  # "Cable ID (# of folds)"
    metric,
    title="",  # f"{metric} test set (multiple folds)\n (Precision image level and Recall ID level)"
    ylim=ylim,
)