# Train a Kidney Segmentation Model with MONAI and PyTorch Lightning

## Kidney Segmentation with PyTorch Lightning and OpenVINO™ - Part 2

This tutorial is a part of the series on how to train, optimize, quantize and show live inference on a medical segmentation model. The goal is to accelerate inference on a kidney segmentation model. The [UNet](https://arxiv.org/abs/1505.04597) model is trained from scratch; the data is from [KiTS19](https://github.com/neheller/kits19).

This second tutorial in the series shows how to:

- train a 2D segmentation model with MONAI and PyTorch Lightning,
- define a Metric that can be used for both training and quantization,
- visualize training results,
- convert the trained model to OpenVINO Intermediate Representation (IR).

All notebooks in this series:

- [Data Preparation for 2D Segmentation of 3D Medical Data](data-preparation-ct-scan.ipynb)
- Train a 2D-UNet Medical Imaging Model with PyTorch Lightning (this notebook)
- [Convert and Quantize a Segmentation Model and Show Live Inference](110-ct-segmentation-quantize-nncf.ipynb)
- [Live Inference and Benchmark CT-scan data](110-ct-scan-live-inference.ipynb) 

## Instructions

This notebook needs the KiTS19 dataset, prepared according to the instructions in the [Data Preparation for 2D Segmentation of 3D Medical Data](data-preparation-ct-scan.ipynb) tutorial. Set `BASEDIR` to the directory that contains the kits19_frames directory in the cell below.

To install the requirements for running this notebook, follow the instructions in the [`README.md` file](README.md).

> **TIP:** Training the model can take a long time. If you want to run the code with a script instead of in the notebook, you can export the notebook to a Python script, using the following command: `jupyter nbconvert --TagRemovePreprocessor.remove_cell_tags=hide --to script pytorch-monai-training.ipynb`. This will export the code without the visualization and TensorBoard cells.


### Table of content:
- [Settings](#Settings-Uparrow)
- [Imports](#Imports-Uparrow)
- [Define Dataset and DataModule](#Define-Dataset-and-DataModule-Uparrow)
- [PyTorch Lightning Model](#PyTorch-Lightning-Model-Uparrow)
- [TensorBoard](#TensorBoard-Uparrow)
- [Train the Model](#Train-the-Model-Uparrow)
- [Show Inference on the Trained Model](#Show-Inference-on-the-Trained-Model-Uparrow)
- [Convert PyTorch model to OpenVINO IR](#Convert-PyTorch-model-to-OpenVINO-IR-Uparrow)
- [Next Steps](#Next-Steps-Uparrow)

## Settings [$\Uparrow$](#Table-of-content:)

In [None]:
%pip install -q "monai>=0.9.1,<1.0.0" nibabel pytorch_lightning "openvino>=2023.1.0"

In [None]:
from pathlib import Path

import torch

# The directory with the dataset, as prepared in the data preparation notebook.
BASEDIR = Path("~/kits19/kits19_frames/").expanduser()
# Set to True to use CUDA for training - this requires an NVIDIA GPU, installed CUDA drivers
# and a PyTorch version with CUDA enabled.
USE_CUDA = torch.cuda.is_available()
# A directory where the saved model will be stored.
MODEL_DIR = "model"

# Check if BASEDIR contains imaging and segmentation frames.
assert len(list(BASEDIR.glob("**/segmentation_frames*/*png"))) > 0
assert len(list(BASEDIR.glob("**/imaging_frames*/*jpg"))) > 0

## Imports [$\Uparrow$](#Table-of-content:)


In [None]:
import datetime
import random
import time

import dateutil.relativedelta
import matplotlib.pyplot as plt
import monai
import pytorch_lightning as pl
import torch
from monai.data import ImageDataset
from monai.transforms import (
    AddChannel,
    Compose,
    EnsureType,
    LabelToMask,
    RandGaussianNoise,
    RandHistogramShift,
    RandRotate,
    Resize,
)
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from skimage import measure
from torch.utils.data import DataLoader
from torchmetrics.classification import BinaryF1Score

## Define Dataset and DataModule [$\Uparrow$](#Table-of-content:)

In this step, a MONAI [`ImageDataset`](https://docs.monai.io/en/latest/data.html#imagedataset) is created to load and transform the data, and a [PyTorch Lightning DataModule](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html) for accessing this data during the training. The dataset returns data as a tuple consisting of `(image, mask, image_metadata, mask_metadata)`.

Use [MONAI Transforms](https://docs.monai.io/en/latest/transforms.html) to transform and augment the data during the training: randomly rotate the data, add noise, and shift pixel values. MONAI's `ImageDataset` ensures that the random seed for the image and segmentation mask transform are the same. Therefore, for the random rotation transforms, an image and a mask will be rotated in the same way. During validation, make sure that the dimensions and the data type are correct.

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.data_dir = data_dir

    def setup(self, stage=None):
        random.seed(1.414213)
        data_path = Path(self.data_dir)
        segs = sorted(data_path.glob("case_*/segmentation_frames/*png"))
        images = sorted(data_path.glob("case_*/imaging_frames/*jpg"))
        val_indices = [7, 10, 14, 15, 30, 60, 71, 74, 75, 81, 92, 93,
                       106, 115, 117, 119, 134, 161, 192, 196, 203]  # fmt: skip
        test_indices = [2, 5, 37, 103, 174]
        train_indices = set(list(range(210))) - set(val_indices) - set(test_indices)

        val_cases = [f"case_{i:05d}" for i in val_indices]
        train_cases = [f"case_{i:05d}" for i in train_indices]

        train_segs = [seg for seg in segs if Path(seg).parents[1].name in train_cases]
        val_segs = [seg for seg in segs if Path(seg).parents[1].name in val_cases]

        train_images = [im for im in images if Path(im).parents[1].name in train_cases]
        val_images = [im for im in images if Path(im).parents[1].name in val_cases]

        # Define transforms for an image and a segmentation mask.
        train_imtrans = Compose(
            [
                AddChannel(),
                Resize((512, 512)),
                RandRotate(0.78, prob=0.5),
                RandGaussianNoise(prob=0.5),
                RandHistogramShift(),
                EnsureType(),
            ]
        )
        train_segtrans = Compose(
            [
                AddChannel(),
                Resize((512, 512)),
                RandRotate(0.78, prob=0.5),
                LabelToMask(select_labels=[1]),
                EnsureType(),
            ]
        )
        val_imtrans = Compose([AddChannel(), EnsureType()])
        val_segtrans = Compose(
            [
                AddChannel(),
                LabelToMask(select_labels=[1]),
                EnsureType(),
            ]
        )

        self.dataset_train = ImageDataset(
            image_files=train_images,
            seg_files=train_segs,
            transform=train_imtrans,
            seg_transform=train_segtrans,
            image_only=False,
            transform_with_metadata=False,
        )
        self.dataset_val = ImageDataset(
            image_files=val_images,
            seg_files=val_segs,
            transform=val_imtrans,
            seg_transform=val_segtrans,
            image_only=False,
            transform_with_metadata=False,
        )

        print(f"Setup train dataset: {len(self.dataset_train)} items")
        print(f"Setup val dataset: {len(self.dataset_val)} items")

        assert len(self.dataset_train) > 0, "Train dataset is empty."
        assert len(self.dataset_val) > 0, "Val dataset is empty"

    def train_dataloader(self):
        return DataLoader(
            self.dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            # Set num_workers to 0 to prevent issues in Jupyter. Increase this in the production code.
            num_workers=0,
            pin_memory=torch.cuda.is_available(),
        )

    def val_dataloader(self):
        return DataLoader(
            self.dataset_val,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=False,
            pin_memory=torch.cuda.is_available(),
        )

    def test_dataloader(self):
        return self.val_dataloader()

Use the code below to check if the dataset looks correct by creating an instance of `DataModule` and plotting a few sample images of the train dataset, after augmentation.

In [None]:
def plotdataset(dataset, num=4):
    assert num < 16
    fig, ax = plt.subplots(3, num, figsize=(3 * num, 9))
    dataset_items = random.choices(dataset, k=num)

    for i, (image, mask, im_meta, mask_meta) in enumerate(dataset_items):
        image = image.long().squeeze(0).cpu()
        image_name = Path(im_meta["filename_or_obj"]).stem
        mask = mask.long().squeeze(0).cpu()
        contours = measure.find_contours(mask.cpu().numpy(), 0.4)

        ax[0, i].imshow(image, cmap="gray")
        ax[0, i].set_title(image_name)
        ax[1, i].imshow(mask, cmap="gray")
        for n, contour in enumerate(contours):
            ax[2, i].imshow(image)
            ax[2, i].plot(contour[:, 1], contour[:, 0], linewidth=2, color="red")

    for axi in ax.ravel():
        axi.axis("off")
    fig.suptitle(f"Dataset directory: {Path(im_meta['filename_or_obj']).parents[2]}")

In [None]:
dm = DataModule(data_dir=BASEDIR, batch_size=8)
dm.setup()
ds = dm.dataset_train
ds.transform.transforms

In [None]:
# Run this cell again to see different randomly selected images.
plotdataset(ds)

## PyTorch Lightning Model [$\Uparrow$](#Table-of-content:)

Create a PyTorch Lightning [`LightningModule`](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html) to train a MONAI [BasicUNet](https://docs.monai.io/en/latest/networks.html#basicunet) model.

Use Binary Cross Entropy Loss for the loss function and the Adam Optimizer with the default learning rate of 0.001. The evaluation metric is the Binary F1/Dice score `BinaryF1Score` from `torchmetrics`.

In [None]:
class MonaiModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self._model = monai.networks.nets.BasicUNet(spatial_dims=2, in_channels=1, out_channels=1)
        # https://docs.monai.io/en/latest/highlights.html?deterministic-training-for-reproducibility
        monai.utils.set_determinism(seed=2.71828, additional_settings=None)

        # https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
        # Set pos_weight to 0.5 to favor precision over recall
        self.loss_function = torch.nn.BCEWithLogitsLoss(pos_weight=torch.as_tensor([0.5]))
        self.metric = BinaryF1Score()

        self.best_val_dice = 0
        self.best_val_epoch = 0

        self.validation_step_outputs = []

    def forward(self, x):
        return self._model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters())
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels, _, _ = batch
        labels = labels.float()
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        self.log("train_loss", loss.item())
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels, _, _ = batch
        labels = labels.float()
        output = self.forward(images)
        loss = self.loss_function(output, labels)

        # Compute statistics for metric computation
        y_true = labels.long()
        y_pred = torch.sigmoid(output).round().long()

        self.metric.update(y_true, y_pred)
        self.log("val_loss", loss)
        self.validation_step_outputs.append({"val_loss": loss, "val_number": len(output)})
        return {"val_loss": loss, "val_number": len(output)}

    def on_validation_epoch_end(self):
        val_loss, num_items = 0, 0

        for output in self.validation_step_outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
        mean_val_dice = self.metric.compute()
        self.metric.reset()

        mean_val_loss = torch.tensor(val_loss / num_items)

        self.logger.experiment.add_scalar("Loss/Validation", mean_val_loss, self.current_epoch)
        self.logger.experiment.add_scalar("F1-score/Validation", mean_val_dice, self.current_epoch)
        self.log("F1", mean_val_dice, prog_bar=True, logger=False)

        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
        self.validation_step_outputs.clear()

## TensorBoard [$\Uparrow$](#Table-of-content:)

During the training loop, loss and metric information is logged to [TensorBoard](https://www.tensorflow.org/tensorboard/get_started). With the TensorBoard Jupyter extension, you can see these in the notebook. 

When running this cell for the first time, there will be a message that reads: *No dashboards are active for the current data set.* Once training has started, click the *reload* button to see the data.

In [None]:
%load_ext tensorboard

In [None]:
# The --bind_all parameter enables you to access TensorBoard if you run this notebook on a remote computer.
%tensorboard --logdir tb_logs --bind_all

## Train the Model [$\Uparrow$](#Table-of-content:)

Create instances of a PyTorch Lightning Model and DataModule, as well as a Logger and a `ModelCheckpoint`. Adjust the batch size for the DataModule if you have a GPU with enough memory. 

The `TensorBoardLogger` enables logging to TensorBoard; the `ModelCheckpoint` saves the top three models with the best F1 score.

For more information on the PyTorch Lightning options, refer to the [PyTorch Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/).

In [None]:
model = MonaiModel()
data = DataModule(data_dir=BASEDIR, batch_size=8)

logger = TensorBoardLogger("tb_logs", name="kits19_monai")
checkpoint_callback = ModelCheckpoint(monitor="F1", mode="max", save_top_k=3)

Running the next cell will start the training. By default, training runs for 15 epochs on GPU, and for 1 epoch on CPU. Running for 1 epoch will not result in a good model.This is only useful for testing purposes.

During the training, a progress bar will be shown with the best F1 score so far. It is possible to stop the training before reaching `max_epochs` by clicking the *stop* button in the Jupyter toolbar at the top of the notebook. The training will stop gracefully and the model that was saved during the training will be used for showing inference and the ONNX export in the next cells.

In [None]:
trainer = pl.Trainer(
    max_epochs=15 if USE_CUDA else 1,
    accelerator="gpu" if USE_CUDA else "cpu",
    devices="auto",
    logger=logger,
    precision=16 if USE_CUDA else 32,
    limit_train_batches=0.5,
    limit_val_batches=0.5,
    callbacks=[checkpoint_callback],
    fast_dev_run=False,  # Set to True to quickly test the PyTorch Lightning model.
)

start = datetime.datetime.now()
print(start.strftime("%H:%M:%S"))
try:
    trainer.fit(model, data)
finally:
    end = datetime.datetime.now()
    print(end.strftime("%H:%M:%S"))
    delta = dateutil.relativedelta.relativedelta(end, start)
    print(f"Training duration: {delta.hours:02d}:{delta.minutes:02d}:{delta.seconds:02d}")


## Show Inference on the Trained Model [$\Uparrow$](#Table-of-content:)

The F1 score gives an indication of the quality of the model. However, it is useful to show model outputs on a few random images as well. Therefore, load the model from the best checkpoint, and visualize model outputs on four randomly selected images.

In [None]:
checkpoint_path = checkpoint_callback.best_model_path
assert checkpoint_path != "", "No checkpoint saved. Please train the model for at least one epoch."
print(f"checkpoint_path: {checkpoint_path}")

best_model = model.load_from_checkpoint(checkpoint_path)
valmodel = best_model._model
valmodel.eval().cpu();

In [None]:
dataset_items = random.choices(data.dataset_val, k=4)
# Set `seed` to current time. To reproduce specific results, copy the printed seed
# and manually set `seed` to that value.
seed = int(time.time())
random.seed(seed)
print(f"Visualizing results with seed {seed}")

fig, ax = plt.subplots(nrows=4, ncols=3, figsize=(24, 16))
for i, (image, mask, im_meta, mask_meta) in enumerate(dataset_items):
    input_image = image.unsqueeze(0)
    image_name = Path(im_meta["filename_or_obj"]).stem
    with torch.no_grad():
        res = valmodel(input_image)
    target_mask = mask.short()[0, ::]
    result_mask = torch.sigmoid(res).round().short()[0, 0, ::]

    ax[i, 0].imshow(image[0, ::], cmap="gray")
    ax[i, 1].imshow(target_mask, cmap="gray")
    ax[i, 2].imshow(result_mask, cmap="gray")
    ax[i, 0].set_title(image_name)
    ax[i, 1].set_title("Annotation")
    ax[i, 2].set_title("Prediction")

## Convert PyTorch model to OpenVINO IR [$\Uparrow$](#Table-of-content:)


In [None]:
import openvino as ov

Path(MODEL_DIR).mkdir(exist_ok=True)
ir_path = Path(MODEL_DIR) / "unet_kits19.xml"
dummy_input = torch.randn(1, 1, 512, 512)
ov_model = ov.convert_model(valmodel, example_input=dummy_input)
ov.save_model(ov_model, ir_path)
print(f"OpenVINO model saved to {ir_path}")

## Next Steps [$\Uparrow$](#Table-of-content:)


Open the [110-ct-segmentation-quantize](../notebooks/110-ct-segmentation-quantize/110-ct-segmentation-quantize.ipynb) notebook to convert the ONNX model to OpenVINO IR, quantize the IR model the model with NNCF with the [Post-training Quantization with NNCF Tool](https://docs.openvino.ai/nightly/basic_quantization_flow.html) API in OpenVINO and show live inference in the notebook.