<a href="https://colab.research.google.com/github/elliemci/vision-transformer-models/blob/main/image_segmentation/transformer_seg_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MRI image segmentation with SegFormer

Train a deep-learning-based semantic segmentation transformer model to identify the tumor locations and boundaries within MRI brain images, using brain images and coresponding ground truth masks. Pnce trained, SegFormer can predict bounderies of tumor regions. A segmentation accuracy metric Mean IoU is appliced to asses model quality.

In [None]:
!pip install jedi

In [None]:
!apt-get install libcairo2-dev pkg-config python3-dev

In [None]:
!pip install pycairo

In [None]:
!pip install --upgrade fsspec==2024.10.0



In [None]:
!pip install datasets transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/ColabNotebooks/ExplainableAI/image_segmentation

/content/drive/MyDrive/ColabNotebooks/ExplainableAI/image_segmentation


In [None]:
!ls

data_preproces.ipynb  mri_data_seg  segmentation_dataset  transformer_seg_traiining.ipynb


## 1. Load MRI dataset

The dataset is a dictionary containg the images and coresponding masks, split into train, test and validate datasets.

In [None]:
from datasets import load_from_disk

dataset = load_from_disk("segmentation_dataset")
dataset

DatasetDict({
    train: Dataset({
        features: ['pixel_values', 'label'],
        num_rows: 400
    })
    test: Dataset({
        features: ['pixel_values', 'label'],
        num_rows: 200
    })
    validate: Dataset({
        features: ['pixel_values', 'label'],
        num_rows: 100
    })
})

## 2. Training Function

### Load Hugging Face SegFormer model

In [None]:
import torch
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# SegformerForSemanticSegmentation model is not competable with 2-class segmentation task
from transformers import SegformerForSemanticSegmentation
from transformers import AdamW, get_linear_schedule_with_warmup

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def list_and_set_gpu_memory_growth():
    """
    Lists available GPUs and sets memory growth for each GPU.
    """

    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print(f"Number of available GPUs: {num_gpus}")

        for i in range(num_gpus):
            gpu_properties = torch.cuda.get_device_properties(i)
            print(f"GPU {i}: {gpu_properties.name}")

            # Enable memory growth for this GPU
            torch.cuda.set_per_process_memory_fraction(1.0, device=i)
            #torch.cuda.empty_cache()

            # Check if memory growth is enabled (optional)
            #is_memory_growth_enabled = torch.cuda.is_memory_growth_enabled(i)  # PyTorch 2.0+
            #print(f"  Memory growth enabled: {is_memory_growth_enabled}")
    else:
        print("CUDA is not available.")

# Call the function to list and set memory growth
list_and_set_gpu_memory_growth()

# initialize the model
model_checkpoint = "nvidia/segformer-b0-finetuned-ade-512-512"
# load the SegformaerForSemanticSegmentation
model = SegformerForSemanticSegmentation.from_pretrained(
    model_checkpoint,
    num_labels = 2,
    id2label = {0: 'background', 1: 'tumor'},
    label2id = {'background': 0, 'tumor': 1},
    ignore_mismatched_sizes = True
)

# set hyperparameters for training
batch_size = 2
num_epochs = 2
num_training_steps = (len(dataset['train'])//batch_size) * num_epochs
learning_rate = 6e-5
weight_decay_rate = 0.01

# compile the model
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=learning_rate,
                              weight_decay=weight_decay_rate)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps = 0,
    num_training_steps = num_training_steps)

criterion = torch.nn.CrossEntropyLoss()

Number of available GPUs: 1
GPU 0: Tesla T4


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/6.88k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/15.0M [00:00<?, ?B/s]

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([2]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([2, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Set up training arguments for initializing hyperparameters

In [None]:
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

batch_size = 2

# define a transform to convert PIL images to PyTorch tensors
transform = transforms.Compose([
    # convert PIL image to PyTorch tensor
    transforms.ToTensor(),  # convert PIL image to PyTorch tensor
])

def process_train_batch(batch):
  """ Read an bath of train images and transform them into inputs.
      Since transformer is pre-trained on RGB images convert from gray scale"""

  batch_imgs = [transform(img.convert("RGB")) for img in batch['pixel_values']]
  batch_masks = [transform(mask.convert("RGB")) for mask in batch['label']]

  # return a dictionary with image and label data
  return {'pixel_values': batch_imgs, 'mask': batch_masks}

def process_val_batch(batch):
  """ Read an bath of test images and transform them into inputs. """
  batch_imgs = [transform(img.convert("RGB")) for img in batch['pixel_values']]
  batch_masks = [transform(label.convert("RGB")) for label in batch['label']]

  # return a dictionary with image and label data
  return {'pixel_values': batch_imgs, 'mask': batch_masks}

dataset["train"].set_transform(process_train_batch)
dataset["test"].set_transform(process_train_batch)
dataset["validate"].set_transform(process_val_batch)

train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
train_loader = DataLoader(dataset['test'], batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset['validate'], batch_size=batch_size, shuffle=True)


for batch in train_loader:
  # convert each gray scale image in batch to RGB

  print(f"batch data type: {type(batch)}")
  print(f"batch images size: {batch['pixel_values'].shape}")
  print(f"batch masks size: {batch['mask'].shape}")
  print()
  image = batch['pixel_values'][0]
  mask = batch['mask'][0]
  print(f"image dimensions: {image.size}")
  print(f"image data type: {image.dtype}")
  print(f"image min value: {image.min()}")
  print(f"image max value: {image.max()}")
  print(f"labels in mask:, {np.unique(mask)}")
  print(f"mask dimensions: {mask.size}")
  print(f"mask data type: {mask.dtype}")
  print(f"mask min value: {mask.min()}")
  print(f"mask max value: {mask.max()}")
  print()
  break

batch data type: <class 'dict'>
batch images size: torch.Size([2, 3, 512, 512])
batch masks size: torch.Size([2, 3, 512, 512])

image dimensions: <built-in method size of Tensor object at 0x7e2449470720>
image data type: torch.float32
image min value: 0.0
image max value: 1.0
labels in mask:, [0. 1.]
mask dimensions: <built-in method size of Tensor object at 0x7e24500f3060>
mask data type: torch.float32
mask min value: 0.0
mask max value: 1.0



## 3. Accuracy Estimation

In [None]:
def calculate_iou(true_class, pred_class):
    """
    Calculate Intersection over Union (IoU) for a specific class.
    Args:
        true_class (torch.Tensor): Boolean mask for the true class.
        pred_class (torch.Tensor): Boolean mask for the predicted class.
    Returns:
        float: IoU score for the class.
    """

    intersection = torch.logical_and(true_class, pred_class)
    union = torch.logical_or(true_class, pred_class)

    if torch.any(union):
        iou = torch.sum(intersection).float() / torch.sum(union).float()
    else:
        iou = 0.0  # if no overlap
    #print(f"IoU: {iou}")
    #print()
    return iou


def calculate_mean_iou(true_mask, pred_mask):
    """
    Calculate the mean Intersection over Union (mIoU) score.
    Args:
        true_mask (torch.Tensor): Ground truth mask.
        pred_mask (torch.Tensor): Predicted mask.
    Returns:
        float: Mean IoU score across all classes.
    """

    class_iou = []

    # Get the maximum class value (assuming contiguous classes starting from 0)
    max_value = true_mask.max().item()
    #print(f"Max class value: {max_value}")

    for i in range(1, max_value + 1):  # Skip background (class 0)
        true_i = (true_mask == i)
        pred_i = (pred_mask == i)
        #print(f"Class {i}: true_i sum: {torch.sum(true_i)}, pred_i sum: {torch.sum(pred_i)}")

        iou = calculate_iou(true_i, pred_i)
        #print(f"Class {i} IoU: {iou}")
        class_iou.append(iou)

    # Compute the mean IoU
    if class_iou:
        mean_iou = torch.mean(torch.tensor(class_iou))
    else:
        mean_iou = 0.0  # Handle case with no classes
        #print(f"Mean IoU: {mean_iou}")
    return mean_iou


## 4. Callback Function

In [None]:
def create_mask(pred_logits, target_size):
    """Create a mask from model logits."""

    # resize logits to match the ground truth size
    pred_logits_resized = F.interpolate(
        pred_logits,
        size=target_size,
        mode='bilinear',
        align_corners=False
    )
    pred_mask = torch.argmax(pred_logits_resized, dim=1)
    return pred_mask

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def display_image(image, ground_truth_mask, predicted_mask):
    """
    Display the input image, ground truth mask, and predicted mask side by side.

    Args:
        image (torch.Tensor): The input image tensor (C, H, W).
        ground_truth_mask (torch.Tensor): The ground truth mask tensor (H, W).
        predicted_mask (torch.Tensor): The predicted mask tensor (H, W).
    """
    # Move the image and masks to the CPU and convert to NumPy
    image = image.cpu().numpy().transpose(1, 2, 0)  # Convert (C, H, W) -> (H, W, C)
    ground_truth_mask = ground_truth_mask.cpu().numpy()

    # Select the first channel of the ground_truth_mask to display it as grayscale
    # ground_truth_mask = ground_truth_mask[0]
    #if predicted_mask is not None:

    predicted_mask = predicted_mask.cpu().numpy()

    # Normalize the image for display (optional: adjust depending on input range)
    image = (image - image.min()) / (image.max() - image.min())  # Normalize to [0, 1]

    # Create a figure with 3 subplots
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(image)
    axs[0].set_title("Input Image")
    axs[0].axis("off")

    axs[1].imshow(ground_truth_mask, cmap="jet") # select the first channel
    axs[1].set_title("Ground Truth Mask")
    axs[1].axis("off")


    #if predicted_mask is not None:
    axs[2].imshow(predicted_mask, cmap="jet")
    axs[2].set_title("Predicted Mask")
    axs[2].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
print("Sample mask unique values:", torch.unique(batch['mask'][0]))
display_image(
    image=batch['pixel_values'][0],
    ground_truth_mask=batch['mask'][0],
    predicted_mask=None  # Skip predictions for raw data inspection
)

In [None]:
@torch.no_grad()
def custom_callback(epoch, model, val_loader, device):
    """
    Callback function executed after each training epoch.
    Evaluates the model on the validation set and calculates mean IoU.
    """
    model.eval()
    print(f"Running validation for epoch {epoch + 1}...")

    total_mean_iou = 0.0
    num_batches = 0

    for batch in val_loader:
        pixel_values = batch['pixel_values'].to(device)
        true_mask = torch.argmax(batch['mask'].to(device), dim=1)

        # Get model predictions
        outputs = model(pixel_values)
        pred_mask = create_mask(outputs.logits, target_size=true_mask.shape[-2:])

        # Compute mean IoU for the batch
        mean_iou = calculate_mean_iou(true_mask, pred_mask)
        total_mean_iou += mean_iou
        num_batches += 1

    avg_mean_iou = total_mean_iou / num_batches
    print(f"Epoch {epoch + 1}: Mean IoU: {avg_mean_iou:.4f}")


    print(f"Unique values in ground truth mask (batch): {torch.unique(batch['mask'])}")
    print(f"Unique values in ground truth mask (per image): {torch.unique(true_mask[0])}")

    display_image(
        image=pixel_values[0],  # First image in the batch
        ground_truth_mask=true_mask[0],  # Ground truth mask for the first image
        predicted_mask=pred_mask[0]  # Predicted mask for the first image
      )


## 5. Run Training

In [None]:
import torch
import torch.nn.functional as F  # for the interpolate function

def train(model, train_loader, val_loader, optimizer, criterion, num_epochs, device, custom_callback):
    """
    Trains the given model on the provided training data and displays predictions on validation data after each epoch.

    Args:
        model: The model to train.
        train_loader: DataLoader for training data.
        val_loader: DataLoader for validation data.
        optimizer: Optimizer for updating model parameters.
        criterion: Loss function.
        num_epochs: Number of training epochs.
        device: Device to use for computation ('cuda' or 'cpu').
        custom_callback: Callback function to execute after each epoch.
    """
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        for batch in train_loader:
            # Move the pixel_values and label tensors to the device
            pixel_values = batch['pixel_values'].to(device)
            labels = torch.argmax(batch['mask'].to(device), dim=1).long()

            optimizer.zero_grad()
            outputs = model(pixel_values)

            # Resize model outputs to match the input image size
            outputs_resized = F.interpolate(
                outputs.logits,  # assuming model returns logits
                size=labels.shape[1:],  # match height and width of labels
                mode='bilinear',
                align_corners=False
            )

            # Compute the loss
            loss = criterion(outputs_resized, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_epoch_loss:.4f}")

        # Execute the custom callback function after each epoch
        custom_callback(epoch, model, val_loader, device)


In [None]:
train(model, train_loader, val_loader, optimizer, criterion, num_epochs, device, custom_callback)