In [1]:
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import yaml
from torch.utils.data import DataLoader

from model.model import get_model
from model.ddim import Classifier

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from monai.apps import DecathlonDataset
from monai import transforms

In [2]:
# relevant for vs code
import matplotlib
matplotlib.use('Qt5Agg')
%matplotlib inline

## Load the config

In [3]:
with open('./configs/ddim_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

Seed set to 42


42

## Load Data for Classification
Use a different dataset for brain MRI which provides labels for disease classification.
We use the dataset as the paper.  

In [4]:
# get home directoray
home_dir = os.path.expanduser('~')
root_dir = os.path.join(home_dir, 'Downloads', 'datasets')
assert os.path.exists(root_dir), f'root_dir {root_dir} does not exist'

In [5]:
channel = 0  # 0 = Flair
assert channel in [0, 1, 2, 3], "Choose a valid channel"

train_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        transforms.EnsureChannelFirstd(keys=["image", "label"]),
        transforms.Lambdad(keys=["image"], func=lambda x: x[channel, :, :, :]),
        transforms.EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
        transforms.EnsureTyped(keys=["image", "label"]),
        transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
        transforms.Spacingd(keys=["image", "label"], pixdim=(3.0, 3.0, 2.0), mode=("bilinear", "nearest")),
        transforms.CenterSpatialCropd(keys=["image", "label"], roi_size=(64, 64, 44)),
        transforms.ScaleIntensityRangePercentilesd(keys="image", lower=0, upper=99.5, b_min=0, b_max=1),
        transforms.RandSpatialCropd(keys=["image", "label"], roi_size=(64, 64, 1), random_size=False),
        transforms.Lambdad(keys=["image", "label"], func=lambda x: x.squeeze(-1)),
        transforms.CopyItemsd(keys=["label"], times=1, names=["slice_label"]),
        transforms.Lambdad(keys=["slice_label"], func=lambda x: 0.0 if x.sum() > 0 else 1.0),
    ]
)

In [6]:
batch_size = config['batch_size']

train_ds = DecathlonDataset(
    root_dir=root_dir,
    task="Task01_BrainTumour",
    section="training",  # validation
    cache_rate=3.0,  # you may need a few Gb of RAM... Set to 0 otherwise
    num_workers=4,
    download=False,  # Set download to True if the dataset hasnt been downloaded yet
    seed=0,
    transform=train_transforms,
)

print(f"Length of training data: {len(train_ds)}")  # this gives the number of patients in the training set
print(f'Train image shape {train_ds[0]["image"].shape}')

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)

Loading dataset:   0%|          | 0/388 [00:00<?, ?it/s]

Loading dataset: 100%|██████████| 388/388 [03:26<00:00,  1.88it/s]

Length of training data: 388
Train image shape torch.Size([1, 64, 64])





In [7]:
val_ds = DecathlonDataset(
    root_dir=root_dir,
    task="Task01_BrainTumour",
    section="validation",
    cache_rate=3.0,  # you may need a few Gb of RAM... Set to 0 otherwise
    num_workers=4,
    download=False,  # Set download to True if the dataset hasnt been downloaded yet
    seed=0,
    transform=train_transforms,
)
print(f"Length of training data: {len(val_ds)}")
print(f'Validation Image shape {val_ds[0]["image"].shape}')

val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=True)

Loading dataset:   0%|          | 0/96 [00:00<?, ?it/s]

Loading dataset: 100%|██████████| 96/96 [00:55<00:00,  1.74it/s]

Length of training data: 96
Validation Image shape torch.Size([1, 64, 64])





## Tensorboard

In [8]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Reusing TensorBoard on port 6006 (pid 25180), started 1 day, 10:41:15 ago. (Use '!kill 25180' to kill it.)

## Prepare model

In [9]:
# Init DDIM model
model = Classifier(config)

# Use tensorboard logger and CSV logger
trainer = pl.Trainer(
    max_epochs=config['num_epochs_cls'],
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./'),
        pl.loggers.CSVLogger(save_dir='./')
    ],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Run training

In [10]:
# Train the model ⚡
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# save the model
trainer.save_checkpoint('./checkpoints/best_classifier_model.ckpt')

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                  | Params
-----------------------------------------------------
0 | classifier | DiffusionModelEncoder | 2.4 M 
1 | scheduler  | DDIMScheduler         | 0     
-----------------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.630     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 999: 100%|██████████| 6/6 [00:01<00:00,  3.50it/s, v_num=2_23]

`Trainer.fit` stopped: `max_epochs=1000` reached.


Epoch 999: 100%|██████████| 6/6 [00:01<00:00,  3.49it/s, v_num=2_23]
