In [1]:
import os
import json
import time
import torch
import matplotlib.pyplot as plt

from torch.nn import L1Loss
from monai.utils import set_determinism, first
from monai.networks.nets import ViTAutoEnc
from monai.losses import ContrastiveLoss
from monai.data import DataLoader, Dataset
from monai.config import print_config
from monai.transforms import (
    LoadImaged,
    Compose,
    CropForegroundd,
    CopyItemsd,
    SpatialPadd,
    EnsureChannelFirstd,
    Spacingd,
    OneOf,
    ScaleIntensityRanged,
    RandSpatialCropSamplesd,
    RandCoarseDropoutd,
    RandCoarseShuffled,
)

print_config()


MONAI version: 1.3.0
Numpy version: 1.26.4
Pytorch version: 2.2.2+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /home/<username>/.conda/envs/unetSSL/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.12.0
Pillow version: 10.3.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 2.2.1
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSI

In [2]:
### Create synthetic dataset

In [3]:
logdir_path = os.path.normpath("./logs/")

#Convert the train and validation images into a list with locations
train_dir = "./Synth3DTrain"
val_dir = "./Synth3DValTrain"

#train image file
timage_filenames = sorted([os.path.join(train_dir, f) for f in os.listdir(train_dir) if f.startswith("im")])
tlabel_filenames = sorted([os.path.join(train_dir, f) for f in os.listdir(train_dir) if f.startswith("seg")])

#validation image files
vimage_filenames = sorted([os.path.join(val_dir, f) for f in os.listdir(val_dir) if f.startswith("im")])
vlabel_filenames = sorted([os.path.join(val_dir, f) for f in os.listdir(val_dir) if f.startswith("seg")])

# Create a list of dictionaries containing the file paths
train_datalist = [{"image": img, "label": lbl} for img, lbl in zip(vimage_filenames, vlabel_filenames)]
validation_datalist = [{"image": img, "label": lbl} for img, lbl in zip(vimage_filenames, vlabel_filenames)]

# Print the datalist to verify
print(train_datalist, validation_datalist)

[{'image': './Synth3DValTrain/im0.nii.gz', 'label': './Synth3DValTrain/seg0.nii.gz'}, {'image': './Synth3DValTrain/im1.nii.gz', 'label': './Synth3DValTrain/seg1.nii.gz'}, {'image': './Synth3DValTrain/im2.nii.gz', 'label': './Synth3DValTrain/seg2.nii.gz'}, {'image': './Synth3DValTrain/im3.nii.gz', 'label': './Synth3DValTrain/seg3.nii.gz'}, {'image': './Synth3DValTrain/im4.nii.gz', 'label': './Synth3DValTrain/seg4.nii.gz'}, {'image': './Synth3DValTrain/im5.nii.gz', 'label': './Synth3DValTrain/seg5.nii.gz'}, {'image': './Synth3DValTrain/im6.nii.gz', 'label': './Synth3DValTrain/seg6.nii.gz'}, {'image': './Synth3DValTrain/im7.nii.gz', 'label': './Synth3DValTrain/seg7.nii.gz'}, {'image': './Synth3DValTrain/im8.nii.gz', 'label': './Synth3DValTrain/seg8.nii.gz'}, {'image': './Synth3DValTrain/im9.nii.gz', 'label': './Synth3DValTrain/seg9.nii.gz'}] [{'image': './Synth3DValTrain/im0.nii.gz', 'label': './Synth3DValTrain/seg0.nii.gz'}, {'image': './Synth3DValTrain/im1.nii.gz', 'label': './Synth3DVa

In [4]:
# Define Training Transforms
train_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        # Spacingd(keys=["image"], pixdim=(2.0, 2.0, 2.0), mode=("bilinear")),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-157,
            a_max=256,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=(32, 32, 32)),
        RandSpatialCropSamplesd(keys=["image"], roi_size=(
            32, 32, 32), random_size=False, num_samples=2),
        CopyItemsd(keys=["image"], times=2, names=[
                   "gt_image", "image_2"], allow_missing_keys=False),
        OneOf(
            transforms=[
                RandCoarseDropoutd(
                    keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32
                ),
                RandCoarseDropoutd(
                    keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64
                ),
            ]
        ),
        RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8),
        # Please note that that if image, image_2 are called via the same transform call because of the determinism
        # they will get augmented the exact same way which is not the required case here, hence two calls are made
        OneOf(
            transforms=[
                RandCoarseDropoutd(
                    keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32
                ),
                RandCoarseDropoutd(
                    keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64
                ),
            ]
        ),
        RandCoarseShuffled(keys=["image_2"], prob=0.8,
                           holes=10, spatial_size=8),
    ]
)


check_ds = Dataset(data=train_datalist, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image = check_data["image"][0][0]
print(f"image shape: {image.shape}")


image shape: torch.Size([32, 32, 32])




In [6]:
# Training Config

# Define Network ViT backbone & Loss & Optimizer
device = torch.device("cpu")
model = ViTAutoEnc(
    in_channels=1,
    img_size=(32, 32, 32),
    patch_size=(16, 16, 16),
    proj_type="conv",
    hidden_size=768,
    mlp_dim=3072,
)

model = model.to(device)

# Define Hyper-paramters for training loop
max_epochs = 500
val_interval = 2
batch_size = 4
lr = 1e-4
epoch_loss_values = []
step_loss_values = []
epoch_cl_loss_values = []
epoch_recon_loss_values = []
val_loss_values = []
best_val_loss = 1000.0

recon_loss = L1Loss()
contrastive_loss = ContrastiveLoss(temperature=0.05)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


# Define DataLoader using MONAI, CacheDataset needs to be used
train_ds = Dataset(data=train_datalist, transform=train_transforms)
train_loader = DataLoader(
    train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

val_ds = Dataset(data=validation_datalist, transform=train_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size,
                        shuffle=True, num_workers=1)


In [7]:
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    epoch_cl_loss = 0
    epoch_recon_loss = 0
    step = 0

    for batch_data in train_loader:
        step += 1
        start_time = time.time()

        inputs, inputs_2, gt_input = (
            batch_data["image"].to(device),
            batch_data["image_2"].to(device),
            batch_data["gt_image"].to(device),
        )
        optimizer.zero_grad()
        outputs_v1, hidden_v1 = model(inputs)
        outputs_v2, hidden_v2 = model(inputs_2)

        flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4)
        flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4)

        r_loss = recon_loss(outputs_v1, gt_input)
        cl_loss = contrastive_loss(flat_out_v1, flat_out_v2)

        # Adjust the CL loss by Recon Loss
        total_loss = r_loss + cl_loss * r_loss

        total_loss.backward()
        optimizer.step()
        epoch_loss += total_loss.item()
        step_loss_values.append(total_loss.item())

        # CL & Recon Loss Storage of Value
        epoch_cl_loss += cl_loss.item()
        epoch_recon_loss += r_loss.item()

        end_time = time.time()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {total_loss.item():.4f}, "
            f"time taken: {end_time-start_time}s"
        )

    epoch_loss /= step
    epoch_cl_loss /= step
    epoch_recon_loss /= step

    epoch_loss_values.append(epoch_loss)
    epoch_cl_loss_values.append(epoch_cl_loss)
    epoch_recon_loss_values.append(epoch_recon_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if epoch % val_interval == 0:
        print("Entering Validation for epoch: {}".format(epoch + 1))
        total_val_loss = 0
        val_step = 0
        model.eval()
        for val_batch in val_loader:
            val_step += 1
            start_time = time.time()
            inputs, gt_input = (
                val_batch["image"].to(device),
                val_batch["gt_image"].to(device),
            )
            print("Input shape: {}".format(inputs.shape))
            outputs, outputs_v2 = model(inputs)
            val_loss = recon_loss(outputs, gt_input)
            total_val_loss += val_loss.item()
            end_time = time.time()

        total_val_loss /= val_step
        val_loss_values.append(total_val_loss)
        print(
            f"epoch {epoch + 1} Validation avg loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s")

        if total_val_loss < best_val_loss:
            print(
                f"Saving new model based on validation loss {total_val_loss:.4f}")
            best_val_loss = total_val_loss
            checkpoint = {"epoch": max_epochs, "state_dict": model.state_dict(
            ), "optimizer": optimizer.state_dict()}
            torch.save(checkpoint, os.path.join(logdir_path, "best_model_Synth_500.pt"))

        plt.figure(1, figsize=(8, 8))
        plt.subplot(2, 2, 1)
        plt.plot(epoch_loss_values)
        plt.grid()
        plt.title("Training Loss")

        plt.subplot(2, 2, 2)
        plt.plot(val_loss_values)
        plt.grid()
        plt.title("Validation Loss")

        plt.subplot(2, 2, 3)
        plt.plot(epoch_cl_loss_values)
        plt.grid()
        plt.title("Training Contrastive Loss")

        plt.subplot(2, 2, 4)
        plt.plot(epoch_recon_loss_values)
        plt.grid()
        plt.title("Training Recon Loss")

        plt.savefig(os.path.join(logdir_path, "loss_plots.png"))
        plt.close(1)

print("Done")


----------
epoch 1/500
1/2, train_loss: 3.1335, time taken: 3.6087486743927s
2/2, train_loss: 2.4974, time taken: 1.179062843322754s
3/2, train_loss: 1.9204, time taken: 0.964601993560791s
epoch 1 average loss: 2.5171
Entering Validation for epoch: 1
Input shape: torch.Size([8, 1, 32, 32, 32])
Input shape: torch.Size([8, 1, 32, 32, 32])
Input shape: torch.Size([4, 1, 32, 32, 32])
epoch 1 Validation avg loss: 0.7492, time taken: 0.4864940643310547s
Saving new model based on validation loss 0.7492
----------
epoch 2/500
1/2, train_loss: 2.3736, time taken: 13.566534996032715s
2/2, train_loss: 2.8517, time taken: 2.8003063201904297s
3/2, train_loss: 2.2988, time taken: 1.5801873207092285s
epoch 2 average loss: 2.5080
----------
epoch 3/500
1/2, train_loss: 2.7715, time taken: 9.56297516822815s
2/2, train_loss: 2.7706, time taken: 1.468001365661621s
3/2, train_loss: 1.9593, time taken: 1.4327049255371094s
epoch 3 average loss: 2.5005
Entering Validation for epoch: 3
Input shape: torch.Size