## Import libraries

In [1]:
from typing import Any, List, Optional, Sequence, Union

import os 
import glob
import numpy as np

import torch

from pytorch_lightning import (
    LightningDataModule, 
    LightningModule, 
    Trainer, 
)
from pytorch_lightning.callbacks import ModelCheckpoint

from monai.data import CacheDataset, DataLoader, list_data_collate
from monai.transforms import (
    Activationsd,
    AddChanneld,
    AsDiscreted,
    BatchInverseTransform,
    Compose,
    CropForegroundd,
    Invertd,
    LoadImaged,
    Orientationd,
    Rand3DElasticd,
    RandAffined,
    RandCropByPosNegLabeld,
    RandFlipd,
    RandGaussianNoised,
    RandShiftIntensityd,
    RandZoomd,
    SaveImaged, 
    ScaleIntensityRanged,
    Spacingd,
    ToTensord,
)
from monai.transforms import AsDiscrete
from monai.networks.nets import UNet
from monai.networks.layers import Act, Norm 
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, FocalLoss
from monai.inferers import sliding_window_inference
from monai.config import print_config

In [2]:
print_config()

MONAI version: 0.6.dev2123
Numpy version: 1.19.2
Pytorch version: 1.8.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 448c2527ad0c39aa7eb8a345ebd71e7f7985d374

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 8.2.0
Tensorboard version: 2.4.1
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.9.1
ITK version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.61.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



## Define DataModule

In [3]:
class KeriDataModule(LightningDataModule):
    """
    Example of LightningDataModule for MNIST dataset.

    A DataModule implements 5 key methods:
        - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
        - setup (things to do on every accelerator in distributed mode)
        - train_dataloader (the training dataloader)
        - val_dataloader (the validation dataloader(s))
        - test_dataloader (the test dataloader(s))

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
    """

    def __init__(
        self,
        data_dir: str = "data/",
        train_batch_size: int = 2,
        val_batch_size: int = 1,
        num_workers: int = 4,
        pin_memory: bool = False,
        pix_dim: Sequence[float] = (1.5, 1.5, 2),
        spatial_size: Union[Sequence[int],int] = (96, 96, 96),
        **kwargs,
    ):
        super().__init__()

        self.data_dir = data_dir
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.pix_dim = pix_dim
        self.spatial_size = spatial_size

        # define the data transforms
        self.train_transforms = Compose(
                [
                    # load the NIfTI files
                    LoadImaged(keys=["image", "label"]),
                    # convert image to "channel-first" format
                    AddChanneld(keys=["image", "label"]),
                    # resample to a consistent voxel format
                    Spacingd(
                        keys=["image", "label"],
                        pixdim=self.pix_dim, 
                        mode=("bilinear", "nearest"),
                    ),
                    # reorientate volumes to have a consistent axes orientation
                    Orientationd(keys=["image", "label"], axcodes="RAS"),
                    
                    ScaleIntensityRanged(
                        keys=["image"], a_min=-57, a_max=164,
                        b_min=0.0, b_max=1.0, clip=True,
                    ),
                    
                    CropForegroundd(keys=["image", "label"], source_key="image"),
                    # randomly crop out patch samples from
                    # big image based on pos / neg ratio
                    # the image centers of negative samples
                    # must be in valid image area
                    RandCropByPosNegLabeld(
                        keys=["image", "label"],
                        label_key="label",
                        spatial_size=self.spatial_size,
                        pos=1,
                        neg=1,
                        num_samples=4,
                        image_key="image",
                        image_threshold=0,
                    ),
                    ToTensord(keys=["image", "label"]),
                ]
        )


        self.val_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                AddChanneld(keys=["image", "label"]),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=self.pix_dim,
                    mode=("bilinear", "nearest"),
                ),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                ScaleIntensityRanged(
                    keys=["image"], a_min=-57, a_max=164,
                    b_min=0.0, b_max=1.0, clip=True,
                ),
                CropForegroundd(keys=["image", "label"], source_key="image"),
                ToTensord(keys=["image", "label"]),
            ]
        )



    def setup(self, stage: Optional[str] = None):
        
        # Load data -> Assign train/val datasets for use in dataloaders
        
        # set up the correct data path
        train_images = sorted(
            glob.glob(os.path.join(self.data_dir, "imagesTr", "*.nii.gz")))
        train_labels = sorted(
            glob.glob(os.path.join(self.data_dir, "labelsTr", "*.nii.gz")))
        
        data_dicts = [
            {"image": image_name, "label": label_name}
            for image_name, label_name in zip(train_images, train_labels)
        ]
        self.train_files, self.val_files = data_dicts[:-9], data_dicts[-9:]

        # we use cached datasets - these are 10x faster than regular datasets
        self.val_ds = CacheDataset(
                data=self.val_files, transform=self.val_transforms,
                cache_rate=1.0, num_workers=4,
            )

    def train_dataloader(self):
        # apply transforms to train dataset in train_dataloader for data augmentation
        # setup() is not run when we reload the train and val dataloaders
        train_ds = CacheDataset(
            data=self.train_files, transform=self.train_transforms,
            cache_rate=1.0, num_workers=4,
        )

        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, 
            batch_size=self.train_batch_size, 
            num_workers=self.num_workers, 
            collate_fn=list_data_collate,
            pin_memory=self.pin_memory,
            shuffle=True,
        )
        return train_loader


    def val_dataloader(self):
        val_loader = torch.utils.data.DataLoader(
            dataset=self.val_ds, 
            batch_size=self.val_batch_size, 
            num_workers=self.num_workers,
            pin_memory=self.pin_memory, 
            shuffle=False
            )
        return val_loader

    def predict_dataloader(self):
        pred_loader = torch.utils.data.DataLoader(
            dataset=self.val_ds, 
            batch_size=self.val_batch_size, 
            num_workers=self.num_workers,
            pin_memory=self.pin_memory, 
            shuffle=False
            )

        self.post_transforms = Compose([
            AsDiscreted(keys="pred", argmax=True, to_onehot=False, n_classes=3),
            Invertd(
                keys="pred",  # invert the `pred` data field, also support multiple fields
                transform=self.val_transforms,
                loader=pred_loader,
                orig_keys="image",  # get the previously applied pre_transforms information on the `img` data field,
                                  # then invert `pred` based on this information. we can use same info
                                  # for multiple fields, also support different orig_keys for different fields
                meta_keys="pred_meta_dict",  # key field to save inverted meta data, every item maps to `keys`
                orig_meta_keys="image_meta_dict",  # get the meta data from `img_meta_dict` field when inverting,
                                                 # for example, may need the `affine` to invert `Spacingd` transform,
                                                 # multiple fields can use the same meta data to invert
                meta_key_postfix="meta_dict",  # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key,
                                               # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}",
                                               # otherwise, no need this arg during inverting
                nearest_interp=True,  # change to use "nearest" mode in interpolation when inverting
                to_tensor=True,  # convert to PyTorch Tensor after inverting
            ),
            SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False),
        ])
        return pred_loader

## Define model

In [4]:
class KeriLitModel(LightningModule):
    """
    Example of LightningModule for MNIST classification.

    A LightningModule organizes your PyTorch code into 5 sections:
        - Computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """

    def __init__(
        self,
        num_res_units: int = 2,
        act=Act.PRELU,
        norm=Norm.BATCH,
        dropout: float = 0.0,
        lr: float = 1e-4,
        loss_function = "dice_loss",
        **kwargs
    ):
        super().__init__()

        # this line ensures params passed to LightningModule will be saved to ckpt
        # it also allows to access params with 'self.hparams' attribute
        self.save_hyperparameters()

        self._model = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=3, # 3 classes = background + trapezium + metacarpial
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            kernel_size=3,
            up_kernel_size=3,
            num_res_units=self.hparams.num_res_units,
            act=self.hparams.act,
            norm=self.hparams.norm,
            dropout=self.hparams.dropout,
        )
        
        # define loss function
        if self.hparams.loss_function == "dice_loss":
            self.loss_function = DiceLoss(include_background=True,
                to_onehot_y=True, softmax=True)
        
        elif self.hparams.loss_function == "focal_loss_2":
            # give twice more weight to the trapezium class (class 1)
            # than to the background or the metacarpus (classes 0 & 2)
            self.loss_function = FocalLoss(include_background = True, 
                to_onehot_y = True, weight = [1.0, 2.0, 1.0])

        else:
            # give three times weight to the trapezium class (class 1)
            # than to the background or the metacarpus (classes 0 & 2)
            self.loss_function = FocalLoss(include_background = True, 
                to_onehot_y = True, weight = [1.0, 3.0, 1.0])

        # execute after model forward to transform model output  to discrete values
        self.post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=3)
        self.post_label = AsDiscrete(to_onehot=True, n_classes=3)

        self.dice_metric = DiceMetric(include_background=False, 
            reduction="mean")


    def forward(self, batch: Any):
        images = batch["image"]
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        logits = sliding_window_inference(
            images, roi_size, sw_batch_size, self._model
        )
        return logits

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters(), self.hparams.lr)
        return optimizer   

    def training_step(self, batch: Any, batch_idx: int):
        images, labels = batch["image"], batch["label"]
        
        # in Lightning, we recommend to separate training from inference
        # therefore, we don't call the forward method during training
        logits = self._model(images)
        loss = self.loss_function(logits, labels)
        self.log('train/loss', loss, on_step=False, on_epoch=True)
        
        return loss

    def validation_step(self, batch: Any, batch_idx: int):
        images, labels = batch["image"], batch["label"]
        logits = self.forward(batch)
        loss = self.loss_function(logits, labels)
        self.log('val/loss', loss, on_step=False, on_epoch=True)
        
        # binarize predictions and convert them to onehot format for DiceMetric
        outputs = self.post_pred(logits)
        labels = self.post_label(labels)
        value, not_nans = self.dice_metric(
            y_pred=outputs,
            y=labels,
        )
        not_nans = not_nans.item()
        return {"val_loss": loss, "val_dice": value, "not_nans": not_nans}

    def validation_epoch_end(self, outputs: List[Any]):
        val_dice, num_items = 0, 0
        for output in outputs:
            val_dice += output["val_dice"].item() * output["not_nans"]
            num_items += output["not_nans"]
        mean_val_dice = torch.tensor(val_dice / num_items)
        self.log('val/dice', mean_val_dice, prog_bar=True)

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]=None):
        batch["pred"] = self.forward(batch)
        self.trainer.datamodule.post_transforms(batch)
    

## Train model

You don't need to run the two following cells if you already have a checkpoint. 

In [4]:
# initialize datamodule
data_dir = "/media/diane/Shared Data/Data/Spleen_Data/Task09_Spleen"
data = KeriDataModule(data_dir=data_dir)

# initialize model
model = KeriLitModel()

In [None]:
# setup checkpoint callback
checkpoint_callback = ModelCheckpoint(monitor="val/dice",
                                      save_top_k=1,
                                      save_last=True,
                                      mode="max",
                                      dirpath='/media/diane/Shared Data/Data/Spleen_Data/checkpoints/')

# initialize trainer
trainer = Trainer(gpus=1, precision=32, amp_backend="native",
                  amp_level="02", progress_bar_refresh_rate=20,
                  max_epochs=300, callbacks=[checkpoint_callback])

# train
trainer.fit(model=model, datamodule=data)

## Make predictions

In [5]:
# load model from checkpoint
CKPT_PATH = "/home/diane/Documents/Invertd_demo/epoch=195-step=3135.ckpt"
trained_model = KeriLitModel.load_from_checkpoint(checkpoint_path=CKPT_PATH)

# initialize datamodule
data_dir = "/media/diane/Shared Data/Data/Spleen_Data/Task09_Spleen"
data = KeriDataModule(data_dir=data_dir)

# initialize trainer for predict
trainer = Trainer(gpus=1, precision=32, amp_backend="native",
                  amp_level="02", progress_bar_refresh_rate=20)

# train
trainer.predict(model=trained_model, datamodule=data)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Loading dataset: 100%|██████████| 9/9 [00:06<00:00,  1.45it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]
=== Transform input info -- method ===
image statistics:
Type: <class 'numpy.ndarray'>
Shape: (1, 226, 157, 113)
Value range: (0.0, 2.0)
image_transforms statistics:
Type: <class 'list'>
Value: [{'class': 'Spacingd', 'id': 140359533164240, 'orig_size': [512, 512, 33], 'extra_info': {'meta_key': 'image_meta_dict', 'old_affine': tensor([[   0.6992,    0.0000,    0.0000, -357.3009],
        [   0.0000,    0.6992,    0.0000, -357.3009],
        [   0.0000,    0.0000,    7.0000,    0.0000],
        [   0.0000,    0.0000,    0.0000,    1.0000]], device='cuda:0',
       dtype=torch.float64), 'mode': 'nearest', 'padding_mode': 'border', 'align_corners': 'none'}}, {'class': 'Orientationd', 'id': 140359533164336, 'orig_size': [239, 239, 113], 'extra_info': {'meta_key': 'image_meta_dict', 'old_affine': tensor([[   1.5000,    0.0000,    0.0000, -357.3009],
        [   0.0000,    1.5000,    0.0000, -357.3009],
        [   0.0000,    0.0000,    2.0000,    0.0000],
    

RuntimeError: applying transform <monai.transforms.post.dictionary.Invertd object at 0x7fa7fc7a8b20>