# Task 1: preRT segmentation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from monai.transforms import LoadImage
from monai.data import Dataset, DataLoader, ThreadDataLoader, CacheDataset, decollate_batch
from monai.losses import DiceLoss, DiceCELoss, FocalLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet, SwinUNETR, BasicUNetPlusPlus, SegResNet, resnet152, UNETR
from monai.utils import set_determinism
import os.path
import random
import torch

from monai.transforms import (
    Compose,
    LoadImaged,
    Compose,
    LoadImaged,
    RandSpatialCropd,
    EnsureChannelFirstd,
    ToTensord,
    Resized,
    AsDiscreted,
    RandFlipd,
    CropForegroundd,
    NormalizeIntensityd,
    Spacingd,
    AsDiscrete,
    CenterSpatialCropd,
    EnsureTyped
)

In [None]:
data_path = "/cluster/projects/vc/data/mic/open/HNTS-MRG/train/"

In [None]:
data_preRT = []
for patient_num in os.listdir(data_path):
    patient = f"{data_path}{patient_num}"
    image = f"{patient}/preRT/{patient_num}_preRT_T2.nii.gz"
    mask = f"{patient}/preRT/{patient_num}_preRT_mask.nii.gz"
    
    data_preRT.append({"image": image, "label": mask})

print(len(data_preRT))

In [None]:
set_determinism(seed=1)

training_data = data_preRT[:105]
validation_data = data_preRT[105:]
roi = (256, 256, 32)

train_transforms = Compose(
     [
        LoadImaged(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        RandSpatialCropd(
            keys=["image", "label"],
            roi_size = [roi[0], roi[1], roi[2]],
            random_center = True,
            random_size = False
            ),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"])
    ]
)

val_transforms = Compose(
     [
        LoadImaged(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"])
    ]
)

In [None]:
train_ds = CacheDataset(data=training_data, transform=train_transforms, cache_rate=1.0)
val_ds = CacheDataset(data=validation_data, transform=val_transforms, cache_rate=1.0)

In [None]:
train_dataloader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None

## Choose a model

In [None]:
#Run this cell for UNet
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16,32,64,128,256),
    strides=(2, 2, 2, 2),
).to(device)

model_name = "U-Net"

In [None]:
model = SwinUNETR(
    img_size = roi,
    in_channels=1,
    out_channels=3,
).to(device)

model_name = "Swin-UNet"

In [None]:
model = BasicUNetPlusPlus(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
)

model_name = "Basic U-Net++"

In [None]:
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=1,
    out_channels=3,
    dropout_prob=0.2,
).to(device)

model_name = "SegResNet"

## Run weighted loss
Based on this paper https://arxiv.org/pdf/1901.05555

In [None]:
from collections import Counter

class_counts = Counter()

for idx in range(len(train_ds)):
    label = train_ds[idx]["label"]

    if label.ndim > 1 and label.shape[0] > 1:  
        label = np.argmax(label, axis=0)
    unique_values, counts = np.unique(label, return_counts=True)
    class_counts.update(dict(zip(unique_values, counts)))

class_counts = np.array([class_counts.get(cls, 0) for cls in range(3)])
class_counts = class_counts / class_counts.min()
print(f"Class counts: {class_counts}")


In [None]:
beta = 0.99
effective_num = (1 - np.power(beta, class_counts)) / (1 - beta)

class_weights = 1.0 / effective_num

class_weights = class_weights / np.sum(class_weights)

class_weights_tensor = torch.tensor(class_weights).float().to(device)
print(class_weights_tensor)

## Load model weights

In [None]:
# Step 2: Load the state dictionary
state_dict = torch.load("model_weights.pth")

# Step 3: Load the weights into the model
model.load_state_dict(state_dict)

## Train the model

In [None]:
import gc
from monai.inferers import sliding_window_inference
from tqdm import tqdm
from collections import defaultdict
import time
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters())
#loss_function = FocalLoss(include_background=True, to_onehot_y=True, weight=class_weights_tensor)
#optimizer = torch.optim.Adam(model.parameters())
#adamW
dice_metric = DiceMetric(include_background=True, reduction="mean")
max_epochs = 200
post_label = AsDiscrete(to_onehot=3)
post_pred = AsDiscrete(argmax=True, to_onehot=3)

print(f"Training: {model_name}")

training_loss_pr_epoch = []
validation_loss_pr_epoch = []
dice_metric_pr_epoch = []
training_dice_pr_epoch = []

start = time.time()

for epoch in range(max_epochs):
        
        torch.cuda.empty_cache()
        gc.collect()

        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        training_losses = []
        unique_labels_dict = defaultdict(int)
        
        for batch_data in tqdm(train_dataloader):
            images, labels = batch_data["image"].to(device), batch_data["label"].to(device)
            optimizer.zero_grad()
            outputs = model(images)
            
            val_labels_list = decollate_batch(labels)
            val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(outputs)
            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            
            
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            training_losses.append(loss.item())
            
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
        training_dice = dice_metric.aggregate().item()
        training_dice_pr_epoch.append(training_dice)
        dice_metric.reset()

        validation_losses = []
        model.eval()
        unique_labels_dict_val = defaultdict(int)
        with torch.no_grad():
            dice_scores = []
            for batch in tqdm(val_dataloader):
                images, labels = batch["image"].to(device), batch["label"].to(device)
                
                outputs = sliding_window_inference(
                    images,                      
                    roi_size=(roi[0], roi[1], roi[2]),     
                    sw_batch_size=4,             
                    predictor=model,            
                    overlap=0.5                
                )
                
                loss = loss_function(outputs, labels)
                validation_losses.append(loss.item())
                
                #segmentation guide
                val_labels_list = decollate_batch(labels)
                val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
                val_outputs_list = decollate_batch(outputs)
                val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
                
                dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            validation_dice = dice_metric.aggregate().item()
            dice_metric.reset()
            dice_metric_pr_epoch.append(validation_dice)


        training_loss_pr_epoch.append(np.mean(training_losses))
        validation_loss_pr_epoch.append(np.mean(validation_losses))
        print(f"Training mean loss: {np.mean(training_losses)}")
        print(f"Validation mean loss: {np.mean(validation_losses)}")
        print(f"Training dice {training_dice}")
        print(f"Validation Mean Dice: {validation_dice}")

end = time.time()
print(f"{max_epochs} took {start - end} time.")

In [None]:
save_name = f"{model_name} {len(training_dice_pr_epoch)}"
torch.save(model.state_dict(), save_name)

## Cross validation code
Was used for exploration of hyperparameters

In [None]:

def train(model, device, train_loader, optimizer, epoch, max_epochs):
    torch.cuda.empty_cache()
    gc.collect()
    epoch_loss = []
    correct = 0
    total = 0
    #print("-" * 10)
    #print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()

    training_losses = []
    for batch_data in train_loader:
        images, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(images)
        #l = np.argmax(outputs[0], axis=0)
        #print(np.unique(l))
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
        training_losses.append(loss.item())
        total += 1
    return training_losses


In [None]:

def evaluate(model, device, val_loader, optimizer):
    validation_losses = []
    model.eval()
    with torch.no_grad():
        dice_scores = []
        for batch in val_loader:
            images, labels = batch["image"].to(device), batch["label"].to(device)
            outputs = model(images)
            loss = loss_function(outputs, labels)
            #print(f"Validation loss: {loss}")
            validation_losses.append(loss.item())
            dice_metric(y_pred=outputs, y=labels)
        mean_dice = dice_metric.aggregate().item()
    return np.mean(validation_losses), mean_dice


In [None]:
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, SubsetRandomSampler

splits = 5
kf = KFold(n_splits=splits, shuffle=True, random_state=42)
grid_crop_x_y = [64, 128, 256]
grid_crop_z = [16, 32]

for x in grid_crop_x_y:
    for y in grid_crop_x_y:
        for z in grid_crop_z:
            
            fold_val_loss = []
            fold_dice_loss = []
            

            train_transforms = Compose(
                 [
                    LoadImaged(keys=["image", "label"]),
                    EnsureChannelFirstd(keys=["image", "label"]),
                    RandSpatialCropd(
                        keys=["image", "label"],
                        roi_size = [x, y, z],
                        random_center = True,
                        random_size = False
                    ),
                    AsDiscreted(keys=["label"], to_onehot=3),
                    ToTensord(keys=["image", "label"])
                ]
            )
            
            train_ds = CacheDataset(data=training_data, transform=train_transforms, cache_rate=1.0)
            
            fold_val_loss = []
            fold_dice_metric = []
            
            for fold, (train_idx, val_idx) in enumerate(kf.split(train_ds)):
                max_epochs = 10


                train_sampler = SubsetRandomSampler(train_idx)
                val_sampler = SubsetRandomSampler(val_idx)

                train_loader = DataLoader(
                    dataset=train_ds,
                    sampler=train_sampler,
                    batch_size=4,
                )

                val_loader = DataLoader(
                    dataset=train_ds,
                    sampler=val_sampler,
                    batch_size=4,
                )

                model = UNet(
                    spatial_dims=3,
                    in_channels=1,
                    out_channels=3,
                    channels=(16, 32, 64, 128, 256),
                    strides=(2, 2, 2, 2),
                ).to(device)

                optimizer = torch.optim.Adam(model.parameters())
                dice_metric = DiceMetric(include_background=True, reduction="mean")
                
                for epoch in range(max_epochs):   
                    _ = train(model, device, train_loader, optimizer, epoch, max_epochs)
                mean_validation_loss, mean_dice = evaluate(model, device, val_loader, optimizer)

                fold_val_loss.append(mean_validation_loss)
                fold_dice_metric.append(mean_dice)
                
            mean_val_loss = np.mean(fold_val_loss)
            mean_dice_metric = np.mean(fold_dice_metric)
            
            print("-------")
            print(f"{splits} fold cross validation for x={x}, y={y}, z={z}")
            print(f"Mean validation loss across folds: {mean_val_loss}")
            print(f"Mean dice metric: {mean_dice_metric}")

## Inspect slices of images

In [None]:
image = train_ds[2]["image"]
label = val_ds[5]["label"]
image_with_batch = np.expand_dims(image, axis=0)
image_with_batch = torch.from_numpy(image_with_batch).float()
label_remove_one_hot = np.argmax(label, axis=0)

unique_values, counts = np.unique(label_remove_one_hot, return_counts=True)

# Print the unique values and their counts
print(f"Unique values: {unique_values}")
print(f"Counts: {counts}")

In [None]:
image_vis = np.squeeze(image, axis=0)
plt.imshow(image_vis[:, :, 10], cmap='gray')

In [None]:
image_with_batch = image_with_batch.to(device)


model.eval()
with torch.no_grad():  
    output = model(image_with_batch)
    output = output.cpu().numpy()
output_label_vis = np.squeeze(output, axis=0)
output_remove_one_hot = np.argmax(output_label_vis, axis=0)
print(np.unique(output_remove_one_hot))
positions = np.where(output_remove_one_hot == 2)
print(positions)

In [None]:
import matplotlib.pyplot as plt
layer=3


plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)  
plt.title("Ground Truth")
plt.imshow(label_remove_one_hot[:, :, layer], cmap='gray')
plt.axis('off') 

plt.subplot(2, 2, 2) 
plt.title("Model Output")
plt.imshow(output_remove_one_hot[:, :, layer], cmap='gray')
plt.axis('off') 

plt.tight_layout() 
plt.show()
