In [None]:
# Copyright (C) 2024 Mila - Institut québécois d'intelligence artificielle
# SPDX-License-Identifier: Apache-2.0

In [None]:
# This notebook permits to visualize for a given run the anomaly score predictions
# and their statistics.

In [None]:
from __future__ import annotations

import os
import sys

import matplotlib.colors
import matplotlib.pyplot as plt
import networkx
import numpy as np
import pandas as pd

from itertools import combinations
from omegaconf import OmegaConf
from PIL import Image

sys.path.append("../")
from utils import compute_metrics, plot_confusion_matrix, plot_histogram, plot_precision_recall_curve, plot_roc_curve

%matplotlib inline

In [None]:
# To adapt
root_directory = os.path.join(os.environ["HOME"])
experiment_name = "hq_kfold_unsupervised_C01"
run_name = "run.2024-05-27_11-25-44"
data_folder = os.path.join(root_directory, "CableInspect-AD")
exp_folder = os.path.join(root_directory, f"results/patchcore/hq/{experiment_name}/{run_name}")

cable_side_pass = "C01"  # Only used for HQ dataset.

In [None]:
# Load the config file
config_path = os.path.join(exp_folder, "config.yaml")
config = OmegaConf.load(config_path)
dataset = config.dataset.format
print(f"Dataset: {dataset}")

In [None]:
print(f"Split config:\n  {config.dataset.split_mode}")

In [None]:
# Load the predictions and metrics
val_img_pred_file_name = os.path.join(exp_folder, "validation_image_predictions.csv")
test_img_pred_file_name = os.path.join(exp_folder, "test_image_predictions.csv")
normalization_file_name = os.path.join(exp_folder, "normalization_stats.csv")
metrics_file_name = os.path.join(exp_folder, "logs", "lightning_logs", "version_0", "metrics.csv")

df_val = pd.read_csv(val_img_pred_file_name)
df_test = pd.read_csv(test_img_pred_file_name)

stats = pd.read_csv(normalization_file_name)
metrics = pd.read_csv(metrics_file_name)[-2:].max()

In [None]:
image_threshold = stats["image_threshold"].values[0].round(2)
metrics_lst = ["F1Score", "Precision", "Recall", "AUPR"]

In [None]:
colormap = {
    "nominal": "tab:blue",
    "anomalous": "tab:orange",
    "bent strand important": plt.cm.tab20(2),
    "bent strand light": plt.cm.tab20(3),
    "broken strands complete": plt.cm.tab20c(8),
    "broken strands extracted": plt.cm.tab20c(10),
    "broken strands partial": plt.cm.tab20c(11),
    "crushed important": plt.cm.tab20(6),
    "crushed light": plt.cm.tab20(7),
    "deposit important": plt.cm.tab20(8),
    "deposit light": plt.cm.tab20(9),
    "long scratches important": plt.cm.tab20(10),
    "long scratches light": plt.cm.tab20(11),
    "spaced strands important": plt.cm.tab20(12),
    "spaced strands light": plt.cm.tab20(13),
    "welded strands deep": plt.cm.tab20c(16),
    "welded strands partial": plt.cm.tab20c(18),
    "welded strands superficial": plt.cm.tab20c(19),
}

# Plot colormap
colors = [colormap[k] for k in sorted(colormap.keys()) if k not in ["nominal", "anomalous"]]
labels = [k for k in sorted(colormap.keys()) if k not in ["nominal", "anomalous"]]
norm = matplotlib.colors.BoundaryNorm(np.arange(1, 8) - 0.5, len(colors))
x = np.arange(1, 17)
cmap = matplotlib.colors.ListedColormap(colors)
sc = plt.scatter(x, x, c=x, s=100, cmap=cmap)
cbar = plt.colorbar(sc, ticks=x)
cbar.ax.set_yticklabels(labels)
plt.show()

# Per split

In [None]:
# Define split to visualize
split = "test"  # "validation"
normalized = False

if split == "validation":
    df = df_val
elif split == "test":
    df = df_test
else:
    raise ValueError("split should be validation or test.")

if normalized:
    prefix_scores = "normalize_"
    threshold = 0.5
    title = f"Image {split} normalized anomaly scores"
else:
    prefix_scores = ""
    threshold = image_threshold
    title = f"Image {split} anomaly scores"

if dataset == "hq":
    df = df[df["image_path"].str.contains(cable_side_pass)]

# Get labels
if list(df["target"].unique()) == [0]:
    # Case where validation set contains only nominal images
    # Note that in that case the end of the notebook will fail
    # which is normal since their is no anomalous samples to analyze
    label_groups = ["nominal"]
else:
    label_groups = ["nominal", "anomalous"]

In [None]:
# Plot the anomaly score distribution of the split
groups = df.groupby(["target"])[f"{prefix_scores}anomaly_score"].apply(list).tolist()

# Define bins for histogram (may need to be readjusted according to the runs)
min_score = df[f"{prefix_scores}anomaly_score"].min()
max_score = df[f"{prefix_scores}anomaly_score"].max()
bin_range = int(max_score + 0.5) - int(min_score - 0.5) + 1
bin_width = 1
if bin_range > 15:
    bin_width = bin_range // 15
bins = [i for i in range(int(min_score - 0.5), int(max_score + 0.5) + bin_width + 1, bin_width)]

# Legend title
metric_prefix = f"{split}_image_"
metrics_list_ = ["Precision", "Recall"]
metrics_ = metrics[[f"{metric_prefix}{m}" for m in metrics_list_]].tolist()
if len(label_groups) > 1:
    legend_title = "\n".join([f"{k}: {v:0.2f}" for k, v in zip(metrics_list_, metrics_)])
else:
    # Case where validation set contains only nominal images
    legend_title = ""

# Keep title outside of the plot for the report
print(title)
print(f"Anomaly score: min = {min_score:0.2f}, max = {max_score:0.2f}")
plot_histogram(bins, groups, label_groups, threshold, legend_title, fontsize=20)

In [None]:
if "anomalous" not in label_groups:
    print("WARNING: The rest of the notebook should not be run.")

In [None]:
# Plot the confusion matrix
predicted = df[f"{prefix_scores}anomaly_score"].to_numpy()
actual = df["target"].to_numpy()
plot_confusion_matrix(np.where(predicted >= threshold, 1, 0), actual, label_groups, title, 22)

In [None]:
# Plot the Precision-Recall curve (image level)
plot_precision_recall_curve(
    actual,
    predicted,
    None,
    None,
    threshold,
    recall_level="image",
    precision_level="image",
    title="",  # f"{split.capitalize()} set\nPrecision-Recall curve"
)

In [None]:
# Plot the ROC curve
plot_roc_curve(
    actual,
    predicted,
    threshold,
    title="",
)

## Per anomaly type

In [None]:
# Extract anomaly types
# Extract anomaly types with grades
# Note that multiple anomalies can happen in the same image.
# For those cases the score will be duplicated so that each type of anomaly is represented in the figure.
# Other possible option: for a given image keep the annotation for the more pronounced anomaly.
labels = pd.read_csv(os.path.join(data_folder, "labels.csv"))
labels["anomaly_types"] = labels["anomaly_type"].fillna("good") + " " + labels["anomaly_grade"].fillna("")
labels["anomaly_types"].replace("good ", "good", inplace=True)
column_names = ["image_path", "frame_id", "anomaly_types", "identification"]
df = df.merge(labels[column_names], on="image_path", how="left")

df["anomaly_types"].replace("good", "nominal", inplace=True)

In [None]:
groups = df.groupby(["anomaly_types"])[f"{prefix_scores}anomaly_score"].apply(list)
labels_groups = groups.index.tolist()
# Put nominal in first position to plot it in blue
labels_groups.remove("nominal")
labels_groups.insert(0, "nominal")
groups_values = groups[labels_groups].tolist()
metric_prefix = f"{split}_image_"
metrics_ = metrics[[f"{metric_prefix}{m}" for m in metrics_lst]].tolist()
metrics_dict = {k: [v] for k, v in zip(metrics_lst, metrics_)}
metrics_idx = ["Global"]
# Compute score by including only one type of anomaly
for anomaly_type in labels_groups[1:]:
    metrics_idx.append(anomaly_type)
    y_pred = groups["nominal"] + groups[anomaly_type]
    y_true = [0] * len(groups["nominal"]) + [1] * len(groups[anomaly_type])
    scores = compute_metrics(y_true, y_pred, threshold, metrics_lst)
    for metric_name, score in zip(metrics_lst, scores):
        metrics_dict[metric_name].append(score.round(4))

In [None]:
# Plot anomaly score distribution per anomaly type
print(title)
plot_histogram(bins, groups_values, labels_groups, threshold, "", 16, colormap)

In [None]:
# Zoom in anomaly score distribution per anomaly type anomalous images
labels_groups_without_nominal = labels_groups.copy()
labels_groups_without_nominal.remove("nominal")
groups_values_without_nominal = groups[labels_groups_without_nominal].tolist()
plot_histogram(bins, groups_values_without_nominal, labels_groups_without_nominal, threshold, "", 16, colormap)

In [None]:
# Print scores
pd.DataFrame(metrics_dict, index=metrics_idx).T

In [None]:
actual = []
predicted = []
for idx, group_name in enumerate(labels_groups):
    actual += [idx] * len(groups_values[idx])
    anomalous_label = len(labels_groups) if group_name == "nominal" else idx
    predicted_ = np.where(np.array(groups_values[idx]) >= threshold, anomalous_label, 0)
    predicted += list(predicted_)
actual = np.array(actual)
predicted = np.array(predicted)
labels_groups.append("anomalous")

In [None]:
# The nominal images that are badly predicted will appear in the Anomalous category
print(title)
plot_confusion_matrix(predicted, actual, labels_groups, "", 10)

# Per anomaly IDs (HQ dataset only)

In [None]:
df["identification"] = df["identification"].fillna("nominal")
df["prediction"] = (df[f"{prefix_scores}anomaly_score"] >= threshold).astype(int)

In [None]:
df_ids = df[["image_path", "identification", "target", "prediction"]].drop_duplicates()
df_groups = df_ids.groupby(["identification", "prediction"]).size()
anomaly_groups = df_groups.unstack().iloc[:-1]
anomaly_groups.plot(figsize=(15, 7), kind="bar", stacked=True, xlabel="Identification", ylabel="Count")
plt.minorticks_on()
plt.grid(axis="y", which="major", linestyle="-", linewidth="0.5", color="grey")
plt.grid(axis="y", which="minor", linestyle=":", linewidth="0.5", color="grey")
plt.show()

In [None]:
nominal_groups = df_groups.unstack().iloc[-1]
nominal_groups.plot(figsize=(15, 7), kind="bar", stacked=True, xlabel="Prediction nominal images", ylabel="Count")
plt.minorticks_on()
plt.grid(axis="y", which="major", linestyle="-", linewidth="0.5", color="grey")
plt.grid(axis="y", which="minor", linestyle=":", linewidth="0.5", color="grey")
plt.show()

In [None]:
df_ids = df[["image_path", "identification", "target", f"{prefix_scores}anomaly_score"]].drop_duplicates()
df_ids_abn = df_ids[df_ids["identification"] != "nominal"]
df_ids_norm = df_ids[df_ids["identification"] == "nominal"][["identification", f"{prefix_scores}anomaly_score"]]
# An anomaly is considered well predicted if found in at least one frame
df_ids_abn = df_ids_abn.groupby(["identification"])[f"{prefix_scores}anomaly_score"].max().reset_index()
pred_ids_level = pd.concat([df_ids_abn, df_ids_norm], axis=0)
pred_ids_level["target"] = (pred_ids_level["identification"] != "nominal").astype(int)

scores = compute_metrics(
    pred_ids_level["target"], pred_ids_level[f"{prefix_scores}anomaly_score"], threshold, metrics_lst
)
metrics_idx.insert(1, "Per unique anomaly")
for metric_name, score in zip(metrics_lst, scores):
    metrics_dict[metric_name].insert(1, score.round(4))

In [None]:
# Print scores
metrics_df = pd.DataFrame(metrics_dict, index=metrics_idx).T
metrics_df

In [None]:
# Global scores with duplicates
predicted = df[f"{prefix_scores}anomaly_score"].to_numpy()
actual = df["target"].to_numpy()
plot_confusion_matrix(np.where(predicted >= threshold, 1, 0), actual, label_groups, title, 22)

In [None]:
# Plot the Precision-Recall curve with duplicate (image level)
plot_precision_recall_curve(
    actual,
    predicted,
    None,
    None,
    threshold,
    recall_level="image",
    precision_level="image",
    title="",  # f"{split.capitalize()} set\nPrecision-Recall curve"
)

In [None]:
# Plot the Precision-Recall curve with duplicate (ID level)
actual_id = pred_ids_level["target"]
predicted_id = pred_ids_level[f"{prefix_scores}anomaly_score"]
plot_precision_recall_curve(
    None,
    None,
    actual_id,
    predicted_id,
    threshold,
    recall_level="ID",
    precision_level="ID",
    title="",  # f"{split.capitalize()} set\nPrecision-Recall curve per unique anomaly"
)

In [None]:
# Plot the Precision-Recall curve with duplicate (image level precision vs ID level recall)
plot_precision_recall_curve(
    actual,
    predicted,
    actual_id,
    predicted_id,
    threshold,
    recall_level="ID",
    precision_level="image",
    title="",  # f"{split.capitalize()} set\nPrecision-Recall curve per unique anomaly"
)

### Information about anomaly ids and their connection

In [None]:
df.groupby(["identification", "anomaly_types"])["frame_id"].apply(list)

In [None]:
lists = []
anomalous_img_labels = df[df["target"] == 1].copy()
anomalous_img_paths = anomalous_img_labels["image_path"].unique()
for img in sorted(anomalous_img_paths):
    img_info = anomalous_img_labels[anomalous_img_labels["image_path"] == img]
    identification = img_info["identification"].tolist()
    lists.append(identification)
# A graph is used to connect the anomalies that appear in a single image.
# That way, we make sure that we have no leak between the splits.
anomaly_graph = networkx.Graph()
for sub_list in lists:
    for edge in combinations(sub_list, r=2):
        anomaly_graph.add_edge(*edge)
connected_anomalies = list(networkx.connected_components(anomaly_graph))
connected_anomalies = [i for i in connected_anomalies if len(i) > 1]

In [None]:
connected_anomalies

In [None]:
# Get wrongly predicted images
df_wrong = df[df["target"] != df["prediction"]]
groups = df_wrong.groupby(["identification"])["image_path"].apply(list)

In [None]:
groups.index

In [None]:
for idx in groups.index:
    print(idx)
    for path in groups.loc[idx]:
        print(path)
    print("\n\n")

In [None]:
# Plot wrongly predicted images for a given anomaly ID or for nominal images

# Uncomment idx variable and change the anomaly ID to plot it's wrongly predicted images
# Or change idx variable to "nominal" to plot wrongly predicted nominal images
# idx = "001_00"
paths = groups.loc[idx]

# TODO: see if we can sort by scores
if idx == "nominal":
    fig, axs = plt.subplots(int(len(paths) / 4) + 1, 4, figsize=(3.5 * 4, 3.5 * int(len(paths) / 4)))
else:
    fig, axs = plt.subplots(len(paths), 2, figsize=(3.5 * 2, 3.5 * len(paths)))
fig.subplots_adjust(hspace=0.3, wspace=0.3)
axs = axs.ravel()

i = 0
for f in paths:
    # Image
    fname = os.path.join(data_folder, f)
    img = Image.open(fname)
    img = img.resize((224, 224))
    img = np.asarray(img)
    axs[i].imshow(img)
    axs[i].set_title(f, fontsize=8)
    i += 1
    if idx != "nominal":
        # Mask
        fname = fname.replace("images", "masks")
        img = Image.open(fname)
        img = img.resize((224, 224))
        img = np.asarray(img)
        axs[i].imshow(img, cmap="gray", vmin=0, vmax=255)
        axs[i].set_title(f.replace("images", "masks"), fontsize=8)
        i += 1