<a href="https://colab.research.google.com/github/AndyCatruna/DSM/blob/main/Lab_05a_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Semantic Segmentation

Semantic Segmentation refers to the task of assigning a specific class to each pixel in an image.

<img src="https://i0.wp.com/cdn-images-1.medium.com/max/850/1*f6Uhb8MI4REGcYKkYM9OUQ.png?w=850&resize=850,662&ssl=1" width=600>

In [None]:
!pip install -q torchmetrics

In [None]:
import sys

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import VOCSegmentation
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from torchmetrics.segmentation import MeanIoU

We will use the [Pascal VOC Dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html) for this lab.

The dataset has 20 classes (21 if including background).

We will keep only 3 classes (Background, Person, and Dog) to make training easier. We will discard the rest of the samples that do not include these classes.

The following code performs this filtering. You can ignore it.

In [None]:
''' Code for filtering the dataset '''

CLASS_MAPPING = {
    0: 0,  # Background stays background
    15: 1,  # Person class
    12: 2,   # Dog class
}
num_classes = len(CLASS_MAPPING)

# Remap target classes in segmentation masks
def remap_mask(mask):
    mask[torch.isin(mask, torch.tensor(list(CLASS_MAPPING.keys()))) == 0] = 0
    for old_class, new_class in CLASS_MAPPING.items():
        mask[mask == old_class] = new_class
    return mask

class FilteredVOCSegmentation(Dataset):
    def __init__(self, root='./data', image_set='train', download=True, transform=None, target_transform=None):
        super().__init__()
        self.dataset = VOCSegmentation(
            root=root, 
            image_set=image_set, 
            download=download, 
            transform=transform, 
            target_transform=target_transform
        )

        self.filtered_dataset = self.filter_dataset()
        if image_set == 'train':
            self.class_weights = self.calculate_class_weights()
            print(f'Class weights: {self.class_weights}')

    def calculate_class_weights(self):
        class_counts = torch.zeros(len(CLASS_MAPPING))
        for _, mask in self.filtered_dataset:
            for class_idx in range(len(CLASS_MAPPING)):
                class_counts[class_idx] += torch.sum(mask == class_idx)
        total = torch.sum(class_counts)
        class_weights = total / class_counts

        return class_weights

    def filter_dataset(self):
        filtered_indices = []
        for idx, (img, mask) in enumerate(self.dataset):
          if torch.unique(mask).tolist() != [0]:
            filtered_indices.append(idx)

        filtered_dataset = torch.utils.data.Subset(self.dataset, filtered_indices)
        return filtered_dataset

    def __len__(self):
      return len(self.filtered_dataset)

    def __getitem__(self, idx):
      return self.filtered_dataset[idx]

In [None]:
''' Code for obtaining the dataloaders '''

# Training transforms - You can modify these to obtain better results but you need to make sure you also modify the masks
train_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Test transforms
test_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Mask transforms - We also apply image transforms to the labels
mask_transform = transforms.Compose([
    transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.PILToTensor(),
    remap_mask
])

train_dataset = FilteredVOCSegmentation(
    root='./data',
    image_set='train',
    download=True,
    transform=train_transform,
    target_transform=mask_transform
)

trainloader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4
)

test_dataset = FilteredVOCSegmentation(
    root='./data',
    image_set='val',
    download=True,
    transform=test_transform,
    target_transform=mask_transform
)

testloader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4
)

In [None]:
''' Code for visualizing images, ground truth masks, and segmentation predictions '''

def visualize_images_and_masks(images, masks, predicted_masks=None, num_samples=5):
  images = images.cpu()
  masks = masks.cpu()

  num_cols = 2
  if predicted_masks is not None:
    predicted_masks = predicted_masks.cpu()
    num_cols = 3

  # De-normalize image
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
  images = images * std + mean
  images = torch.clamp(images, 0, 1)

  fig, axes = plt.subplots(num_samples, num_cols, figsize=(6, 3 * num_samples))
  for i in range(num_samples):
    if num_cols == 3:
      ax_img, ax_mask, ax_pred_mask = axes[i]
    else:
      ax_img, ax_mask = axes[i]

    # Plot the image
    ax_img.imshow(images[i].permute(1, 2, 0))
    ax_img.set_title("Image")
    ax_img.axis("off")

    # Plot the mask
    ax_mask.imshow(masks[i].squeeze(0), cmap="Accent", vmin=0, vmax=num_classes)
    ax_mask.set_title("Mask")
    ax_mask.axis("off")

    # Plot the predicted masks
    if predicted_masks is not None:
      ax_pred_mask.imshow(predicted_masks[i].squeeze(0), cmap="Accent", vmin=0, vmax=num_classes)
      ax_pred_mask.set_title("Predicted Mask")
      ax_pred_mask.axis("off")

  plt.tight_layout()
  plt.show()

In [None]:
# Visualize 5 images and masks from the train_loader
images, masks = next(iter(trainloader))
visualize_images_and_masks(images, masks, num_samples=5)

#### U-Net

<img src="https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png" width=600>

We will build a segmentation model from scratch.

We will try to reproduce U-Net with one small difference - At each stage we will extract half the feature maps as in the original model (32 instead of 64, 64 instead of 128 etc.) so as to have less parameters.

Below is the code for the core modules of the U-Net.

In [None]:
class ConvBlock(nn.Module):
  ''' Single Convolution followed by BatchNorm and activation (ReLU) - blue arrow in image '''
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

  def forward(self, x):
    return self.conv(x)

class DoubleConvBlock(nn.Module):
  ''' Two consecutive convolution blocks - double blue arrow in image '''
  def __init__(self, in_channels, out_channels, intermediary_channels=None):
    super().__init__()
    if not intermediary_channels:
        intermediary_channels = out_channels

    self.double_conv = nn.Sequential(
        ConvBlock(in_channels, intermediary_channels),
        ConvBlock(intermediary_channels, out_channels)
    )

  def forward(self, x):
    return self.double_conv(x)

class Downscale(nn.Module):
  ''' Apply max pooling to reduce spatial size then double conv block - red arrow followed by double blue arrow in image '''
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.conv = DoubleConvBlock(in_channels, out_channels)

  def forward(self, x):
    x = self.maxpool(x)
    return self.conv(x)

class Upscale(nn.Module):
  '''Upscale with Transpose convolution then apply conv block - green arrow followed by double blue arrow in image '''

  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
    self.conv = DoubleConvBlock(in_channels, out_channels)

  ''' Input is: 1. output of previous module and 2. output of corresponding Downsample block - gray arrow in the image '''
  def forward(self, x1, x2):
    x1 = self.up(x1)
    x = torch.cat([x2, x1], dim=1)

    return self.conv(x)

class OutputModule(nn.Module):
  '''1x1 Convolution to obtain the prediction - out_channels is equal to number of classes - light blue arrow in image '''
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

  def forward(self, x):
    return self.conv(x)

The UNet has 4 Downscale modules and 4 Upscale modules.

The Upscale operations also take as input the features obtained by the corresponding Downscale modules (the gray arrows in the image).

In [None]:
class UNet(nn.Module):
  def __init__(self, n_channels, n_classes):
    super().__init__()

    # Initial convolutions
    self.initial = DoubleConvBlock(n_channels, 32)

    # Downsampling part
    self.down1 = Downscale(32, 64)
    self.down2 = Downscale(64, 128)
    self.down3 = Downscale(128, 256)
    self.down4 = Downscale(256, 512)

    # Upsampling part
    self.up1 = Upscale(512, 256)
    self.up2 = Upscale(256, 128)
    self.up3 = Upscale(128, 64)
    self.up4 = Upscale(64, 32)
    self.out = OutputModule(32, n_classes)

  def forward(self, x):
    # Initial convolutions
    x1 = self.initial(x)

    # Downsampling part - we need to keep the outputs for the upsampling part
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)

    # Upscaling part
    x = self.up1(x5, x4)
    x = self.up2(x, x3)
    x = self.up3(x, x2)
    x = self.up4(x, x1)

    return self.out(x)

In [None]:
unet = UNet(n_channels=3, n_classes=num_classes)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
unet.to(device)

In [None]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"The model has {count_parameters(unet):,} trainable parameters")

The code for training is very similar to that for classification. The only difference is that now we apply the cross-entropy loss at the pixel level.

For validation we utilize the Mean Intersection over Union (mIoU) - you can read more about it [here](https://www.jeremyjordan.me/evaluating-image-segmentation-models/)

Intersection over Union is simply computed as:

<img src="https://miro.medium.com/v2/resize:fit:640/format:webp/1*2w493Z_V6-sE_3aYa48a9w.png">



The metric can take values in the range 0-1:

<img src="https://miro.medium.com/v2/resize:fit:1100/format:webp/1*kK0G-BmCqigHrc1rXs7tYQ.jpeg" width=500>

In [None]:
def train_epoch(model, dataloader, device, optimizer, criterion, epoch):
    # We set the model to be in training mode
    model.train()

    total_train_loss = 0.0
    dataset_size = 0

    # This is only for showing the progress bar
    bar = tqdm(enumerate(dataloader), total=len(dataloader), colour='cyan', file=sys.stdout)

    # We iterate through all batches - 1 step is 1 batch of batch_size images
    for step, (images, labels) in bar:
        # We take the images and their labels and push them on GPU
        images = images.to(device)
        labels = labels.to(device).squeeze(1)
        labels = labels.long()
        batch_size = images.shape[0]
        # Reset gradients
        optimizer.zero_grad()

        # Obtain predictions
        pred = model(images)
        # Compute loss for this batch
        loss = criterion(pred, labels)

        # Compute gradients for each weight (backpropagation)
        loss.backward()

        # Update weights based on gradients (gradient descent)
        optimizer.step()

        # We keep track of the average training loss
        total_train_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = np.round(total_train_loss / dataset_size, 2)
        bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss)

    return epoch_loss

def valid_epoch(model, dataloader, device, criterion, epoch):
    # We set the model in evaluation mode
    model.eval()

    total_val_loss = 0.0
    dataset_size = 0

    iou_metric = MeanIoU(num_classes=num_classes, input_format='index').to(device)

    # This is only for showing the progress bar
    bar = tqdm(enumerate(dataloader), total=len(dataloader), colour='cyan', file=sys.stdout)

    for step, (images, labels) in bar:
        images = images.to(device)
        labels = labels.to(device).squeeze(1)
        labels = labels.long()
        batch_size = images.shape[0]

        pred = model(images)
        loss = criterion(pred, labels)

        _, predicted = torch.max(pred, 1)

        # Compute IoU for each class
        iou_metric.update(predicted, labels)

        total_val_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = np.round(total_val_loss / dataset_size, 2)

        bar.set_postfix(Epoch=epoch, Valid_Loss=epoch_loss)

    mean_iou = iou_metric.compute()
    print(f"Mean IoU: {mean_iou}")
    return mean_iou, epoch_loss

def run_training(model, num_epochs, learning_rate):
    # Define criterion
    criterion = nn.CrossEntropyLoss(weight=train_dataset.class_weights).to(device)

    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Check if we are using GPU
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))

    # For keeping track of the best validation accuracy
    top_miou = 0.0

    # We train the emodel for a number of epochs
    for epoch in range(num_epochs):

        train_loss = train_epoch(model, trainloader, device, optimizer, criterion, epoch)

        # For validation we do not keep track of gradients
        with torch.no_grad():
            mean_iou, val_loss = valid_epoch(model, testloader, device, criterion, epoch)
            if mean_iou > top_miou:
                print(f"Mean IoU Improved ({top_miou} ---> {mean_iou})")
                top_miou = mean_iou
        print()

In [None]:
run_training(unet, num_epochs=25, learning_rate=0.0001)

In [None]:
# Visualize 5 predictions from training set - re-run this cell to look at other predictions
images, masks = next(iter(trainloader))
images = images.to(device)
masks = masks.to(device)
predicted_masks = unet(images)
predicted_masks = torch.argmax(predicted_masks, dim=1)

visualize_images_and_masks(images, masks, predicted_masks, num_samples=5)

In [None]:
# Visualize 15 predictions from testing set
images, masks = next(iter(testloader))
images = images.to(device)
masks = masks.to(device)
predicted_masks = unet(images)
predicted_masks = torch.argmax(predicted_masks, dim=1)

visualize_images_and_masks(images, masks, predicted_masks, num_samples=15)

**Exercise** - Fine-tune a pre-trained model for image segmentation on this dataset

Our model could not learn very much as the number of samples is low and we did not use augmentations. Use a pre-traind model to obtain better results.

You can use one from [here](https://github.com/qubvel-org/segmentation_models.pytorch)


In [None]:
!pip install -q segmentation-models-pytorch

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import segmentation_models_pytorch as smp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a pre-trained DeepLabV3+ model with ResNet backbone
model = smp.DeepLabV3Plus(
    encoder_name="resnet50",      # Pretrained ResNet50 backbone
    encoder_weights="imagenet",   # Use ImageNet pre-trained weights
    classes=num_classes,          # Number of output classes in segmentation
    activation=None               # No activation, since we use CrossEntropyLoss
).to(device)

learning_rate = 0.001
criterion = nn.CrossEntropyLoss(weight=train_dataset.class_weights.to(device))
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

run_training(model, num_epochs=20, learning_rate=learning_rate)