This code is based on an [example](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb) provided by project Monai

## Setup imports

In [None]:
import os
import json
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
from monai.config import print_config
from monai.data import DataLoader, Dataset, decollate_batch
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss, GeneralizedWassersteinDiceLoss, HausdorffDTLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.metrics import HausdorffDistanceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism
from monai.transforms.transform import MapTransform
from monai.utils.enums import TransformBackends
from monai.config import KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
import numpy as np
import nibabel as nib
from collections.abc import Mapping, Hashable

import torch

## Set deterministic training for reproducibility

In [None]:
set_determinism(seed=0)

# Generate json

In [None]:
import os
import json

root = '../data/segmentation/train/'

# Amount of subjects
print(len(os.listdir(root)))

d = {}
d['training'] = []

# split the training data into an even amount of folds
num_per_fold = len(os.listdir(root)) / 5
curr_fold = 0
curr_num = 0
for traindir in os.listdir(root):
    curr_num += 1
    if curr_num == num_per_fold - 1:
        curr_num = 0
        curr_fold += 1        
    v = {'fold': curr_fold,
         'image': [root+traindir+'/'+f for f in os.listdir(root+traindir) if 'seg' not in f],
         'label': [root+traindir+'/'+f for f in os.listdir(root+traindir) if 'seg' in f ][0]
        }
    d['training'].append(v)
    
with open('data.json', 'w') as f:
    json.dump(d, f)

In [None]:
root = '../data/segmentation/test/'

# Amount of subjects
print(len(os.listdir(root)))

d = {}
d['testing'] = []

# split the training data into an even amount of folds
num_per_fold = len(os.listdir(root)) / 5
curr_fold = 0
curr_num = 0
for traindir in os.listdir(root):
    curr_num += 1
    if curr_num == num_per_fold - 1:
        curr_num = 0
        curr_fold += 1        
    v = {'fold': curr_fold,
         'image': [root+traindir+'/'+f for f in os.listdir(root+traindir) if 'seg' not in f],
         #'label': [root+traindir+'/'+f for f in os.listdir(root+traindir) if 'seg' in f ][0]
        }
    d['testing'].append(v)
    
with open('test.json', 'w') as f:
    json.dump(d, f)

## Set up transforms for the data

In [None]:
class MakeLabelDimsd(MapTransform):        
    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
        super().__init__(keys, allow_missing_keys)
    
    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            img = d[key]
            if img.ndim == 4 and img.shape[0] == 1:
                img = img.squeeze(0)

            # This seperates the ground-truth into all given channels except the background, as it does not fit the model otherwise
            result = [img == 1, img == 2, img == 4]
            d[key] = torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)
        return d
    
# Same as above but including the background label
class TestMakeLabelDimsd(MapTransform):        
    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
        super().__init__(keys, allow_missing_keys)
    
    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            img = d[key]
            if img.ndim == 4 and img.shape[0] == 1:
                img = img.squeeze(0)

            result = [img == 0, img == 1, img == 2, img == 4]
            d[key] = torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)
        return d

## Setup transforms for training and validation

In [None]:
train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        MakeLabelDimsd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 80], random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
'''
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        MakeLabelDimsd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)
'''
val_transform = Compose (
 [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        MakeLabelDimsd(keys="label"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]   
)

test_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        TestMakeLabelDimsd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

val_transform = Compose (
 [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image"]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]   
)

## Set up data-loaders

In [None]:
datalist = "./data.json"
testlist = "./test.json"
batch_size = 1
fold = (0,1)

with open(datalist) as f:
    json_data = json.load(f)

    json_data = json_data['training']

    tr, val, te = [], [], []
    for d in json_data:
        if "fold" in d and d["fold"] == fold[1]:
            val.append(d)
        else:
            tr.append(d)
        
with open(testlist) as f:
    json_data = json.load(f)
    json_data = json_data['testing']
    te = []
    for d in json_data:
        if "fold" in d:
            te.append(d)
            
train_ds = Dataset(data=tr, transform=train_transform)
train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )

val_ds = Dataset(data=val, transform=val_transform)
val_loader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )

test_ds = Dataset(data=te, transform=val_transform)
test_loader = DataLoader(
        test_ds,
        batch_size=1,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )

'''
test_ds = Dataset(data=te, transform=val_transform)
test_loader = DataLoader(
        test_ds,
        batch_size=1,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
'''

## Create Model, Loss, Optimizer

In [None]:
max_epochs = 800
val_interval = 5
VAL_AMP = True
use_pretrained = True

torch.cuda.empty_cache()

os.environ['CUDA_VISIBLE_DEVICES'] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)

if use_pretrained:
    model.load_state_dict(torch.load(os.path.join('./trained models', "best_metric_model_800.pth"), map_location=device))
    
loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)

optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

hausdorff_metric = HausdorffDistanceMetric(include_background=True, reduction="mean")
hausdorff_metric_batch = HausdorffDistanceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


# define inference method
def inference(input, model):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)


# use amp to accelerate training
scaler = torch.cuda.amp.GradScaler()
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True

# Train the model

In [None]:

post_transforms = Compose(
    [
        Invertd(
            keys="pred",
            transform=test_transform,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
            device="cpu",
        ),
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold=0.5),
    ]
)



best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
val_loss_values = []
metric_values = []
metric_values_ncr = []
metric_values_ed = []
metric_values_et = []

total_start = time.time()
for epoch in range(max_epochs):
    torch.cuda.empty_cache()
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        '''
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
            f", step time: {(time.time() - step_start):.4f}"
        )'''
    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        val_loss = 0
        step = 0
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                step += 1
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                val_outputs = inference(val_inputs)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)
                #outputs = model(val_inputs)
                #val_loss += loss_function(outputs, val_labels)

            #val_loss_values.append(val_loss/step)
            metric = dice_metric.aggregate().item()
            metric_values.append(metric)
            metric_batch = dice_metric_batch.aggregate()
            metric_ncr = metric_batch[0].item()
            metric_values_ncr.append(metric_ncr)
            metric_ed = metric_batch[1].item()
            metric_values_ed.append(metric_ed)
            metric_et = metric_batch[2].item()
            metric_values_et.append(metric_et)
            dice_metric.reset()
            dice_metric_batch.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                torch.save(
                    model.state_dict(),
                    os.path.join("./trained models", "best_metric_model_diceCE_800.pth"),
                )
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" ncr: {metric_ncr:.4f} ed: {metric_ed:.4f} et: {metric_et:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
total_time = time.time() - total_start

In [None]:
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}, total time: {total_time}.")

## Write results

In [None]:
np.savetxt("epoch_loss_values.csv", epoch_loss_values, delimiter=',')
np.savetxt("metric_values.csv", metric_values, delimiter=',')
np.savetxt("metric_values_ncr.csv", metric_values_ncr, delimiter=',')
np.savetxt("metric_values_ed.csv", metric_values_ed, delimiter=',')
np.savetxt("metric_values_et.csv", metric_values_et, delimiter=',')

# Make evaluation segmentation masks

In [None]:
from monai.utils import convert_data_type
from monai.metrics.utils import ignore_background
from monai.transforms import MeanEnsemble, VoteEnsemble
from pathlib import Path 
from shutil import copytree

def get_background(output):
    all_labels_tensor = []
    for out in decollate_batch(output):
        out = post_trans(out)
        neg = np.ones(out[0,:,:,:].shape)[np.newaxis,:,:,:]
        for i in range(3):
            neg = np.subtract(neg, out[i,:,:,:], where=neg != 0)
        all_labels_tensor.append(
            torch.cat([torch.from_numpy(neg).to(device), out], dim=0))
    return torch.stack(all_labels_tensor)

with torch.no_grad():
    for val_data, vald in zip(val_loader, val):
        model.load_state_dict(torch.load(os.path.join('./trained models', "best_metric_model_800.pth"), map_location=device))
        model.eval()
        
        torch.cuda.empty_cache()    
    
        val_inputs = val_data["image"].to(device)
        out = inference(val_inputs, model)
        out = get_background(out)
    
        # Use argmax along the first dimension to get the original segmentation mask
        reversed_mask = np.argmax(out[0].detach().cpu().numpy(), axis=0)
        
        # Map indices back to the original values (0, 1, 2, 4)
        original_mask = np.zeros_like(reversed_mask)
        for i, value in enumerate([0, 1, 2, 4]):
            original_mask[reversed_mask == i] = value
        
        i = vald["image"][0].split('/')[-2]
        Path(f'./pred/segmentation/{i}').mkdir(parents=True, exist_ok=True)
        shutil.copy(vald["label"], f'./pred/segmentation/{i}')
        Path(f'./results/segmentation/{i}').mkdir(parents=True, exist_ok=True)
        label_true = nib.load(vald["image"][0])
        aff = label_true.affine
        
        ni_img = nib.Nifti1Image(original_mask.astype('float64'), aff)
        nib.save(ni_img, f'./results/segmentation/{i}/{i}_seg.nii.gz')
        

# Make final submission segmentation masks

In [None]:
from monai.utils import convert_data_type
from monai.metrics.utils import ignore_background
from pathlib import Path
from shutil import copytree
from monai.transforms import MeanEnsemble, VoteEnsemble

def get_background(output):
    all_labels_tensor = []
    for out in decollate_batch(output):
        out = post_trans(out)
        neg = np.ones(out[0,:,:,:].shape)[np.newaxis,:,:,:]
        for i in range(3):
            neg = np.subtract(neg, out[i,:,:,:], where=neg != 0)
        all_labels_tensor.append(
            torch.cat([torch.from_numpy(neg).to(device), out], dim=0))
    return torch.stack(all_labels_tensor)


ensemble = VoteEnsemble()
with torch.no_grad():
    for val_data, vald in zip(test_loader, te):
        results = []
        model.load_state_dict(torch.load(os.path.join('./trained models', "best_metric_model_800.pth")))
        model.eval()
        
        torch.cuda.empty_cache()
    
        val_inputs = val_data["image"].to(device)
        out = inference(val_inputs)
        out = get_background(out)
        
        # Use argmax along the first dimension to get the original segmentation mask
        reversed_mask = np.argmax(out[0].detach().cpu().numpy(), axis=0)

        # Map indices back to the original values (0, 1, 2, 4)
        original_mask = np.zeros_like(reversed_mask)
        for i, value in enumerate([0, 1, 2, 4]):
            original_mask[reversed_mask == i] = value
        
        i = vald["image"][0].split('/')[-2]
        Path(f'./results/testing/segmentation/{i}').mkdir(parents=True, exist_ok=True)
        label_true = nib.load(vald["image"][0])
        aff = label_true.affine
        
        ni_img = nib.Nifti1Image(original_mask.astype('float64'), aff)
        nib.save(ni_img, f'./results/testing/segmentation/{i}/{i}_seg.nii.gz')