## Evaluation of classifier

We can load a model and check its metrics.
In this one, we will load our trained model and obtain metrics about the classification task.

(we will also compare the original model and check its metrics)

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ast import literal_eval
from tqdm.notebook import tqdm

from FindClf import Dataset, Models

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2 as transforms
from torchmetrics.classification import (
    MultilabelAccuracy,
    MultilabelPrecision,
    MultilabelStatScores,
    MultilabelRecall,
    MultilabelF1Score,
    MultilabelROC,
    MultilabelAUROC,
)

In [None]:
# Parameters
imagepath = ""  # Image directory with vindr Dataset images processed with our method
csvpath = (
    "finding_annotations_V2.csv"  # Grouped annotations for asymmetries and retractions
)
label_names = [
    "No Finding",
    "Mass",
    "Suspicious Calcification",
    "Asymmetries",
    "Architectural Distortion",
    "Suspicious Lymph Node",
    "Skin Thickening",
    "Retractions",
]

batch_size = 32
scales = (0.05, 5.0)
ratios = (0.33, 1.66)
window_size = (256, 256)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
njobs = 16
seed = 348

In [None]:
# we will set a seed for reproducibility when evaluating both models
import random


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


## Semilla (para reproducibilidad)
g = torch.Generator()
g.manual_seed(seed)

In [None]:
# Load the transforms
test_transforms = transforms.Compose(
    [
        transforms.Resize(
            window_size,
            interpolation=transforms.InterpolationMode.BILINEAR,
            antialias=True,
        ),
        transforms.ToDtype(torch.float32, scale=True),
    ]
)

In [None]:
# load the csv
df = pd.read_csv(csvpath)
df_test = df.groupby("split").get_group("test")
test_dataset = Dataset.VindrDataset(df_test, imagepath, test_transforms, stage="test")

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=njobs,
    worker_init_fn=seed_worker,
    generator=g,
)

In [None]:
# we load the model
model_path = ""  # path to the model checkpoint
state_dict = torch.load(model_path, weights_only=True, map_location=device)

model = Models.create_efficientNetV2(len(label_names))
model.load_state_dict(state_dict["model_state_dict"])
model.to(device)

model.eval()
print("Model loaded")

In [None]:
# Load all metrics
nlabels = len(label_names)
gral_metrics = {
    "Accuracy": MultilabelAccuracy(
        num_labels=nlabels, average="weighted", ignore_index=0
    ),
    "Precision": MultilabelPrecision(
        num_labels=nlabels, average="weighted", ignore_index=0
    ),
    "Recall": MultilabelRecall(num_labels=nlabels, average="weighted", ignore_index=0),
    "F1": MultilabelF1Score(num_labels=nlabels, average="weighted", ignore_index=0),
    "AUROC": MultilabelAUROC(num_labels=nlabels, average="weighted"),
}
class_metrics = {
    "Accuracy": MultilabelAccuracy(num_labels=nlabels, average=None),
    "Precision": MultilabelPrecision(num_labels=nlabels, average=None),
    "Recall": MultilabelRecall(num_labels=nlabels, average=None),
    "F1": MultilabelF1Score(num_labels=nlabels, average=None),
    "AUROC": MultilabelAUROC(num_labels=nlabels, average=None),
    "StatScores": MultilabelStatScores(num_labels=nlabels, average=None),
}

[metric.to(device) for metric in gral_metrics.values()]
[metric.to(device) for metric in class_metrics.values()]

print("Metrics loaded")

In [None]:
yreal, predictions = [], []

with tqdm(total=len(test_loader), desc="Evaluating") as pbar:
    for i, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            pred = torch.sigmoid(outputs["Classifier"])

            for metric in gral_metrics.values():
                metric.update(outputs["Classifier"], labels.int())

            for metric in class_metrics.values():
                metric.update(pred, labels.int())

            yreal.append(labels.detach().cpu().numpy())
            predictions.append(pred.detach().cpu().numpy())

        pbar.update()

In [None]:
# general Metrics
for name, metric in gral_metrics.items():
    print(f"{name}: {metric.compute():.4f}")

In [None]:
scores = class_metrics["StatScores"].compute()
scores[:, -1]

In [None]:
from sklearn import metrics

In [None]:
yyreal = np.concatenate(yreal, axis=0)
yypred = np.concatenate(predictions, axis=0)
print(yyreal.shape, yypred.shape)

In [None]:
cms = metrics.multilabel_confusion_matrix(yyreal, yypred > 0.5)

In [None]:
cms[1]

In [None]:
cms.diagonal(axis1=1, axis2=2) / cms.sum(axis=1)

In [None]:
print("Class Metrics")
print(
    f"{'Labels':>30}: {'Accuracy':^10} {'Precision':^10} {'Recall':^10} {'F1':^10} {'AUROC':^10} {'Support':^10}"
)
for i, label in enumerate(label_names):
    print(f"{label:>30}:", end=" ")
    for metric in [
        "Accuracy",
        "Precision",
        "Recall",
        "F1",
        "AUROC",
    ]:
        print(f"{class_metrics[metric].compute()[i]:^10.4f}", end=" ")
    print(f"{scores[i, -1].item():^10d}")

In [None]:
es_label_names = [
    "Sin Hallazgo",
    "Masa",
    "Calcificación Sospechosa",
    "Asimetrías",
    "Distorsión Arquitectura",
    "Linfonodo Sospechoso",
    "Engrosamiento de Piel",
    "Retracciones",
]

statsdf = pd.DataFrame(
    index=pd.Index(es_label_names, name="Hallazgo"),
    columns=["Accuracy", "Precision", "Recall", "F1", "AUROC", "Soporte"],
)

for i, label in enumerate(es_label_names):
    for metric in ["Accuracy", "Precision", "Recall", "F1", "AUROC"]:
        statsdf.loc[label, metric] = class_metrics[metric].compute()[i].item()
    statsdf.loc[label, "Soporte"] = scores[i, -1].item()

statsdf

In [None]:
asd = (
    statsdf.style.format("{:.3%}", subset=["Accuracy", "Precision", "Recall"])
    .format("{:.4f}", subset=["F1", "AUROC"])
    .format("{:d}", subset=["Soporte"])
)

In [None]:
print(statsdf.to_markdown(tablefmt="grid", floatfmt=".4f"))