## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Setup imports

In [None]:
import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from monai.config import print_config
from monai.data import (CacheDataset, DataLoader, decollate_batch,
                        load_decathlon_datalist)
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.transforms import (AsDiscrete, Compose, EnsureChannelFirstd,
                              LoadImaged, NormalizeIntensityd, Orientationd,
                              RandAdjustContrastd, RandBiasFieldd,
                              RandCropByPosNegLabeld, RandFlipd,
                              RandGibbsNoised, RandHistogramShiftd,
                              RandKSpaceSpikeNoised, RandRotate90d, Spacingd,
                              SpatialPadd, ToTensord)
from monai.utils import set_determinism
from tqdm import tqdm

from swin_unetr.cascaded_unet import CascadedUNet

print_config()

##### Note:
Define training as decathlon dataset with training/validation split. Also set directories, and detect number of labels that the network must predict.

The feature_nets list contains the path to pre-trained feature nets, and the number of output channels predicted by those networks.

##### Define file paths & output directory path

In [None]:
json_path = os.path.normpath("D:/lloyd/datasets/CC/charm_corrected6_t1.json")
data_dir = os.path.normpath("D:/lloyd/datasets/CC")
logdir = os.path.normpath("D:/lloyd/datasets/CC/charm_corrected6_cascade_unet")
labels = json.loads(Path(json_path).read_text())["labels"]
labels = {int(k): v for k, v in labels.items()}
num_classes = max(labels.keys()) + 1
feature_nets = [
    (
        4,
        "D:/lloyd/datasets/CC/WM-GM-CSF_train/training/3/epoch=488-val_loss=0.06-val_dice=0.9217.pth",
    ),
]

if os.path.exists(logdir) is False:
    os.mkdir(logdir)

set_determinism(seed=960)

##### Defined flag to utilize pre-trained weights and path to pre-trained weights. If flag is set to 'False', random initialization will be used.

In [None]:
use_pretrained = False
pretrained_path = os.path.normpath(os.path.join(logdir, "best_metric_model.pth"))

##### MONAI Transforms for training and validation, training configuration

In [None]:
# Training Hyper-params
lr = 4e-4
max_iterations = 30000
eval_num = 100

# Transforms
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], reader="ITKReader"),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=False, channel_wise=True),
        # CropForegroundd(keys=["image", "label"], source_key="image"),
        SpatialPadd(
            keys=["image", "label"],
            spatial_size=(96, 96, 96),
        ),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandAdjustContrastd(keys="image", prob=0.2, gamma=(0.5, 4.5)),
        RandHistogramShiftd(keys="image", prob=0.2, num_control_points=10),
        RandBiasFieldd(keys="image", prob=0.2),
        # RandRicianNoised(keys="image", prob=0.1),
        RandGibbsNoised(keys="image", prob=0.2, alpha=(0.0, 1.0)),
        RandKSpaceSpikeNoised(keys="image", prob=0.2),
        ToTensord(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], reader="ITKReader"),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=False, channel_wise=True),
        # CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ]
)

##### Load data list and create dataloaders for training

Since there is a mismatch between the spacing and the affine matrix in the BTCV dataset, users will see warnings "pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1".
This is expected, and not affecting the results in this notebook.

In [None]:
datalist = load_decathlon_datalist(
    base_dir=data_dir,
    data_list_file_path=json_path,
    is_segmentation=True,
    data_list_key="training",
)

val_files = load_decathlon_datalist(
    base_dir=data_dir,
    data_list_file_path=json_path,
    is_segmentation=True,
    data_list_key="validation",
)


train_ds = CacheDataset(
    data=datalist,
    transform=train_transforms,
    cache_num=24,
    cache_rate=1.0,
    num_workers=4,
)

val_ds = CacheDataset(
    data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4
)
train_loader = DataLoader(
    train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True
)

# Sanity check for shapes from data loaders
case_num = 0
img = val_ds[case_num]["image"]
label = val_ds[case_num]["label"]
img_shape = img.shape
label_shape = label.shape
print(f"image shape: {img_shape}, label shape: {label_shape}")

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

model = CascadedUNet(
    in_channels=1,
    feature_channel_list=[p[0] for p in feature_nets],
    out_channels=num_classes,
)

if use_pretrained is True:
    print(f"Loading Weights from the Path {pretrained_path}")
    model_state_dict = torch.load(pretrained_path)
    model.load_state_dict(model_state_dict, strict=True)
else:
    for idx, num_path in enumerate(feature_nets):
        model_state_dict = torch.load(num_path[1])
        model.feature_nets[idx].load_state_dict(model_state_dict, strict=True)

model.to(device)

# loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

post_label = AsDiscrete(to_onehot=num_classes)
post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []

In [None]:
def validation(epoch_iterator_val):
    model.eval()
    dice_vals = []

    with torch.no_grad():
        for _step, batch in enumerate(epoch_iterator_val):
            val_inputs, val_labels = (
                batch["image"].to(device),
                batch["label"].to(device),
            )
            val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [
                post_label(val_label_tensor) for val_label_tensor in val_labels_list
            ]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [
                post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list
            ]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            dice = dice_metric.aggregate().item()
            dice_vals.append(dice)
            epoch_iterator_val.set_description(
                "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)
            )

        dice_metric.reset()

    mean_dice_val = np.mean(dice_vals)
    return mean_dice_val


def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(
        train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True
    )
    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = (batch["image"].to(device), batch["label"].to(device))
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description(
            "Training (%d / %d Steps) (loss=%2.5f)"
            % (global_step, max_iterations, loss)
        )

        if (
            global_step % eval_num == 0 and global_step != 0
        ) or global_step == max_iterations:
            epoch_iterator_val = tqdm(
                val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True
            )
            dice_val = validation(epoch_iterator_val)

            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(
                    model.state_dict(), os.path.join(logdir, "best_metric_model.pth")
                )
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )

            plt.figure(1, (12, 6))
            plt.subplot(1, 2, 1)
            plt.title("Iteration Average Loss")
            x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
            y = epoch_loss_values
            plt.xlabel("Iteration")
            plt.plot(x, y)
            plt.grid()
            plt.subplot(1, 2, 2)
            plt.title("Val Mean Dice")
            x = [eval_num * (i + 1) for i in range(len(metric_values))]
            y = metric_values
            plt.xlabel("Iteration")
            plt.plot(x, y)
            plt.grid()
            plt.savefig(os.path.join(logdir, "btcv_finetune_quick_update.png"))
            plt.clf()
            plt.close(1)

        global_step += 1
    return global_step, dice_val_best, global_step_best


while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(
        global_step, train_loader, dice_val_best, global_step_best
    )

torch.save(model.state_dict(), os.path.join(logdir, "final_iteration_model.pth"))
model.load_state_dict(torch.load(os.path.join(logdir, "best_metric_model.pth")))

print(
    f"train completed, best_metric: {dice_val_best:.4f} "
    f"at iteration: {global_step_best}"
)

##### Visualize the training curves

In [None]:
plt.figure(1, (12, 6))
plt.subplot(1, 2, 1)
plt.title("Iteration Average Loss")
x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.grid()
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [eval_num * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.grid()
plt.show()