In [1]:
#Load the libraries
import os
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.config import print_config
from monai.transforms import (
    AsDiscrete,
    Compose,
    CropForegroundd,
    EnsureChannelFirstd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    ToTensord,
    Resized
)

from monai.metrics import DiceMetric
from monai.networks.nets import UNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

print_config()


MONAI version: 1.3.0
Numpy version: 1.26.4
Pytorch version: 2.2.2+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /home/<username>/.conda/envs/unetSSL/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.12.0
Pillow version: 10.3.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 2.2.1
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSI

### Setup the file input and output locations

In [2]:
logdir = os.path.normpath("./logs/fine/")

if os.path.exists(logdir) is False:
    os.mkdir(logdir)

In [16]:
#Load the pre-trained model
use_pretrained = True
pretrained_path = os.path.normpath("./logs/best_model_Synth_500.pt")


In [17]:
#Convert the train and validation images into a list with locations

train_dir = "./Synth3D"
val_dir = "./Synth3DVal"

#train image file
timage_filenames = sorted([os.path.join(train_dir, f) for f in os.listdir(train_dir) if f.startswith("im")])
tlabel_filenames = sorted([os.path.join(train_dir, f) for f in os.listdir(train_dir) if f.startswith("seg")])

#validation image files
vimage_filenames = sorted([os.path.join(val_dir, f) for f in os.listdir(val_dir) if f.startswith("im")])
vlabel_filenames = sorted([os.path.join(val_dir, f) for f in os.listdir(val_dir) if f.startswith("seg")])

# Create a list of dictionaries containing the file paths
train_datalist = [{"image": img, "label": lbl} for img, lbl in zip(vimage_filenames, vlabel_filenames)]
validation_datalist = [{"image": img, "label": lbl} for img, lbl in zip(vimage_filenames, vlabel_filenames)]

# Print the datalist to verify
print(train_datalist, validation_datalist)

[{'image': './Synth3DVal/im0.nii.gz', 'label': './Synth3DVal/seg0.nii.gz'}, {'image': './Synth3DVal/im1.nii.gz', 'label': './Synth3DVal/seg1.nii.gz'}, {'image': './Synth3DVal/im2.nii.gz', 'label': './Synth3DVal/seg2.nii.gz'}, {'image': './Synth3DVal/im3.nii.gz', 'label': './Synth3DVal/seg3.nii.gz'}, {'image': './Synth3DVal/im4.nii.gz', 'label': './Synth3DVal/seg4.nii.gz'}, {'image': './Synth3DVal/im5.nii.gz', 'label': './Synth3DVal/seg5.nii.gz'}, {'image': './Synth3DVal/im6.nii.gz', 'label': './Synth3DVal/seg6.nii.gz'}, {'image': './Synth3DVal/im7.nii.gz', 'label': './Synth3DVal/seg7.nii.gz'}, {'image': './Synth3DVal/im8.nii.gz', 'label': './Synth3DVal/seg8.nii.gz'}, {'image': './Synth3DVal/im9.nii.gz', 'label': './Synth3DVal/seg9.nii.gz'}] [{'image': './Synth3DVal/im0.nii.gz', 'label': './Synth3DVal/seg0.nii.gz'}, {'image': './Synth3DVal/im1.nii.gz', 'label': './Synth3DVal/seg1.nii.gz'}, {'image': './Synth3DVal/im2.nii.gz', 'label': './Synth3DVal/seg2.nii.gz'}, {'image': './Synth3DVal

### Train and validation transforms

In [29]:
# Training Hyper-params
lr = 1e-4
max_iterations = 1000
eval_num = 100

# Transforms
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Spacingd(
        #     keys=["image", "label"],
        #     pixdim=(1.5, 1.5, 4.0),
        #     mode=("bilinear", "nearest"),
        # ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(32, 32, 32),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,),
        # ),
        # RandFlipd(
        #     keys=["image", "label"],
        #     spatial_axis=[0],
        #     prob=0.10,
        # ),
        # RandFlipd(
        #     keys=["image", "label"],
        #     spatial_axis=[1],
        #     prob=0.10,
        # ),
        # RandFlipd(
        #     keys=["image", "label"],
        #     spatial_axis=[2],
        #     prob=0.10,
        # ),
        # RandRotate90d(
        #     keys=["image", "label"],
        #     prob=0.10,
        #     max_k=3,
        # ),
        # RandShiftIntensityd(
        #     keys=["image"],
        #     offsets=0.10,
        #     prob=0.50,
        # ),
        ToTensord(keys=["image", "label"]),
    ]
)

# Validation transforms
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Spacingd(
        #     keys=["image", "label"],
        #     pixdim=(1.5, 1.5, 2.0),
        #     mode=("bilinear", "nearest"),
        # # ),
        # ScaleIntensityRanged(keys=["image"], a_min=-175,
        #                      a_max=250, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ]
)


### Dataloaders for train and validation

In [30]:

train_ds = CacheDataset(
    data=train_datalist,
    transform=train_transforms,
    cache_num=24,
    cache_rate=1.0,
    num_workers=2,
)
train_loader = DataLoader(train_ds, batch_size=1,
                          shuffle=True, num_workers=4, pin_memory=True)
val_ds = CacheDataset(data=validation_datalist, transform=val_transforms,
                      cache_num=6, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1,
                        shuffle=False, num_workers=4, pin_memory=True)


Loading dataset: 100%|██████████| 10/10 [00:00<00:00, 63.28it/s]
Loading dataset: 100%|██████████| 6/6 [00:00<00:00, 58.52it/s]


In [31]:
# just conforming the image sizes for the data inside the folder

for case_num in range(len(val_ds)):
    img = val_ds[case_num]["image"]
    label = val_ds[case_num]["label"]
    img_shape = img.shape
    label_shape = label.shape
    print(f"image shape: {img_shape}, label shape: {label_shape}")

image shape: torch.Size([1, 57, 57, 57]), label shape: torch.Size([1, 57, 57, 57])
image shape: torch.Size([1, 57, 55, 55]), label shape: torch.Size([1, 57, 55, 55])
image shape: torch.Size([1, 59, 60, 60]), label shape: torch.Size([1, 59, 60, 60])
image shape: torch.Size([1, 57, 57, 58]), label shape: torch.Size([1, 57, 57, 58])
image shape: torch.Size([1, 53, 53, 53]), label shape: torch.Size([1, 53, 53, 53])
image shape: torch.Size([1, 60, 60, 60]), label shape: torch.Size([1, 60, 60, 60])
image shape: torch.Size([1, 57, 57, 57]), label shape: torch.Size([1, 57, 57, 57])
image shape: torch.Size([1, 45, 46, 45]), label shape: torch.Size([1, 45, 46, 45])
image shape: torch.Size([1, 53, 54, 53]), label shape: torch.Size([1, 53, 54, 53])
image shape: torch.Size([1, 58, 57, 57]), label shape: torch.Size([1, 58, 57, 57])


### Network

In [32]:
device = torch.device("cpu") # current GPU cannot handle this

model = UNETR(
    in_channels=1,
    out_channels=2,
    img_size=(32, 32, 32),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="conv",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
)

# Load ViT backbone weights into UNETR
if use_pretrained is True:
    print("Loading Weights from the Path {}".format(pretrained_path))
    vit_dict = torch.load(pretrained_path)
    vit_weights = vit_dict["state_dict"]
    model_dict = model.vit.state_dict()

    vit_weights = {k: v for k, v in vit_weights.items() if k in model_dict}
    model_dict.update(vit_weights)
    model.vit.load_state_dict(model_dict)
    del model_dict, vit_weights, vit_dict
    print("Pretrained Weights Succesfully Loaded !")

elif use_pretrained is False:
    print("No weights were loaded, all weights being used are randomly initialized!")

model.to(device)

loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True,
                         reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []


Loading Weights from the Path logs/best_model_Synth_500.pt
Pretrained Weights Succesfully Loaded !


### Train and validation 

In [33]:
def validation(epoch_iterator_val):
    model.eval()
    dice_vals = []

    with torch.no_grad():
        for _step, batch in enumerate(epoch_iterator_val):
            val_inputs, val_labels = (
                batch["image"], batch["label"])
            val_outputs = sliding_window_inference(
                val_inputs, (32, 32, 32), 4, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [post_label(
                val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [post_pred(val_pred_tensor)
                                  for val_pred_tensor in val_outputs_list]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            dice = dice_metric.aggregate().item()
            dice_vals.append(dice)
            epoch_iterator_val.set_description(
                "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice))

        dice_metric.reset()

    mean_dice_val = np.mean(dice_vals)
    return mean_dice_val


def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(
        train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = (batch["image"], batch["label"])
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description(
            "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss))

        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(
                val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
            dice_val = validation(epoch_iterator_val)

            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(model.state_dict(), os.path.join(
                    logdir, "best_metric_model_synth_1000.pth"))
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val)
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )

            plt.figure(1, (12, 6))
            plt.subplot(1, 2, 1)
            plt.title("Iteration Average Loss")
            x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
            y = epoch_loss_values
            plt.xlabel("Iteration")
            plt.plot(x, y)
            plt.grid()
            plt.subplot(1, 2, 2)
            plt.title("Val Mean Dice")
            x = [eval_num * (i + 1) for i in range(len(metric_values))]
            y = metric_values
            plt.xlabel("Iteration")
            plt.plot(x, y)
            plt.grid()
            plt.savefig(os.path.join(logdir, "btcv_finetune_quick_update_synth.png"))
            plt.clf()
            plt.close(1)

        global_step += 1
    return global_step, dice_val_best, global_step_best


while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(
        global_step, train_loader, dice_val_best, global_step_best)
model.load_state_dict(torch.load(
    os.path.join(logdir, "best_metric_model_synth_1000.pth")))

print(
    f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")


Training (9 / 1000 Steps) (loss=1.05583): 100%|██████████| 10/10 [00:14<00:00,  1.43s/it]
Training (19 / 1000 Steps) (loss=0.88059): 100%|██████████| 10/10 [00:18<00:00,  1.88s/it]
Training (29 / 1000 Steps) (loss=0.74115): 100%|██████████| 10/10 [00:18<00:00,  1.87s/it]
Training (39 / 1000 Steps) (loss=0.70103): 100%|██████████| 10/10 [00:22<00:00,  2.27s/it]
Training (49 / 1000 Steps) (loss=0.64819): 100%|██████████| 10/10 [00:18<00:00,  1.85s/it]
Training (59 / 1000 Steps) (loss=0.63909): 100%|██████████| 10/10 [00:18<00:00,  1.83s/it]
Training (69 / 1000 Steps) (loss=0.63445): 100%|██████████| 10/10 [00:18<00:00,  1.87s/it]
Training (79 / 1000 Steps) (loss=0.47958): 100%|██████████| 10/10 [00:18<00:00,  1.81s/it]
Training (89 / 1000 Steps) (loss=0.49246): 100%|██████████| 10/10 [00:18<00:00,  1.89s/it]
Training (99 / 1000 Steps) (loss=0.48077): 100%|██████████| 10/10 [00:17<00:00,  1.79s/it]
Validate (100 / 10 Steps) (dice=0.94254): 100%|██████████| 10/10 [00:11<00:00,  1.15s/it]


Model Was Saved ! Current Best Avg. Dice: 0.9392373144626618 Current Avg. Dice: 0.9392373144626618


Training (109 / 1000 Steps) (loss=0.56608): 100%|██████████| 10/10 [00:31<00:00,  3.13s/it]
Training (119 / 1000 Steps) (loss=0.48768): 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Training (129 / 1000 Steps) (loss=0.45849): 100%|██████████| 10/10 [00:17<00:00,  1.78s/it]
Training (139 / 1000 Steps) (loss=0.59224): 100%|██████████| 10/10 [00:18<00:00,  1.86s/it]
Training (149 / 1000 Steps) (loss=0.42707): 100%|██████████| 10/10 [00:18<00:00,  1.90s/it]
Training (159 / 1000 Steps) (loss=0.40931): 100%|██████████| 10/10 [00:17<00:00,  1.71s/it]
Training (169 / 1000 Steps) (loss=0.41161): 100%|██████████| 10/10 [00:18<00:00,  1.81s/it]
Training (179 / 1000 Steps) (loss=0.40412): 100%|██████████| 10/10 [00:19<00:00,  1.93s/it]
Training (189 / 1000 Steps) (loss=0.35481): 100%|██████████| 10/10 [00:19<00:00,  1.93s/it]
Training (199 / 1000 Steps) (loss=0.43654): 100%|██████████| 10/10 [00:18<00:00,  1.85s/it]
Validate (200 / 10 Steps) (dice=0.94182): 100%|██████████| 10/10 [00:10<00:00,  

Model Was Not Saved ! Current Best Avg. Dice: 0.9392373144626618 Current Avg. Dice: 0.9372654438018799


Training (209 / 1000 Steps) (loss=0.43049): 100%|██████████| 10/10 [00:29<00:00,  2.95s/it]
Training (219 / 1000 Steps) (loss=0.33585): 100%|██████████| 10/10 [00:17<00:00,  1.73s/it]
Training (229 / 1000 Steps) (loss=0.42044): 100%|██████████| 10/10 [00:17<00:00,  1.77s/it]
Training (239 / 1000 Steps) (loss=0.35464): 100%|██████████| 10/10 [00:17<00:00,  1.74s/it]
Training (249 / 1000 Steps) (loss=0.56065): 100%|██████████| 10/10 [00:17<00:00,  1.77s/it]
Training (259 / 1000 Steps) (loss=0.30176): 100%|██████████| 10/10 [00:19<00:00,  1.98s/it]
Training (269 / 1000 Steps) (loss=0.28971): 100%|██████████| 10/10 [00:17<00:00,  1.78s/it]
Training (279 / 1000 Steps) (loss=0.29634): 100%|██████████| 10/10 [00:18<00:00,  1.88s/it]
Training (289 / 1000 Steps) (loss=0.25047): 100%|██████████| 10/10 [00:18<00:00,  1.88s/it]
Training (299 / 1000 Steps) (loss=0.27577): 100%|██████████| 10/10 [00:18<00:00,  1.82s/it]
Validate (300 / 10 Steps) (dice=0.94829): 100%|██████████| 10/10 [00:11<00:00,  

Model Was Saved ! Current Best Avg. Dice: 0.943657511472702 Current Avg. Dice: 0.943657511472702


Training (309 / 1000 Steps) (loss=0.33832): 100%|██████████| 10/10 [00:32<00:00,  3.24s/it]
Training (319 / 1000 Steps) (loss=0.28621): 100%|██████████| 10/10 [00:18<00:00,  1.80s/it]
Training (X / X Steps) (loss=X.X):   0%|          | 0/10 [00:00<?, ?it/s]