# Model Pipeline Creation

In [None]:
## Importing Libs

from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImage,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    DivisiblePadd,
    RandAffined,
    RandRotated,
    RandGaussianNoised,
    Resized
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet, SwinUNETR
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
from torch.utils.data import ConcatDataset
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
from datetime import datetime
import shutil

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
## Getting the dataset

##IMPORTANT: CHANGE HERE TO THE DATA PATH##
# it is made with the same format as written in the cluster

data_dir = "/tsi/data_education/data_challenge"
train_images = sorted(glob.glob(os.path.join(data_dir, "train/volume", "*.nii*")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "train/seg", "*.nii*")))
data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-90], data_dicts[-90:] # getting aprox 1/3 of the data to validation
print(train_images) ##check if it gets the data

In [None]:
## TRANSFORMS PreProcessing

train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        DivisiblePadd(["image", "label"], 16),
        Resized(keys=["image", "label"], spatial_size=(192,192,192)) #downsampling to 192 due the memory limits
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        DivisiblePadd(["image", "label"], 16),
        Resized(keys=["image", "label"],spatial_size=(192,192,192))
    ]
)

In [None]:
## Creating the datasets

#batch size = 1 due to memory limitation
train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds,num_workers=4,batch_size=1)

val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds,num_workers=4,batch_size=1)

In [None]:
# implementing the early stopping and functions to save and load the dataset dicts

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [None]:
## Creating the model instance - SwinUNETR

model = SwinUNETR(img_size=(192,192,192), in_channels=1, out_channels=2, use_checkpoint=True,attn_drop_rate=0.2, dropout_path_rate=0.2,drop_rate=0.2).to(device)
optimizer = torch.optim.AdamW(model.parameters(), 1e-4,amsgrad=True)

loss_function = DiceCELoss(to_onehot_y=True,softmax = True, include_background=False)
early_stopper = EarlyStopper(patience=3, min_delta=0.3)

In [None]:
## Pytorch usual pipeline

total_time = 0
max_epochs = 10 ## The best results came from the model with 10 epochs, so if it is necessary maybe change this value
val_interval = 2
print_interval = 1
epoch_loss_values = []
losses_validation = []

for epoch in range(max_epochs):
    start_time = datetime.now()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train() # tell Dropout and BatchNorm to work bcs it is training
    epoch_loss = 0
    step = 0

    # getting data for each batch
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )

        # normal pipeline
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
       

        #loss.item -> each batch loss
        epoch_loss += loss.item()
        if (epoch) % print_interval == 0:
            print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")

    # measuring time
    actual_time = datetime.now() - start_time
    print(f"time to train this epoch: {actual_time}")
    total_time += actual_time.total_seconds()
    # saving the loss for the actual epoch
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")


    if (epoch + 1) % val_interval == 0:
        model.eval() # tell Dropout and BatchNorm to "turn off" bcs I am evaluating the model
        with torch.no_grad():
            loss_val = 0
            counter = 1
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                #get the loss to the validation set
                outputs = model(val_inputs)
                loss_val += loss_function(outputs, val_labels).item()
                print(f"val_loss: {loss_val}")

            loss_val_avg = loss_val / len(val_loader)
            losses_validation.append(loss_val_avg)
            
            print(f"validation average loss: {loss_val_avg:.4f}")
            f = open("valLossesSWIN.txt", "a")
            f.write(f"validation average loss: {loss_val_avg:.4f}\n")
            f.close()

            if early_stopper.early_stop(loss_val_avg):
                print("early stopped!")
                break

        #saving for each epoch
        model.eval()
        torch.save(model, "modelSWIN.h5")
