# Training the pathological finding classifier

After we identified our model `EfficientNetV2`, we train our classifier with all samples.

The idea on this, is to get crops at different ranges of scales and aspect ratios to train our classifier.
Then train this classifier using our best parameters found on [the earlier notebook](02-GridSearch.ipynb).

Finally, we save the trained model and do some inference using our explainable algorithm.

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ast import literal_eval
from tqdm.notebook import tqdm

from datetime import datetime

import torch
from torch import optim
from torch.utils.data import DataLoader, WeightedRandomSampler

from torchvision.transforms import v2 as transforms
from torchvision.ops import sigmoid_focal_loss

from torchmetrics.classification import (
    MultilabelAccuracy,
    MultilabelPrecision,
    MultilabelRecall,
    MultilabelF1Score,
    MultilabelAUROC,
)
from torchinfo import summary

from FindClf import Dataset, Models


In [None]:
# Parameters
imagepath = ""  # Image directory with vindr Dataset images processed with our method
csvpath = (
    "finding_annotations_V2.csv"  # Grouped annotations for asymmetries and retractions
)
label_names = [
    "No Finding",
    "Mass",
    "Suspicious Calcification",
    "Asymmetries",
    "Architectural Distortion",
    "Suspicious Lymph Node",
    "Skin Thickening",
    "Retractions",
]

# Hyperparameters
batch_size = 32
epochs = 100

scales = (0.05, 5.0)
ratios = (0.33, 1.66)
window_size = (256, 256)

# Hyperparameters found by Ray Tune
lr = 0.0001
weight_decay = 1e-5
focal_alpha = 0.8  # 0.95 was the best value but is kind of too high
focal_gamma = 2.0  # 2.0 was as expected though

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
njobs = 16


## instanciate dataset

In [None]:
# set transforms
train_transforms = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(
            degrees=30, interpolation=transforms.InterpolationMode.BILINEAR
        ),
        transforms.RandomApply(
            [
                transforms.ColorJitter(
                    brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
                ),
                transforms.RandomChannelPermutation(),
            ],
            p=0.5,
        ),
        transforms.Resize(
            window_size,
            interpolation=transforms.InterpolationMode.BILINEAR,
            antialias=True,
        ),
        transforms.ToDtype(torch.float32, scale=True),
    ]
)

test_transforms = transforms.Compose(
    [
        transforms.Resize(
            window_size,
            interpolation=transforms.InterpolationMode.BILINEAR,
            antialias=True,
        ),
        transforms.ToDtype(torch.float32, scale=True),
    ]
)

In [None]:
# Load the CSV file
df = pd.read_csv(csvpath)
df_train = df.groupby("split").get_group("training")
df_test = df.groupby("split").get_group("test")

In [None]:
# create the dataset objects
train_dataset = Dataset.VindrDataset(
    df_train, imagepath, train_transforms, stage="train"
)
test_dataset = Dataset.VindrDataset(df_test, imagepath, test_transforms, stage="test")

In [None]:
# Weighted sampler functions
def get_class_weights(dataset):
    events_per_class = np.zeros_like(label_names, dtype=np.int64)
    for val in dataset.df.finding_categories:
        for label in literal_eval(val):
            idx = label_names.index(label)
            events_per_class[idx] += 1
    n_samples = len(dataset)
    class_weights = n_samples / events_per_class
    return class_weights


def get_sample_weights(dataset):
    sample_weights = []
    class_weights = get_class_weights(dataset)
    for val in dataset.df.finding_categories:
        label = literal_eval(val)
        sample_weight = [class_weights[label_names.index(l)] for l in label]
        sample_weights.append(sum(sample_weight))
    return np.array(sample_weights)

In [None]:
# create the sampler
class_weights = get_class_weights(train_dataset)
print(f"Class weights: {class_weights}")

train_weights = get_sample_weights(train_dataset)
sampler = WeightedRandomSampler(train_weights, len(train_weights), replacement=True)

In [None]:
# Dataloaders
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, num_workers=njobs, sampler=sampler
)

test_loader = DataLoader(
    test_dataset, batch_size=batch_size, num_workers=njobs, shuffle=True
)

## Define the model and optimizer

In [None]:
model = Models.create_efficientNetV2(len(label_names))
model.to(device)

summary(
    model,
    input_size=(1, 3, 256, 256),
    col_names=["input_size", "output_size", "num_params"],
    depth=3,
)

In [None]:
# Optimizer and schedulers
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=epochs, eta_min=1e-7, last_epoch=-1
)

In [None]:
metrics = {
    "Accuracy": MultilabelAccuracy(
        num_labels=len(label_names), average="macro", ignore_index=0
    ),
    "F1": MultilabelF1Score(
        num_labels=len(label_names), average="macro", ignore_index=0
    ),
    "AUC": MultilabelAUROC(
        num_labels=len(label_names), average="weighted", ignore_index=0, thresholds=10
    ),
}

[metric.to(device) for metric in metrics.values()]

## Training loop

In [None]:
def crear_carpeta(path):
    if not os.path.isdir(path):
        os.makedirs(path, exist_ok=True)


start = datetime.now().strftime("%Y%m%d_%H%M%S")
ckpt_path = f"checkpoints/{start}"
crear_carpeta(ckpt_path)

In [None]:
with tqdm(total=epochs, desc="Training") as trainbar:
    for epoch in range(1, epochs + 1):
        model.train()  # set the model to training mode

        current_lr = scheduler.get_last_lr()[0]
        train_loss = 0.0
        [metric.reset() for metric in metrics.values()]  # reset the metrics
        with tqdm(
            total=len(train_loader), desc=f"Training (lr={current_lr:.2e})", leave=False
        ) as pbar:
            for i, (images, labels) in enumerate(train_loader, 1):
                images, labels = images.to(device), labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                outputs = model(images)  # forward pass
                loss = sigmoid_focal_loss(
                    outputs["Classifier"],
                    labels,
                    alpha=focal_alpha,
                    gamma=focal_gamma,
                    reduction="sum",
                )
                loss.backward()  # backward pass
                optimizer.step()

                # metrics
                train_loss += loss.item()
                metric_text = f"Loss: {train_loss / i:.4f}"
                for name, metric in metrics.items():
                    metric.update(outputs["Classifier"], labels.int())
                    metric_text += f" {name}: {metric.compute():.4f}"
                pbar.set_postfix_str(metric_text)
                pbar.update()

        # Validation
        model.eval()  # set the model to evaluation mode
        [metric.reset() for metric in metrics.values()]  # reset the metrics
        val_loss = 0.0
        with tqdm(total=len(test_loader), desc="Validation", leave=False) as pbar:
            for i, (images, labels) in enumerate(test_loader, 1):
                with torch.no_grad():
                    images, labels = images.to(device), labels.to(device)

                    outputs = model(images)  # forward pass
                    loss = sigmoid_focal_loss(
                        outputs["Classifier"],
                        labels,
                        alpha=focal_alpha,
                        gamma=focal_gamma,
                        reduction="sum",
                    )

                    # metrics
                    val_loss += loss.item()
                    metric_text = f"Loss: {val_loss / i:.4f}"
                    for name, metric in metrics.items():
                        metric.update(outputs["Classifier"], labels.int())
                        metric_text += f" {name}: {metric.compute():.4f}"
                    pbar.set_postfix_str(metric_text)
                    pbar.update()

        # Post-epoch operations
        # update the progress bar
        trainbar.set_postfix_str(
            f" F1: {metrics['F1'].compute():.4f} AUROC: {metrics['AUC'].compute():.4f}"
        )
        trainbar.update()

        scheduler.step()  # update the learning rate
        if epoch % 10 == 0:
            current_ckpt = os.path.join(
                ckpt_path, f"EfficientNetV2_epoch{epoch:03d}.pth"
            )
            save_dict = {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
            }
            torch.save(save_dict, current_ckpt)

# Save the final model
current_ckpt = os.path.join(ckpt_path, f"EfficientNetV2_final.pth")
save_dict = {
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "scheduler_state_dict": scheduler.state_dict(),
}
torch.save(save_dict, current_ckpt)
print(f"Model saved at {current_ckpt}")
