# All Together Now!

We've been going through each individual step by hand, while there's plenty of helpful libraries out there than can do what we've done (and probably better).
Let's cover some of the more popular libraries for data manipulation and machine learning in Python.

## Quick Note

We're using `tqdm` to show progress bars in this notebook.
It may not always show nicely, depending on your environment, or preferred ide/editor.
It should be purely cosmetic, so it shouldn't affect the code execution.

## Lightning Round!

[`lightning`](https://lightning.ai/docs/pytorch/stable/), originally PyTorch Lightning, is a python library build on top of `torch` that provides a simplified interface for training models.
Whereas previously we've been implementing our own training loops manually, `lightning` provides a lot of the boilerplate code for us.
Simply put, it codifies the common steps in training a model, so you can focus on the specifics of your model.

For example, here's a simple example of training a mockup model with `lightning`:

```python
import torch
import lightning as pl
from sklearn.utils.model_selection import train_test_split

class MyModel(pl.LightningModule):
    def __init__(
        self,
        loss_fn=torch.nn.functional.cross_entropy,
    ):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)
        self.loss_fn = loss_fn
    
    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

class MyDataModule(pl.LightningDataModule):
    def __init__(self, samples, batch_size=32):
        super().__init__()
        self.samples = samples
        self.batch_size = batch_size

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
    
    def setup(self, stage=None):
        train, test = train_test_split(self.samples, test_size=0.2)
        train, val = train_test_split(train, test_size=0.2)

        self.train_dataset = torch.utils.data.TensorDataset(*train)
        self.val_dataset = torch.utils.data.TensorDataset(*val)
        self.test_dataset = torch.utils.data.TensorDataset(*test)
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size)

model = MyModel()
datamodule = MyDataModule(samples)
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, datamodule=datamodule)
result = trainer.test(model, datamodule=datamodule)

print(result)
```

This is a very simple example, but it shows how `lightning` can simplify the process of training a model.
Note how the `MyModel` class is a subclass of `pl.LightningModule`, but seems very similar to the `nn.Module` we've been using.
The primary difference is that `lightning` provides a `[training/validation/testing]_step` method that is called for each batch of data, and a `configure_optimizers` method that returns the optimizer to use.
These are the functions that the `pl.Trainer` class uses to train the model.

The `MyDataModule` class is a subclass of `pl.LightningDataModule`, and provides a unified interface for loading data.
It has a `setup` method that is called before training, and `train/val/test_dataloader` methods that return the appropriate `DataLoader` for each stage of training.
This is, again, passed to the `pl.Trainer` class to train the model.

Finally, the `pl.Trainer` class is used to train the model.
It replaces the manual training loop we've been using, and provides a lot of useful features like early stopping, logging, and checkpointing.

Some people are not fans of `lightning`, as it can be a bit heavy-handed and abstract away too much of the details.
We'll leave it up to you to decide if it's right for you.

## MONAI

[`MONAI`](https://monai.io/) is a library for deep learning in medical imaging.
It provides a lot of the same functionality as `torch` and `lightning`, but is specifically tailored for medical imaging.
In particular, the `monai.data` module provides a lot of useful tools for loading and preprocessing medical imaging data, and the `monai.transforms` module provides a lot of useful tools for preprocessing medical imaging data.
A strong point in favour of `monai` is that it provides the dictionary transform paradigm, which allows you to apply a series of transforms to a dictionary of data.

What does that mean?

Well, let's say you have a dataset of medical images, and you want to apply a series of transforms, or augmentations, to each image.
You could manually write a series of functions that take an image as input and return an augmented image as output.
Or, you can set up a pipeline of transforms that take a dictionary of data as input and return a dictionary of data as output.

For example:

```python
from monai.data import CacheDataset
from monai import transforms as mt
from torch.utils.data import DataLoader

samples = [
    {"image": path_to_image1, "label": path_to_label1},
    {"image": path_to_image2, "label": path_to_label2},
    ...
]

transforms = mt.Compose([
    mt.LoadImaged(keys=["image", "label"]),
    mt.EnsureChannelFirstd(keys=["image"]),
    mt.NormalizeIntensityd(keys=["image"]),
    mt.Resized(keys=["image", "label"], spatial_size=(256, 256)),
    mt.RandRotated(
        keys=["image", "label"],
        prob=0.5,
        range_x[0, 360],
        range_y=[0, 360],
        range_z=[0, 360],
        mode=["bilinear", "nearest"]
    ),
    mt.ToTensord(keys=["image", "label"]),
])

dataset = CacheDataset(data=samples, transform=transforms)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
...
```

Note how the transforms can be composed into the `mt.Compose` transform, and then applied to the dataset.
This way you set up your pipeline once, including all of the paramaters for each transform, and then apply it to the dataset.

The `CacheDataset` class is a subclass of `torch.utils.data.Dataset`, and has a `transform` parameter that takes a series of transforms to apply to the data.
What sets the `CacheDataset` apart from the naive `torch` `Dataset` is that it caches the non-randomizable transformed data, so you only have to transform it once.
In our example code, for example, it would run the transforms up to the `Resized` transform once, and then cache the result.
In subsequent epochs, it would only run the `RandRotated` and `ToTensord` transforms.

## Interoperability

The presented libraries are not mutually exclusive, and can be used together.
However, they are also not nessecarily built to work together, so you may need to do some work to get them to play nice.

Let's go through the training of a toy model with the Medical Decathlon Spleen segmentation dataset.
The first version will be pure `torch`, the second version will use `monai` for data loading and preprocessing, and the third version will use `lightning` for training.

Let's download the dataset first:

In [None]:
import tarfile
from pathlib import Path

import requests
from tqdm.auto import tqdm

def download_medical_decathlon(url: str, path: Path):
    response = requests.get(url, stream=True)

    data_size = int(response.headers.get('Content-Length', 0))
    block_size = 1024

    with tqdm(total=data_size, unit='B', unit_scale=True, unit_divisor=1024) as pbar:
        with open(path, 'wb') as f:
            for data in response.iter_content(block_size):
                f.write(data)
                pbar.update(len(data))

def extract_tar(path: Path, dest: Path):
    with tarfile.open(path, 'r') as tar:
        tar.extractall(dest)


url = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
path = Path(r"./Data/Task09_Spleen.tar").resolve()
dest = Path(r"./Data/Task09_Spleen").resolve()
if not dest.exists():
    dest.mkdir(parents=True)
    download_medical_decathlon(url, path)
    extract_tar(path, dest)
else:
    print(f"{dest} already exists, skipping download and extraction")

First we'll tackle the model architecture.
We'll be using a [U-Net](https://arxiv.org/abs/1505.04597), this architecture has shot to fame since its introduction in 2015, forming the basis for many publications since.
While there's been a few variations, such as U-Net++, U-Net3+, and others, the basic form of U-Net has still been performing well to this day.
The very popular [nnUNet](https://www.nature.com/articles/s41592-020-01008-z), for example, still mainly uses the basic U-Net architecture.

We define our model based on the original paper, with some slight alterations to keep things simple.
First up, the double convolutions.

As per the article:

> The contracting path follows the typical architecture of a convolutional network. It consists of the repeated application of two 3x3 convolutions (unpadded convolutions), each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 for downsampling. At each downsampling step we double the number of feature channels.

We'll don't follow the exact specificiations, as we do use padding, to keep the spatial dimensions the same.
We can implement this as follows:

In [None]:
import torch
from torch import nn

class DoubleConv(nn.Module):
    def __init__(
        self,
        n_dims: int,
        in_channels: int,
        out_channels: int,
        use_normalization: bool = True,
    ) -> None:
        super().__init__()

        # This is a little Python trick to make the code more readable
        # You can set a variable to contain a function or class, and then call it later
        conv = None
        norm = None

        if n_dims == 2:
            conv = nn.Conv2d
            norm = nn.BatchNorm2d if use_normalization else nn.Identity
        elif n_dims == 3:
            conv = nn.Conv3d
            norm = nn.BatchNorm3d if use_normalization else nn.Identity
        else:
            raise ValueError("Invalid number of dimensions")
        
        layers = [
            conv(in_channels, out_channels, kernel_size=3, padding=1), # For kernel_size=3, padding=1, and the default stride=1, the output size equals the input size.        
            norm(out_channels),                                        # Does Normalization change the dimension of the image? 
            nn.ReLU(inplace=True),                                     # Does a ReLu activation layer change the dimension of the image?     
            conv(out_channels, out_channels, kernel_size=3, padding=1),                                         
            norm(out_channels),
            nn.ReLU(inplace=True),
        ]

        # This is a Python specific syntax to "unpack" the list of layers
        # and pass them as arguments to nn.Sequential
        self.double_conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.double_conv(x)
    
class EncoderBlock(nn.Module):
    def __init__(
        self,
        n_dims: int,
        in_channels: int,
        out_channels: int,
        use_normalization: bool = True,
    ) -> None:
        super().__init__()

        pool = None
        if n_dims == 2:
            pool = nn.MaxPool2d
        elif n_dims == 3:
            pool = nn.MaxPool3d
        else:
            raise ValueError("Invalid number of dimensions")

        self.encode = nn.Sequential(
            pool(kernel_size=2, stride=2),
            DoubleConv(n_dims, in_channels, out_channels, use_normalization),
        )

    def forward(self, x):
        return self.encode(x)

Now for the decoder part of the network.

As per the article:
> Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. The cropping is necessary due to the loss of border pixels in every convolution.

Again, we pad the convolutions to keep the spatial dimensions the same, so we don't need to crop the feature maps.

In [None]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        n_dims: int,
        in_channels: int,
        out_channels: int,
        use_transpose: bool = False,
        use_normalization: bool = True,
    ) -> None:
        super().__init__()

        # Our earlier trick to make the code more readable
        # Unfortunately doesn't work quite as nicely here because 
        # nn.Upsample and nn.ConvTranspose have different signatures
        conv = None

        if n_dims == 2:
            conv = nn.Conv2d

            # Two methods to upsample: transpose or with interpolation
            ## Upsample the spatial dimensions (height, width) and reduce the number of channels by half.
            if use_transpose:
                self.upsample = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)  # We cut the number of in-channels in half. How this operation affects the other dimensions?
            else:
                self.upsample = nn.Sequential(  # We want to the same image dimension as by using use_transpose
                    nn.Upsample(   # What is it changing here? 
                        scale_factor=2,
                        mode="bilinear",
                        align_corners=True
                    ),
                    conv(in_channels, in_channels // 2, kernel_size=1, padding=0),  # We performe a 1x1 convolution. Why is this useful? Did the image change shape?
                )
        elif n_dims == 3:
            conv = nn.Conv3d

            if use_transpose:
                self.upsample = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            else:
                self.upsample = nn.Sequential(
                    nn.Upsample(
                        scale_factor=2,
                        mode="trilinear",  # Why do we use trilinear interpolation here and not bilinear? What is the difference?
                        align_corners=True
                    ),
                    conv(in_channels, in_channels // 2, kernel_size=1, padding=0),
                )
        
        self.decode = DoubleConv(n_dims, in_channels, out_channels, use_normalization)

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        x = self.upsample(x)
        
        # This is the skip connection
        x = torch.cat((x, skip), dim=1)
        x = self.decode(x)
        return x

Now we can put it all together in the U-Net architecture.

In [None]:
class UNet(nn.Module):
    def __init__(
        self,
        n_dims: int,
        in_channels: int,
        out_channels: int,
        base_channels: int = 8,
        depth: int = 4,
        use_transpose: bool = False,
        use_normalization: bool = True,
        final_activation: nn.Module | None = None,
    ) -> None:
        super().__init__()

        self.n_dims = n_dims
        self.depth = depth

        if depth < 2:
            raise ValueError("Model depth must be 2 or greater")
        
        # Define the input layer
        layers = [DoubleConv(n_dims, in_channels, base_channels, use_normalization)]
        
        # Define the encoder path: it progressively doubles the number of channels
        current_features = base_channels
        for _ in range(depth - 1):
            layers.append(EncoderBlock(n_dims, current_features, current_features * 2, use_normalization))
            current_features *= 2

        # Define the decoder path: progressively halves the number of channels
        for _ in range(depth - 1):
            layers.append(DecoderBlock(n_dims, current_features, current_features // 2, use_transpose, use_normalization))
            current_features //= 2
        
        # Define the output layer
        if n_dims == 2:
            layers.append(nn.Conv2d(current_features, out_channels, kernel_size=1))
        elif n_dims == 3:
            layers.append(nn.Conv3d(current_features, out_channels, kernel_size=1))
        else:
            raise ValueError("Invalid number of dimensions")
        
        self.layers = nn.ModuleList(layers)
        if final_activation is not None:
            self.final_activation = final_activation
        else:
            self.final_activation = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        xi = [self.layers[0](x)]

        # Encoder path
        # Pretty simple, just loop over the encoder blocks
        for layer in self.layers[1:self.depth]:
            xi.append(layer(xi[-1]))
        
        # Decoder path
        # We need to loop over the decoder blocks, but also need to
        # keep track of the skip connections
        for i, layer in enumerate(self.layers[self.depth:-1]):
            xi[-1] = layer(xi[-1], xi[-2 - i])
        
        return self.final_activation(self.layers[-1](xi[-1]))

Done!

Now we can move on to the data loading and preprocessing.

Let's define the preprocessing pipeline first.
We'll keep it very simple here, just normalizing the intensities and a random crop.

Note that for testing you'll want a deterministic pipeline, so the random crop should be turned off.
For simplicity, we'll just use the same pipeline for both training and testing.

In [None]:
def preprocess(
    image: torch.Tensor,
    label: torch.Tensor | None = None,
    crop_size: tuple[int, ...] = (28, 28, 28)
) -> tuple[torch.Tensor, torch.Tensor]:
    # Normalize the image (Z-score normalization)
    image = (image - image.mean()) / image.std()
    
    # Add a channel dimension to the image
    image = image.unsqueeze(0)

    # Random crop
    crop_origin = [0, 0, 0]
    for dim in range(3):  # Remember, image.shape = [Ch, X, Y, Z ]
        max_value = image.shape[...] - crop_size[...]
        crop_origin[dim] = torch.randint(0, max_value, (1,)).item()
    
    image = image[
        ...,
        crop_origin[0]:crop_origin[0] + crop_size[0],
        crop_origin[1]:crop_origin[1] + crop_size[1],
        crop_origin[2]:crop_origin[2] + crop_size[2],
    ]
    
    # Add a channel dimension to the label
    if label is not None:
        label = label.unsqueeze(0)

        label = label[
            ...,
            crop_origin[0]:crop_origin[0] + crop_size[0],
            crop_origin[1]:crop_origin[1] + crop_size[1],
            crop_origin[2]:crop_origin[2] + crop_size[2],
        ]
    
    return image, label

As we are using the Medical Decathlon Spleen segmentation dataset, we'll need to define a custom `Dataset` class to load the data.

In [None]:
import SimpleITK as sitk
from torch.utils.data import DataLoader, Dataset


def collect_samples(root: Path, test_set: bool = False) -> list[tuple[Path, ...]]:
    """
    Collects the samples from the Medical Decathlon dataset.

    Parameters
    ----------
    root : Path
        The root directory of the dataset.
    test_set : bool
        Whether to collect the test set or the training set.

    Returns
    -------
    list[tuple[Path, ...]]
        A list of tuples containing the image and label paths.
    """
    # We need to find the files in the dataset
    # The dataset is structured as follows:
    # root
    # └── Task09_Spleen
    #     ├── imagesTr
    #     │   ├── spleen_1.nii.gz
    #     │   ├── spleen_2.nii.gz
    #     │   └── ...
    #     └── labelsTr
    #         ├── spleen_1.nii.gz
    #         ├── spleen_2.nii.gz
    #         └── ...

    if test_set:
        image_dir = root / "imagesTs"
        label_dir = None
    else:
        image_dir = root / "imagesTr"
        label_dir = root / "labelsTr"

    if not image_dir.exists():
        raise ValueError(f"Could not find dataset in {root}")
    
    samples = []
    images = [x for x in image_dir.iterdir() if x.is_file()]
    for image in images:
        label = None
        if label_dir is not None:
            label = label_dir / image.name
        
        samples.append((image, label))
    
    return samples


class MedicalDecathlonDataset(Dataset):
    def __init__(self, samples: list[tuple[Path, ...]], test: bool = False) -> None:
        self.samples = samples
        
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
        image_path, label_path = self.samples[idx]
        
        image = sitk.ReadImage(image_path)
        image = torch.tensor(sitk.GetArrayFromImage(image)).float() # Convert image to PyTorch tensor and cast it to float

        if label_path is None:
           image, _ = preprocess(image)
           return image

        label = sitk.ReadImage(label_path)
        label = torch.tensor(sitk.GetArrayFromImage(label)).float()

        return preprocess(image, label)


Now we can do our old reliable training loop.

This might take a while, we didn't do much to really speed things up.

In [None]:
n_epochs = 3
batch_size = 4
learning_rate = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(n_dims=3, in_channels=1, out_channels=1, depth=3).to(device)  # 3D U-Net. How many channels?
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

root = Path(r"./Data/Task09_Spleen/Task09_Spleen/").resolve()
samples = collect_samples(root)

# We're just doing a naive split here, you might want to do something more sophisticated
train_samples = samples[:int(len(samples) * 0.8)]  # from 0 to X
val_samples = samples[int(len(samples) * 0.8):]    # From X to the end

train_ds = MedicalDecathlonDataset(train_samples)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

val_ds = MedicalDecathlonDataset(val_samples)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

for epoch in (prog_bar := tqdm(range(n_epochs), desc="Training", unit="epoch", total=n_epochs, position=0)):

    model.train() # We set the model in training mode
    train_losses = []   # List to track training losses
    prog_bar.set_description(f"Training Loop")

    for i, (image, label) in tqdm(enumerate(train_dl), total=len(train_dl), desc="Training", unit="batch", position=1, leave=False):
        image, label = image.to(device), label.to(device)

        optimizer.zero_grad() # Clear gradients
        output = model(image) # Model forward pass
        loss = loss_fn(output, label) # Compute loss
        loss.backward() # Backpropagate loss
        optimizer.step() # Update model weights

        train_losses.append(loss.item()) # Append training loss for this batch
    
    prog_bar.set_postfix({"Training loss": sum(train_losses) / len(train_losses)})

    prog_bar.set_description(f"Validation Loop")
     
    model.eval() # We set the model in evaluation mode
    val_losses = [] 
    for i, (image, label) in tqdm(enumerate(val_dl), total=len(val_dl), desc="Validation", unit="batch", position=1, leave=False):
        image, label = image.to(device), label.to(device)

        with torch.no_grad():
            output = model(image)
            loss = loss_fn(output, label)
        
        val_losses.append(loss.item())
    
    prog_bar.set_postfix({"Training loss": sum(train_losses) / len(train_losses), "Validation loss": sum(val_losses) / len(val_losses)})

# Monai Integration

Now that we've trained a model using nothing but base `torch`, let's see how we can improve things with `monai`.

First, let's redefine our sample collection code to be more `monai`-friendly.
Because we want to use the dictionary transforms, our samples need to be dictionaries with keys for each data item.
So, we'll define a function to load the data and return a dictionary.

In our implementation, we'll standardize on the keys `"image"` and `"label"` for the image and label data, respectively.
Though you can use any keys you like, as long as they match the keys in the transforms.

In [None]:
from pathlib import Path


def collect_samples(root: Path, is_test: bool = False) -> list[dict[str, Path]]:
    """
    Collects the samples from the Medical Decathlon dataset.

    Parameters
    ----------
    root : Path
        The root directory of the dataset.
    is_test : bool
        Whether to collect the test set or the training set.

    Returns
    -------
    list[dict[str, Path]]
        A list of dictionaries containing the image and label paths.
    """

    if is_test:
        image_dir = root / "imagesTs"
        label_dir = None
    else:
        image_dir = root / "imagesTr"
        label_dir = root / "labelsTr"

    if not image_dir.exists():
        raise ValueError(f"Could not find dataset in {root}")
    
    samples = []
    images = [x for x in image_dir.iterdir() if x.is_file()]
    for image in images:
        sample = {"image": image}
        if label_dir is not None:
            sample["label"] = label_dir / image.name
        
        samples.append(sample)
    
    return samples

Next up, the transforms.

This is pretty straightforward, you define each step sequentially, and provide the transform with the keys it should operate on.
Let's also take the opportunity to define a separate pipeline for training and inference.

Finally, we pass our transforms to a `CacheDataset`, which will apply the transforms to the data and cache the non-random results.
Note that the `CacheDataset` might take a while to run the first time, as it applies all of the transforms to the data.
Additionally, your laptop might not have enough memory to cache all of the data.
If that is the case, you can also import the `Dataset` class from `monai.data` and use that instead.
Be aware that this is _not_ the same as the `torch` `Dataset` class, but is very similar.

In [None]:
from monai import transforms as mt
from monai.data import CacheDataset

train_transforms = mt.Compose([
    mt.LoadImaged(keys=["image", "label"]),
    mt.EnsureChannelFirstd(keys=["image", "label"]),
    mt.NormalizeIntensityd(keys=["image"]),
    mt.RandSpatialCropd(keys=["image", "label"], roi_size=[28, 28, 28]),
    # You can add more transforms here! See
    # https://docs.monai.io/en/stable/transforms.html#dictionary-transforms
])

val_transforms = mt.Compose([
    mt.LoadImaged(keys=["image", "label"]),
    mt.EnsureChannelFirstd(keys=["image", "label"]),
    mt.NormalizeIntensityd(keys=["image"]),
])

root = Path(r"./Data/Task09_Spleen/Task09_Spleen/").resolve()
samples = collect_samples(root)

train_samples = samples[:int(len(samples) * 0.8)]
val_samples = samples[int(len(samples) * 0.8):]

# You might get an error here, make sure you install the required extra dependencies
# https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
train_ds = CacheDataset(data=train_samples, transform=train_transforms)
val_ds = CacheDataset(data=val_samples, transform=val_transforms)

From here, it's just a matter of training the model as before.

In [None]:
from monai.inferers import sliding_window_inference

n_epochs = 3
batch_size = 4
sw_batch_size = 16
learning_rate = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(
    n_dims=3,
    in_channels=1,
    out_channels=1,
    depth=3,
    final_activation=nn.Sigmoid(),
).to(device)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# We use the CacheDatasets here
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

# We're doing a batch size 1 here because we're using sliding window inference
# and not all images are the same size, which is a requirement for batching
val_dl = DataLoader(val_ds, batch_size=1, shuffle=False)

for epoch in (prog_bar := tqdm(range(n_epochs), desc="Epochs", unit="epoch", total=n_epochs, position=0)):
    prog_bar.set_description("Training Loop")
    model.train()
    train_losses = []
    for i, batch in tqdm(enumerate(train_dl), total=len(train_dl), desc="Training", unit="batch", position=1, leave=False):
        # Note that we're using the batch dictionary here
        # Make sure to adjust the keys if you changed them
        image = batch["image"].to(device)
        label = batch["label"].to(device)

        optimizer.zero_grad()
        output = model(image)
        loss = loss_fn(output, label)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
    
    prog_bar.set_postfix({"Training loss": sum(train_losses) / len(train_losses)})

    prog_bar.set_description("Validation Loop")
    model.eval()
    val_losses = []
    for i, batch in tqdm(enumerate(val_dl), total=len(val_dl), desc="Validation", unit="batch", position=1, leave=False):
        image = batch["image"].to(device)
        label = batch["label"].to(device)

        with torch.no_grad():
            # We're using the sliding window inference here
            # This is a simple way to handle larger images
            # without having to define a custom collate function
            output = sliding_window_inference(
                image,
                (28, 28, 28),
                sw_batch_size,
                model,
                overlap=0.0,  # No overlap between patches
                mode="constant",  # Overlap merging strategy
                progress=True,
            )
            loss = loss_fn(output, label)
        
        val_losses.append(loss.item())
    
    prog_bar.set_postfix({"Training loss": sum(train_losses) / len(train_losses), "Validation loss": sum(val_losses) / len(val_losses)})

Sweet!
Note how the training section is also much faster now, as the data is preprocessed and cached.
Overall the total time to train will probably be a bit longer, as we're doing a full prediction on the validation set after each epoch instead of just grabbing a small patch.

# Lightning Integration

Now, finally, we will include `lightning` into the mix.
While it is not specifically designed to work with `monai`, them both being built on top of `torch` means that combining the two is relatively straightforward.

First up, let's take our model, and convert it to a `lightning` model.
There's a few ways we could tackle this.
We could encapsulate, or wrap, our model as an element of a `lightning.LightningModule` class.
We could also rewrite our original `UNet` class to inherit from `lightning.LightningModule`.
Alternatively, we could inherit from both `UNet` and `lightning.LightningModule`, but that might be a bit overkill.

For simplicity, we'll go with the first option.

In [None]:
# pl is the old way to shorten pytorch_lightning
# I'm using it here because it is what I'm used to
# But the new way is to
# import lightning as L
import lightning as pl


class LightningUNet(pl.LightningModule):
    def __init__(
        self,
        n_dims: int,
        in_channels: int,
        out_channels: int,
        base_channels: int = 8,
        depth: int = 4,
        use_transpose: bool = False,
        use_normalization: bool = True,
        final_activation: nn.Module | None = None,
        input_size: tuple[int, ...] = (28, 28, 28),
        lr: float = 1e-3,
    ) -> None:
        super().__init__()

        if n_dims not in (2, 3):
            raise ValueError("Invalid number of dimensions")
        
        if n_dims != len(input_size):
            raise ValueError("Input size must match number of dimensions")

        # This is for logging hyperparameters
        self.save_hyperparameters()
        # This is used in the ModelSummary callback, not mandatory but nice to have
        self.example_input_array = torch.rand(1, in_channels, *input_size)

        self.model = UNet(
            n_dims=n_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            base_channels=base_channels,
            depth=depth,
            use_transpose=use_transpose,
            use_normalization=use_normalization,
            final_activation=final_activation,
        )

        self.loss_fn = nn.BCEWithLogitsLoss()
        self.lr = lr
        self.input_size = input_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)
    
    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, torch.Tensor]:
        # Note how we are using the monai batch dictionary here
        # If you weren't using monai, you would have to adjust this
        # to match your own batch structure
        image = batch["image"]
        label = batch["label"]

        output = self(image)
        loss = self.loss_fn(output, label)

        log = {
            "train_loss": loss,
            # You can log more things here, for example metrics
        }

        # self.log() and self.log_dict() are used for making sure
        # that the current progress is displayed properly, and written
        # by the TensorboardLogger. the use of either is not mandatory
        self.log_dict(log, prog_bar=True, on_epoch=True, on_step=True)

        # training_step() has to return either:
        #   - A scalar of the loss
        #   - A dictionary, with the key "loss" present
        # this return is used by Lightning internally
        return loss
    
    def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
        image = batch["image"]
        label = batch["label"]

        output = sliding_window_inference(
            image,
            self.input_size,
            sw_batch_size,
            model,
            overlap=0.0,  # No overlap between patches
            mode="constant",  # Overlap merging strategy
        )
        loss = self.loss_fn(output, label)

        log = {
            "val_loss": loss,
            # You can log more things here, for example metrics
        }

        self.log_dict(log, prog_bar=True, on_epoch=True)

        # validation_step() does _not_ require a return,
        # in fact, you don't even need to use the same
        # loss function as you do in training_step().

    def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
        # This is pretty much the same as validation_step
        # but is used for the test set

        image = batch["image"]
        label = batch["label"]

        output = sliding_window_inference(
            image,
            self.input_size,
            sw_batch_size,
            model,
            overlap=0.0,  # No overlap between patches
            mode="constant",  # Overlap merging strategy
        )
        loss = self.loss_fn(output, label)

        log = {
            "test_loss": loss,
            # You can log more things here, for example metrics
        }

        self.log_dict(log, prog_bar=True, on_epoch=True)

Now we can define a `DataModule` class to handle the data loading and preprocessing.

In [None]:
class DecathlonDataModule(pl.LightningDataModule):
    def __init__(
        self,
        root: Path,
        batch_size: int = 4,
        split: float = 0.8,
        patch_size: tuple[int, ...] = (28, 28, 28),
    ) -> None:
        super().__init__()

        self.root = root
        self.batch_size = batch_size
        self.split = split

        self.train_set = None
        self.val_set = None
        self.test_set = None

        self.train_transforms = mt.Compose([
            mt.LoadImaged(keys=["image", "label"]),
            mt.EnsureChannelFirstd(keys=["image", "label"]),
            mt.NormalizeIntensityd(keys=["image"]),
            mt.RandSpatialCropd(keys=["image", "label"], roi_size=patch_size),
        ])

        # We're turning on allow_missing_keys here
        # This is because our toy test set doesn't have labels and we want to be able to run
        # the training.
        self.val_transforms = mt.Compose([
            mt.LoadImaged(keys=["image", "label"], allow_missing_keys=True),
            mt.EnsureChannelFirstd(keys=["image", "label"], allow_missing_keys=True),
            mt.NormalizeIntensityd(keys=["image"]),
        ])

    def setup(self, stage: str | None = None) -> None:
        samples = collect_samples(self.root)
        train_samples = samples[:int(len(samples) * self.split)]
        val_samples = samples[int(len(samples) * self.split):]
        test_samples = collect_samples(self.root, is_test=True)

        self.train_set = CacheDataset(data=train_samples, transform=train_transforms)
        self.val_set = CacheDataset(data=val_samples, transform=val_transforms)
        self.test_set = CacheDataset(data=test_samples, transform=val_transforms)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.val_set, batch_size=1, shuffle=False)
    
    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.test_set, batch_size=1, shuffle=False)

Now it's just a matter of training the model.

Lightning has the concept of a `Trainer` class, which is used to train the model.
It handles all of the boilerplate code for training, including logging, checkpointing, and early stopping.
Let's train a model using the `Trainer` class.

In [None]:
from lightning.pytorch import callbacks as plc
from lightning.pytorch import loggers

patch_size = (28, 28, 28)

trainer = pl.Trainer(
    max_epochs=3,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    precision="16-mixed",
    callbacks=[
        plc.ModelSummary(max_depth=3),
        plc.EarlyStopping(monitor="val_loss"),
    ],
    log_every_n_steps=1,
)

model = LightningUNet(
    n_dims=3,
    in_channels=1,
    out_channels=1,
    base_channels=8,
    depth=3,
    final_activation=nn.Sigmoid(),
    lr=1e-3,
    input_size=patch_size,
)
datamodule = DecathlonDataModule(
    root=root,
    batch_size=4,
    patch_size=patch_size,
)

trainer.fit(model, datamodule=datamodule)

Just for completeness, let's plot a prediction.
Don't expect too much performance, as we're only training for a few epochs, and we've not done much in the way of hyperparameter tuning.

In [None]:
import matplotlib.pyplot as plt

test_set = datamodule.test_dataloader()
test_set = iter(test_set)

model.eval()
with torch.no_grad():
    model = model.to("cpu")

batch = next(test_set)
image = batch["image"]

output = sliding_window_inference(
    image,
    model.input_size,
    sw_batch_size,
    model,
    overlap=0.0,  # No overlap between patches
    mode="constant",  # Overlap merging strategy
    progress=True,
)

image = image.squeeze().numpy()
output = output.detach().squeeze().numpy()

middle_slice = image.shape[-1] // 2

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image[..., middle_slice], cmap="gray")
ax[0].set_title("Input")
ax[0].axis("off")

ax[1].imshow(output[..., middle_slice], cmap="gray")
ax[1].set_title("Output")
ax[1].axis("off")

plt.tight_layout()
plt.show()

And that's it!
We've trained a model using `lightning` and `monai` together.
For your own assignments, you can decide which library to use, or mix and match as you see fit.

Also consider that there are many other libraries out there that can help you with your machine learning projects.
For example, `torchmetrics` might have some more advanced metrics, `monai` itself has a library of common models, losses, and metrics, and `pytorch-lightning` has a lot of advanced features like distributed training and logging.

Good luck with your projects!