Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# Self-Supervised Learning

## Setup environment

## Setup imports

In [1]:
import os
import time
import torch
import matplotlib.pyplot as plt

from torch.nn import L1Loss
from monai.utils import set_determinism, first
from monai.losses import ContrastiveLoss
from monai.data import DataLoader, Dataset, CacheDataset
from monai.transforms import (
    LoadImaged,
    Compose,
    CropForegroundd,
    CopyItemsd,
    SpatialPadd,
    EnsureChannelFirstd,
    Spacingd,
    OneOf,
    RandSpatialCropSamplesd,
    RandCoarseDropoutd,
    RandCoarseShuffled,
    MapTransform,
)
import os
import glob
import logging
from importlib import reload
from ssl_head import SSLHead

# print_config()

##### Define file paths & output directory path

In [2]:
root_dir = '../PETCT/'
logdir_path = './log/'
roi_size = (96,96,64)

##### Create result logging directories, manage data paths & set determinism

In [3]:
if os.path.exists(logdir_path) is False:
    os.mkdir(logdir_path)
    
train_images = sorted(glob.glob(root_dir+"train/*petct.nii.gz"))
# train_labels = sorted(glob.glob(root_dir+"train/*seg.nii.gz"))

# data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
data_dicts = [{"image": image_name} for image_name in zip(train_images)]
train_files, val_files = data_dicts[:320], data_dicts[320:]
print(len(train_files),end='+')
print(len(val_files))

# Set Determinism
set_determinism(seed=0)

320+80


##### Define MONAI Transforms 

In [14]:
class NormalizeFrom0to1(MapTransform):

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            d[key] = (d[key] - torch.min(d[key])) / (torch.max(d[key]) - torch.min(d[key]))
        return d
    
# Define Training Transforms
train_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Spacingd(keys=["image"], pixdim=(2.0, 2.0, 1.0), mode=("bilinear")),
        # ScaleIntensityRanged(
        #     keys=["image"],
        #     a_min=-1000,
        #     a_max=1000,
        #     b_min=0.0,
        #     b_max=1.0,
        #     clip=True,
        # ),
        CropForegroundd(keys=["image"], source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=roi_size),
        RandSpatialCropSamplesd(keys=["image"], roi_size=roi_size, random_size=False, num_samples=2),
        NormalizeFrom0to1(keys=["image"]),
        CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False),
        OneOf(
            transforms=[
                RandCoarseDropoutd(
                    keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32
                ),
                RandCoarseDropoutd(
                    keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64
                ),
            ]
        ),
        RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8),
        # Please note that that if image, image_2 are called via the same transform call because of the determinism
        # they will get augmented the exact same way which is not the required case here, hence two calls are made
        OneOf(
            transforms=[
                RandCoarseDropoutd(
                    keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32
                ),
                RandCoarseDropoutd(
                    keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64
                ),
            ]
        ),
        RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8),
    ]
)



In [None]:
check_ds = Dataset(data=train_files, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
print(check_data["image"].shape)
image = (check_data["image"][0][0])
print(f"image shape: {image.shape}")
import numpy as np
print(np.unique(image))
cols = 2
rows = image.shape[2]
fig = plt.figure("check", (4, 100))
for i in range(rows):
    fig.add_subplot(rows,cols,(i*cols)+1)
    plt.title("image")
    plt.imshow(image[:, :, i], cmap="gray")
    plt.axis('off')

##### Training Configuration

In [16]:
# Training Config

# Define Network ViT backbone & Loss & Optimizer
device = torch.device("cuda:0")
# model = ViTAutoEnc(
#     in_channels=1,
#     img_size=roi_size,
#     patch_size=(16, 16, 16),
#     pos_embed="conv",
#     hidden_size=768,
#     mlp_dim=3072,
# )
model = SSLHead()

model = model.to(device)

# Define Hyper-paramters for training loop
max_epochs = 500
val_interval = 2
batch_size = 1
lr = 1e-4
epoch_loss_values = []
step_loss_values = []
epoch_cl_loss_values = []
epoch_recon_loss_values = []
val_loss_values = []
best_val_loss = 1000.0

recon_loss = L1Loss()
contrastive_loss = ContrastiveLoss(temperature=0.05)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


# Define DataLoader using MONAI, CacheDataset needs to be used
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

val_ds = CacheDataset(data=val_files, transform=train_transforms, cache_rate=1.0)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

Loading dataset: 100%|██████████| 320/320 [02:38<00:00,  2.02it/s]
Loading dataset: 100%|██████████| 80/80 [00:40<00:00,  1.96it/s]


In [11]:
# model.load_state_dict(torch.load(os.path.join('./log/best_model.pt')))
swinvit_dict = torch.load('./log/best_model.pt')
swinvit_weights = swinvit_dict["state_dict"]
model.load_state_dict(swinvit_weights)

<All keys matched successfully>

##### Training loop with validation

In [None]:
#reload logger
logging.shutdown()
reload(logging)
logging.basicConfig(filename=logdir_path+'training_log.txt',filemode='a',format='%(message)s',level=logging.DEBUG)

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    epoch_cl_loss = 0
    epoch_recon_loss = 0
    step = 0    
    start_time = time.time()

    for batch_data in train_loader:
        step += 1

        inputs, inputs_2, gt_input = (
            batch_data["image"].to(device),
            batch_data["image_2"].to(device),
            batch_data["gt_image"].to(device),
        )
        optimizer.zero_grad()
        outputs_v1 = model(inputs)
        outputs_v2 = model(inputs_2)

        flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4)
        flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4)

        r_loss = recon_loss(outputs_v1, gt_input)
        cl_loss = contrastive_loss(flat_out_v1, flat_out_v2)

        # Adjust the CL loss by Recon Loss
        total_loss = r_loss + cl_loss * r_loss

        total_loss.backward()
        optimizer.step()
        epoch_loss += total_loss.item()
        step_loss_values.append(total_loss.item())

        # CL & Recon Loss Storage of Value
        epoch_cl_loss += cl_loss.item()
        epoch_recon_loss += r_loss.item()

        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {total_loss.item():.4f}"
        )

    epoch_loss /= step
    epoch_cl_loss /= step
    epoch_recon_loss /= step

    epoch_loss_values.append(epoch_loss)
    epoch_cl_loss_values.append(epoch_cl_loss)
    epoch_recon_loss_values.append(epoch_recon_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    logging.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if epoch % val_interval == 0:
        print("Entering Validation for epoch: {}".format(epoch + 1))
        total_val_loss = 0
        val_step = 0
        model.eval()
        for val_batch in val_loader:
            val_step += 1
            inputs, gt_input = (
                val_batch["image"].to(device),
                val_batch["gt_image"].to(device),
            )
            print("Input shape: {}".format(inputs.shape))
            outputs, outputs_v2 = model(inputs)
            val_loss = recon_loss(outputs, gt_input)
            total_val_loss += val_loss.item()

        total_val_loss /= val_step
        val_loss_values.append(total_val_loss)
        print(f"epoch {epoch + 1} Validation avg loss: {total_val_loss:.4f}")

        if total_val_loss < best_val_loss:
            print(f"Saving new model based on validation loss {total_val_loss:.4f}")
            logging.info(f"Saving new model based on validation loss {total_val_loss:.4f}")
            best_val_loss = total_val_loss
            checkpoint = {"epoch": max_epochs, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
            torch.save(checkpoint, os.path.join(logdir_path, "best_model.pt"))

        plt.figure(1, figsize=(8, 8))
        plt.subplot(2, 2, 1)
        plt.plot(epoch_loss_values)
        plt.grid()
        plt.title("Training Loss")

        plt.subplot(2, 2, 2)
        plt.plot(val_loss_values)
        plt.grid()
        plt.title("Validation Loss")

        plt.subplot(2, 2, 3)
        plt.plot(epoch_cl_loss_values)
        plt.grid()
        plt.title("Training Contrastive Loss")

        plt.subplot(2, 2, 4)
        plt.plot(epoch_recon_loss_values)
        plt.grid()
        plt.title("Training Recon Loss")

        plt.savefig(os.path.join(logdir_path, "loss_plots.png"))
        plt.close(1)    

    end_time = time.time()
    print(f"time taken: {end_time-start_time}s")
print("Done")