In [1]:
import os
import json
import time
import torch
import matplotlib.pyplot as plt

from torch.nn import L1Loss
from monai.utils import set_determinism, first
from monai.networks.nets import ViTAutoEnc
from monai.losses import ContrastiveLoss
from monai.data import DataLoader, Dataset
from monai.config import print_config
from monai.transforms import (
    LoadImaged,
    Compose,
    CropForegroundd,
    CopyItemsd,
    SpatialPadd,
    EnsureChannelFirstd,
    Spacingd,
    OneOf,
    ScaleIntensityRanged,
    RandSpatialCropSamplesd,
    RandCoarseDropoutd,
    RandCoarseShuffled,
)

print_config()


MONAI version: 1.4.dev2414
Numpy version: 1.26.4
Pytorch version: 2.2.2+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 5b248f6a0dd29cb9c2a9545f980a88de16a6b753
MONAI __file__: /home/<username>/virtenvs/SSLUnet/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: 0.22.0
scipy version: 1.13.0
Pillow version: 10.3.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: 0.7.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN V

In [2]:
json_path = os.path.normpath("./split_data.json")
data_root = os.path.normpath("./PKG - CT-Covid19-August2020-V1/")
logdir_path = os.path.normpath("./logs/")


In [3]:
if os.path.exists(logdir_path) is False:
    os.mkdir(logdir_path)

with open(json_path, "r") as json_f:
    json_data = json.load(json_f)

train_data = json_data["training"]
val_data = json_data["validation"]

for idx, _each_d in enumerate(train_data):
    train_data[idx]["image"] = os.path.join(
        data_root, train_data[idx]["image"])

for idx, _each_d in enumerate(val_data):
    val_data[idx]["image"] = os.path.join(data_root, val_data[idx]["image"])

print("Total Number of Training Data Samples: {}".format(len(train_data)))
print(train_data)
print("#" * 10)
print("Total Number of Validation Data Samples: {}".format(len(val_data)))
print(val_data)
print("#" * 10)

# Set Determinism
set_determinism(seed=123)


Total Number of Training Data Samples: 520
[{'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0176.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0587.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0467.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0184_1.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0213.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0391.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0256_1.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0236.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0284.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0553.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0077.nii.gz'}, {'image': 'PKG - CT-Covid19-August2020-V1/data/volume-covid19-A-0670.nii.gz'}, {'im

In [4]:
# Define Training Transforms
train_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Spacingd(keys=["image"], pixdim=(2.0, 2.0, 2.0), mode=("bilinear")),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=(32, 32, 32)),
        RandSpatialCropSamplesd(keys=["image"], roi_size=(
            32, 32, 32), random_size=False, num_samples=2),
        CopyItemsd(keys=["image"], times=2, names=[
                   "gt_image", "image_2"], allow_missing_keys=False),
        OneOf(
            transforms=[
                RandCoarseDropoutd(
                    keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32
                ),
                RandCoarseDropoutd(
                    keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64
                ),
            ]
        ),
        RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8),
        # Please note that that if image, image_2 are called via the same transform call because of the determinism
        # they will get augmented the exact same way which is not the required case here, hence two calls are made
        OneOf(
            transforms=[
                RandCoarseDropoutd(
                    keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32
                ),
                RandCoarseDropoutd(
                    keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64
                ),
            ]
        ),
        RandCoarseShuffled(keys=["image_2"], prob=0.8,
                           holes=10, spatial_size=8),
    ]
)


check_ds = Dataset(data=train_data, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image = check_data["image"][0][0]
print(f"image shape: {image.shape}")




image shape: torch.Size([32, 32, 32])


In [7]:
# Training Config

# Define Network ViT backbone & Loss & Optimizer
device = torch.device("cpu")
model = ViTAutoEnc(
    in_channels=1,
    img_size=(32, 32, 32),
    patch_size=(16, 16, 16),
    proj_type="conv",
    hidden_size=768,
    mlp_dim=3072,
)

model = model.to(device)

# Define Hyper-paramters for training loop
max_epochs = 100
val_interval = 2
batch_size = 4
lr = 1e-4
epoch_loss_values = []
step_loss_values = []
epoch_cl_loss_values = []
epoch_recon_loss_values = []
val_loss_values = []
best_val_loss = 1000.0

recon_loss = L1Loss()
contrastive_loss = ContrastiveLoss(temperature=0.05)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


# Define DataLoader using MONAI, CacheDataset needs to be used
train_ds = Dataset(data=train_data, transform=train_transforms)
train_loader = DataLoader(
    train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

val_ds = Dataset(data=val_data, transform=train_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size,
                        shuffle=True, num_workers=1)


In [8]:
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    epoch_cl_loss = 0
    epoch_recon_loss = 0
    step = 0

    for batch_data in train_loader:
        step += 1
        start_time = time.time()

        inputs, inputs_2, gt_input = (
            batch_data["image"].to(device),
            batch_data["image_2"].to(device),
            batch_data["gt_image"].to(device),
        )
        optimizer.zero_grad()
        outputs_v1, hidden_v1 = model(inputs)
        outputs_v2, hidden_v2 = model(inputs_2)

        flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4)
        flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4)

        r_loss = recon_loss(outputs_v1, gt_input)
        cl_loss = contrastive_loss(flat_out_v1, flat_out_v2)

        # Adjust the CL loss by Recon Loss
        total_loss = r_loss + cl_loss * r_loss

        total_loss.backward()
        optimizer.step()
        epoch_loss += total_loss.item()
        step_loss_values.append(total_loss.item())

        # CL & Recon Loss Storage of Value
        epoch_cl_loss += cl_loss.item()
        epoch_recon_loss += r_loss.item()

        end_time = time.time()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {total_loss.item():.4f}, "
            f"time taken: {end_time-start_time}s"
        )

    epoch_loss /= step
    epoch_cl_loss /= step
    epoch_recon_loss /= step

    epoch_loss_values.append(epoch_loss)
    epoch_cl_loss_values.append(epoch_cl_loss)
    epoch_recon_loss_values.append(epoch_recon_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if epoch % val_interval == 0:
        print("Entering Validation for epoch: {}".format(epoch + 1))
        total_val_loss = 0
        val_step = 0
        model.eval()
        for val_batch in val_loader:
            val_step += 1
            start_time = time.time()
            inputs, gt_input = (
                val_batch["image"].to(device),
                val_batch["gt_image"].to(device),
            )
            print("Input shape: {}".format(inputs.shape))
            outputs, outputs_v2 = model(inputs)
            val_loss = recon_loss(outputs, gt_input)
            total_val_loss += val_loss.item()
            end_time = time.time()

        total_val_loss /= val_step
        val_loss_values.append(total_val_loss)
        print(
            f"epoch {epoch + 1} Validation avg loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s")

        if total_val_loss < best_val_loss:
            print(
                f"Saving new model based on validation loss {total_val_loss:.4f}")
            best_val_loss = total_val_loss
            checkpoint = {"epoch": max_epochs, "state_dict": model.state_dict(
            ), "optimizer": optimizer.state_dict()}
            torch.save(checkpoint, os.path.join(logdir_path, "best_model_32.pt"))

        plt.figure(1, figsize=(8, 8))
        plt.subplot(2, 2, 1)
        plt.plot(epoch_loss_values)
        plt.grid()
        plt.title("Training Loss")

        plt.subplot(2, 2, 2)
        plt.plot(val_loss_values)
        plt.grid()
        plt.title("Validation Loss")

        plt.subplot(2, 2, 3)
        plt.plot(epoch_cl_loss_values)
        plt.grid()
        plt.title("Training Contrastive Loss")

        plt.subplot(2, 2, 4)
        plt.plot(epoch_recon_loss_values)
        plt.grid()
        plt.title("Training Recon Loss")

        plt.savefig(os.path.join(logdir_path, "loss_plots.png"))
        plt.close(1)

print("Done")


----------
epoch 1/100
1/130, train_loss: 1.0361, time taken: 2.914113759994507s
2/130, train_loss: 0.8577, time taken: 0.9740486145019531s
3/130, train_loss: 0.7127, time taken: 0.887160062789917s
4/130, train_loss: 0.8338, time taken: 0.8886680603027344s
5/130, train_loss: 0.7300, time taken: 0.7869091033935547s
6/130, train_loss: 0.7450, time taken: 0.6051368713378906s
7/130, train_loss: 0.7372, time taken: 1.8034999370574951s
8/130, train_loss: 0.6568, time taken: 0.9007139205932617s
9/130, train_loss: 1.0088, time taken: 1.3712882995605469s
10/130, train_loss: 0.6792, time taken: 0.9545109272003174s
11/130, train_loss: 0.9428, time taken: 0.8415629863739014s
12/130, train_loss: 0.5523, time taken: 0.858116865158081s
13/130, train_loss: 0.7515, time taken: 0.7572906017303467s
14/130, train_loss: 0.7768, time taken: 0.7852821350097656s
15/130, train_loss: 0.6727, time taken: 0.7954914569854736s
16/130, train_loss: 0.8000, time taken: 0.8133792877197266s
17/130, train_loss: 0.7617, t

KeyboardInterrupt: 

In [None]:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()


In [None]:
import torch
gpu_available = torch.cuda.is_available()
if gpu_available:
    print("GPU is available.")
else:
    print("GPU is not available. Using CPU for computation.")


GPU is available.
