In [None]:
# Copyright (C) 2024 Mila - Institut québécois d'intelligence artificielle
# SPDX-License-Identifier: Apache-2.0

In [None]:
# This notebook permits to visualize the ID level predictions of all runs in a kfold experiment.

In [None]:
import os
import glob

import matplotlib.pyplot as plt
import pandas as pd

%matplotlib inline

In [None]:
# To adapt
cable = "C01"
root_directory = os.environ["HOME"]
experiment_directory = os.path.join(root_directory, f"results/patchcore/hq/hq_kfold_unsupervised_{cable}")

In [None]:
# Get all runs directory
runs_directories = glob.glob(f"{experiment_directory}/*/")
print(f"# runs: {len(runs_directories)}")

In [None]:
predictions_all = None
for run_directory in runs_directories:
    # Load anomaly IDs level predictions
    predictions_fname = os.path.join(run_directory, "test_identification_predictions.csv")
    if not os.path.isfile(predictions_fname):
        print(f"Broken run: {run_directory}")
        continue
    predictions = pd.read_csv(predictions_fname)
    if predictions_all is None:
        predictions_all = predictions
    else:
        predictions_all = pd.concat([predictions_all, predictions], ignore_index=True)

In [None]:
predictions_all.replace({"predictions": {0: "Nominal", 1: "Anomalous"}}, inplace=True)
color_map = {"Nominal": "tab:blue", "Anomalous": "tab:orange"}

In [None]:
# Cable unique anomaly predictions
# Remove nominal predictions
fontsize = 22
anomalies_predictions_all = predictions_all[predictions_all["identification"] != "good"]
df_groups = anomalies_predictions_all.groupby(["identification", "predictions"]).size()
anomaly_groups = df_groups.unstack()
ax = anomaly_groups.plot(
    figsize=(20, 5),
    kind="bar",
    stacked=True,
    color=color_map,
    fontsize=19,
)
# Ticks and label
plt.xlabel("Identification", fontsize=fontsize)
plt.ylabel("# of folds", fontsize=fontsize)
plt.yticks(fontsize=fontsize)
# Grid
plt.minorticks_on()
plt.grid(axis="y", which="major", linestyle="-", linewidth="0.5", color="black")
plt.grid(axis="y", which="minor", linestyle=":", linewidth="0.5", color="black")
# Legend
handles_, labels_ = ax.get_legend_handles_labels()
ax.legend(
    handles_[::-1], labels_[::-1], title="Predictions:", loc="lower right", fontsize=fontsize, title_fontsize=fontsize
)
# plt.title(f"Cable {cable[-1]} unique anomaly predictions")
plt.show()