# Introduction
Competition home page: https://www.kaggle.com/competitions/mayo-clinic-strip-ai

Reference: 
* https://github.com/Project-MONAI/monai-bootcamp/blob/main/MONAICore
* https://github.com/Project-MONAI/tutorials
* https://docs.monai.io

# Import Libraries

In [None]:
# Check GPU
!nvidia-smi

In [None]:
%%capture
# Install MONAI
!pip install -qU "monai[ignite, nibabel, torchvision, tqdm]==0.9.0"

In [None]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL

import torch
import monai

from monai.apps import download_and_extract
from monai.config import print_config
from monai.metrics import ROCAUCMetric
from monai.data import decollate_batch, partition_dataset_classes
from monai.data import PILReader
from monai.networks.nets import DenseNet121
from monai.transforms import (
    AddChannel,
    Compose,
    LoadImage,
    RandFlip,
    RandRotate,
    RandZoom,
    ScaleIntensity,
    ToTensor,
    Activations,
    AsDiscrete,
    EnsureType,
    Resize,
    ResizeWithPadOrCrop,
    EnsureChannelFirst,
    CenterSpatialCrop
)
from monai.utils import set_determinism

from ignite.engine import Events
from ignite.handlers import ModelCheckpoint
from ignite.metrics import Accuracy
from monai.handlers import ROCAUC, ValidationHandler, CheckpointSaver
from monai.engines import SupervisedTrainer, SupervisedEvaluator
from ignite.utils import convert_tensor

# Config

In [None]:
# Define the config for the pipeline
class cfg:
    imgTrainDir = "../input/mayo-clinic-strip-ai-jpg-dataset/data/train"
    imgTestDir = "../input/mayo-clinic-strip-ai-jpg-dataset/data/test"
    imgotherDir = "../input/mayo-clinic-strip-ai-jpg-dataset/data/other"
    imgTrainMetaFile = "../input/mayo-clinic-strip-ai-jpg-dataset/data/train.csv"
    outDir = "./"
    saveModelFilename = "best_metric_model"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    seed = 0
    debug = True
    sampleSize = 50
    split = (8,1,1)
    imgSize = [128,128]
    classNum = 2
    batchSize = 5
    numWorkers = 2
    learningRate = 1e-5
    trainEpochs = 4
    validEpochs = 1
    saveEpochs = 4

In [None]:
# Set the random seeds in both Numpy and PyTorch to ensure reproducibility
set_determinism(seed=cfg.seed)

# Prepare data
We will use this dataset: https://www.kaggle.com/datasets/dariussingh/mayo-clinic-strip-ai-jpg-dataset

Note that there are other approaches to process WSI images.

In [None]:
# Read the image filenames and class names
# We only use the images under the train folder as a POC. Consider all images in practice.

df = pd.read_csv(cfg.imgTrainMetaFile)
if cfg.debug:
    df = df.sample(cfg.sampleSize)
df.head()
imageFiles = [os.path.join(cfg.imgTrainDir,x+".jpg") for x in df.image_id.values]
imageClass = df.label.values.tolist()
imageClass = [int(x=="CE") for x in imageClass] # convert categorical to int

In [None]:
# Visualize
im = PIL.Image.open(np.random.choice(imageFiles))
arr = np.array(im)
plt.imshow(arr, cmap="gray", vmin=0, vmax=255)
plt.tight_layout()
plt.show()

In [None]:
# Split training, validation, and test data
trainIdx, valIdx, testIdx = partition_dataset_classes(np.arange(len(imageFiles)), 
                                                      imageClass,cfg.split, 
                                                      shuffle=True)

trainX = [imageFiles[i] for i in trainIdx]
trainY = [imageClass[i] for i in trainIdx]
valX = [imageFiles[i] for i in valIdx]
valY = [imageClass[i] for i in valIdx]
testX = [imageFiles[i] for i in testIdx]
testY = [imageClass[i] for i in testIdx]

In [None]:
# MONAI transforms
trainTransforms = Compose([LoadImage(reader=PILReader(converter=lambda image: image.convert("L")),image_only=True),
                           AddChannel(),
                           Resize(spatial_size=cfg.imgSize),
                            ScaleIntensity(),
                            RandRotate(range_x=15, prob=0.5, keep_size=True),
                            RandFlip(spatial_axis=0, prob=0.5),
                            RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
                            ToTensor()])
valTransforms = Compose([LoadImage(reader=PILReader(converter=lambda image: image.convert("L")),image_only=True),  
                         AddChannel(),
                         Resize(spatial_size=cfg.imgSize),
                         ScaleIntensity(), 
                         ToTensor()])
act = Compose([EnsureType(),Activations(softmax=True)])
toOnehot = Compose([EnsureType(),AsDiscrete(to_onehot=cfg.classNum)])

In [None]:
# MONAI dataset
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, image_files, labels, transforms):
        self.image_files = image_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        return self.transforms(self.image_files[index]), self.labels[index]

trainDs = MyDataset(trainX, trainY, trainTransforms)
valDs = MyDataset(valX, valY, valTransforms)
testDs = MyDataset(testX, testY, valTransforms)

In [None]:
# MONAI dataloader
trainLoader = torch.utils.data.DataLoader(trainDs,
                                           batch_size=cfg.batchSize,
                                           shuffle=True,
                                           num_workers=cfg.numWorkers)
valLoader = torch.utils.data.DataLoader(valDs, 
                                         batch_size=cfg.batchSize, 
                                         num_workers=cfg.numWorkers)
testLoader = torch.utils.data.DataLoader(testDs, 
                                          batch_size=cfg.batchSize, 
                                          num_workers=cfg.numWorkers)

# Prepare model

In [None]:
net = DenseNet121(spatial_dims=2, 
                  in_channels=1, 
                  out_channels=cfg.classNum).to(cfg.device)
lossFunction = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),cfg.learningRate)

# Prepare training loop (Ignite) 

In [None]:
iterLosses = []
batchSizes = []
epochLossValues = []
metricValues = []

stepsPerEpoch = len(trainDs) // trainLoader.batch_size
if len(trainDs) % trainLoader.batch_size != 0:
    stepsPerEpoch += 1


def roc_auc_trans(x):
    if isinstance(x, list):
        pred = torch.cat([i[0][None, :] for i in x])
        label = torch.cat([i[1][None, :] for i in x])
        return pred, label

    return act(x["pred"]), toOnehot(x["label"])


def prepare_batch(batchdata, device, non_blocking):
    img, classes = batchdata
    return convert_tensor(img, device, non_blocking),convert_tensor(classes, device, non_blocking)


evaluator = SupervisedEvaluator(
    device=cfg.device,
    val_data_loader=valLoader,
    network=net,
    postprocessing=roc_auc_trans,
    key_val_metric={"rocauc": ROCAUC(output_transform=roc_auc_trans)},
    prepare_batch=prepare_batch,
)

# TODO: CheckpointSaver for saving model
trainer = SupervisedTrainer(
    device=cfg.device,
    max_epochs=cfg.trainEpochs,
    train_data_loader=trainLoader,
    network=net,
    optimizer=optimizer,
    loss_function=lossFunction,
    train_handlers=[ValidationHandler(cfg.validEpochs, evaluator),
                    CheckpointSaver(save_dir=cfg.outDir, 
                                    save_dict={'network': net},
                                    save_interval=cfg.saveEpochs,
                                    key_metric_filename=cfg.saveModelFilename)],
    prepare_batch=prepare_batch,
)


@trainer.on(Events.ITERATION_COMPLETED)
def _end_iter(engine):
    loss = np.average([o["loss"] for o in engine.state.output])
    batch_len = len(engine.state.batch[0])
    epoch = engine.state.epoch
    epochLen = engine.state.max_epochs
    step = engine.state.iteration
    iterLosses.append(loss)
    batchSizes.append(batch_len)

    print(f"epoch {epoch}/{epochLen}, step {step}/{stepsPerEpoch}, training_loss = {loss:.4f}")


@trainer.on(Events.EPOCH_COMPLETED)
def run_validation(engine):
    # the overall average loss must be weighted by batch size
    overallAverageLoss = np.average(iterLosses, weights=batchSizes)
    epochLossValues.append(overallAverageLoss)

    # clear the contents of iter_losses and batch_sizes for the next epoch
    del iterLosses[:]
    del batchSizes[:]
    
    # reset iteration for next epoch
    engine.state.iteration = 0

    # fetch and report the validation metrics
    roc = evaluator.state.metrics["rocauc"]
    metricValues.append(roc)
    print(f"evaluation for epoch {engine.state.epoch},  rocauc = {roc:.4f}")

# Train and save model

In [None]:
trainer.run()

# Plot loss and metric

In [None]:
plt.figure("train", (12, 6))

plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epochLossValues))]
y = epochLossValues
plt.xlabel("epoch")
plt.plot(x, y)

plt.subplot(1, 2, 2)
plt.title("Val AUC")
x = [(i + 1) for i in range(len(metricValues))]
y = metricValues
plt.xlabel("epoch")
plt.plot(x, y)

plt.show()

# Inference

In [None]:
net.load_state_dict(torch.load(os.path.join("./network_epoch=4.pt")))
net.eval()
y_true = list()
y_pred = list()

with torch.no_grad():
    for test_data in testLoader:
        test_images, test_labels = (
            test_data[0].to(cfg.device),
            test_data[1].to(cfg.device),
        )
        pred = net(test_images).argmax(dim=1)
        
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())
        
        idx = 1
        arr = np.array(test_data[0][idx])
        plt.imshow(arr[0,:,:], cmap="gray", vmin=0, vmax=1)
        plt.title(f"Label: {test_labels[idx].item()}, Pred: {pred[idx].item()}")
        plt.tight_layout()
        plt.show()

        break