In [2]:
import os
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.config import print_config
from monai.transforms import (
    AsDiscrete,
    Compose,
    CropForegroundd,
    EnsureChannelFirstd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    ToTensord,
)

from monai.metrics import DiceMetric
from monai.networks.nets import UNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

print_config()


MONAI version: 1.4.dev2414
Numpy version: 1.26.4
Pytorch version: 2.2.2+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 5b248f6a0dd29cb9c2a9545f980a88de16a6b753
MONAI __file__: /home/<username>/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: 0.22.0
scipy version: 1.13.0
Pillow version: 10.3.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: 0.7.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN V

In [3]:
json_path = os.path.normpath("./dataset_0_init_6.json")
data_dir = os.path.normpath("./BTCV")
logdir = os.path.normpath("./logs/fine/")

if os.path.exists(logdir) is False:
    os.mkdir(logdir)

In [4]:
use_pretrained = True
pretrained_path = os.path.normpath("./logs/best_model_32.pt")

In [17]:
# Training Hyper-params
lr = 1e-4
max_iterations = 30000
eval_num = 100

# Transforms
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        # RandCropByPosNegLabeld(
        #     keys=["image", "label"],
        #     label_key="label",
        #     spatial_size=(32, 32, 32),
        #     pos=1,
        #     neg=1,
        #     num_samples=4,
        #     image_key="image",
        #     image_threshold=0,
        # ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        ToTensord(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(keys=["image"], a_min=-175,
                             a_max=250, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ]
)




In [18]:
datalist = load_decathlon_datalist(
    base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="training"
)

val_files = load_decathlon_datalist(
    base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="validation"
)

train_ds = CacheDataset(
    data=datalist,
    transform=train_transforms,
    cache_num=24,
    cache_rate=1.0,
    num_workers=2,
)
train_loader = DataLoader(train_ds, batch_size=1,
                          shuffle=True, num_workers=4, pin_memory=True)
val_ds = CacheDataset(data=val_files, transform=val_transforms,
                      cache_num=6, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1,
                        shuffle=False, num_workers=4, pin_memory=True)

# Sanity check for shapes from data loaders
case_num = 0
img = val_ds[case_num]["image"]
label = val_ds[case_num]["label"]
img_shape = img.shape
label_shape = label.shape
print(f"image shape: {img_shape}, label shape: {label_shape}")


Loading dataset: 100%|██████████| 24/24 [00:02<00:00, 11.28it/s]
Loading dataset: 100%|██████████| 6/6 [00:00<00:00, 10.85it/s]

image shape: torch.Size([1, 241, 241, 57]), label shape: torch.Size([1, 241, 241, 57])





In [19]:
for case_num in range(len(val_ds)):
    img = val_ds[case_num]["image"]
    label = val_ds[case_num]["label"]
    img_shape = img.shape
    label_shape = label.shape
    print(f"image shape: {img_shape}, label shape: {label_shape}")

image shape: torch.Size([1, 241, 241, 57]), label shape: torch.Size([1, 241, 241, 57])
image shape: torch.Size([1, 306, 306, 57]), label shape: torch.Size([1, 306, 306, 57])
image shape: torch.Size([1, 253, 253, 45]), label shape: torch.Size([1, 253, 253, 45])
image shape: torch.Size([1, 300, 300, 45]), label shape: torch.Size([1, 300, 300, 45])
image shape: torch.Size([1, 227, 198, 31]), label shape: torch.Size([1, 227, 198, 31])
image shape: torch.Size([1, 240, 202, 31]), label shape: torch.Size([1, 240, 202, 31])
image shape: torch.Size([1, 213, 213, 37]), label shape: torch.Size([1, 213, 213, 37])
image shape: torch.Size([1, 213, 213, 37]), label shape: torch.Size([1, 213, 213, 37])
image shape: torch.Size([1, 266, 266, 28]), label shape: torch.Size([1, 266, 266, 28])
image shape: torch.Size([1, 267, 267, 46]), label shape: torch.Size([1, 267, 267, 46])
image shape: torch.Size([1, 226, 184, 31]), label shape: torch.Size([1, 226, 184, 31])
image shape: torch.Size([1, 241, 241, 49]),

In [22]:
case_num = 60
img = train_ds[case_num]["image"]
label = train_ds[case_num]["label"]
img_shape = img.shape
label_shape = label.shape
print(f"image shape: {img_shape}, label shape: {label_shape}")


image shape: torch.Size([1, 253, 245, 22]), label shape: torch.Size([1, 253, 245, 22])


In [24]:
print(len(train_ds))


66


In [None]:
device = torch.device("cpu") # current GPU cannot handle this

model = UNETR(
    in_channels=1,
    out_channels=14,
    img_size=(32, 32, 32),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="conv",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
)

# Load ViT backbone weights into UNETR
if use_pretrained is True:
    print("Loading Weights from the Path {}".format(pretrained_path))
    vit_dict = torch.load(pretrained_path)
    vit_weights = vit_dict["state_dict"]

    # Remove items of vit_weights if they are not in the ViT backbone (this is used in UNETR).
    # For example, some variables names like conv3d_transpose.weight, conv3d_transpose.bias,
    # conv3d_transpose_1.weight and conv3d_transpose_1.bias are used to match dimensions
    # while pretraining with ViTAutoEnc and are not a part of ViT backbone.
    model_dict = model.vit.state_dict()

    vit_weights = {k: v for k, v in vit_weights.items() if k in model_dict}
    model_dict.update(vit_weights)
    model.vit.load_state_dict(model_dict)
    del model_dict, vit_weights, vit_dict
    print("Pretrained Weights Succesfully Loaded !")

elif use_pretrained is False:
    print("No weights were loaded, all weights being used are randomly initialized!")

model.to(device)

loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True,
                         reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []


Loading Weights from the Path logs/best_model_32.pt
Pretrained Weights Succesfully Loaded !


: 

In [None]:
def validation(epoch_iterator_val):
    model.eval()
    dice_vals = []

    with torch.no_grad():
        for _step, batch in enumerate(epoch_iterator_val):
            val_inputs, val_labels = (
                batch["image"], batch["label"])
            val_outputs = sliding_window_inference(
                val_inputs, (32, 32, 32), 4, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [post_label(
                val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [post_pred(val_pred_tensor)
                                  for val_pred_tensor in val_outputs_list]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            dice = dice_metric.aggregate().item()
            dice_vals.append(dice)
            epoch_iterator_val.set_description(
                "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice))

        dice_metric.reset()

    mean_dice_val = np.mean(dice_vals)
    return mean_dice_val


def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(
        train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = (batch["image"], batch["label"])
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description(
            "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss))

        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(
                val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
            dice_val = validation(epoch_iterator_val)

            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(model.state_dict(), os.path.join(
                    logdir, "best_metric_model.pth"))
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val)
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )

            plt.figure(1, (12, 6))
            plt.subplot(1, 2, 1)
            plt.title("Iteration Average Loss")
            x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
            y = epoch_loss_values
            plt.xlabel("Iteration")
            plt.plot(x, y)
            plt.grid()
            plt.subplot(1, 2, 2)
            plt.title("Val Mean Dice")
            x = [eval_num * (i + 1) for i in range(len(metric_values))]
            y = metric_values
            plt.xlabel("Iteration")
            plt.plot(x, y)
            plt.grid()
            plt.savefig(os.path.join(logdir, "btcv_finetune_quick_update.png"))
            plt.clf()
            plt.close(1)

        global_step += 1
    return global_step, dice_val_best, global_step_best


while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(
        global_step, train_loader, dice_val_best, global_step_best)
model.load_state_dict(torch.load(
    os.path.join(logdir, "best_metric_model.pth")))

print(
    f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")


Training (8 / 30000 Steps) (loss=3.65203):  14%|█▎        | 9/66 [00:09<01:01,  1.08s/it]


RuntimeError: Caught RuntimeError in DataLoader worker process 1.
Original Traceback (most recent call last):
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/transform.py", line 141, in apply_transform
    return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/transform.py", line 98, in _apply_transform
    return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/croppad/dictionary.py", line 997, in __call__
    self.randomize(d.get(self.label_key), fg_indices, bg_indices, d.get(self.image_key))
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/croppad/dictionary.py", line 979, in randomize
    self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image)
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/croppad/array.py", line 1158, in randomize
    self.centers = generate_pos_neg_label_crop_centers(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/utils.py", line 617, in generate_pos_neg_label_crop_centers
    centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/utils.py", line 541, in correct_crop_centers
    raise ValueError(
ValueError: The size of the proposed random crop ROI is larger than the image size, got ROI size (32, 32, 32) and label image size (253, 245, 22) respectively.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/transform.py", line 141, in apply_transform
    return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/transform.py", line 98, in _apply_transform
    return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/compose.py", line 335, in __call__
    result = execute_compose(
             ^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/compose.py", line 111, in execute_compose
    data = apply_transform(
           ^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/transform.py", line 171, in apply_transform
    raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform <monai.transforms.croppad.dictionary.RandCropByPosNegLabeld object at 0x7fad3198f5d0>

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/data/dataset.py", line 112, in __getitem__
    return self._transform(index)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/data/dataset.py", line 916, in _transform
    return super()._transform(index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/data/dataset.py", line 98, in _transform
    return apply_transform(self.transform, data_i) if self.transform is not None else data_i
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erattakulangara/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/transforms/transform.py", line 171, in apply_transform
    raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform <monai.transforms.compose.Compose object at 0x7fac89c0e090>


: 

: 

: 