# Task 2: midRT Segmentation
This task reuses code written for task 1 and adapts it to solve task 2.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from monai.transforms import LoadImage
from monai.data import Dataset, DataLoader, ThreadDataLoader
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet, SwinUNETR, BasicUNetPlusPlus
from monai.utils import set_determinism
import os.path
import random
import torch

from monai.transforms import (
    Compose,
    LoadImaged,
    Compose,
    LoadImaged,
    RandSpatialCropd,
    EnsureChannelFirstd,
    ToTensord,
    Resized,
    AsDiscreted
)

In [None]:
data_path = "/cluster/projects/vc/data/mic/open/HNTS-MRG/train/"

In [None]:
data_midRT = []
for patient_num in os.listdir(data_path):
    patient = f"{data_path}{patient_num}"
    image = f"{patient}/midRT/{patient_num}_midRT_T2.nii.gz"
    mask = f"{patient}/midRT/{patient_num}_midRT_mask.nii.gz"
    
    data_midRT.append({"image": image, "label": mask})

print(len(data_midRT))

In [None]:
set_determinism(seed=0)

training_data = data_midRT[:105]
validation_data = data_midRT[105:]

train_transforms = Compose(
     [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        #Resized(keys=["image", "label"], spatial_size=(128, 128, 64)),
        RandSpatialCropd(
            keys=["image", "label"],
            roi_size = [256, 256, 32],
            random_center = True,
            random_size = False
            ),
        AsDiscreted(keys=["label"], to_onehot=3),
        ToTensord(keys=["image", "label"])
    ]
)
train_ds = Dataset(data=training_data, transform=train_transforms)
val_ds = Dataset(data=validation_data, transform=train_transforms)


train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16,32,64,128,256),
    strides=(2, 2, 2, 2),
).to(device)

model_name = "U-Net"

In [None]:
import gc
from tqdm import tqdm 

loss_function = DiceLoss(softmax=True)
optimizer = torch.optim.Adam(model.parameters())
dice_metric = DiceMetric(include_background=True, reduction="mean")
max_epochs = 10

print(f"Training: {model_name}")

for epoch in range(max_epochs):
        torch.cuda.empty_cache()
        gc.collect()
        epoch_loss = []
        correct = 0
        total = 0
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        
        training_losses = []
        for batch_data in tqdm(train_dataloader):
            images, labels = batch_data["image"].to(device), batch_data["label"].to(device)
            optimizer.zero_grad()
            outputs = model(images)
            #l = np.argmax(outputs[0], axis=0)
            #print(np.unique(l))
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.item())
            training_losses.append(loss.item())
            total += 1

            #print(f"train loss: {loss.item()}")
        
        #print(f"epoch {epoch + 1} average loss: {sum(epoch_loss)/total:.6f}")
        
        validation_losses = []
        model.eval()
        with torch.no_grad():
            dice_scores = []
            for batch in tqdm(val_dataloader):
                images, labels = batch["image"].to(device), batch["label"].to(device)
                outputs = model(images)
                loss = loss_function(outputs, labels)
                #print(f"Validation loss: {loss}")
                validation_losses.append(loss.item())
                dice_metric(y_pred=outputs, y=labels)
            mean_dice = dice_metric.aggregate().item()
        print(f"Training mean loss: {np.mean(training_losses)}")
        print(f"Validation mean loss: {np.mean(validation_losses)}")
        print(f"Validation Mean Dice: {mean_dice:.6f}")