# Evaluation of artefact detector

This notebook generates test metrics (ROC-AUC + confusion matrices) for the artefact detector. 

In [None]:
import numpy as np
import pandas as pd
from pytorch_lightning import Trainer, seed_everything
import matplotlib.pyplot as plt
from dataset import EMBEDMammoDataModule, ANNOTATION_FILE
from sklearn.metrics import multilabel_confusion_matrix
from artifact_detector_model import Multilabel_ArtifactDetector, MARKER_NAMES
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
plt.rcParams["font.family"] = "Serif"

In [None]:
seed_everything(42, workers=True)
model_dir = "output/artifact-detector/version_0/checkpoints/epoch=6-step=2898.ckpt"

In [None]:
df = pd.read_csv(ANNOTATION_FILE)

num_classes = len(MARKER_NAMES)
df["multilabel_markers"] = df.apply(
    lambda row: np.array([row[name] for name in MARKER_NAMES]), axis=1
)
data = EMBEDMammoDataModule(df, image_size=(512, 384), target="artifact", batch_size=32, split_dataset=True)
model = Multilabel_ArtifactDetector.load_from_checkpoint(model_dir)

trainer = Trainer()

split = "test"
method = trainer.test if split == "test" else trainer.validate
method(model=model, datamodule=data)

In [None]:
targets = model.val_trgts if split == "val" else model.test_trgts
predictions = model.val_preds if split == "val" else model.test_preds
y_true = np.asarray(targets)
labels = MARKER_NAMES
y_pred = np.asarray(predictions)
y_true = np.asarray(targets)
y_pred = predictions > 0.5

cm = multilabel_confusion_matrix(y_true, y_pred)
cm_normalized = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis]

# fig, axes = plt.subplots(2, 3, figsize=(9, 2))
# ax = axes
# fig, axes = plt.subplots(2, 3, figsize=(10, 5))\n",
fig = plt.figure(figsize=(6, 4), facecolor="none")
# Create a GridSpec with 2 rows and 3 columns
gs = gridspec.GridSpec(2, 6, figure=fig)
ax = []
# First row: 3 columns (one per column)\n",
ax.append(fig.add_subplot(gs[0, 0:2]))  # First subplot in row 1, column 1\n",
ax.append(fig.add_subplot(gs[0, 2:4]))  # Second subplot in row 1, column 2\n",
ax.append(fig.add_subplot(gs[0, 4:]))  # Third subplot in row 1, column 3\n",
# Second row: 2 columns centered (each one spans two columns)\n",
ax.append(
    fig.add_subplot(gs[1, 1:3])
)  # First subplot in row 2, spanning columns 1 and 2\n",
ax.append(
    fig.add_subplot(gs[1, 3:5])
)  # Second subplot in row 2, spanning columns 2 and 3\n",
fig.subplots_adjust(wspace=0.3)


for i, (cf_matrix, label, axi) in enumerate(zip(cm, labels, ax)):
    # display = ConfusionMatrixDisplay(matrix, display_labels=[0, 1])
    group_counts = [f"{value:0.0f}" for value in cf_matrix.flatten()]
    group_percentages = [
        f"{value:.2%}"
        for value in (cf_matrix / np.sum(cf_matrix, 1, keepdims=True)).flatten()
    ]
    labels = [f"{v3}\n(N={v2})" for v2, v3 in zip(group_counts, group_percentages)]
    labels = np.asarray(labels).reshape(2, 2)
    print(labels)
    sns.heatmap(
        (cf_matrix / np.sum(cf_matrix, 1, keepdims=True)),
        annot=labels,
        fmt="",
        cmap="Blues",
        ax=axi,
        vmin=0,
        vmax=1,
        cbar=False,
    )
    axi.set_title(' '.join([r'$\bf{' + t +'}$' for t in label.capitalize().split(' ')]), fontsize=11) 
plt.tight_layout()
plt.savefig("output/confusion_matrix.pdf", bbox_inches="tight")
plt.show()

## [Optional] plot the errors

In [None]:
import matplotlib
import os
from skimage.io import imread

plot = False
if plot:
    incorrect = []
    y_id = [x for xs in model.test_image_ids for x in xs]
    for prd, label, image_id in zip(y_pred, y_true, y_id):
        if prd[3] == 0 and label[3] == 1:
            incorrect.append(image_id)

    f, axes = plt.subplots(1, 4, figsize=(25, 10))

    for i, ax in enumerate(axes.flat):
        if i >= len(incorrect):
            ax.axis("off")
        else:
            test_image = incorrect[i]
            img_path = os.path.join(test_image)
            image = imread(img_path).astype(np.float32)
            image = (image - np.min(image)) / (np.max(image) - np.min(image))
            image = (image * 255).astype(np.uint8)
            ax.imshow(image, cmap=matplotlib.cm.gray)
            ax.axis("off")
            ax.set_title(test_image.split('/')[-1])
            print(test_image.split('/')[-1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()