# Task 1: preRT segmentation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from monai.transforms import LoadImage
from monai.data import Dataset, DataLoader
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.utils import set_determinism
import os.path
import random
import torch

from monai.transforms import (
    Compose,
    LoadImaged,
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
    ToTensord,
    Resized,
    AsDiscreted
)

In [None]:
data_path = "/cluster/projects/vc/data/mic/open/HNTS-MRG/train/"

In [None]:
data_preRT = []
for patient_num in os.listdir(data_path):
    patient = f"{data_path}{patient_num}"
    image = f"{patient}/preRT/{patient_num}_preRT_T2.nii.gz"
    mask = f"{patient}/preRT/{patient_num}_preRT_mask.nii.gz"
    
    data_preRT.append({"image": image, "label": mask})

print(len(data_preRT))

In [None]:
set_determinism(seed=0)

training_data = data_preRT[:105]
validation_data = data_preRT[105:]

train_transforms = Compose(
     [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Resized(keys=["image", "label"], spatial_size=(128, 128, 16)),
        AsDiscreted(keys=["label"], to_onehot=3),
        ToTensord(keys=["image", "label"])
    ]
)

train_ds = Dataset(data=data_preRT, transform=train_transforms)
val_ds = Dataset(data=validation_data, transform=train_transforms)


train_dataloader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import gc
import time
from tqdm import tqdm 

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16,32,64,128,256),
    strides=(2, 2, 2, 2),
).to(device)

loss_function = DiceLoss(softmax=True)
optimizer = torch.optim.Adam(model.parameters())

maxepochs = 2

for epoch in range(maxepochs):
        torch.cuda.empty_cache()
        gc.collect()
        epoch_start = time.time()
        epoch_loss = []
        correct = 0
        total = 0
        print("-" * 10)
        print(f"epoch {epoch + 1}/{maxepochs}")
        model.train()

        for batch_data in tqdm(train_dataloader):
            print(batch_data["image"].shape)
            print(batch_data["label"].shape)
            images, labels = batch_data["image"].to(device), batch_data["label"].to(device)
            optimizer.zero_grad()
            outputs = model(images)
            print(outputs.shape)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.item())

            total += 1

            print(f"train loss: {loss.item()}")
        print(f"epoch {epoch + 1} average loss: {sum(epoch_loss)/total:.4f}")

