# Medical Image Segmentation Using 🤗 HuggingFace & PyTorch

Medical image segmentation is an innovative process that enables surgeons to have a virtual "x-ray vision." It is a highly valuable tool in healthcare, providing non-invasive diagnostics and in-depth analysis. With this in mind, in this post, we will explore the UW-Madison GI Tract Image Segmentation Kaggle challenge dataset. As part of this project, we will utilize PyTorch along with PyTorch-Lightning. We will use 🤗 HuggingFace transformers to load and fine-tune the Segformer transformer-based model on the medical segmentation dataset. Finally, we will create a Gradio app for image inference and deploy it on HuggingFace spaces.

<img src="https://learnopencv.com/wp-content/uploads/2023/07/medical-image-segmentation_feature_Image.png">

## Table of Contents


* [1 Install & Import Necessary Libraries](#1-Install-&-Import-Necessary-Libraries)
* [2 Set Hyperparameters For The Project](#2-Set-Hyperparameters-For-The-Project)
* [3 Loading The Medical Image Segmentation Dataset](#3-Loading-The-Medical-Image-Segmentation-Dataset)
    * [3.1 Defining A Custom PyTorch Dataset Class For Medical Image Segmentation](#3.1-Defining-A-Custom-PyTorch-Dataset-Class-For-Medical-Image-Segmentation)
    * [3.2 Defining The Custom LightningDataModule Class](#3.2-Defining-The-Custom-LightningDataModule-Class)
    * [3.3 Visualization Helper Functions](#3.3-Visualization-Helper-Functions)
    * [3.4 Display Sample Images From The Dataset](#3.4-Display-Sample-Images-From-The-Dataset)
* [4 Loading SegFormer From 🤗 HuggingFace](#4-Loading-SegFormer-From-🤗-HuggingFace)
* [5 Evaluation Metric & Loss Function](#5-Evaluation-Metric--&-Loss-Function)
    * [5.1 Custom Loss Functions - Smooth Dice + Cross-Entropy](#5.1-Custom-Loss-Functions---Smooth-Dice-+-Cross-Entropy)
    * [5.2 Evaluation Metric - Dice Coefficient (F1-Score)](#5.2-Evaluation-Metric---Dice-Coefficient-(F1-Score))
* [6 Creating The Custom LightningModule Class](#6-Creating-The-Custom-LightningModule-Class)
* [7 Start Training](#7-Start-Training)
* [8 Inference on the Medical Segmentation Dataset](#8-Inference-on-the-Medical-Segmentation-Dataset)
    * [8.1 Load The Best Trained Model](#8.1-Load-The-Best-Trained-Model)
    * [8.2 Evaluate Model On Validation Dataset](#8.2-Evaluate-Model-On-Validation-Dataset)
    * [8.3 Image Inference Using DataLoader Objects](#8.3-Image-Inference-Using-DataLoader-Objects)
* [9 Summary](#9-Summary)


## What is Medical Image Segmentation?

Medical image segmentation is a process that involves dividing medical images, such as CT scans or MRI scans, into distinct regions or structures of interest. This technique is used to identify and isolate specific areas within the image, which is crucial for diagnosis, treatment planning, and monitoring of diseases. It can be done manually by experts or automated using computer algorithms and machine learning. Medical image segmentation plays a vital role in various medical specialties and enables quantitative analysis and precise measurements.

The dataset for this project is taken from the <a href="https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/overview" target="_blank">UW-Madison GI Tract Image Segmentation</a> Kaggle competition. The dataset consists of 3 classes: the stomach, small bowel, and large bowel.

<img src="https://learnopencv.com/wp-content/uploads/2023/07/medical-image-segmentation_competition_dataset_example.png">

**Note:** In this notebook, we'll work with the final processed dataset.

In [None]:
!nvidia-smi

## 1 Install & Import Necessary Libraries

Before we begin the coding part, we need to ensure we have all the required libraries installed. For this project, apart from PyTorch, we are installing additional tools to help ease the implementation process. 

The major ones are:

<img src="https://learnopencv.com/wp-content/uploads/2023/07/medical-image-segmentation_tool_logos.png">

1. `transformers`: To load SegFormer transformer model.
2. `lightning`: To simplify and structure code implementations.
3. `torchmetrics`: For evaluating the model's performance.
4. `wandb`: For experiment tracking. 
5. `albumentations`:  For applying augmentations. 

In [None]:
# Install libraries and restart kernel.
%pip install -qqqU wandb transformers lightning albumentations torchmetrics torchinfo
%pip install -qqq requests gradio
%pip install -qqq ipywidgets chardet charset-normalizer==3.1.0

In [None]:
import os
import zipfile
import platform
import warnings
from glob import glob
from dataclasses import dataclass

# To filter UserWarning.
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
import cv2
import requests
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# For data augmentation and preprocessing.
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Imports required SegFormer classes
from transformers import SegformerForSemanticSegmentation

# Importing lighting along with a built-in callback it provides.
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

# Importing torchmetrics modular and functional implementations.
from torchmetrics import MeanMetric
from torchmetrics.classification import MulticlassF1Score

# To print model summary.
from torchinfo import summary

In [None]:
# Sets the internal precision of float32 matrix multiplications.
torch.set_float32_matmul_precision('high')

# To enable determinism.
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

# To render the matplotlib figure in the notebook.
%matplotlib inline

For this project, instead of the default tensorboard used by pytorch-lightning for tracking experiments, we will use a proper MLOps tool: Weights & Biases (WandB). 

The following code cell will help us to log into our `wandb` account.

In [None]:
import wandb

wandb.login()

The code cell will ask you to paste your API key in the dialogue box. You need to click on the <a href="https://wandb.ai/authorize" target="_blank">Sign In with Auth0</a> link provided.

## 2 Set Hyperparameters For The Project

Next, we will declare all the different hyperparameters used for the project. For this, we are defining three dataclasses. They will be used throughout the notebook.


1. `DatasetConfig`  – A class that holds all the hyperparameters we will use to process images. It contains the following information:
    1. Image size to use.
    2. Number of classes present in the dataset,
    3. The mean and standard deviation to use for image normalization.
    4. URL of the preprocessed dataset.
    5. Directory path to download the dataset to. 
    
2. `Paths` – This class contains the locations of the images and masks of the train and validation sets. It uses the “root dataset path”  set DatasetConfig as the base.

3. `TrainingConfig` –  A class that holds all the hyperparameters we will use for training and evaluation.  It contains the following information:
    1. Batch size.
    2. Initial learning rate.
    3. The number of epochs to train the model.
    4. The number of workers to use for data loading.
    4. Model, optimizer & learning rate scheduler-related configurations.

4. `InferenceConfig` – This class contains the (optional) batch size and the number of batches we will use to display our inference results at the end.

Note: We’ve uploaded the preprocessed dataset to our Dropbox and <a href="https://www.kaggle.com/datasets/learnopencvblog/uwm-gi-tract-segmentation-img-msk-split" target="_blank">Kaggle</a> accounts. There are two options. You can manually download the dataset and move it to your workstation or utilize the data download code we’ve written below to do it automatically.

In [None]:
@dataclass(frozen=True)
class DatasetConfig:
    NUM_CLASSES:   int = 4 # including background.
    IMAGE_SIZE: tuple[int,int] = (288, 288) # W, H
    MEAN: tuple = (0.485, 0.456, 0.406)
    STD:  tuple = (0.229, 0.224, 0.225)
    BACKGROUND_CLS_ID: int = 0
    URL: str = r"https://www.dropbox.com/scl/fi/r0685arupp33sy31qhros/dataset_UWM_GI_Tract_train_valid.zip?rlkey=w4ga9ysfiuz8vqbbywk0rdnjw&dl=1"
    DATASET_PATH: str = os.path.join(os.getcwd(), "dataset_UWM_GI_Tract_train_valid")

@dataclass(frozen=True)
class Paths:
    DATA_TRAIN_IMAGES: str = os.path.join(DatasetConfig.DATASET_PATH, "train", "images", r"*.png")
    DATA_TRAIN_LABELS: str = os.path.join(DatasetConfig.DATASET_PATH, "train", "masks",  r"*.png")
    DATA_VALID_IMAGES: str = os.path.join(DatasetConfig.DATASET_PATH, "valid", "images", r"*.png")
    DATA_VALID_LABELS: str = os.path.join(DatasetConfig.DATASET_PATH, "valid", "masks",  r"*.png")
        
@dataclass
class TrainingConfig:
    BATCH_SIZE:      int = 48 # 8
    NUM_EPOCHS:      int = 100
    INIT_LR:       float = 3e-4
    NUM_WORKERS:     int = 0 if platform.system() == "Windows" else os.cpu_count()

    OPTIMIZER_NAME:  str = "AdamW"
    WEIGHT_DECAY:  float = 1e-4
    USE_SCHEDULER:  bool = True # Use learning rate scheduler?
    SCHEDULER:       str = "MultiStepLR" # Name of the scheduler to use.
    MODEL_NAME:str = "nvidia/segformer-b4-finetuned-ade-512-512" 
    

@dataclass
class InferenceConfig:
    BATCH_SIZE:  int = 10
    NUM_BATCHES: int = 2

## 3 Loading The Medical Image Segmentation Dataset

**Set class ID to RGB color mapping and vice versa.**

In [None]:
# Create a mapping of class ID to RGB value.
id2color = {
    0: (0, 0, 0),    # background pixel
    1: (0, 0, 255),  # Stomach
    2: (0, 255, 0),  # Small Bowel
    3: (255, 0, 0),  # large Bowel
}


DatasetConfig.NUM_CLASSES = len(id2color)

print("Number of classes", DatasetConfig.NUM_CLASSES)

# Reverse id2color mapping.
# Used for converting RGB mask to a single channel (grayscale) representation.
rev_id2color = {value: key for key, value in id2color.items()}

### 3.1 Defining A Custom PyTorch Dataset Class For Medical Image Segmentation

First, we will define our custom PyTorch `Dataset` class. This custom is designed to load images and masks for each image. The `Dataset` class is essential for efficient and organized data handling in machine learning tasks. It provides a standardized interface to load and preprocess data samples from various sources. Encapsulating the dataset into a single object simplifies data management. It enables seamless integration with other PyTorch components like data loaders and models. 

The custom class performs the following functions:

1. Load each image-mask pair.
2. Apply geometric and pixel augmentations if the pair belongs to the training set.
3. Apply preprocessing transformations such as normalization and standardization.

In [None]:
# Custom Class for creating training and validation (segmentation) dataset objects.

class MedicalDataset(Dataset):
    def __init__(self, *, image_paths, mask_paths, img_size, ds_mean, ds_std, is_train=False):
        self.image_paths = image_paths
        self.mask_paths  = mask_paths  
        self.is_train    = is_train
        self.img_size    = img_size
        self.ds_mean = ds_mean
        self.ds_std = ds_std
        self.transforms  = self.setup_transforms(mean=self.ds_mean, std=self.ds_std)

    def __len__(self):
        return len(self.image_paths)

    def setup_transforms(self, *, mean, std):
        transforms = []

        # Augmentation to be applied to the training set.
        if self.is_train:
            transforms.extend([
                A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5),
                A.ShiftScaleRotate(scale_limit=0.12, rotate_limit=0.15, shift_limit=0.12, p=0.5),
                A.RandomBrightnessContrast(p=0.5),
                A.CoarseDropout(max_holes=8, max_height=self.img_size[1]//20, max_width=self.img_size[0]//20, min_holes=5, fill_value=0, mask_fill_value=0, p=0.5)
            ])

        # Preprocess transforms - Normalization and converting to PyTorch tensor format (HWC --> CHW).
        transforms.extend([
                A.Normalize(mean=mean, std=std, always_apply=True),
                ToTensorV2(always_apply=True),  # (H, W, C) --> (C, H, W)
        ])
        return A.Compose(transforms)

    def load_file(self, file_path, depth=0):
        file = cv2.imread(file_path, depth)
        if depth == cv2.IMREAD_COLOR:
            file = file[:, :, ::-1]
        return cv2.resize(file, (self.img_size), interpolation=cv2.INTER_NEAREST)

    def __getitem__(self, index):
        # Load image and mask file.
        image = self.load_file(self.image_paths[index], depth=cv2.IMREAD_COLOR)
        mask  = self.load_file(self.mask_paths[index],  depth=cv2.IMREAD_GRAYSCALE)
        
        # Apply Preprocessing (+ Augmentations) transformations to image-mask pair
        transformed = self.transforms(image=image, mask=mask)
        image, mask = transformed["image"], transformed["mask"].to(torch.long)
        return image, mask

### 3.2 Defining The Custom LightningDataModule Class

In this section, we will define the custom `MedicalSegmentationDataModule` class inherited from Lightning’s `LightningDataModule` class. It helps organize and encapsulate all the data-related operations and logic in a PyTorch project. It acts as a bridge between your data and Lightning’s training pipeline. It is a convenient abstraction that encapsulates data-related operations, promotes code organization, and facilitates seamless integration with other Lightning components for efficient and reproducible deep-learning experiments.

The class will perform the following functions:

1. Download the dataset from Dropbox.
2. Create a MedicalDataset class object for each set.
3. Create and return the DataLoader objects for each set.


The class methods we need to define are as follows:

1. `prepare_data(..)`: This method is used for data preparation, like downloading and one-time preprocessing with the dataset. When training in a distributed setting, this will be called from each GPU machine.
2. `setup(...)`:  When you want to perform data operations on every GPU, this method is apt for it will call from every GPU. For example, perform train/val/test splits.
3. `train_dataloader(...)`: This method returns the train dataloader.
4. `val_dataloader(...)` : This method returns validation dataloader.

In [None]:
class MedicalSegmentationDataModule(pl.LightningDataModule):
    def __init__(
        self,
        num_classes=10,
        img_size=(384, 384),
        ds_mean=(0.485, 0.456, 0.406),
        ds_std=(0.229, 0.224, 0.225),
        batch_size=32,
        num_workers=0,
        pin_memory=False,
        shuffle_validation=False,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.img_size    = img_size
        self.ds_mean     = ds_mean
        self.ds_std      = ds_std
        self.batch_size  = batch_size
        self.num_workers = num_workers
        self.pin_memory  = pin_memory
        
        self.shuffle_validation = shuffle_validation

    def prepare_data(self):
        # Download dataset.
        dataset_zip_path = f"{DatasetConfig.DATASET_PATH}.zip"

        # Download if dataset does not exists.
        if not os.path.exists(DatasetConfig.DATASET_PATH):

            print("Downloading and extracting assets...", end="")
            file = requests.get(DatasetConfig.URL)
            open(dataset_zip_path, "wb").write(file.content)

            try:
                with zipfile.ZipFile(dataset_zip_path) as z:
                    z.extractall(os.path.split(dataset_zip_path)[0]) # Unzip where downloaded.
                    print("Done")
            except:
                print("Invalid file")

            os.remove(dataset_zip_path) # Remove the ZIP file to free storage space.

    def setup(self, *args, **kwargs):
        # Create training dataset and dataloader.
        train_imgs = sorted(glob(f"{Paths.DATA_TRAIN_IMAGES}"))
        train_msks  = sorted(glob(f"{Paths.DATA_TRAIN_LABELS}"))

        # Create validation dataset and dataloader.
        valid_imgs = sorted(glob(f"{Paths.DATA_VALID_IMAGES}"))
        valid_msks = sorted(glob(f"{Paths.DATA_VALID_LABELS}"))

        self.train_ds = MedicalDataset(image_paths=train_imgs, mask_paths=train_msks, img_size=self.img_size,  
                                       is_train=True, ds_mean=self.ds_mean, ds_std=self.ds_std)

        self.valid_ds = MedicalDataset(image_paths=valid_imgs, mask_paths=valid_msks, img_size=self.img_size, 
                                       is_train=False, ds_mean=self.ds_mean, ds_std=self.ds_std)

    def train_dataloader(self):
        # Create train dataloader object with drop_last flag set to True.
        return DataLoader(
            self.train_ds, batch_size=self.batch_size,  pin_memory=self.pin_memory, 
            num_workers=self.num_workers, drop_last=True, shuffle=True
        )    

    def val_dataloader(self):
        # Create validation dataloader object.
        return DataLoader(
            self.valid_ds, batch_size=self.batch_size,  pin_memory=self.pin_memory, 
            num_workers=self.num_workers, shuffle=self.shuffle_validation
        )

**Usage**: Let's download the dataset and initialize train and validation data loaders. We’ll use them to visualize the dataset.

In [None]:
%%time

dm = MedicalSegmentationDataModule(
    num_classes=DatasetConfig.NUM_CLASSES,
    img_size=DatasetConfig.IMAGE_SIZE,
    ds_mean=DatasetConfig.MEAN,
    ds_std=DatasetConfig.STD,
    batch_size=InferenceConfig.BATCH_SIZE,
    num_workers=0,
    shuffle_validation=True,
)

# Donwload dataset.
dm.prepare_data()

# Create training & validation dataset.
dm.setup()

train_loader, valid_loader = dm.train_dataloader(), dm.val_dataloader()

### 3.3 Visualization Helper Functions

To help visualize our dataset, we need to define some additional helper functions. They are as follows:

A) `num_to_rgb(...)`: Function will be used to convert single-channel mask representations to an integrated RGB mask for visualization purposes

In [None]:
def num_to_rgb(num_arr, color_map=id2color):
    single_layer = np.squeeze(num_arr)
    output = np.zeros(num_arr.shape[:2] + (3,))
 
    for k in color_map.keys():
        output[single_layer == k] = color_map[k]
 
    # return a floating point array in range [0.0, 1.0]
    return np.float32(output) / 255.0

B) `image_overlay(...)`: This function overlays an RGB segmentation map on top of an RGB image.

In [None]:
# Function to overlay a segmentation map on top of an RGB image.
def image_overlay(image, segmented_image):
    alpha = 1.0  # Transparency for the original image.
    beta = 0.7  # Transparency for the segmentation map.
    gamma = 0.0  # Scalar added to each sum.

    segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)

    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    image = cv2.addWeighted(image, alpha, segmented_image, beta, gamma, image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    return np.clip(image, 0.0, 1.0)

C) `display_image_and_mask(...)`: The convenience function below will display the original image, the ground truth mask, and the ground truth mask overlayed on the original image.

In [None]:
def display_image_and_mask(*, images, masks, color_map=id2color):
    title = ["GT Image", "Color Mask", "Overlayed Mask"]

    for idx in range(images.shape[0]):
        image = images[idx]
        grayscale_gt_mask = masks[idx]

        fig = plt.figure(figsize=(15, 4))

        # Create RGB segmentation map from grayscale segmentation map.
        rgb_gt_mask = num_to_rgb(grayscale_gt_mask, color_map=color_map)

        # Create the overlayed image.
        overlayed_image = image_overlay(image, rgb_gt_mask)

        plt.subplot(1, 3, 1)
        plt.title(title[0])
        plt.imshow(image)
        plt.axis("off")

        plt.subplot(1, 3, 2)
        plt.title(title[1])
        plt.imshow(rgb_gt_mask)
        plt.axis("off")

        plt.imshow(rgb_gt_mask)
        plt.subplot(1, 3, 3)
        plt.title(title[2])
        plt.imshow(overlayed_image)
        plt.axis("off")

        plt.tight_layout()
        plt.show()

    return

D) `denormalize(...)`: This function is used to denormalize the image tensors and clip values between `0` and `1`. It is used to denormalize the images for visualization.

In [None]:
def denormalize(tensors, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    for c in range(3):
        tensors[:, c, :, :].mul_(std[c]).add_(mean[c])

    return torch.clamp(tensors, min=0.0, max=1.0)

### 3.4 Display Sample Images From The Dataset

In the code cell below, we loop over the first batch in the validation dataset and display the ground truth image, ground truth mask, and the ground truth mask overlayed on the image. The overlay helps us better visualize the segmented classes in the context of the original image.

In [None]:
for batch_images, batch_masks in valid_loader:

    batch_images = denormalize(batch_images, mean=DatasetConfig.MEAN, std=DatasetConfig.STD).permute(0, 2, 3, 1).numpy()
    batch_masks  = batch_masks.numpy()

    print("batch_images shape:", batch_images.shape)
    print("batch_masks shape: ", batch_masks.shape)
    
    display_image_and_mask(images=batch_images, masks=batch_masks)

    break

## 4 Loading SegFormer From 🤗 HuggingFace

The SegFormer model was proposed in the paper titled <a href="https://arxiv.org/abs/2105.15203" target="_blank">SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers.</a> The model consists of a hierarchical <a href="learnopencv.com/the-future-of-image-recognition-is-here-pytorch-vision-transformer/" target="_blank">Transformer</a> encoder made of efficient multi-head attention modules and a final lightweight all-MLP decoder head.

Abstract from the paper:

> We present SegFormer, a simple, efficient yet powerful semantic segmentation framework which unifies Transformers with lightweight multilayer perception (MLP) decoders. SegFormer has two appealing features: 1) SegFormer comprises a novel hierarchically structured Transformer encoder which outputs multiscale features. It does not need positional encoding, thereby avoiding the interpolation of positional codes which leads to decreased performance when the testing resolution differs from training. 2) SegFormer avoids complex decoders. The proposed MLP decoder aggregates information from different layers, and thus combining both local attention and global attention to render powerful representations. We show that this simple and lightweight design is the key to efficient segmentation on Transformers. We scale our approach up to obtain a series of models from SegFormer-B0 to SegFormer-B5, reaching significantly better performance and efficiency than previous counterparts. For example, SegFormer-B4 achieves 50.3% mIoU on ADE20K with 64M parameters, being 5x smaller and 2.2% better than the previous best method. Our best model, SegFormer-B5, achieves 84.0% mIoU on Cityscapes validation set and shows excellent zero-shot robustness on Cityscapes-C.

<img src="https://learnopencv.com/wp-content/uploads/2023/07/medical-image-segmentation_Segformer_architecture.png" width="75%" align="center">

Source: Arxiv paper

**You can check all the trained weights available for SegFormer model on HuggingFace <a href="https://huggingface.co/models?pipeline_tag=image-segmentation&sort=downloads&search=nvidia%2Fsegformer" target="_blank">over here.</a>**

Loading a pre-trained model version and getting it ready for inference or finetuning is very easy, thanks to HuggingFace. We only have to pass the following:

1. `pretrained_model_name_or_path`: (string). The id/path of a pre-trained model hosted on the Huggingface model zoo.
2. `num_labels`: (int) The number of channels (one for each class) we want the model to give as output. Suppose the number differs from the original number. In that case, the layer will be replaced with a new layer with randomly initialized weights.
3. `ignore_mismatched_sizes`: (bool) Boolean value to whether or not to ignore the weight key mismatch. Here, it occurs because we change the `num_labels` value.  

In [None]:
def get_model(*, model_name, num_classes):
    model = SegformerForSemanticSegmentation.from_pretrained(
        model_name,
        num_labels=num_classes,
        ignore_mismatched_sizes=True,
    )
    return model

**Usage**

In [None]:
# Define model
model = get_model(model_name=TrainingConfig.MODEL_NAME, num_classes=DatasetConfig.NUM_CLASSES)

* The model's forward pass takes multiple arguments <a href="https://huggingface.co/docs/transformers/v4.15.0/model_doc/segformer#transformers.SegformerForSemanticSegmentation.forward" target="_blank">[SegFormer Documentation]</a>. The two important ones are `pixel_values` and `labels`.
* The `pixel_values` argument refers to the input images. The `labels` argument is for passing the ground-truth mask.  
* The model's forward pass also calculates the cross-entropy loss if `labels` are passed.
* The output logits are smaller than the input image size. To get the **outputs** to match the input image size, we need to simply **upsample** it.

In [None]:
# Create dummy inputs.
print(DatasetConfig.IMAGE_SIZE[::-1])
data    = torch.randn(1, 3, *DatasetConfig.IMAGE_SIZE[::-1])
target = torch.rand(1, *DatasetConfig.IMAGE_SIZE[::-1]).to(torch.long)

In [None]:
# Generate dummy outputs.
outputs = model(pixel_values=data, labels=target, return_dict=True)

# Upsample model outputs to match input image size.
upsampled_logits = nn.functional.interpolate(outputs["logits"], size=target.shape[-2:], mode="bilinear", align_corners=False)

To access the model's output, we have to use the `["logits"]` key. Similarly, we can access the loss via the `"loss"` key.

In [None]:
print("Model Outputs: outputs['logits']:", outputs["logits"].shape)

print("Model Outputs Resized::", upsampled_logits.shape)

print("Loss: outputs['loss']:", outputs["loss"])

In this project, we won’t be using the CE loss returned by the model for training. Instead, we will define our custom combo loss function that combines the Smooth Dice coefficient & CE to compute the loss.

In [None]:
summary(model, input_size=(1, 3, *DatasetConfig.IMAGE_SIZE[::-1]), depth=2, device="cpu")

## 5 Evaluation Metric  & Loss Function

The **Dice Coefficient** (otherwise known as the *F1-Score*) is a function that is commonly used in the context of segmentation and is often specifically used as the basis for a loss function for segmentation problems. We will write the custom loss function next based on the Dice Coefficient, but let's first provide the motivation for why this might be a good idea. 

For a binary classification problem, the metric is defined as follows using set notation, where `A` and `B` are segmentation masks representing the ground truth mask and the predicted segmentation map. 
<br>

$$ 
Dice = \frac{2*|A\cap B\hspace{1mm}|}{|A\hspace{1mm}| + |B\hspace{1mm}|} \hspace{2mm}
$$

Simply put, the metric is twice the overlap area divided by the total number of pixels in both images. As you can see, the Dice Coefficient is very similar to IoU. Both metrics range from `0` to `1` and are positively correlated with each other. In terms of confusion matrix components, the metric can also be defined as follows:
<br>
$$ Dice =  \hspace{2mm} \frac{2TP}{2TP + FP + FN}
$$
<br>

However, the Dice Coefficient is not quite as intuitive as IoU. To better understand the formulation, we need to consider two important quantities that lead to its development: Precision and Recall, as defined below. 
<br>

$$P:= \frac{TP}{TP + FP}  \hspace{10mm} R:= \frac{TP}{TP + FN}$$


<br>
Precision is a measure of how precise the model is in making predictions (quality or purity of the positive predictions), and Recall considers what we missed or describes the completeness of the positive predictions. 
This is the motivation that gave rise to the development of the Dice Coefficient (F1-Score) below, defined as the harmonic mean of the two quantities (a balancing between the two quantities):
<br>

$$Dice = (\frac{2}{\frac{1}{P} + \frac{1}{R}}) \hspace{2mm} =  \hspace{2mm} \frac{2TP}{2TP + FP + FN}$$
<br>

Another way to look at each component is by referring to the following figure from Wikipedia for the <a href="https://en.wikipedia.org/wiki/F-score" target="_blank">F1-Score</a>. Here we see that it's important to consider which elements are relevant and which elements are retrieved. In this context, it is easy to see that both Precision and Recall are essential components for quantifying the accuracy of a model.

<img src='https://opencv.org/wp-content/uploads/2022/07/c4-05-precision-recall.png' align='center' width="60%">

---

Note that the Dice Coefficient can also be used as an evaluation metric and is used in the Kaggle competition as an evaluation metric along with 3D Hausdorff distance. But since, for this project, we are focusing on 2D images, we will stick with the Dice coefficient as our primary evaluation metric.

### 5.1 Custom Loss Functions - Smooth Dice + Cross-Entropy

Below, we define a custom loss function often used in segmentation problems when there is an imbalance in the classes within the dataset. The loss is based on the Dice metric and combined with Cross-entropy (CE) loss. 

Dice + CE is a good loss function for semantic segmentation as it combines pixel-wise accuracy with boundary alignment, encouraging precise object localization. It addresses the class imbalance issue by incorporating the Dice coefficient, promoting balanced predictions and improving overall segmentation performance.

In practice, we found that using a combined loss (Dice loss + CCE loss) works better than Dice loss alone. This is also supported by our experiments:

<img src="https://learnopencv.com/wp-content/uploads/2023/07/medical-image-segmentation_run_f1_compare.png">

The gray one refers to *Dice + CE loss* & the green is for *only Dice loss*

In [None]:
def dice_coef_loss(predictions, ground_truths, num_classes=2, dims=(1, 2), smooth=1e-8):
    """Smooth Dice coefficient + Cross-entropy loss function."""

    ground_truth_oh = F.one_hot(ground_truths, num_classes=num_classes)
    prediction_norm = F.softmax(predictions, dim=1).permute(0, 2, 3, 1)

    intersection = (prediction_norm * ground_truth_oh).sum(dim=dims)
    summation = prediction_norm.sum(dim=dims) + ground_truth_oh.sum(dim=dims)

    dice = (2.0 * intersection + smooth) / (summation + smooth)
    dice_mean = dice.mean()


    CE = F.cross_entropy(predictions, ground_truths)

    return (1.0 - dice_mean) + CE

### 5.2 Evaluation Metric - Dice Coefficient (F1-Score) 


To calculate the Dice score for the medical image segmentation task, we will use the `MulticlassF1Score` class from the `torchmetrics` library with the "`macro`" average reduction method.

**Macro** average refers to a method of calculating average performance in multiclass or multilabel classification problems, which treats all classes equally.

## 6 Creating The Custom LightningModule Class

The final custom class we need to create is the `MedicalSegmentationModel` which inherits its functionalities from Lightning’s `LightningModule` class.

The `LightningModule` class in pytorch-lightning is a higher-level abstraction that simplifies the training and organizing of PyTorch models. It provides a structured, standardized interface for defining and training deep learning models. It separates the concerns of model definition, optimization, and training loop, making the code more modular and readable.


The class methods we need to define are as follows:

1. Model initialization: `__init__(...)` method where the model and its parameters are defined. This method also includes the initialization of the loss and metric calculation methods.
2. Forward pass: `forward(...)` method where the forward pass of the model is defined.
Training step: training_step(...) method where the training step for each batch is defined. It includes calculating loss and metrics, which are logged for tracking.
3. Validation step: `validation_step(...)` method where the validation step for each batch is defined. It also includes the calculation of loss and metrics.
4. Optimizer configuration: `configure_optimizers(...)` method where the optimizer and, optionally, the learning rate scheduler are defined.

Moreover, two methods, `on_train_epoch_end(...)` and `on_validation_epoch_end(...)`, are defined to log the average loss and f1 score after each epoch for training and validation, respectively.

In [None]:
class MedicalSegmentationModel(pl.LightningModule):
    def __init__(
        self,
        model_name: str,
        num_classes: int = 10,
        init_lr: float = 0.001,
        optimizer_name: str = "Adam",
        weight_decay: float = 1e-4,
        use_scheduler: bool = False,
        scheduler_name: str = "multistep_lr",
        num_epochs: int = 100,
    ):
        super().__init__()

        # Save the arguments as hyperparameters.
        self.save_hyperparameters()

        # Loading model using the function defined above.
        self.model = get_model(model_name=self.hparams.model_name, num_classes=self.hparams.num_classes)

        # Initializing the required metric objects.
        self.mean_train_loss = MeanMetric()
        self.mean_train_f1 = MulticlassF1Score(num_classes=self.hparams.num_classes, average="macro")
        self.mean_valid_loss = MeanMetric()
        self.mean_valid_f1 = MulticlassF1Score(num_classes=self.hparams.num_classes, average="macro")

    def forward(self, data):
        outputs = self.model(pixel_values=data, return_dict=True)
        upsampled_logits = F.interpolate(outputs["logits"], size=data.shape[-2:], mode="bilinear", align_corners=False)
        return upsampled_logits
    
    def training_step(self, batch, *args, **kwargs):
        data, target = batch
        logits = self(data)

        # Calculate Combo loss (Segmentation specific loss (Dice) + cross entropy)
        loss = dice_coef_loss(logits, target, num_classes=self.hparams.num_classes)
        
        self.mean_train_loss(loss, weight=data.shape[0])
        self.mean_train_f1(logits.detach(), target)

        self.log("train/batch_loss", self.mean_train_loss, prog_bar=True, logger=False)
        self.log("train/batch_f1", self.mean_train_f1, prog_bar=True, logger=False)
        return loss

    def on_train_epoch_end(self):
        # Computing and logging the training mean loss & mean f1.
        self.log("train/loss", self.mean_train_loss, prog_bar=True)
        self.log("train/f1", self.mean_train_f1, prog_bar=True)
        self.log("epoch", self.current_epoch)

    def validation_step(self, batch, *args, **kwargs):
        data, target = batch
        logits = self(data)
        
        # Calculate Combo loss (Segmentation specific loss (Dice) + cross entropy)
        loss = dice_coef_loss(logits, target, num_classes=self.hparams.num_classes)

        self.mean_valid_loss.update(loss, weight=data.shape[0])
        self.mean_valid_f1.update(logits, target)

    def on_validation_epoch_end(self):
        
        # Computing and logging the validation mean loss & mean f1.
        self.log("valid/loss", self.mean_valid_loss, prog_bar=True)
        self.log("valid/f1", self.mean_valid_f1, prog_bar=True)
        self.log("epoch", self.current_epoch)

    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.hparams.optimizer_name)(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.hparams.init_lr,
            weight_decay=self.hparams.weight_decay,
        )

        LR = self.hparams.init_lr
        WD = self.hparams.weight_decay

        if self.hparams.optimizer_name in ("AdamW", "Adam"):
            optimizer = getattr(torch.optim, self.hparams.optimizer_name)(model.parameters(), lr=LR, 
                                                                          weight_decay=WD, amsgrad=True)
        else:
            optimizer = optim.SGD(model.parameters(), lr=LR, weight_decay=WD)

        if self.hparams.use_scheduler:
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[self.trainer.max_epochs // 2,], gamma=0.1)

            # The lr_scheduler_config is a dictionary that contains the scheduler
            # and its associated configuration.
            lr_scheduler_config = {"scheduler": lr_scheduler, "interval": "epoch", "name": "multi_step_lr"}
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}

        else:
            return optimizer

## 7 Start Training

Once we have organized the `LightningModule` and `LightningDataModule` classes, we can utilize Lightning's `Trainer` class to automate the remaining tasks effortlessly.

The `Trainer` offers a range of valuable deep-learning training functionalities, such as mixed-precision training, distributed training, deterministic training, profiling, gradient accumulation, batch overfitting, and more. Implementing these functionalities correctly can be time-consuming, but it becomes a swift process with the `Trainer` class.

By initializing our `MedicalSegmentationModel` and `MedicalSegmentationDataModule` classes and passing them to the `.fit(...)` method of the `Trainer` class instance, we can promptly commence training. This streamlined approach eliminates the need to implement various training aspects manually, providing convenience and efficiency.

In [None]:
# Seed everything for reproducibility.
pl.seed_everything(42, workers=True)

model = MedicalSegmentationModel(
    model_name=TrainingConfig.MODEL_NAME,
    num_classes=DatasetConfig.NUM_CLASSES,
    init_lr=TrainingConfig.INIT_LR,
    optimizer_name=TrainingConfig.OPTIMIZER_NAME,
    weight_decay=TrainingConfig.WEIGHT_DECAY,
    use_scheduler=TrainingConfig.USE_SCHEDULER,
    scheduler_name=TrainingConfig.SCHEDULER,
    num_epochs=TrainingConfig.NUM_EPOCHS,
)

data_module = MedicalSegmentationDataModule(
    num_classes=DatasetConfig.NUM_CLASSES,
    img_size=DatasetConfig.IMAGE_SIZE,
    ds_mean=DatasetConfig.MEAN,
    ds_std=DatasetConfig.STD,
    batch_size=TrainingConfig.BATCH_SIZE,
    num_workers=TrainingConfig.NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
)

Next, we will define a `ModelCheckpoint` and a `LearningRateMonitor` callback for saving the best model during training and the current learning rate of an epoch.

In [None]:
# Creating ModelCheckpoint callback. 
# We'll save the model on basis on validation f1-score.
model_checkpoint = ModelCheckpoint(
    monitor="valid/f1",
    mode="max",
    filename="ckpt_{epoch:03d}-vloss_{valid/loss:.4f}_vf1_{valid/f1:.4f}",
    auto_insert_metric_name=False,
)

# Creating a learning rate monitor callback which will be plotted/added in the default logger.
lr_rate_monitor = LearningRateMonitor(logging_interval="epoch")

We will also initialize the `WandbLogger` to upload the training metrics to your wandb project.
 
During the logger initialization, we set two parameters:
1. `log_model=True` - Upload the model as an artifact when the training is completed.
2. `project` - The project name to use on WandB. A project typically contains logs from multiple experiments and their checkpoints.

In [None]:
# Initialize logger.
wandb_logger = WandbLogger(log_model=True, project="UM_medical_segmentation")

When the logger is initialized, it will also print the link for the current experiment, which you open on any device to monitor the training process and also share with your team.

**Train**

In [None]:
# Initializing the Trainer class object.
trainer = pl.Trainer(
    accelerator="auto",  # Auto select the best hardware accelerator available
    devices="auto",  # Auto select available devices for the accelerator (For eg. mutiple GPUs)
    strategy="auto",  # Auto select the distributed training strategy.
    max_epochs=TrainingConfig.NUM_EPOCHS,  # Maximum number of epoch to train for.
    enable_model_summary=False,  # Disable printing of model summary as we are using torchinfo.
    callbacks=[model_checkpoint, lr_rate_monitor],  # Declaring callbacks to use.
    precision="16-mixed",  # Using Mixed Precision training.
    logger=wandb_logger
)

# Start training
trainer.fit(model, data_module)

## 8 Inference on the Medical Segmentation Dataset

For inference, we will use the same validation data as we did during training. We will plot the ground truth images, the ground truth masks, and the predicted segmentation maps overlayed on the ground truth images.

### 8.1 Load The Best Trained Model

In [None]:
# Get the path of the best saved model.
CKPT_PATH = model_checkpoint.best_model_path
CKPT_PATH

Initialize the model with trained weights.

In [None]:
model = MedicalSegmentationModel.load_from_checkpoint(CKPT_PATH)

### 8.2 Evaluate Model On Validation Dataset

In [None]:
# Get the validation dataloader.

data_module.setup()
valid_loader = data_module.val_dataloader()

Get the best evaluation metrics using the saved model.

In [None]:
# Initialize trainer class for inference.
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,        
    enable_checkpointing=False,
    inference_mode=True,
)

# Run evaluation.
results = trainer.validate(model=model, dataloaders=valid_loader)

<img src="https://learnopencv.com/wp-content/uploads/2023/07/medical-image-segmentation_best_run_results_chart.png">

*Blue* - Train, *Orange* - Valid

Log them as experiment summary metrics to WandB.

In [None]:
if os.environ.get("LOCAL_RANK", None) is None:
    wandb.run.summary["best_valid_f1"] = results[0]["valid/f1"]
    wandb.run.summary["best_valid_loss"] = results[0]["valid/loss"]

### 8.3 Image Inference Using DataLoader Objects

In the code below, we define a helper function  that performs `inference` given a trained model and a dataloader object. The model prediction will also be uploaded to wandb.

In [None]:
@torch.inference_mode()
def inference(model, loader, img_size, device="cpu"):
    num_batches_to_process = InferenceConfig.NUM_BATCHES

    for idx, (batch_img, batch_mask) in enumerate(loader):
        predictions = model(batch_img.to(device))

        pred_all = predictions.argmax(dim=1).cpu().numpy()

        batch_img = denormalize(batch_img.cpu(), mean=DatasetConfig.MEAN, std=DatasetConfig.STD)
        batch_img = batch_img.permute(0, 2, 3, 1).numpy()

        if idx == num_batches_to_process:
            break

        for i in range(0, len(batch_img)):
            fig = plt.figure(figsize=(20, 8))

            # Display the original image.
            ax1 = fig.add_subplot(1, 4, 1)
            ax1.imshow(batch_img[i])
            ax1.title.set_text("Actual frame")
            plt.axis("off")

            # Display the ground truth mask.
            true_mask_rgb = num_to_rgb(batch_mask[i], color_map=id2color)
            ax2 = fig.add_subplot(1, 4, 2)
            ax2.set_title("Ground truth labels")
            ax2.imshow(true_mask_rgb)
            plt.axis("off")

            # Display the predicted segmentation mask.
            pred_mask_rgb = num_to_rgb(pred_all[i], color_map=id2color)
            ax3 = fig.add_subplot(1, 4, 3)
            ax3.set_title("Predicted labels")
            ax3.imshow(pred_mask_rgb)
            plt.axis("off")

            # Display the predicted segmentation mask overlayed on the original image.
            overlayed_image = image_overlay(batch_img[i], pred_mask_rgb)
            ax4 = fig.add_subplot(1, 4, 4)
            ax4.set_title("Overlayed image")
            ax4.imshow(overlayed_image)
            plt.axis("off")
            plt.show()
            
            # Upload predictions to WandB.
            images = wandb.Image(fig, caption=f"Prediction Sample {idx}_{i}")
            
            if os.environ.get("LOCAL_RANK", None) is None:
                wandb.log({"Predictions": images})

In [None]:
# Use GPU if available.
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model.to(DEVICE)
model.eval()

inference(model, valid_loader, device=DEVICE, img_size=DatasetConfig.IMAGE_SIZE)

Terminate the wandb experiment run.

In [None]:
if os.environ.get("LOCAL_RANK", None) is None:
    wandb.run.finish()

You can use the Gradio inference demo app here --> <a href="https://huggingface.co/spaces/veb-101/UWMGI_Medical_Image_Segmentation" target="_blank">Medical Image Segmentation Gradio App</a>

All the files required can be accessed from here --> <a href="https://huggingface.co/spaces/veb-101/UWMGI_Medical_Image_Segmentation/tree/main" target="_blank">Gradio App Files</a>

## 9 Summary

Medical image segmentation using deep learning provides significant advantages. Deep learning models excel at capturing complex patterns and features, leading to highly accurate and precise segmentation results compared to traditional methods. Additionally, deep learning algorithms automate segmentation, improving efficiency and enabling analysis of large volumes of medical image data. Moreover, deep learning models demonstrate adaptability and generalization, making them suitable for diverse image characteristics, imaging modalities, patient populations, and clinical settings, expanding their utility in medical imaging applications.

To summarise this article📜, we covered a comprehensive list of related topics:

1. Medical Image Segmentation: Explored the definition and challenges of medical image segmentation.
2. Dataset Preparation: Used the UW-Madison GI Tract segmentation dataset, made observations, and created preprocessed training and validation sets.
3. We defined a few essential functions and classes for PyTorch and PyTorch-Lightning frameworks to facilitate ease of training.
4. We learned how to use the Segformer model from Hugging Face transformers for segmentation and fine-tuned it on our dataset.
5. We defined a custom loss function combining the Dice coefficient with cross-entropy for improved segmentation performance.
6. Training and Metrics Tracking: Trained the model, monitored metrics using WandB, and uploaded the model as an artifact for future use.
7. We designed a user-friendly interface using the Gradio app, making our medical multi-label image classification model accessible to everyone.