Adapated from Project-MONAI implementation:  
https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb  

Original papers:  
https://arxiv.org/abs/2111.14791  
https://arxiv.org/abs/2201.01266  

## Setup environment

In [1]:
!python -c "import monai; import nibabel; import tqdm" || pip install -q "monai-weekly[nibabel, tqdm, einops]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [2]:
import monai
import nibabel
import tqdm
import numpy as np
import gc
import optuna

  from .autonotebook import tqdm as notebook_tqdm


## Setup imports

In [2]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    EnsureTyped,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import SwinUNETR

from monai.data import (
    ThreadDataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
)


import torch
torch.cuda.empty_cache() 

print_config()

MONAI version: 1.4.dev2350
Numpy version: 1.26.2
Pytorch version: 2.1.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: e54c05d7a659e4d402f5194412889779e6856cbd
MONAI __file__: /home/<username>/anaconda3/envs/myenv/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.1.0
scikit-image version: 0.22.0
scipy version: 1.11.4
Pillow version: 10.1.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.65.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 2.1.4
einops version: 0.7.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version:

  from .autonotebook import tqdm as notebook_tqdm


## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [4]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmpc_ddpf_2


In [5]:
num_samples = 4

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], ensure_channel_first=True),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Spacingd(
        #     keys=["image", "label"],
        #     pixdim=(1.5, 1.5, 2.0),
        #     mode=("bilinear", "nearest"),
        # ),
        EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=num_samples,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], ensure_channel_first=True),
        ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Spacingd(
        #     keys=["image", "label"],
        #     pixdim=(1.5, 1.5, 2.0),
        #     mode=("bilinear", "nearest"),
        # ),
        EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
    ]
)



In [7]:
data_dir = "data/"
split_json = "dataset_0.json"


datasets = data_dir + split_json
datalist = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")
train_ds = CacheDataset(
    data=datalist,
    transform=train_transforms,
    cache_num=24,
    cache_rate=1.0,
    num_workers=8,
)
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=1, shuffle=True)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4)
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)

# # as explained in the "Setup transforms" section above, we want cached training images to not have metadata, and validations to have metadata
# # the EnsureTyped transforms allow us to make this distinction
# # on the other hand, set_track_meta is a global API; doing so here makes sure subsequent transforms (i.e., random transforms for training)
# # will be carried out as Tensors, not MetaTensors
set_track_meta(False)

Loading dataset: 100%|██████████| 15/15 [00:03<00:00,  4.21it/s]
Loading dataset: 100%|██████████| 5/5 [00:01<00:00,  3.13it/s]


In [3]:
# uncomment to download the pre-trained weights
# !wget https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/model_swinvit.pt
weight = torch.load("../model_weights/model_swinvit.pt")


### Execute a typical PyTorch training process

In [9]:
def validation(model, dice_metric, post_label, post_pred, global_step, epoch_iterator_val):
    model.eval()
    with torch.no_grad():
        for batch in epoch_iterator_val:
            val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
            with torch.cuda.amp.autocast():
                val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
        mean_dice_val = dice_metric.aggregate().item()
        dice_metric.reset()
    return mean_dice_val


def train(model, loss_function, scaler, optimizer, dice_metric, post_label, 
          post_pred, global_step, train_loader, val_loader, dice_val_best, 
          global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True, disable = True)
    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = (batch["image"].cuda(), batch["label"].cuda())
        # print(x.shape)
        with torch.cuda.amp.autocast():
            logit_map = model(x)
            loss = loss_function(logit_map, y)
        scaler.scale(loss).backward()
        epoch_loss += loss.item()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        epoch_iterator.set_description(f"Training ({global_step} / {max_iterations} Steps) (loss={loss:2.5f})")
        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True, disable = True)
            dice_val = validation(model, dice_metric, post_label, post_pred, global_step, epoch_iterator_val)
            epoch_loss /= step
            # epoch_loss_values.append(epoch_loss)
            # metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )
        global_step += 1
    return global_step, dice_val_best, global_step_best

In [10]:
max_iterations = 3000
eval_num = 500
post_label = AsDiscrete(to_onehot=17)
post_pred = AsDiscrete(argmax=True, to_onehot=17)

def objective(trial):
    params = {
        'learning_rate': trial.suggest_float('learning_rate', 1e-6, 1e-3, log=True),
        'weight_decay': trial.suggest_float('weight_decay', 1e-6, 1e-3, log=True),
        # 'feature_size' : trial.suggest_categorical('feature_size', [24, 48]),
    }

    model = SwinUNETR(
        img_size=(96, 96, 96),
        in_channels=1,
        out_channels=17,
        feature_size= 48,
        use_checkpoint=True,
    ).to(device)
    model.load_from(weights=weight)

    torch.backends.cudnn.benchmark = True
    loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr= params['learning_rate'] , weight_decay= params['weight_decay'])
    scaler = torch.cuda.amp.GradScaler()   
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) 

    global_step = 0
    dice_val_best = 0.0
    global_step_best = 0
    # epoch_loss_values = []
    # metric_values = []
    
    while global_step < max_iterations:
        global_step, dice_val_best, global_step_best = train(model, loss_function, scaler, optimizer, dice_metric, post_label, 
                                                            post_pred, global_step, train_loader, val_loader, dice_val_best, global_step_best)
    
    del model, loss_function, optimizer, scaler, dice_metric
    torch.cuda.empty_cache()
    gc.collect()

    return dice_val_best

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=12)
print('Best trial:', study.best_trial.params)

[I 2023-12-11 07:55:25,481] A new study created in memory with name: no-name-1786e41f-3ec2-4745-b7c9-9c2b02e38fc6


Model Was Saved ! Current Best Avg. Dice: 0.13845957815647125 Current Avg. Dice: 0.13845957815647125
Model Was Saved ! Current Best Avg. Dice: 0.17365843057632446 Current Avg. Dice: 0.17365843057632446
Model Was Saved ! Current Best Avg. Dice: 0.2079610377550125 Current Avg. Dice: 0.2079610377550125
Model Was Saved ! Current Best Avg. Dice: 0.24421270191669464 Current Avg. Dice: 0.24421270191669464
Model Was Saved ! Current Best Avg. Dice: 0.2765645384788513 Current Avg. Dice: 0.2765645384788513


[I 2023-12-11 08:19:22,796] Trial 0 finished with value: 0.2765645384788513 and parameters: {'learning_rate': 4.7830071309333595e-05, 'weight_decay': 0.00017661539680879465}. Best is trial 0 with value: 0.2765645384788513.


Model Was Saved ! Current Best Avg. Dice: 0.1345311999320984 Current Avg. Dice: 0.1345311999320984
Model Was Saved ! Current Best Avg. Dice: 0.14748993515968323 Current Avg. Dice: 0.14748993515968323
Model Was Saved ! Current Best Avg. Dice: 0.17255418002605438 Current Avg. Dice: 0.17255418002605438
Model Was Saved ! Current Best Avg. Dice: 0.1792481243610382 Current Avg. Dice: 0.1792481243610382
Model Was Saved ! Current Best Avg. Dice: 0.1847773939371109 Current Avg. Dice: 0.1847773939371109


[I 2023-12-11 08:43:26,295] Trial 1 finished with value: 0.1847773939371109 and parameters: {'learning_rate': 1.2447803541195326e-05, 'weight_decay': 1.8925251785853914e-06}. Best is trial 0 with value: 0.2765645384788513.


Model Was Saved ! Current Best Avg. Dice: 0.18008962273597717 Current Avg. Dice: 0.18008962273597717
Model Was Saved ! Current Best Avg. Dice: 0.24802164733409882 Current Avg. Dice: 0.24802164733409882
Model Was Saved ! Current Best Avg. Dice: 0.3255198001861572 Current Avg. Dice: 0.3255198001861572
Model Was Saved ! Current Best Avg. Dice: 0.35230299830436707 Current Avg. Dice: 0.35230299830436707
Model Was Saved ! Current Best Avg. Dice: 0.37923064827919006 Current Avg. Dice: 0.37923064827919006


[I 2023-12-11 09:07:24,554] Trial 2 finished with value: 0.37923064827919006 and parameters: {'learning_rate': 0.00016710975741455968, 'weight_decay': 2.2818388889240655e-06}. Best is trial 2 with value: 0.37923064827919006.


Model Was Saved ! Current Best Avg. Dice: 0.15660706162452698 Current Avg. Dice: 0.15660706162452698
Model Was Saved ! Current Best Avg. Dice: 0.18313857913017273 Current Avg. Dice: 0.18313857913017273
Model Was Saved ! Current Best Avg. Dice: 0.21326899528503418 Current Avg. Dice: 0.21326899528503418
Model Was Saved ! Current Best Avg. Dice: 0.24154575169086456 Current Avg. Dice: 0.24154575169086456
Model Was Saved ! Current Best Avg. Dice: 0.2856752872467041 Current Avg. Dice: 0.2856752872467041


[I 2023-12-11 09:31:28,280] Trial 3 finished with value: 0.2856752872467041 and parameters: {'learning_rate': 3.949931476946886e-05, 'weight_decay': 6.425866209968602e-06}. Best is trial 2 with value: 0.37923064827919006.


Model Was Saved ! Current Best Avg. Dice: 0.11024899780750275 Current Avg. Dice: 0.11024899780750275
Model Was Saved ! Current Best Avg. Dice: 0.12684547901153564 Current Avg. Dice: 0.12684547901153564
Model Was Saved ! Current Best Avg. Dice: 0.13275131583213806 Current Avg. Dice: 0.13275131583213806
Model Was Saved ! Current Best Avg. Dice: 0.13383108377456665 Current Avg. Dice: 0.13383108377456665
Model Was Saved ! Current Best Avg. Dice: 0.13615910708904266 Current Avg. Dice: 0.13615910708904266


[I 2023-12-11 09:55:26,773] Trial 4 finished with value: 0.13615910708904266 and parameters: {'learning_rate': 4.802585515817835e-06, 'weight_decay': 6.362109261074949e-05}. Best is trial 2 with value: 0.37923064827919006.


Model Was Saved ! Current Best Avg. Dice: 0.08310196548700333 Current Avg. Dice: 0.08310196548700333
Model Was Saved ! Current Best Avg. Dice: 0.11384280025959015 Current Avg. Dice: 0.11384280025959015
Model Was Saved ! Current Best Avg. Dice: 0.12309964746236801 Current Avg. Dice: 0.12309964746236801
Model Was Saved ! Current Best Avg. Dice: 0.12577804923057556 Current Avg. Dice: 0.12577804923057556
Model Was Saved ! Current Best Avg. Dice: 0.12759996950626373 Current Avg. Dice: 0.12759996950626373


[I 2023-12-11 10:19:27,349] Trial 5 finished with value: 0.12759996950626373 and parameters: {'learning_rate': 1.356189384915904e-06, 'weight_decay': 3.340329439843701e-06}. Best is trial 2 with value: 0.37923064827919006.


Model Was Saved ! Current Best Avg. Dice: 0.08405785262584686 Current Avg. Dice: 0.08405785262584686
Model Was Saved ! Current Best Avg. Dice: 0.11446747928857803 Current Avg. Dice: 0.11446747928857803
Model Was Saved ! Current Best Avg. Dice: 0.12733714282512665 Current Avg. Dice: 0.12733714282512665
Model Was Saved ! Current Best Avg. Dice: 0.1373184770345688 Current Avg. Dice: 0.1373184770345688
Model Was Saved ! Current Best Avg. Dice: 0.14117325842380524 Current Avg. Dice: 0.14117325842380524


[I 2023-12-11 10:43:24,548] Trial 6 finished with value: 0.14117325842380524 and parameters: {'learning_rate': 1.8432292380554347e-06, 'weight_decay': 4.624292579977696e-05}. Best is trial 2 with value: 0.37923064827919006.


Model Was Saved ! Current Best Avg. Dice: 0.17551763355731964 Current Avg. Dice: 0.17551763355731964
Model Was Saved ! Current Best Avg. Dice: 0.2519058287143707 Current Avg. Dice: 0.2519058287143707
Model Was Saved ! Current Best Avg. Dice: 0.3105911314487457 Current Avg. Dice: 0.3105911314487457
Model Was Saved ! Current Best Avg. Dice: 0.38026162981987 Current Avg. Dice: 0.38026162981987
Model Was Saved ! Current Best Avg. Dice: 0.39890754222869873 Current Avg. Dice: 0.39890754222869873


[I 2023-12-11 11:07:22,624] Trial 7 finished with value: 0.39890754222869873 and parameters: {'learning_rate': 0.0003330356752774575, 'weight_decay': 7.128698747473174e-06}. Best is trial 7 with value: 0.39890754222869873.


Model Was Saved ! Current Best Avg. Dice: 0.166367769241333 Current Avg. Dice: 0.166367769241333
Model Was Saved ! Current Best Avg. Dice: 0.1990254819393158 Current Avg. Dice: 0.1990254819393158
Model Was Saved ! Current Best Avg. Dice: 0.26705077290534973 Current Avg. Dice: 0.26705077290534973
Model Was Saved ! Current Best Avg. Dice: 0.3656631112098694 Current Avg. Dice: 0.3656631112098694
Model Was Saved ! Current Best Avg. Dice: 0.39128023386001587 Current Avg. Dice: 0.39128023386001587


[I 2023-12-11 11:30:15,344] Trial 8 finished with value: 0.39128023386001587 and parameters: {'learning_rate': 0.00012142440916681518, 'weight_decay': 2.6195031036105918e-06}. Best is trial 7 with value: 0.39890754222869873.


Model Was Saved ! Current Best Avg. Dice: 0.17619851231575012 Current Avg. Dice: 0.17619851231575012
Model Was Saved ! Current Best Avg. Dice: 0.264008104801178 Current Avg. Dice: 0.264008104801178
Model Was Saved ! Current Best Avg. Dice: 0.3408814072608948 Current Avg. Dice: 0.3408814072608948
Model Was Saved ! Current Best Avg. Dice: 0.3781793713569641 Current Avg. Dice: 0.3781793713569641
Model Was Saved ! Current Best Avg. Dice: 0.3869604170322418 Current Avg. Dice: 0.3869604170322418


[I 2023-12-11 11:53:23,822] Trial 9 finished with value: 0.3869604170322418 and parameters: {'learning_rate': 0.00038315554203506325, 'weight_decay': 7.0416760141919685e-06}. Best is trial 7 with value: 0.39890754222869873.


Model Was Saved ! Current Best Avg. Dice: 0.17668679356575012 Current Avg. Dice: 0.17668679356575012
Model Was Saved ! Current Best Avg. Dice: 0.2176225632429123 Current Avg. Dice: 0.2176225632429123
Model Was Saved ! Current Best Avg. Dice: 0.31870120763778687 Current Avg. Dice: 0.31870120763778687
Model Was Saved ! Current Best Avg. Dice: 0.3497554659843445 Current Avg. Dice: 0.3497554659843445
Model Was Saved ! Current Best Avg. Dice: 0.40037065744400024 Current Avg. Dice: 0.40037065744400024


[I 2023-12-11 12:17:27,458] Trial 10 finished with value: 0.40037065744400024 and parameters: {'learning_rate': 0.0009649251413105664, 'weight_decay': 1.7889241681570157e-05}. Best is trial 10 with value: 0.40037065744400024.


Model Was Saved ! Current Best Avg. Dice: 0.17532141506671906 Current Avg. Dice: 0.17532141506671906
Model Was Saved ! Current Best Avg. Dice: 0.27209362387657166 Current Avg. Dice: 0.27209362387657166
Model Was Saved ! Current Best Avg. Dice: 0.3231876790523529 Current Avg. Dice: 0.3231876790523529
Model Was Not Saved ! Current Best Avg. Dice: 0.3231876790523529 Current Avg. Dice: 0.3141504228115082
Model Was Saved ! Current Best Avg. Dice: 0.39293235540390015 Current Avg. Dice: 0.39293235540390015


[I 2023-12-11 12:41:32,799] Trial 11 finished with value: 0.39293235540390015 and parameters: {'learning_rate': 0.0009564590139544854, 'weight_decay': 1.1591862714075823e-05}. Best is trial 10 with value: 0.40037065744400024.


Best trial: {'learning_rate': 0.0009649251413105664, 'weight_decay': 1.7889241681570157e-05}


In [11]:
if directory is None:
    shutil.rmtree(root_dir)