<a href="https://colab.research.google.com/github/dudeurv/SAM_MRI/blob/main/U_NET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Medical Image Segmentation with U-NET Tutorial

In this tutorial, you will develop and train a convolutional neural network for brain tumour image segmentation.

In [2]:
# Import libraries
import tarfile
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import time
import os
import random
import matplotlib.pyplot as plt
from matplotlib import colors

## Download the imaging dataset

The dataset is curated from the brain imaging dataset in [Medical Decathlon Challenge](http://medicaldecathlon.com/). To save the storage and reduce the computational cost for this tutorial, we extract 2D image slices from T1-Gd contrast enhanced 3D brain volumes and downsample the images.

The dataset consists of a training set and a test set. Each image is of dimension 120 x 120, with a corresponding label map of the same dimension. There are four number of classes in the label map:

- 0: background
- 1: edema
- 2: non-enhancing tumour
- 3: enhancing tumour

In [None]:
# Download the dataset
!wget https://www.dropbox.com/s/zmytk2yu284af6t/Task01_BrainTumour_2D.tar.gz

# Unzip the '.tar.gz' file to the current directory
datafile = tarfile.open('Task01_BrainTumour_2D.tar.gz')
datafile.extractall()
datafile.close()

## Implement a dataset class

#### 1. Implement `normalise_intensity` Function

- **Goal**: Normalize the intensity of an image, focusing on the Region of Interest (ROI).
  
  **Inputs**:
  1. `image`: A NumPy array (2D) of the image.
  2. `thres_roi`: A float value for the percentile threshold (default 1.0).

  **Process**:
  1. Calculate the intensity threshold using `np.percentile`. This step isolates significant parts of the image for focused normalization.
  2. Create a binary ROI mask where pixels greater than or equal to the threshold are marked True.
  3. Compute the mean and standard deviation of the pixel values within the ROI. These statistics are crucial for normalizing the image.
  4. Normalize the image: adjust pixel values based on the calculated mean and standard deviation. This standardization is key for enhancing model training and analysis effectiveness.

  **Outputs**:
  - Normalized image: A NumPy array with adjusted intensity values.

  **Documentation**:
  - NumPy Percentile: [https://numpy.org/doc/stable/reference/generated/numpy.percentile.html](https://numpy.org/doc/stable/reference/generated/numpy.percentile.html)
  - NumPy Mean: [https://numpy.org/doc/stable/reference/generated/numpy.mean.html](https://numpy.org/doc/stable/reference/generated/numpy.mean.html)
  - NumPy Standard Deviation: [https://numpy.org/doc/stable/reference/generated/numpy.std.html](https://numpy.org/doc/stable/reference/generated/numpy.std.html)

#### 2. Define `BrainImageSet` Class

- **Goal**: Manage and preprocess a dataset of brain images for neural network models.

  **Constructor `__init__`**:
  - **Inputs**:
    1. `image_path`: Path to image directory.
    2. `label_path`: Path to label directory (optional, default empty).
    3. `deploy`: Boolean for deployment mode (optional, default False).
  - **Process**:
    1. Initialize instance variables for paths, deployment status, and image/label lists.
    2. Load and sort image filenames to ensure a consistent order for data processing.
    3. Read images and optionally labels using `imageio.imread`, converting files into processable NumPy arrays.
    4. Skip loading labels if `deploy` is True, as labels are not used in deployment mode.

  **Method `__len__`**:
  - **Outputs**: Total number of images in the dataset (Integer).

  **Method `__getitem__`**:
  - **Inputs**: `idx` (Integer index for an image/label pair).
  - **Outputs**: Tuple of a normalized image and its label (both NumPy arrays).
  - **Process**: Fetch and preprocess an image for the given index.

  **Method `get_random_batch`**:
  - **Inputs**: `batch_size` (Integer for the number of pairs to include).
  - **Outputs**: Tuple of a batch of images and labels (both NumPy arrays).
  - **Process**:
    1. Select random indices to form a batch. This randomization is essential for effective neural network training.
    2. Retrieve and normalize images and labels for these indices.
    3. Convert the lists of images and labels to NumPy arrays.
    4. Use `np.expand_dims` on images to add a necessary dimension. Neural networks typically expect a specific input shape, including a channel dimension even for grayscale images.

  **Documentation**:
  - os.listdir: [https://docs.python.org/3/library/os.html#os.listdir](https://docs.python.org/3/library/os.html#os.listdir)
  - os.path.join: [https://docs.python.org/3/library/os.path.html#os.path.join](https://docs.python.org/3/library/os.path.html#os.path.join)
  - imageio.imread: [https://imageio.readthedocs.io/en/stable/userapi.html#imageio.imread](https://imageio.readthedocs.io/en/stable/userapi.html#imageio.imread)
  - random.sample: [https://docs.python.org/3/library/random.html#random.sample](https://docs.python.org/3/library/random.html#random.sample)
  - numpy.expand_dims: [https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html](https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html)

## Build a U-net architecture
#### 1. Set Up the Basic Class Structure
- Define a Python class named `UNet` that inherits from `nn.Module`.
- In the `__init__` method, set up the basic structure:
  - Define initial variables: `input_channel`, `output_channel`, `num_filter`.

#### 2. Create Your First Encoder Block
- Start with creating your first encoder block (`self.conv1`):
  - Use `nn.Sequential` to chain layers.
  - Add two `nn.Conv2d` layers, each followed by `nn.BatchNorm2d` and `nn.ReLU`. The first Conv2d layer should change the channel size from `input_channel` to `num_filter`, and the second Conv2d layer keeps the channel size at `num_filter`.
  - For both Conv2d layers, use `kernel_size=3` and `padding=1`.

#### 3. Add Subsequent Encoder Blocks
- Create the next encoder blocks (`self.conv2`, `self.conv3`, `self.conv4`):
  - For each block, double the number of filters (e.g., `num_filter * 2`, `num_filter * 4`, etc.).
  - Add two Conv2d layers in each block, similar to `self.conv1`, but with the updated number of filters. For the first Conv2d layer in each block, use `stride=2` for downsampling.
  - Remember to include BatchNorm and ReLU layers after each Conv2d layer.

#### 4. Construct Decoder Blocks
- For each decoder block (`self.up3`, `self.conv_up3`, `self.up2`, `self.conv_up2`, `self.up1`, `self.conv_up1`):
  - Start with a `nn.ConvTranspose2d` for upsampling (reverse of downsampling). The number of filters should be halved compared to the preceding block.
  - Follow it with a `nn.Sequential` containing two Conv2d layers (like the encoder), but this time the number of filters decreases with each block.

#### 5. Define the Output Convolution
- Create the final output layer (`self.out`) with a kernel size of 1 using `nn.Conv2d` with `output_channel` filters. This layer maps the deep features to the output classes or segments.

#### 6. Define the Forward Pass

1. **Encoder Forward Pass**:
   - Sequentially pass input `x` through each encoder block.
   - Save the output of each block for skip connections.

2. **Decoder Forward Pass**:
   - Upsample the output of the last encoder block.
   - For each decoder block:
     - Concatenate the upsampled output with the corresponding encoder output using `torch.cat`. Ensure that the feature dimensions are aligned.
     - Pass the concatenated output through the next decoder block.

3. **Final Output**:
   - Pass the output of the last decoder block through the output layer to obtain the final result.

#### Documentation

- nn.Module: [https://pytorch.org/docs/stable/nn.html#module](https://pytorch.org/docs/stable/nn.html#module)
- Conv2d: [https://pytorch.org/docs/stable/nn.html#conv2d](https://pytorch.org/docs/stable/nn.html#conv2d)
- BatchNorm2d: [https://pytorch.org/docs/stable/nn.html#batchnorm2d](https://pytorch.org/docs/stable/nn.html#batchnorm2d)
- ReLU: [https://pytorch.org/docs/stable/nn.html#relu](https://pytorch.org/docs/stable/nn.html#relu)
- ConvTranspose2d: [https://pytorch.org/docs/stable/nn.html#convtranspose2d](https://pytorch.org/docs/stable/nn.html#convtranspose2d)
- torch.cat: [https://pytorch.org/docs/stable/generated/torch.cat.html](https://pytorch.org/docs/stable/generated/torch.cat.html)

## Train the segmentation model
#### 1. Build the Model
- **Goal**: Instantiate the U-Net model.
- **Action**:
  - Define `num_class`, the number of output classes for the segmentation task.
  - Create an instance of the `UNet` class with specified input and output channels, and the number of filters.
  - Transfer the model to the chosen device (`model.to(device)`).

#### 2. Prepare for Saving Models
- **Goal**: Set up a directory to save trained model parameters.
- **Action**: Check if a directory (e.g., `saved_models`) exists; if not, create it.

#### 3. Define the Optimizer
- **Goal**: Set up the optimizer for training.
- **Action**: Use `optim.Adam(params, lr=1e-3)` with the model's parameters and a learning rate.

#### 4. Set Up the Loss Function
- **Goal**: Define the loss function for the segmentation task.
- **Action**: Use `nn.CrossEntropyLoss()` as the criterion.

#### 5. Load Datasets
- **Goal**: Prepare training and test datasets.
- **Action**: Instantiate `BrainImageSet` for both training and test sets with appropriate image and label paths.

#### 6. Train the Model
- **Goal**: Implement the training loop.
- **Actions**:
  - Iterate over a specified number of iterations (`num_iter`).
  - In each iteration:
    - Set the model to training mode (`model.train()`).
    - Fetch a batch of training data and transfer it to the device.
    - Perform a forward pass through the model (`logits = model(images)`).
    - Clear previous gradients (`optimizer.zero_grad()`).
    - Compute loss using `criterion`.
    - Perform backpropagation (`loss.backward()`).
    - Update model parameters (`optimizer.step()`).
    - Print training loss.

#### 7. Evaluate the Model
- **Goal**: Periodically evaluate the model on the test set.
- **Actions**:
  - Every few iterations (e.g., `it % 10 == 0`):
    - Set the model to evaluation mode (`model.eval()`).
    - Disable gradient calculations (`with torch.no_grad():`).
    - Fetch a batch of test data and transfer it to the device.
    - Perform a forward pass and compute the test loss.
    - Print the test loss.

#### 8. Save the Model
- **Goal**: Save the model's state at certain intervals.
- **Action**: Every few iterations (e.g., `it % 5000 == 0`), save the model's state dictionary.

#### Documentation
- Model Saving and Loading: [PyTorch Saving & Loading](https://pytorch.org/tutorials/beginner/saving_loading_models.html)
- Optimizers: [PyTorch Optim](https://pytorch.org/docs/stable/optim.html)
- Loss Functions: [PyTorch Losses](https://pytorch.org/docs/stable/nn.html#loss-functions)