# Grid Search

In this notebook we do a grid search on the dataset for finetune some parameters 

More specifically, we will find learning rate, and the parameters for the focal loss function

In [None]:
import os
from functools import partial
import numpy as np
import pandas as pd
from pathlib import Path
import tempfile

import torch
from torch import optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.transforms import v2 as transforms
from torchvision.ops import sigmoid_focal_loss

from torchmetrics.classification import MultilabelAccuracy, MultilabelF1Score
from torchinfo import summary

from FindClf import Dataset, Models

from ray import tune
from ray import train as Train
from ray.train import Checkpoint, get_checkpoint
from ray.tune.schedulers import ASHAScheduler
import ray.cloudpickle as pickle

In [None]:
# Parameters
imagepath = ""  # Path to images directory
csvpath = (
    "finding_annotations_V2.csv"  # Grouped annotations for asymmetries and retractions
)
label_names = [
    "No Finding",
    "Mass",
    "Suspicious Calcification",
    "Asymmetries",
    "Architectural Distortion",
    "Suspicious Lymph Node",
    "Skin Thickening",
    "Retractions",
]

batch_size = 48
epochs = 10

scales = (0.05, 5.0)
ratios = (0.33, 1.66)
window_size = (256, 256)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
njobs = 12


## Construct dataset object

In [None]:
general_transforms = transforms.Compose(
    [
        transforms.Resize(
            window_size,
            interpolation=transforms.InterpolationMode.BILINEAR,
            antialias=True,
        ),
        transforms.ToDtype(torch.float32, scale=True),
    ]
)

In [None]:
# We load the dataset
df = pd.read_csv(csvpath)
traindf = df.groupby("split").get_group("training")

# split training into training and validation with 80% and 20% respectively
# This subset will have some images from the same patient, but for hyperparameter tuning we will ignore this
df_train = traindf.sample(frac=0.8, random_state=42)
df_val = traindf.drop(df_train.index)

df_test = df.groupby("split").get_group("test")

print(f"Training: {len(df_train)}")
print(f"Validation: {len(df_val)}")
print(f"Test: {len(df_test)}")


In [None]:
# create the dataset objects
train_dataset = Dataset.VindrDataset(
    df_train, imagepath, general_transforms, stage="train"
)
val_dataset = Dataset.VindrDataset(df_val, imagepath, general_transforms, stage="val")
test_dataset = Dataset.VindrDataset(
    df_test, imagepath, general_transforms, stage="test"
)

Now, we create our training function for finding the hyperparams we want

In [None]:
def train_model(config):
    # create the model
    model = Models.create_efficientNetV2(len(label_names))

    # device
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
    model.to(device)

    # optimizer
    optimizer = optim.Adam(
        model.parameters(), lr=config["lr"], weight_decay=config["decay"]
    )

    checkpoint = get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as ckpt_dir:
            datapath = Path(ckpt_dir) / "data.pkl"
            with open(datapath, "rb") as fp:
                ckpt_state = pickle.load(fp)

            start_epoch = ckpt_state["epoch"]
            model.load_state_dict(ckpt_state["net_state_dict"])
            optimizer.load_state_dict(ckpt_state["optimizer_state_dict"])
    else:
        start_epoch = 0

    trainloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=njobs,
        drop_last=True,
    )
    valloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=njobs,
        drop_last=True,
    )

    # Metrics
    accuracy = MultilabelAccuracy(
        num_labels=len(label_names), average="macro", ignore_index=0
    )
    accuracy.to(device)
    f1score = MultilabelF1Score(
        num_labels=len(label_names), average="macro", ignore_index=0
    )
    f1score.to(device)

    for epoch in range(start_epoch, epochs):
        running_loss = 0.0
        epoch_steps = 0
        accuracy.reset()
        f1score.reset()
        for i, (inputs, labels) in enumerate(trainloader, 0):
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = sigmoid_focal_loss(
                outputs["Classifier"],
                labels,
                alpha=config["alpha"],
                gamma=config["gamma"],
                reduction="sum",
            )
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1

            if i % 100 == 99:
                print(
                    f"[{epoch + 1:d}, {i + 1:5d}] loss: {running_loss / epoch_steps:.4f}"
                )
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0

        for i, (inputs, labels) in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                accuracy.update(outputs["Classifier"], labels)
                f1score.update(outputs["Classifier"], labels)
                # _, predicted = torch.max(outputs['Classifier'].data, 1)
                # total += labels.size(0)
                # correct += (predicted == labels).sum().item()

                loss = sigmoid_focal_loss(
                    outputs["Classifier"],
                    labels,
                    alpha=config["alpha"],
                    gamma=config["gamma"],
                    reduction="sum",
                )
                val_loss += loss.item()
                val_steps += 1
        val_acc = accuracy.compute().item()
        val_f1 = f1score.compute().item()
        print(
            f"Val Epoch {epoch + 1:d} loss: {val_loss / val_steps:.4f} accuracy: {val_acc:.4f} f1score: {val_f1:.4f}"
        )

        ckpt_data = {
            "epoch": epoch,
            "net_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }
        with tempfile.TemporaryDirectory() as ckpt_dir:
            datapath = Path(ckpt_dir) / "data.pkl"
            with open(datapath, "wb") as fp:
                pickle.dump(ckpt_data, fp)

            checkpoint = Checkpoint.from_directory(ckpt_dir)
            Train.report(
                {"loss": val_loss / val_steps, "accuracy": val_acc, "f1score": val_f1},
                checkpoint=checkpoint,
            )
    print("Finished Training")

In [None]:
def test_accuracy(model, device="cuda"):
    testloader = DataLoader(
        test_dataset, batch_size=16, shuffle=False, num_workers=njobs
    )

    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs["Classifier"].data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [None]:
# config space
config = {
    "lr": tune.loguniform(1e-5, 1e-2),
    "decay": tune.loguniform(1e-6, 1e-3),
    #'batch_size': tune.choice([8, 16, 32, 48]),
    "alpha": tune.uniform(0, 1),
    "gamma": tune.quniform(1.0, 5.0, 0.5),
}

In [None]:
result = tune.run(
    train_model,
    resources_per_trial={"cpu": 24, "gpu": 1},
    config=config,
    num_samples=50,
    scheduler=ASHAScheduler(
        metric="f1score", mode="max", max_t=10, grace_period=2, reduction_factor=2
    ),
)

In [None]:
best = result.get_best_trial("f1score", "max", "last")
print(f"Best trial config: {best.config}")
print(f"Best trial final validation loss: {best.last_result['loss']}")
print(f"Best trial final validation accuracy: {best.last_result['accuracy']}")
print(f"Best trial final validation f1score: {best.last_result['f1score']}")

In [None]:
best = result.get_best_trial("f1score", "max", "last")
print(f"Best trial config: {best.config}")
print(f"Best trial final validation loss: {best.last_result['loss']}")
print(f"Best trial final validation accuracy: {best.last_result['accuracy']}")
print(f"Best trial final validation f1score: {best.last_result['f1score']}")

In [None]:
res_df = result.results_df
res_df.head()

In [None]:
res_df.to_csv("vindr_tune_results.csv")

In [None]:
from scipy.interpolate import griddata
import matplotlib.pyplot as plt

In [None]:
# grid for alpha and gamma parameters of focal loss
ax, gy = np.meshgrid(np.linspace(0, 1, 100), np.linspace(1, 5, 100))
grid = griddata(
    res_df[["config/alpha", "config/gamma"]],
    res_df["f1score"],
    (ax, gy),
    method="linear",
)
plt.imshow(grid, extent=(0, 1, 1, 5), aspect="auto", origin="lower")

In [None]:
# grid for learning rate and weight decay
ax, gy = np.meshgrid(np.logspace(-5, -2, 100), np.logspace(-6, -3, 100))
grid = griddata(
    res_df[["config/lr", "config/decay"]], res_df["f1score"], (ax, gy), method="linear"
)
plt.imshow(grid, extent=(-5, -2, -6, -3), aspect="auto", origin="lower")
plt.xlabel("Learning Rate")
plt.ylabel("Weight Decay")

In [None]:
res_df[["config/alpha", "config/gamma", "f1score"]]