<a href="https://colab.research.google.com/github/dudeurv/SAM_MRI/blob/main/U_NET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Medical Image Segmentation with U-NET Tutorial

In this tutorial, you will develop and train a convolutional neural network for brain tumour image segmentation.

In [2]:
# Import libraries
import tarfile
import imageio.v3 as iio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import time
import os
import random
import matplotlib.pyplot as plt
from matplotlib import colors

## Download the imaging dataset

The dataset is curated from the brain imaging dataset in [Medical Decathlon Challenge](http://medicaldecathlon.com/). To save the storage and reduce the computational cost for this tutorial, we extract 2D image slices from T1-Gd contrast enhanced 3D brain volumes and downsample the images.

The dataset consists of a training set and a test set. Each image is of dimension 120 x 120, with a corresponding label map of the same dimension. There are four number of classes in the label map:

- 0: background
- 1: edema
- 2: non-enhancing tumour
- 3: enhancing tumour

In [3]:
# Download the dataset
!wget https://www.dropbox.com/s/zmytk2yu284af6t/Task01_BrainTumour_2D.tar.gz

# Unzip the '.tar.gz' file to the current directory
datafile = tarfile.open('Task01_BrainTumour_2D.tar.gz')
datafile.extractall()
datafile.close()

--2023-11-26 21:07:12--  https://www.dropbox.com/s/zmytk2yu284af6t/Task01_BrainTumour_2D.tar.gz
Resolving www.dropbox.com (www.dropbox.com)... 162.125.65.18, 2620:100:6017:18::a27d:212
Connecting to www.dropbox.com (www.dropbox.com)|162.125.65.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /s/raw/zmytk2yu284af6t/Task01_BrainTumour_2D.tar.gz [following]
--2023-11-26 21:07:13--  https://www.dropbox.com/s/raw/zmytk2yu284af6t/Task01_BrainTumour_2D.tar.gz
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://ucee335ccf10ef66ab61c598eb5e.dl.dropboxusercontent.com/cd/0/inline/CITEE5ropihrQGPKYCHZTlVUQgWyO8s8YbU5XZpdOnqxlUG9yjSYQ6O8BW8ZisBQJCrQXsrepXQajaRkTScAMV4sWDXEehEjdYg2Xf5Xt6kcIK4qv9ySFC97FanYk7xS2n4/file# [following]
--2023-11-26 21:07:14--  https://ucee335ccf10ef66ab61c598eb5e.dl.dropboxusercontent.com/cd/0/inline/CITEE5ropihrQGPKYCHZTlVUQgWyO8s8YbU5XZpdOnqxlUG9yjSYQ6O8BW8ZisBQJCrQXsrep

## Implement a dataset class

  **Documentation**:
  - os.listdir: [https://docs.python.org/3/library/os.html#os.listdir](https://docs.python.org/3/library/os.html#os.listdir)
  - os.path.join: [https://docs.python.org/3/library/os.path.html#os.path.join](https://docs.python.org/3/library/os.path.html#os.path.join)
  - imageio.imread: [https://imageio.readthedocs.io/en/stable/userapi.html#imageio.imread](https://imageio.readthedocs.io/en/stable/userapi.html#imageio.imread)
  - random.sample: [https://docs.python.org/3/library/random.html#random.sample](https://docs.python.org/3/library/random.html#random.sample)
  - numpy.expand_dims: [https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html](https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html)

In [4]:
def normalise_intensity(image, ROI_thres=0.1):
    """
    The function identifies the ROI in the image by applying a percentile-based threshold,
    then standardizes the pixel values in this region by subtracting the mean and dividing
    by the standard deviation.

    Args:
        image (np.array): Input image as a NumPy array.
        ROI_thres (float): Percentile threshold for defining the ROI (default is 0.1).

    Returns:
        np.array: Image array with normalized intensity in the ROI.
    """
    pixel_thres = np.percentile(image, ROI_thres)
    ROI = np.where(image > pixel_thres, image, 0) # If image value is greater than pixel threshold, return image value, otherwise return 0
    mean = np.mean(ROI)
    std = np.std(ROI)
    ROI_norm = (ROI - mean)/std # Normalise ROI
    return ROI_norm

class BrainImage():
    def __init__(self, image_path, label_path, deploy=False):
      # Initialise instant variables
      self.image_path = image_path
      self.label_path = label_path
      self.deploy = deploy # If deploy=True this means model is in testing mode
      self.images = [] # List of loaded image arrays
      self.labels = [] # List of loaded label arrays

      image_names = sorted(os.listdir(image_path)) # Sorted list containing image filenames to ensure a consistent order for data processing.
      for image_name in image_names:
        full_image_path = os.path.join(image_path, image_name)
        image = iio.imread(full_image_path) # Loads image into a processable NumPy array.
        self.images.append(image)

        if deploy == False: # If model is in training mode, load the labels as well
          full_label_path = os.path.join(label_path, image_name)
          label = iio.imread(full_label_path)
          self.labels.append(label)

    def __getitem__(self, idx):
        # Get an image and perform intensity normalisation
        image = normalise_intensity(self.images[idx])

        # Get its label map
        label = self.labels[idx]
        return image, label

    def get_random_batch(self, batch_size):
        # Get a batch of paired images and label maps
        images_batch, labels_batch = [], []

        idx_array = range(0, len(self.images)) # Creates an array of indices ranging from 0 to len(self.images)
        batch_idx = random.sample(idx_array, batch_size) # Randomly selects a batch_size number of indices
        for i in range(batch_size):
            image, label = self.__getitem__(batch_idx[i])
            images_batch.append(image)
            labels_batch.append(label)

        images_batch, labels_batch = np.array(images_batch), np.array(labels_batch)
        images_batch = np.expand_dims(images_batch, 1)
        return images_batch, labels_batch

## Build a U-net architecture
ters. This layer maps the deep features to the output classes or segments.

#### Documentation

- nn.Module: [https://pytorch.org/docs/stable/nn.html#module](https://pytorch.org/docs/stable/nn.html#module)
- Conv2d: [https://pytorch.org/docs/stable/nn.html#conv2d](https://pytorch.org/docs/stable/nn.html#conv2d)
- BatchNorm2d: [https://pytorch.org/docs/stable/nn.html#batchnorm2d](https://pytorch.org/docs/stable/nn.html#batchnorm2d)
- ReLU: [https://pytorch.org/docs/stable/nn.html#relu](https://pytorch.org/docs/stable/nn.html#relu)
- ConvTranspose2d: [https://pytorch.org/docs/stable/nn.html#convtranspose2d](https://pytorch.org/docs/stable/nn.html#convtranspose2d)
- torch.cat: [https://pytorch.org/docs/stable/generated/torch.cat.html](https://pytorch.org/docs/stable/generated/torch.cat.html)

In [5]:
class Double_Convolution(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    # 2 sets of: 3x3 Convolution layers, a ReLU activation to add non-linearity, and Batch Normalisation
    # Bias is set as false, as Batch Normalisation would remove the bias
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_channels),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_channels)
    )
  def forward(self, x):
    return self.conv(x)

class UNETModel(nn.Module):
  # U-Net Model involves an encoder, a bottleneck and decoder section
  def __init__(self):
    super().__init__()
    # Encoder with 4 blocks of Double_Convolution layers
    # Includes 4 sets of max pooling operations with kernel 2x2 and stride 2 for downsampling
    self.down_1 = nn.Sequential(
        Double_Convolution(1, 64),
        nn.MaxPool2d(2, 2)
    )
    self.down_2 = nn.Sequential(
        Double_Convolution(64, 128),
        nn.MaxPool2d(2, 2)
    )
    self.down_3 = nn.Sequential(
        Double_Convolution(128, 256),
        nn.MaxPool2d(2, 2)
    )
    self.down_4 = nn.Sequential(
        Double_Convolution(256, 512),
        nn.MaxPool2d(2, 2)
    )
    # Bottleneck
    # Includes a Double_Convolution layer followed by upsampling with ConvTranspose2d
    self.bottleneck = nn.Sequential(
        Double_Convolution(512, 1024),
        nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
    )
    # Decoder
    # Concatinating with skip connections causes feature channels to double
    # Double_Convolution and ConvTranspose2d each causes feature channels to halve
    self.up_1 = nn.Sequential(
        Double_Convolution(1024, 512),
        nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
    )
    self.up_2 = nn.Sequential(
        Double_Convolution(512, 256),
        nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
    )
    self.up_3 = nn.Sequential(
        Double_Convolution(256, 128),
        nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
    )
    # Final Output Layer
    self.out = nn.Sequential(
        Double_Convolution(128, 64),
        nn.Conv2d(64, 1, 3, 1, 1, bias=False)
    )
  def forward(self, x):
    # Encoder
    x1 = self.down_1(x)  # Size [1, 64, 263, 263]
    x2 = self.down_2(x1) # Size [1, 128, 131, 131]
    x3 = self.down_3(x2) # Size [1, 256, 65, 65]
    x4 = self.down_4(x3) # Size [1, 512, 32, 32]

    # Bottleneck
    x5 = self.bottleneck(x4) # Size [1, 512, 64, 64]

    # Decoder
    skip_x4 = F.interpolate(x4, size=x5.size()[2:])
    x6 = self.up_1(torch.cat([x5, skip_x4], dim=1))

    skip_x3 = F.interpolate(x3, size=x6.size()[2:])
    x7 = self.up_2(torch.cat([x6, skip_x3], dim=1))

    skip_x2 = F.interpolate(x2, size=x7.size()[2:])
    x8 = self.up_3(torch.cat([x7, skip_x2], dim=1))

    # Final Output Layer
    skip_x1 = F.interpolate(x1, size=x8.size()[2:])
    x9 = self.out(torch.cat([x8, skip_x1], dim=1))

    return x9

model = UNETModel()
model(torch.rand(1, 1, 527, 527))

tensor([[[[ 4.0134e-01,  8.0189e-01, -4.9929e-01,  ...,  1.1166e+00,
            8.7047e-01,  7.9254e-01],
          [ 1.5676e-01,  1.1310e+00,  9.4324e-01,  ...,  1.2408e+00,
            1.4608e+00,  6.9818e-01],
          [ 4.0672e-01,  4.6674e-01,  7.0603e-01,  ...,  5.7690e-01,
           -8.1200e-03,  1.2300e-01],
          ...,
          [ 7.2112e-01,  5.4620e-01, -4.2740e-02,  ...,  1.1756e-01,
           -1.2450e+00, -1.6777e-01],
          [ 3.2768e-01,  7.6127e-01, -8.4678e-01,  ...,  1.6581e-01,
           -6.0674e-01, -1.0053e-03],
          [-1.8598e-01,  7.9071e-03, -3.7258e-01,  ..., -5.8541e-01,
           -2.7512e-01, -4.4655e-01]]]], grad_fn=<ConvolutionBackward0>)

## Train the segmentation model

#### 6. Train the Model
- **Goal**: Implement the training loop.
- **Actions**:
  - Iterate over a specified number of iterations (`num_iter`).
  - In each iteration:
    - Set the model to training mode (`model.train()`).
    - Fetch a batch of training data and transfer it to the device.
    - Perform a forward pass through the model (`logits = model(images)`).
    - Clear previous gradients (`optimizer.zero_grad()`).
    - Compute loss using `criterion`.
    - Perform backpropagation (`loss.backward()`).
    - Update model parameters (`optimizer.step()`).
    - Print training loss.

#### 7. Evaluate the Model
- **Goal**: Periodically evaluate the model on the test set.
- **Actions**:
  - Every few iterations (e.g., `it % 10 == 0`):
    - Set the model to evaluation mode (`model.eval()`).
    - Disable gradient calculations (`with torch.no_grad():`).
    - Fetch a batch of test data and transfer it to the device.
    - Perform a forward pass and compute the test loss.
    - Print the test loss.

#### 8. Save the Model
- **Goal**: Save the model's state at certain intervals.
- **Action**: Every few iterations (e.g., `it % 5000 == 0`), save the model's state dictionary.

#### Documentation
- Model Saving and Loading: [PyTorch Saving & Loading](https://pytorch.org/tutorials/beginner/saving_loading_models.html)
- Optimizers: [PyTorch Optim](https://pytorch.org/docs/stable/optim.html)
- Loss Functions: [PyTorch Losses](https://pytorch.org/docs/stable/nn.html#loss-functions)

In [None]:
# Use GPU if cuda is available
device = "cuda"
print(f"Device = {device}")

# Instantiate the UNET model
model = UNETModel()
model = model.to(device)
params = model.parameters()

# Set up optimizer and loss function
optimizer = optim.Adam(params, lr=0.001)
loss_fn = nn.BCEWithLogitsLoss()

# Instantiate BrainImageSet for both training and test sets with appropriate image and label paths
BrainImage_train = BrainImage('Task01_BrainTumour_2D/training_images', 'Task01_BrainTumour_2D/training_labels')
BrainImage_test = BrainImage('Task01_BrainTumour_2D/test_images', 'Task01_BrainTumour_2D/test_labels')

epochs = 5000
for epoch in range(epochs):
  # Fetch a batch of training data and transfer it to the device
  train_images, train_labels = BrainImage_train.get_random_batch(batch_size=30)
  train_images, train_labels = torch.from_numpy(train_images).to(device, dtype=torch.float32), torch.from_numpy(train_labels).to(device, dtype=torch.float32)

  # Perform a forward pass through the model
  model_labels = model(train_images)

  # Match model labels size to target labels
  model_labels = F.interpolate(model_labels, size=train_labels.shape[1:])

  # Ensure model output and labels have the same shape
  model_labels = model_labels.squeeze()  # Remove the channel dimension if it's 1
  train_labels = train_labels.squeeze()  # Same for labels

  optimizer.zero_grad() # Clear previous gradients
  loss = loss_fn(model_labels, train_labels) # Compute loss
  loss.backward() # Carry out backpropagation and calculate gradients
  optimizer.step() # Update model parameters

  # Evaluate model
  if epoch % 500 == 0:
    model.eval()
    with torch.inference_mode():
      # Fetch a batch of testing data and transfer it to the device
      test_images, test_labels = BrainImage_test.get_random_batch(batch_size=30)
      test_images, test_labels = torch.from_numpy(test_images).to(device, dtype=torch.float32), torch.from_numpy(test_labels).to(device, dtype=torch.float32)

      pred_labels = model(test_images)

      pred_labels = F.interpolate(pred_labels, size=test_labels.shape[1:])

      # Ensure they have the same shape
      pred_labels = pred_labels.squeeze()  # Remove the channel dimension if it's 1
      test_labels = test_labels.squeeze()  # Same for labels

      test_loss = loss_fn(pred_labels, test_labels)
      print(f"Loss during training is {loss}. Loss turing testing is {test_loss} ")


Device = cuda
Loss during training is 0.750813364982605. Loss turing testing is 0.6852824091911316 
