# Self-Supervised Learning with Swin Transformer AutoEncoder on the AutoPET PET-CT Dataset

This notebook demonstrates self-supervised learning for medical imaging using a Swin Transformer-based AutoEncoder, following the approach described by Tang et al. [1]. The model is pre-trained on the AutoPET dataset [2], which consists of whole-body FDG-PET/CT images with manually annotated tumor lesions. The aim is to leverage both PET and CT modalities in a self-supervised framework to learn robust feature representations beneficial for downstream medical image analysis tasks.

---
## Prerequisites

The following packages are required to run the notebook:

- `monai`
- `torch`
- `fire`
- `protobuf`
- `einops=3.20`
- `pytorch-ignite`

## References

[1] Y. Tang et al., ‘Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis’, arXiv [cs.LG]. 2021.  
[2] S. Gatidis and T. Kuestner, ‘A whole-body FDG-PET/CT dataset with manually annotated tumor lesions (FDG-PET-CT-Lesions)’. The Cancer Imaging Archive, 2022.

In [None]:
from monai.data import load_decathlon_datalist, Dataset, DataLoader
from monai.data.utils import partition_dataset
from pathlib import Path
from monai.utils import set_determinism, first
from monai.transforms import (
    LoadImaged,
    Compose,
    CropForegroundd,
    CopyItemsd,
    SpatialPadd,
    EnsureChannelFirstd,
    Spacingd,
    OneOf,
    ScaleIntensityRanged,
    RandSpatialCropSamplesd,
    RandCoarseDropoutd,
    RandCoarseShuffled,
)
import sys
sys.path.append("LymphomaDetection")
import json
import os
import time
import torch
import matplotlib.pyplot as plt

import torch

from torch.nn import L1Loss
from monai.losses import ContrastiveLoss
from src.networks import SwinAutoEnc

from monai.handlers import (
    StatsHandler,
    from_engine,
    MeanDice,
    ValidationHandler,
    LrScheduleHandler,
    CheckpointSaver,
    CheckpointLoader,
    TensorBoardStatsHandler,
    MLFlowHandler,
    IgniteMetricHandler,
    TensorBoardImageHandler
)
from monai.engines import SupervisedTrainer, SupervisedEvaluator

from src.utils import create_image_list, threshold_CT, prepare_batch, prepare_val_batch, mlflow_transform, tb_batch_transform, tb_output_transform

from src.training import iteration

from src.metrics import TotalLoss, recon_loss_transform, recon_val_loss_transform, contrastive_loss_transform, AMPContrastiveLoss
import matplotlib.pyplot as plt

In [None]:
data_dir = "Data/AutoPET/PSMA-FDG-PET-CT-Lesions"

In [None]:
preprocess_dir = Path(data_dir).joinpath("preprocessed")

In [None]:
datalist = load_decathlon_datalist(data_list_file_path=Path(data_dir).joinpath("preprocess.json"), base_dir=preprocess_dir)

In [None]:
image_data_list = create_image_list(datalist)

In [None]:
train_data, val_data = partition_dataset(image_data_list, ratios=[0.8, 0.2], shuffle=True)

In [None]:
patch_size = (96, 96, 96)

fill_value = (0, 0.2)
# Define Training Transforms
train_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        #Spacingd(keys=["image"], pixdim=(2.0, 2.0, 2.0), mode=("bilinear")),
        #ScaleIntensityRanged(
        #    keys=["image"],
        #    a_min=-57,
        #    a_max=164,
        #    b_min=0.0,
        #    b_max=1.0,
        #    clip=True,
        #),
        #CropForegroundd(keys=["image"], source_key="image"),
        CropForegroundd(keys=["image"], source_key="image",channel_indices=[0], select_fn=threshold_CT),
        SpatialPadd(keys=["image"], spatial_size=patch_size),
        RandSpatialCropSamplesd(keys=["image"], roi_size=patch_size, random_size=False, num_samples=2),
        CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False),
        OneOf(
            transforms=[
                RandCoarseDropoutd(
                    keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32, fill_value=fill_value
                ),
                RandCoarseDropoutd(
                    keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64, fill_value=fill_value
                ),
            ]
        ),
        RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8),
        # Please note that that if image, image_2 are called via the same transform call because of the determinism
        # they will get augmented the exact same way which is not the required case here, hence two calls are made
        OneOf(
            transforms=[
                RandCoarseDropoutd(
                    keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32, fill_value=fill_value
                ),
                RandCoarseDropoutd(
                    keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64, fill_value=fill_value
                ),
            ]
        ),
        RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8),
    ]
)


In [None]:
# Define DataLoader using MONAI, CacheDataset needs to be used
batch_size = 2
train_ds = Dataset(data=train_data, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

val_ds = Dataset(data=val_data, transform=train_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
check_ds = Dataset(data=train_data, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)

image = check_data["image"][0][0]
print(f"image shape: {image.shape}")

In [None]:
batch_id = 1
channel_id = 0
fig, axs = plt.subplots(3,3,figsize=(20*3, 20))
z_center = int(check_data["image"][batch_id][channel_id].shape[-1]/2)
y_center = int(check_data["image"][batch_id][channel_id].shape[-2]/2)
x_center = int(check_data["image"][batch_id][channel_id].shape[-3]/2)
axs[0,0].imshow(check_data["image"][batch_id][channel_id,:,:,z_center],cmap='gray')
axs[0,1].imshow(check_data["image"][batch_id][channel_id, :,y_center,:],cmap='gray')
axs[0,2].imshow(check_data["image"][batch_id][channel_id, x_center,:,:],cmap='gray')

axs[1,0].imshow(check_data["image_2"][batch_id][channel_id,:,:,z_center],cmap='gray')
axs[1,1].imshow(check_data["image_2"][batch_id][channel_id, :,y_center,:],cmap='gray')
axs[1,2].imshow(check_data["image_2"][batch_id][channel_id, x_center,:,:],cmap='gray')

axs[2,0].imshow(check_data["gt_image"][batch_id][channel_id,:,:,z_center],cmap='gray')
axs[2,1].imshow(check_data["gt_image"][batch_id][channel_id, :,y_center,:],cmap='gray')
axs[2,2].imshow(check_data["gt_image"][batch_id][channel_id, x_center,:,:],cmap='gray')


In [None]:
# Define Network ViT backbone & Loss & Optimizer
device = torch.device("cuda:0")
model = SwinAutoEnc(
    in_chans=2,
    out_channels=2,
    embed_dim=96,
    window_size=(4, 4, 4),
    patch_size=(2, 2, 2),
    depths=(2, 2, 6, 2),
    num_heads=(3, 6, 12, 24),
)
model = model.to(device)

# Define Hyper-paramters for training loop
max_epochs = 500
val_interval = 2

lr = 1e-4
best_val_loss = 1000.0


optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
recon_loss = L1Loss()
contrastive_loss = ContrastiveLoss(temperature=0.05)


total_loss = TotalLoss(
    recon_loss=recon_loss,
    contrastive_loss=contrastive_loss,
)

In [None]:
logdir_path = os.path.normpath("logs")

In [None]:

key_val_metric = IgniteMetricHandler(loss_fn=recon_loss,output_transform=recon_val_loss_transform)

In [None]:
ckpt_dir = "models"
log_dir = "logs"

In [None]:
val_handlers = [StatsHandler(iteration_log=False)]
val_handlers.append(
    CheckpointSaver(
        save_dir=ckpt_dir,
        save_dict={
            "model": model,
            "optimizer_state": optimizer,
            #"scheduler": lr_scheduler,
        },
        # save_final= True,
        save_interval=1,
        key_metric_negative_sign=True,
        save_key_metric=True,
        # final_filename= "model_final.pt",
        #key_metric_filename= "model.pt",
        n_saved=1,
    )
)
val_handlers.append(TensorBoardStatsHandler(log_dir=log_dir, iteration_log=False))

In [None]:
val_handlers.append(TensorBoardImageHandler(log_dir=log_dir, batch_transform=tb_batch_transform, output_transform=tb_output_transform,max_channels=2))

In [None]:
evaluator = SupervisedEvaluator(
    amp=True,
    device=device,
    epoch_length=2,
    network=model,
    key_val_metric={"Val_Reconstruction_Loss": key_val_metric},
    prepare_batch=prepare_val_batch,
    val_data_loader=val_loader,
    val_handlers=val_handlers,
    #postprocessing=postprocessing,
    #additional_metrics=additional_metrics,
)

In [None]:
train_handlers = [StatsHandler(output_transform=from_engine(["loss"], first=True), tag_name="train_loss"), ValidationHandler(epoch_level=True, interval=1, validator=evaluator)]

In [None]:
train_handlers.append(
    TensorBoardStatsHandler(log_dir=log_dir, output_transform=from_engine(["loss"], first=True), tag_name="train_loss")
)

In [None]:
mlflow_experiment_name = "SwinAutoEnc"
mlflow_run_name = "SwinAutoEnc_Train"
tracking_uri = "http://localhost:5000"

train_handlers.append(
    MLFlowHandler(
        dataset_dict={"train": train_data},
        dataset_keys="image",
        #experiment_param=create_mlflow_experiment_params(params_file),
        experiment_name=mlflow_experiment_name,
        #label_dict=label_dict,
        output_transform=mlflow_transform,
        run_name=mlflow_run_name,
        state_attributes=["best_metric", "best_metric_epoch"],
        tag_name="Train_Loss",
        tracking_uri=tracking_uri,
    )
)

val_handlers.append(
    MLFlowHandler(
        experiment_name=mlflow_experiment_name,
        iteration_log=False,
        #label_dict=label_dict,
        output_transform=mlflow_transform,
        run_name=mlflow_run_name,
        state_attributes=["best_metric", "best_metric_epoch"],
        tracking_uri=tracking_uri,
    )
)

In [None]:
recon_loss_metric = L1Loss()
contrastive_loss_metric = AMPContrastiveLoss(temperature=0.05)

additional_metric = IgniteMetricHandler(loss_fn=contrastive_loss_metric,output_transform=contrastive_loss_transform)
train_key_metric = IgniteMetricHandler(loss_fn=recon_loss_metric,output_transform=recon_loss_transform)

In [None]:
trainer = SupervisedTrainer(
    device="cuda",
    max_epochs=500,
    train_data_loader=train_loader,
    network=model,
    optimizer=optimizer,
    loss_function=total_loss,
    inferer=None,
    key_train_metric={"Reconstruction_Loss": train_key_metric},
    train_handlers=train_handlers,
    additional_metrics={"Contrastive_Loss": additional_metric},
    amp=True,
    prepare_batch=prepare_batch,
    iteration_update=iteration
)

In [None]:
trainer.run()

## Run Training as a MONAI Bundle

In [None]:
import os 
import yaml

def create_config(config_folder, output_file):
    config_files = [f.path for f in os.scandir(config_folder) if f.path.endswith(".yaml")]
    config = {}
    for config_file in config_files:
        with open(config_file, "r") as file:
            config.update(yaml.safe_load(file))

    if output_file.endswith(".yaml"):
        with open(output_file, "w") as file:
            yaml.dump(config, file)
    if output_file.endswith(".json"):
        with open(output_file, "w") as file:
            json.dump(config, file)

    return config

In [None]:
%%bash
mkdir -p LymphomaDetection/SSL/configs/

In [None]:
config = create_config("LymphomaDetection/Bundles/SSL", "LymphomaDetection/SSL/configs/train.yaml")

In [None]:
%%bash
cp -r LymphomaDetection/src LymphomaDetection/SSL/

In [None]:
%%bash

export BUNDLE_ROOT="LymphomaDetection/SSL"
export PYTHONPATH=$BUNDLE_ROOT
export DATA_FOLDER=Data/AutoPET/PSMA-FDG-PET-CT-Lesions

cd $BUNDLE_ROOT
python -m monai.bundle run \
    --bundle_root $BUNDLE_ROOT \
    --data_dir $DATA_FOLDER/preprocessed \
    --decathlon_data_list $DATA_FOLDER/preprocess.json \
    --tracking_uri "http://localhost:5000" \
    --config_file $BUNDLE_ROOT/configs/train.yaml