In [None]:
# Copyright (C) 2022 Mila - Institut québécois d'intelligence artificielle
# SPDX-License-Identifier: Apache-2.0

In [None]:
# This notebook permits to visualize the aggregated predictions
# and statistics for unsupervised k-fold experiments.

In [None]:
from __future__ import annotations

import os
import sys

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

sys.path.append("../")
from utils import plot_distribution_per_group

%matplotlib inline

In [None]:
# To adapt
root = os.path.join(os.environ["HOME"], "results")
# Experiment name without the cable id
experiment_name = "hq_kfold_unsupervised"
threshold_strategy = "max"  # "max", "beta_prime", "empirical", "whisker"

In [None]:
cables = ["C01", "C02", "C03"]
results = None
for cable in cables:
    results_fname = f"patchcore/hq/{experiment_name}_{cable}/aggregated_results.csv"
    if results is None:
        results = pd.read_csv(os.path.join(root, results_fname))
    else:
        temp_results = pd.read_csv(os.path.join(root, results_fname))
        results = pd.concat([results, temp_results], ignore_index=True)

In [None]:
# Remove columns with unique values
for col in results.columns:
    if len(results[col].unique()) == 1:
        print(f"Delete {col}: {results[col].unique()}")
        results.drop(col, inplace=True, axis=1)

In [None]:
# Post process results
new_column_names = {
    "dataset.split_mode.cable": "cable_id",
    "dataset.split_mode.anomaly_group_id": "anomaly_group_id",
}
results.rename(columns=new_column_names, inplace=True)
results["run_name"] = results["project.path"].str.split("/", expand=True)[9]
mapping_cable = {"C01": "Cable 1", "C02": "Cable 2", "C03": "Cable 3"}
results["cable_id"].replace(mapping_cable, inplace=True)

In [None]:
nrow, ncol = results.shape
print(f"# rows: {nrow}, # columns: {ncol}")

In [None]:
results.head()

In [None]:
# Show best and worst runs per cable based on the metric
metric = "AUPR"  # To adapt
cables = results["cable_id"].unique()
for cable in cables:
    # Groups stats
    cable_results = results[results["cable_id"] == cable]
    nb_anomaly_group = cable_results["anomaly_group_id"].nunique()
    print(f"Cable {cable} # of unique anomaly group ID: {nb_anomaly_group}")
    # Best run
    if threshold_strategy == "max":
        selected_metric = f"test_image_{metric}"
    else:
        selected_metric = f"test_image_{threshold_strategy}_{metric}"
    max_metric = cable_results[selected_metric].max()
    best_cond = cable_results[selected_metric] == max_metric
    best_run = cable_results[best_cond]
    run_name = best_run["run_name"].values[0]
    print(f"Best run {selected_metric}: {round(max_metric, 2)} - {run_name}")
    # Worst run
    min_metric = cable_results[selected_metric].min()
    worst_cond = cable_results[selected_metric] == min_metric
    worst_run = cable_results[worst_cond]
    run_name = worst_run["run_name"].values[0]
    print(f"Worst run {selected_metric}: {round(min_metric, 2)} - {run_name}\n")

In [None]:
# Image threshold validation set (multiple folds)
threshold_col = f"validation_{threshold_strategy}"
plot_distribution_per_group(
    results,
    "cable_id",
    [threshold_col],
    "",  # "Cable ID (# of folds)"
    ylabel="Image threshold",
    title="",  # "Image threshold validation set (multiple folds)"
)

In [None]:
# Image scores min/max anomaly score (multiple folds)
stats = ["min", "max"]
for stat in stats:
    plot_distribution_per_group(
        results,
        "cable_id",
        [f"validation_image_{stat}", f"test_image_{stat}"],
        "",  # "Cable ID (# of folds)"
        f"Image {stat} anomaly score",
        title="",  # f"Image {stat} anomaly score (multiple folds)"
    )

In [None]:
# Metrics test set (multiple folds)
# Options: "F1Score", "Precision", "Recall", "FPR", "AUPR", "AUROC"
metric = "AUPR"
metric_col = f"test_image_{threshold_strategy}_{metric}"
ylim = {"ymax": 1.0, "ymin": 0.0}  # To adapt
plot_distribution_per_group(
    results,
    "cable_id",
    [metric_col],
    "",  # "Cable ID (# of folds)"
    metric,
    title="",  # f"{metric} test set (multiple folds)"
    ylim=ylim,
)

In [None]:
# Precision vs. Recall
prec_col = f"test_image_{threshold_strategy}_Precision"
rec_col = f"test_image_{threshold_strategy}_Recall"

labels = {
    prec_col: "Precision",
    rec_col: "Recall",
    "cable_id": "",
}

fig = px.scatter(
    data_frame=results,
    x=prec_col,
    y=rec_col,
    labels=labels,
    range_x=[0.35, 1.03],
    range_y=[0.35, 1.03],
    color="cable_id",
    marginal_y="box",
    marginal_x="box",
    width=700,
    height=700,
)
fig.update_layout(
    font={"size": 22},
    legend={"xanchor": "right", "yanchor": "top", "x": 0.98, "y": 0.96},
)
fig.update_traces(marker={"size": 10})
fig.update_xaxes(
    tickmode="array",
    tickvals=np.linspace(0.4, 1.0, 7),
)
fig.show()

In [None]:
# Compare threshold or FPR for different thresholding strategies
# To adapt:
variable = "Image threshold"  # Options: "Image threshold", "FPR"

fig = go.Figure()

colors = px.colors.qualitative.Plotly

labels = {
    "max": "max",
    "beta_prime": "beta-prime95",
    "empirical": "empirical95",
    "whisker": "boxplot outliers",
}

for idx, thr in enumerate(["max", "beta_prime", "empirical", "whisker"]):
    # Defining x axis
    x = results["cable_id"].tolist()
    if variable == "Image threshold":
        y = results[f"validation_{thr}"].tolist()
    elif variable == "FPR":
        y = results[f"test_image_{thr}_FPR"].tolist()
    else:
        raise AttributeError("variable should be 'Image threshold' or 'FPR'.")

    fig.add_trace(
        go.Box(
            y=y,
            x=x,
            boxpoints="all",
            name=labels[thr],
            marker_color=colors[idx],
        )
    )

fig.update_layout(
    yaxis_title=variable,
    boxmode="group",
    boxgap=0.0,
    boxgroupgap=0.5,
    margin={"l": 0, "r": 0, "b": 0, "pad": 0},
    font={"size": 30},
    width=1300,
    height=600,
    legend={"orientation": "h", "yanchor": "bottom", "xanchor": "right", "x": 1, "y": 1.02},
)
fig.update_yaxes(showgrid=True, gridwidth=2, gridcolor="white", minor_griddash="dot")
fig.update_xaxes(range=[-0.5, 2.47])
fig.show()