# Task 2: midRT Segmentation
This task reuses code written for task 1 and adapts it to solve task 2.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from monai.transforms import LoadImage
from monai.data import Dataset, DataLoader, ThreadDataLoader, CacheDataset, decollate_batch
from monai.losses import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet, SwinUNETR, BasicUNetPlusPlus
from monai.utils import set_determinism
import os.path
import random
import torch
import gc
from monai.inferers import sliding_window_inference
from tqdm import tqdm
from collections import defaultdict
import time

from monai.transforms import (
    Compose,
    LoadImaged,
    Compose,
    LoadImaged,
    RandSpatialCropd,
    EnsureChannelFirstd,
    ToTensord,
    Resized,
    AsDiscreted,
    EnsureTyped,
    RandFlipd,
    NormalizeIntensityd,
    AsDiscrete
)

In [None]:
data_path = "/cluster/projects/vc/data/mic/open/HNTS-MRG/train/"
test_data_path = "/cluster/projects/vc/data/mic/open/HNTS-MRG/test/"

In [None]:
data_midRT = []
for patient_num in os.listdir(data_path):
    patient = f"{data_path}{patient_num}"
    image = f"{patient}/midRT/{patient_num}_midRT_T2.nii.gz"
    mask = f"{patient}/midRT/{patient_num}_midRT_mask.nii.gz"
    
    data_midRT.append({"image": image, "label": mask})

print(len(data_midRT))

In [None]:
# Load test data
test_midRT = []
for patient_num in os.listdir(test_data_path):
    patient = f"{test_data_path}{patient_num}"
    image = f"{patient}/midRT/{patient_num}_midRT_T2.nii.gz"
    mask = f"{patient}/midRT/{patient_num}_midRT_mask.nii.gz"
    
    test_midRT.append({"image": image, "label": mask})

print(len(test_midRT))

In [None]:
set_determinism(seed=1)

training_data = data_midRT[:105]
validation_data = data_midRT[105:]
roi = (256, 256, 32)

train_transforms = Compose(
     [
        LoadImaged(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        RandSpatialCropd(
            keys=["image", "label"],
            roi_size = [roi[0], roi[1], roi[2]],
            random_center = True,
            random_size = False
            ),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"])
    ]
)

val_transforms = Compose(
     [
        LoadImaged(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"])
    ]
)

In [None]:
train_ds = CacheDataset(data=training_data, transform=train_transforms, cache_rate=1.0)
val_ds = CacheDataset(data=validation_data, transform=val_transforms, cache_rate=1.0)

In [None]:
train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16,32,64,128,256),
    strides=(2, 2, 2, 2),
).to(device)

model_name = "U-Net"

In [None]:
model = SwinUNETR(
    img_size = roi,
    in_channels=1,
    out_channels=3,
).to(device)

model_name = "Swin-UNet"

### Train the model

In [None]:
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters())
#loss_function = FocalLoss(include_background=True, to_onehot_y=True, weight=class_weights_tensor)
#optimizer = torch.optim.Adam(model.parameters())
#adamW
dice_metric = DiceMetric(include_background=True, reduction="mean")
max_epochs = 200
post_label = AsDiscrete(to_onehot=3)
post_pred = AsDiscrete(argmax=True, to_onehot=3)

print(f"Training: {model_name}")

training_loss_pr_epoch = []
validation_loss_pr_epoch = []
dice_metric_pr_epoch = []
training_dice_pr_epoch = []

start = time.time()

for epoch in range(max_epochs):
        
        torch.cuda.empty_cache()
        gc.collect()

        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        training_losses = []
        unique_labels_dict = defaultdict(int)
        
        for batch_data in tqdm(train_dataloader):
            images, labels = batch_data["image"].to(device), batch_data["label"].to(device)
            optimizer.zero_grad()
            outputs = model(images)
            
            val_labels_list = decollate_batch(labels)
            val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(outputs)
            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            
            
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            training_losses.append(loss.item())
            
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            
        training_dice = dice_metric.aggregate().item()
        training_dice_pr_epoch.append(training_dice)
        dice_metric.reset()

        validation_losses = []
        model.eval()
        unique_labels_dict_val = defaultdict(int)
        with torch.no_grad():
            dice_scores = []
            for batch in tqdm(val_dataloader):
                images, labels = batch["image"].to(device), batch["label"].to(device)
                
                outputs = sliding_window_inference(
                    images,                      
                    roi_size=(roi[0], roi[1], roi[2]),     
                    sw_batch_size=4,             
                    predictor=model,            
                    overlap=0.5                
                )
                
                loss = loss_function(outputs, labels)
                validation_losses.append(loss.item())
                
                #segmentation guide
                val_labels_list = decollate_batch(labels)
                val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
                val_outputs_list = decollate_batch(outputs)
                val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
                
                dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            validation_dice = dice_metric.aggregate().item()
            dice_metric.reset()
            dice_metric_pr_epoch.append(validation_dice)


        training_loss_pr_epoch.append(np.mean(training_losses))
        validation_loss_pr_epoch.append(np.mean(validation_losses))
        print(f"Training mean loss: {np.mean(training_losses)}")
        print(f"Validation mean loss: {np.mean(validation_losses)}")
        print(f"Training dice {training_dice}")
        print(f"Validation Mean Dice: {validation_dice}")

end = time.time()
print(f"{max_epochs} took {start - end} time.")

In [None]:
save_name = f"{model_name} {len(training_dice_pr_epoch)}"
torch.save(model.state_dict(), save_name)

In [None]:
# Loss
plt.plot(training_loss_pr_epoch, label="Training loss")
plt.plot(validation_loss_pr_epoch, label="Validation loss ")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title(f"Training and validation loss - {model_name} - DiceCELoss")
plt.legend()

In [None]:
plt.plot(dice_metric_pr_epoch, label="Dice Metric")
plt.xlabel("Epoch")
plt.ylabel("Dice Metric")
plt.title(f"Validation Dice Metric {model_name}-DiceCELoss")
plt.legend()

## Running the model on the test set

In [None]:
#Using the validation transform on the test data set
test_ds = Dataset(data=test_midRT, transform=val_transforms)
test_dataloader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)

### Load the weights if needed

In [None]:
# Load the state dictionary
state_dict = torch.load("UNet_400.pth")
# Load the weights into the model
model.load_state_dict(state_dict)

In [None]:
model.eval()
with torch.no_grad():
    dice_scores = []
    for batch in tqdm(test_dataloader):
        images, labels = batch["image"].to(device), batch["label"].to(device)
        
        outputs = sliding_window_inference(
            images,                      
            roi_size=(roi[0], roi[1], roi[2]),     
            sw_batch_size=4,             
            predictor=model,            
            overlap=0.5                
        )
        
        #segmentation guide
        test_labels_list = decollate_batch(labels)
        test_labels_convert = [post_label(test_label_tensor) for test_label_tensor in test_labels_list]
        test_outputs_list = decollate_batch(outputs)
        test_output_convert = [post_pred(test_pred_tensor) for test_pred_tensor in test_outputs_list]
        
        dice_metric(y_pred=test_output_convert, y=test_labels_convert)
    test_dice = dice_metric.aggregate().item()
    dice_metric.reset()
    
    print(f"The mean test dice is {test_dice}")
