# Project on the CamVid dataset

Implementation of different semantic segmentation networks and their comparison on the task.

In [None]:
from google.colab import drive

drive.mount("/content/drive")

### Imports and deterministic code

In [None]:
import torch
import torchvision
from torch import nn
import zipfile
import os
import shutil
import random
import numpy as np
import torch.nn.functional as F
import PIL
from torch.utils.data import Dataset, DataLoader
import cv2
import matplotlib.pyplot as plt
from torchvision.transforms import v2 as T
from torchsummary import summary
import tqdm
!pip install wandb
import wandb
from copy import deepcopy
from pathlib import Path


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")
def set_seed(seed: int = 42):

    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

## Data

### Loading dataset and obtaining new classes

In [None]:
src_zip = "/content/drive/MyDrive/CamVid/CamVid.zip"
dst_zip = "/content/CamVid"

with zipfile.ZipFile(src_zip, "r") as zip_ref:
    zip_ref.extractall(dst_zip)

TEST_PATH = "/content/CamVid/CamVid/test"
TRAIN_PATH = "/content/CamVid/CamVid/train"
VALID_PATH = "/content/CamVid/CamVid/val"
LABEL_TEST_PATH = "/content/CamVid/CamVid/test_labels"
LABEL_TRAIN_PATH = "/content/CamVid/CamVid/train_labels"
LABEL_VALID_PATH = "/content/CamVid/CamVid/val_labels"

print(f"There are {len([img for img in os.listdir(TRAIN_PATH)])} train images.")
print(f"There are {len([img for img in os.listdir(TEST_PATH)])} test images.")
print(f"There are {len([img for img in os.listdir(VALID_PATH)])} validation images.")

In [None]:
RGBLabel2LabelName = {
    (128, 128, 128): "Sky",

    (0, 128, 64): "Building",
    (128, 0, 0): "Building",
    (64, 192, 0): "Building",
    (64, 0, 64): "Building",
    (192, 0, 128): "Building",

    (192, 192, 128): "Pole",
    (0, 0, 64): "Pole",

    (128, 64, 128): "Road",
    (128, 0, 192): "Road",
    (192, 0, 64): "Road",

    (0, 0, 192): "Sidewalk",
    (64, 192, 128): "Sidewalk",
    (128, 128, 192): "Sidewalk",

    (128, 128, 0): "Tree",
    (192, 192, 0): "Tree",

    (192, 128, 128): "SignSymbol",
    (128, 128, 64): "SignSymbol",
    (0, 64, 64): "SignSymbol",

    (64, 64, 128): "Fence",

    (64, 0, 128): "Car",
    (64, 128, 192): "Car",
    (192, 128, 192): "Car",
    (192, 64, 128): "Car",
    (128, 64, 64): "Car",

    (64, 64, 0): "Pedestrian",
    (192, 128, 64): "Pedestrian",
    (64, 0, 192): "Pedestrian",
    (64, 128, 64): "Pedestrian",

    (0, 128, 192): "Bicyclist",
    (192, 0, 192): "Bicyclist",

    (0, 0, 0): "Void"
}

CAMVID_CLASSES = ['Sky',
                  'Building',
                  'Pole',
                  'Road',
                  'Sidewalk',
                  'Tree',
                  'SignSymbol',
                  'Fence',
                  'Car',
                  'Pedestrian',
                  'Bicyclist',
                  'Void']
Class2LabelId = {}

for i, v in enumerate(CAMVID_CLASSES):
    Class2LabelId[v] = i

Class2LabelId['Void'] = 255
NUM_CLASSES = len(CAMVID_CLASSES)
print(NUM_CLASSES)

CAMVID_CLASS_COLORS = [
    (128, 128, 128),  # Sky
    (128, 0, 0),      # Building
    (192, 192, 128),  # Pole
    (128, 64, 128),   # Road
    (0, 0, 192),      # Sidewalk
    (128, 128, 0),    # Tree
    (192, 128, 128),  # SignSymbol
    (64, 64, 128),    # Fence
    (64, 0, 128),     # Car
    (64, 64, 0),      # Pedestrian
    (0, 128, 192),    # Bicyclist
    (0, 0, 0),        # Void / background
]

# Build a dictionary {class_id: RGB tuple}
COLORMAP = {i: color for i, color in enumerate(CAMVID_CLASS_COLORS)}

In [None]:
def convert32to11(src_label_dir, dst_label_dir):

  if not os.path.exists(dst_label_dir):
    os.makedirs(dst_label_dir)
  else:
    print(f"{dst_label_dir} already exists")
    return

  img_names = sorted([img_name for img_name in os.listdir(src_label_dir)])

  for img in img_names:

    print(os.path.join(src_label_dir, img))
    image = PIL.Image.open(os.path.join(src_label_dir, img))
    np_img = np.array(image)
    ret_img = np.ones(np_img.shape[:2], dtype=np.uint8) * 255
    w, h = np_img.shape[:2]

    for x in range(w):
      for y in range(h):
        if tuple(np_img[x,y]) in RGBLabel2LabelName:
          ret_img[x,y] = Class2LabelId[RGBLabel2LabelName[tuple(np_img[x,y])]]
        else:
          ret_img[x,y] = 255

    ret_img = PIL.Image.fromarray(ret_img)
    ret_img.save(os.path.join(dst_label_dir, img))
    print(f"Converted {img}")
  print("Done")


In [None]:
DRIVE_BASE = "/content/drive/MyDrive/CamVid/CamVid_Labels"
# convert32to11(LABEL_TEST_PATH, f"{DRIVE_BASE}/test_labels_11")
# convert32to11(LABEL_TRAIN_PATH, f"{DRIVE_BASE}/train_labels_11")
# convert32to11(LABEL_VALID_PATH, f"{DRIVE_BASE}/valid_labels_11")
LABEL_TRAIN_PATH = "/content/drive/MyDrive/CamVid/CamVid_Labels/train_labels_11"
LABEL_TEST_PATH = "/content/drive/MyDrive/CamVid/CamVid_Labels/test_labels_11"
LABEL_VALID_PATH = "/content/drive/MyDrive/CamVid/CamVid_Labels/valid_labels_11"

### Add train images

In [None]:
num_to_move = 132  # number of images to move

# Directories
images_root = "/content/CamVid/CamVid"
labels_root = "/content/drive/My Drive/CamVid/CamVid_Labels"

# Subdirectories
train_img_dir = os.path.join(images_root, "train")
test_img_dir = os.path.join(images_root, "test")
train_label_dir = os.path.join(labels_root, "train_labels_11_v2")
test_label_dir = os.path.join(labels_root, "test_labels_11_v2")

train_2_dir = "/content/CamVid/CamVid/train_2"
test_2_dir = "/content/CamVid/CamVid/test_2"

for src, dst in [(train_img_dir, train_2_dir), (test_img_dir, test_2_dir)]:
    os.makedirs(dst, exist_ok=True)
    for file in os.listdir(src):
        src_path = os.path.join(src, file)
        dst_path = os.path.join(dst, file)
        if os.path.isfile(src_path):
            shutil.copy(src_path, dst_path)
print("Files in destination:", len(os.listdir(train_2_dir)))
print("File in source: ", len(os.listdir(test_2_dir)))

In [None]:
# EXECUTE ONLY FIRST TIME

# save_list_path = "/content/drive/My Drive/CamVid/CamVid_Labels/masks_to_move.txt"

# test_masks = [f for f in os.listdir(test_label_dir) if f.endswith("png")]
# print(f"Found {len(test_masks)} masks in test_labels_11_v2")

# selected_masks = random.sample(test_masks, min(num_to_move, len(test_masks)))
# with open(save_list_path, "w") as f:
#     for name in selected_masks:
#         f.write(name + "\n")

# print(f"Saved list of {len(selected_masks)} randomly chosen masks to: {save_list_path}")
# print("Example of first 5 chosen files:", selected_masks[:5])

# with open("/content/drive/My Drive/CamVid/CamVid_Labels/masks_to_move.txt", 'r') as f:
#     filenames = [line.strip() for line in f if line.strip()]

# for name in filenames:
#     src_path = os.path.join(test_label_dir, name)
#     dst_path = os.path.join(train_label_dir, name)

#     if os.path.exists(src_path):
#         shutil.move(src_path, dst_path)
#     else:
#         print(f"⚠️ File not found: {name}")

In [None]:
import re

with open("/content/drive/My Drive/CamVid/CamVid_Labels/masks_to_move.txt", 'r') as f:
    raw_names = [os.path.splitext(line.strip())[0] for line in f if line.strip()]

def clean_name(name):
    return re.split(r'_L|\s|\(', name)[0].lower()

base_names = [clean_name(n) for n in raw_names]
print(base_names[:5])

# Ensure destination exists
os.makedirs(train_2_dir, exist_ok=True)

# Loop through source directory and move matches
moved = 0
for filename in os.listdir(test_2_dir):
    name_no_ext, ext = os.path.splitext(filename)
    if name_no_ext.lower() in base_names:
        shutil.move(os.path.join(test_2_dir, filename), os.path.join(train_2_dir, filename))
        moved += 1

print(f"Moved {moved} files from {test_2_dir} to {train_2_dir}")

NEW_TRAIN_LABELS = "/content/drive/My Drive/CamVid/CamVid_Labels/train_labels_11_v2"
NEW_TEST_LABELS = "/content/drive/My Drive/CamVid/CamVid_Labels/test_labels_11_v2"
NEW_TRAIN_PATH = "/content/CamVid/CamVid/train_2"
NEW_TEST_PATH = "/content/CamVid/CamVid/test_2"

### Dataset class

In [None]:
class CamVidDataset(Dataset):

  def __init__(self, img_dir, mask_dir, width, height, transforms=None, color_transforms=None):

    self.img_dir = img_dir
    self.mask_dir = mask_dir

    self.images = sorted([img_name for img_name in os.listdir(img_dir)])
    self.masks = sorted([mask_name for mask_name in os.listdir(mask_dir)])

    self.transforms = transforms
    self.color_transforms = color_transforms
    self.width = width
    self.height = height

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):

    img_name = self.images[idx]
    mask_name = self.masks[idx]

    img = cv2.imread(os.path.join(self.img_dir, img_name))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(os.path.join(self.mask_dir, mask_name), cv2.IMREAD_GRAYSCALE)

    img_resized = cv2.resize(img, (self.width, self.height))
    mask_resized = cv2.resize(mask, (self.width, self.height), interpolation=cv2.INTER_NEAREST)

    img_resized = torch.from_numpy(img_resized).permute(2, 0, 1).float() / 255.0
    mask_resized = torch.from_numpy(mask_resized).long()

    mask_resized = mask_resized.unsqueeze(0)

    if self.transforms:
      img_tv = torchvision.tv_tensors.Image(img_resized)
      mask_tv = torchvision.tv_tensors.Mask(mask_resized)
      img_resized, mask_resized = self.transforms(img_tv, mask_tv)

    if self.color_transforms:
      img_resized = self.transforms(img_resized)

    return img_resized, mask_resized

### Data transformation

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

def get_train_transform():

  return T.Compose([
      T.ToDtype(torch.uint8, scale=True),
      T.RandomHorizontalFlip(p=0.5),
      T.RandomRotation(degrees=30),
      T.RandomPerspective(distortion_scale=0.2, p=0.5),
      T.ToDtype(torch.float32, scale=True),
      T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])

def get_valid_transform():

  return T.Compose([
      T.ToDtype(torch.float32, scale=True),
      T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])

### Class weights

In [None]:
t_d = CamVidDataset(NEW_TRAIN_PATH, NEW_TRAIN_LABELS, 520, 520, get_train_transform())
t_loader = DataLoader(t_d, batch_size=8, shuffle=False, collate_fn=collate_fn)

num_classes = 11
class_counts = torch.zeros(num_classes)

for _, targets in t_loader:  # targets: (B, H, W)
    targets = torch.stack(targets, dim=0)  # (B, H, W) -> (B, 1, H, W)
    targets = targets.squeeze(1)  # (B, 1, H, W) -> (B, H, W)

    for cls in range(num_classes):
        mask = (targets == cls) & (targets != 255)
        class_counts[cls] += mask.sum().item()

# Normalizing
class_weights = 1.0 / (class_counts + 1e-6)  # avoid division by zero
class_weights = class_weights / class_weights.max()

# Median frequency balancing
# frequencies = class_counts / class_counts.sum()
# median_freq = torch.median(frequencies[frequencies > 0])
# class_weights = median_freq / (frequencies + 1e-6)
CLASS_WEIGHTS = torch.cat([class_weights, torch.tensor([0.0])])
for i in range(len(CLASS_WEIGHTS)):
  print(f"Weights for class {CAMVID_CLASSES[i]} is : {CLASS_WEIGHTS[i]:.3f}.")

In [None]:
freq_np = class_counts.cpu().numpy()

freq_np = (freq_np / freq_np.sum())

# Sort for better readability
sorted_idx = np.argsort(freq_np)[::-1]
freq_np = freq_np[sorted_idx]
CAMVID_CLASSES_SORTED = [CAMVID_CLASSES[i] for i in sorted_idx]

# Create figure
plt.figure(figsize=(10, 6))

# Define colors (optional but improves visual clarity)
colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(freq_np)))

# Bar chart
bars = plt.bar(CAMVID_CLASSES_SORTED[:11], freq_np, color=colors, edgecolor='black', alpha=0.85)

# Add percentage labels on top of each bar
for bar, val in zip(bars, freq_np):
    plt.text(
        bar.get_x() + bar.get_width()/2,
        val + 0.005,
        f"{val*100:.1f}%",
        ha='center', va='bottom', fontsize=10
    )

# Titles and labels
plt.title("CamVid Class Frequency Distribution", fontsize=16, weight='bold')
plt.ylabel("Frequency", fontsize=12)
plt.xlabel("Classes", fontsize=12)

# Rotate x labels for readability
plt.xticks(rotation=40, ha='right')

# Optional grid for better readability
plt.grid(axis='y', linestyle='--', alpha=0.6)

plt.tight_layout()
plt.show()

## Visualization utilities

In [None]:
def colorize_mask(mask):
    """Convert single-channel label mask to color RGB image."""
    mask = mask.squeeze(0)
    mask = mask.cpu().numpy()
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for k, color in COLORMAP.items():
        color_mask[mask == k] = color
    return color_mask

def denormalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    for t, m, s in zip(image, mean, std):
        t.mul_(s).add_(m)
    return image

def visualize(image, mask, predicted=None):

  plt.figure(figsize=(16,12))

  len = 2
  if predicted is not None:
    len = 3
  denormalize(image)
  plt.subplot(1, len, 1)
  plt.imshow(image.permute(1, 2, 0))
  plt.axis('off')
  plt.title('Image')

  plt.subplot(1, len, 2)
  plt.imshow(colorize_mask(mask))
  plt.axis('off')
  plt.title('Ground Truth')
  if predicted is not None:
    plt.subplot(1, len, len)
    plt.imshow(colorize_mask(predicted))
    plt.axis('off')
    plt.title('Prediction')
  plt.show()

def plot_curves(train_loss, val_loss, val_metric):
    epochs = range(1, len(train_loss) + 1)

    plt.figure(figsize=(10, 5))

    # ---- Subplot 1: Training & Validation Loss ----
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, label='Train Loss', linewidth=2)
    plt.plot(epochs, val_loss, label='Validation Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)

    # ---- Subplot 2: Validation Metric ----
    plt.subplot(1, 2, 2)
    plt.plot(epochs, val_metric, color='green', label='Validation Metric', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Mean IoU')
    plt.title('Validation Metric')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

def print_trainable_params(model):

  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  total_params = sum(p.numel() for p in model.parameters())
  print(f"Trainable parameters: {trainable_params:,} / {total_params:,}")


In [None]:
data = CamVidDataset(NEW_TRAIN_PATH, NEW_TRAIN_LABELS, 520, 520, get_valid_transform())
data_transf = CamVidDataset(NEW_TRAIN_PATH, NEW_TRAIN_LABELS, 520, 520, get_train_transform())

img, mask = data[0]
img_transf, mask_transf = data_transf[0]

print(f"Image shape: {img.shape}")
print(f"Mask shape: {mask.shape}")

visualize(img, mask)
visualize(data[5][0], data[5][1])
visualize(data[20][0], data[20][1])
visualize(data[100][0], data[100][1])
visualize(data[300][0], data[300][1])
visualize(img_transf, mask_transf)


## Evaluation functions

I will evaluate using mean Intersection over Union and I will also provide a pixel accuracy measure. Moreover, here, I implement also a function that completely evaluates the model on a test set.

In [None]:
def multiclass_iou(preds: torch.Tensor, targets: torch.Tensor, num_classes: int = 11, void_label: int = 255):

    # Convert logits to predicted class IDs
    pred_classes = preds.argmax(dim=1)  # (B, H, W)

    # Mask out void pixels
    valid_mask = targets != void_label
    pred_classes = pred_classes[valid_mask]
    targets = targets[valid_mask]

    # Flatten everything
    pred_classes = pred_classes.view(-1)
    targets = targets.view(-1)

    intersection = torch.zeros(num_classes, dtype=torch.float32, device=preds.device)
    union = torch.zeros(num_classes, dtype=torch.float32, device=preds.device)

    # Compute per-class intersection and union
    for cls in range(num_classes):
        pred_mask = pred_classes == cls
        target_mask = targets == cls

        inter = (pred_mask & target_mask).sum()
        u = pred_mask.sum() + target_mask.sum() - inter

        intersection[cls] = inter
        union[cls] = u

    # IoU per class
    iou_per_class = intersection / (union + 1e-6)
    iou_per_class[union == 0] = float('nan')  # skip absent classes

    # Mean IoU over valid (non-NaN) classes
    mean_iou = torch.nanmean(iou_per_class)

    return iou_per_class, mean_iou


def compute_pixel_accuracy_multiclass(model, dataloader, device='cuda', pretrained=False):

    model.eval()
    correct_pixels = 0
    total_pixels = 0

    with torch.no_grad():
        for images, masks in dataloader:
            images = torch.stack(images, dim=0).to(device)
            masks = torch.stack(masks, dim=0).to(device)

            # Ensure masks have shape [B, H, W]
            if masks.ndim == 4 and masks.shape[1] == 1:
                masks = masks.squeeze(1)

            # Forward pass
            outputs = model(images)
            if pretrained:
              outputs = outputs['out']

            # Predicted class per pixel
            preds = torch.argmax(outputs, dim=1)  # [B, H, W]

            # Count correct pixels
            correct_pixels += (preds == masks).sum().item()
            total_pixels += masks.numel()

    pixel_acc = correct_pixels / total_pixels
    return pixel_acc

In [None]:
@torch.no_grad()
def evaluate_model(model, test_loader, device, num_classes=11, void_label=255, pretrained=False, class_weights=None):
    model.eval()

    total_intersection = torch.zeros(num_classes, dtype=torch.float32, device=device)
    total_union = torch.zeros(num_classes, dtype=torch.float32, device=device)

    correct_pixels = 0
    total_pixels = 0

    for images, targets in tqdm.tqdm(test_loader, desc="Evaluating", leave=False):
        images = torch.stack(images, dim=0).to(device)
        targets = torch.stack(targets, dim=0).to(device).squeeze(1)

        # Forward pass
        preds = model(images)  # (B, C, H, W)

        if pretrained:
          preds = preds['out']

        # Compute per-batch IoU stats
        pred_classes = preds.argmax(dim=1)
        correct_pixels += (pred_classes == targets).sum().item()
        total_pixels += targets.numel()

        valid_mask = targets != void_label
        valid_mask = valid_mask
        pred_classes = pred_classes[valid_mask]
        targets_flat = targets[valid_mask]

        for cls in range(num_classes):
            pred_mask = pred_classes == cls
            target_mask = targets_flat == cls

            inter = (pred_mask & target_mask).sum()
            u = pred_mask.sum() + target_mask.sum() - inter

            total_intersection[cls] += inter
            total_union[cls] += u

    # Compute final IoU
    iou_per_class = total_intersection / (total_union + 1e-6)
    iou_per_class[total_union == 0] = float('nan')

    if class_weights is None:
      mean_iou = torch.nanmean(iou_per_class)
    else:
      weights = class_weights.to(device).clone()
      # weigths = weights[:num_classes]
      weights[torch.isnan(iou_per_class)] = 0
      weights = weights / weights.sum()
      mean_iou = torch.nanmean(iou_per_class * weights)


    print("\nPer-class IoU:")
    for i, iou in enumerate(iou_per_class):
        print(f"  Class {i} - {CAMVID_CLASSES[i]}: {iou.item():.4f}")
    print(f"\nMean IoU: {mean_iou.item():.4f}")
    print(f"\nPixel Accuracy: {correct_pixels / total_pixels:.4f}")

    return iou_per_class, mean_iou

## Focal loss

In [None]:
class FocalLoss(nn.Module):

    def __init__(self, alpha=None, gamma=2, ignore_index=255, device=device):
      super().__init__()
      self.alpha = alpha.to(device)
      self.gamma = gamma
      self.ignore_index = ignore_index
      self.device = device

    def forward(self, inputs, targets):

      log_prob = F.log_softmax(inputs, dim=1)
      prob = torch.exp(log_prob)

      targets[targets==255] = 11

      idx = targets.unsqueeze(1)  #[B,1,H,W]
      log_prob_true_classes = torch.gather(log_prob, dim=1, index=idx).squeeze(1) #[B,H,W]
      prob_true_classes = torch.exp(log_prob_true_classes)
      alpha_true_classes = self.alpha[targets] #[B,H,W]

      focal = - (alpha_true_classes * (1 - prob_true_classes) ** self.gamma * log_prob_true_classes)

      valid_mask = (targets != self.ignore_index)
      focal = focal[valid_mask]

      return focal.mean()

## Models Implementation

I will implement different models and compare their performances on the dataset. I will try:
- U-Net with traditional VGG backbone
- U-Net with ResNets backbones
- DeepLab V3+
- Use the pretrained DeepLab V3 Pytorch provides

In [None]:
!pip install thop

from thop import profile

### U-net

In [None]:
class DoubleConvBlock(nn.Module):

  def __init__(self, in_channels, out_channels, dropout=0.0):

    super().__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
    self.relu = nn.ReLU()
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.use_dropout = dropout > 0
    if self.use_dropout:
      self.dropout = nn.Dropout2d(p=dropout)

  def forward(self, inputs):

    x = self.conv1(inputs)
    x = self.relu(x)
    x = self.bn1(x)
    if self.use_dropout:
      x = self.dropout(x)
    x = self.conv2(x)
    x = self.relu(x)
    x = self.bn2(x)

    return x


class EncoderLayer(nn.Module):

  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.cnv = DoubleConvBlock(in_channels, out_channels)
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

  def forward(self, inputs):

    skip = self.cnv(inputs)
    x = self.pool(skip)

    return x, skip


class DecoderLayer(nn.Module):

  def __init__(self, in_channels, out_channels):

    super().__init__()

    self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
    self.cnv = DoubleConvBlock(in_channels, out_channels)

  def forward(self, inputs, skip):

    x = self.up(inputs)
    x = torch.cat([x, skip], axis=1)
    x = self.cnv(x)

    return x


class UNet(nn.Module):

  def __init__(self, n_classes):

    super().__init__()

    self.encoder_1 = EncoderLayer(3, 64)
    self.encoder_2 = EncoderLayer(64, 128)
    self.encoder_3 = EncoderLayer(128, 256)
    self.encoder_4 = EncoderLayer(256, 512)

    # Bottlneck + added dropout
    self.bottleneck = DoubleConvBlock(512, 1024, dropout=0.5)

    self.decoder_1 = DecoderLayer(1024, 512)
    self.decoder_2 = DecoderLayer(512, 256)
    self.decoder_3 = DecoderLayer(256, 128)
    self.decoder_4 = DecoderLayer(128, 64)

    # Scorer
    self.last = nn.Conv2d(64, n_classes, kernel_size=1)

  def forward(self, inputs):

    x, skip_1 = self.encoder_1(inputs)
    x, skip_2 = self.encoder_2(x)
    x, skip_3 = self.encoder_3(x)
    x, skip_4 = self.encoder_4(x)

    x = self.bottleneck(x)

    x = self.decoder_1(x, skip_4)
    x = self.decoder_2(x, skip_3)
    x = self.decoder_3(x, skip_2)
    x = self.decoder_4(x, skip_1)

    x = self.last(x)

    return x

unet = UNet(NUM_CLASSES).to(device)
# summary(unet, input_size=(3, 512, 512))

input_tensor = torch.randn(1, 3, 512, 512).to(device)
flops, params = profile(unet, inputs=(input_tensor, ))
print(f"FLOPs: {flops/1e9:.2f} GFLOPs")
print(f"Params: {params/1e6:.2f} M")

### ResNet Unet

#### ResNet-34

In [None]:
class DecoderBlock(nn.Module):

  def __init__(self, up_input_c, output_c, in_conv=None, dropout=0.1):

    super().__init__()
    if in_conv is None:
      in_conv = up_input_c

    self.up = nn.ConvTranspose2d(up_input_c, output_c, kernel_size=2, stride=2, padding=0)
    self.conv = nn.Sequential(
        nn.Conv2d(in_conv, output_c, kernel_size=3, padding=1),
        nn.BatchNorm2d(output_c),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Conv2d(output_c, output_c, kernel_size=3, padding=1),
        nn.BatchNorm2d(output_c),
        nn.ReLU()
    )

  def forward(self, inputs, skip):
    x = self.up(inputs)
    x = torch.cat([x, skip], dim=1)
    x = self.conv(x)

    return x

class UNetResNet34(nn.Module):

  def __init__(self, n_classes, dropout=0.1, pretrained=True):

    super().__init__()

    if pretrained:
      backbone = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.DEFAULT)
    else:
      backbone = torchvision.models.resnet34(weights=None)

    # ENCODER

    # Input: 3 x H x W ---> Output: 64 x H/2 x W/2
    self.conv1 = backbone.conv1
    self.bn1 = backbone.bn1
    self.relu = backbone.relu

    # Input: 64 x H/2 x W/2 --> Output: 64 x H/4 x W/4
    self.maxpool = backbone.maxpool

    # Input: 64 x H/4 x W/4 --> Output: 128 x H/4 x W/4
    self.enc1 = backbone.layer1

    #Input: 128 x H/4 x W/4 ---> Output: 256 x H/8 x W/8
    self.enc2 = backbone.layer2

    #Input: 256 x H/8 x W/8 ---> Output: 512 x H/16 x W/16
    self.enc3 = backbone.layer3

    # Input: 512 x H/16 x W/16 ---> Output: 512 x H/32 x W/32
    self.enc4 = backbone.layer4

    # Input: 512 x H/32 x W/32 ---> Output: 1024 x H/64 x W/64
    self.bottleneck = nn.Sequential(
        nn.Conv2d(512, 1024, kernel_size=3, padding=1),
        nn.BatchNorm2d(1024),
        nn.ReLU(),
        nn.Dropout(p=dropout),
        nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    )

    # DECODER
    # Input: from upsampling 1024 x H/64 x W/64, from convolution 512 x H/32 x W/32
    # Output: 512 x H/32 x W/32
    self.decoder1 = DecoderBlock(1024, 512, dropout=dropout)

    # Input: from upsampling 512 x H/32 x W/32, from convolution 256 x H/16 x W/16
    # Output: 256 x H/16 x W/16
    self.decoder2 = DecoderBlock(512, 256, dropout=dropout)

    # Input: from upsampling 256 x H/16 x W/16, from convolution 128 x H/8 x W/8
    # Output: 128 x H/8 x W/8
    self.decoder3 = DecoderBlock(256, 128, dropout=dropout)

    # Input: from upsampling 128 x H/8 x W/8, from convolution 64 x H/4 x W/4
    # Output: 64 x H/4 x W/4
    self.decoder4 = DecoderBlock(128, 64, dropout=dropout)

    # Input: from upsampling 64 x H/4 x W/4, from convolution 64 x H/2 x W/2
    # Output: 64 x H/2 x W/2
    self.decoder5 = DecoderBlock(64, 64, dropout=dropout, in_conv=128)

    # Input: 64 x H/2 x W/2, Output: n_classes x H x W
    self.last = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0),
            nn.Conv2d(64, n_classes, kernel_size=3, padding=1)
    )

  def forward(self, inputs):

    e0 = self.conv1(inputs)
    e0 = self.bn1(e0)
    e0 = self.relu(e0)
    x = self.maxpool(e0)

    # Encoder
    enc1 = self.enc1(x)
    enc2 = self.enc2(enc1)
    enc3 = self.enc3(enc2)
    enc4 = self.enc4(enc3)

    bottleneck = self.bottleneck(enc4)

    # Decoder
    dec1 = self.decoder1(bottleneck, enc4)
    dec2 = self.decoder2(dec1, enc3)
    dec3 = self.decoder3(dec2, enc2)
    dec4 = self.decoder4(dec3, enc1)
    dec5 = self.decoder5(dec4, e0)
    out = self.last(dec5)

    return out

In [None]:
unet = UNetResNet34(NUM_CLASSES).to(device)
# summary(unet, input_size=(3, 512, 512))
input_tensor = torch.randn(1, 3, 512, 512).to(device)
flops, params = profile(unet, inputs=(input_tensor, ))
print(f"FLOPs: {flops/1e9:.2f} GFLOPs")
print(f"Params: {params/1e6:.2f} M")

#### ResNet-50 and ResNet-101

In [None]:
class DecoderBlock(nn.Module):

  def __init__(self, up_input_c, output_c, in_conv=None, dropout=0.1):

    super().__init__()
    if in_conv is None:
      in_conv = up_input_c

    self.up = nn.ConvTranspose2d(up_input_c, output_c, kernel_size=2, stride=2, padding=0)
    self.conv = nn.Sequential(
        nn.Conv2d(in_conv, output_c, kernel_size=3, padding=1),
        nn.BatchNorm2d(output_c),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Conv2d(output_c, output_c, kernel_size=3, padding=1),
        nn.BatchNorm2d(output_c),
        nn.ReLU()
    )

  def forward(self, inputs, skip):
    x = self.up(inputs)
    x = torch.cat([x, skip], dim=1)
    x = self.conv(x)

    return x

class UNetResNet(nn.Module):

  def __init__(self, n_classes, backbone='resnet101', dropout=0.1):

    super().__init__()

    if backbone == 'resnet101':
      backbone = torchvision.models.resnet101(weights=torchvision.models.ResNet101_Weights.DEFAULT)
      filters = [64, 256, 512, 1024]
    elif backbone == 'resnet50':
      backbone = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
      filters = [64, 256, 512, 1024]
    else:
      raise ValueError('Unknown backbone')

    # ENCODER
    # Input: 3 x H x W ---> Output: 64 x H/2 x W/2
    self.conv1 = backbone.conv1
    self.bn1 = backbone.bn1
    self.relu = backbone.relu

    # Input: 64 x H/2 x W/2 --> Output: 64 x H/4 x W/4
    self.maxpool = backbone.maxpool

    # Input: 64 x H/4 x W/4 --> Output: 256 x H/4 x W/4
    self.enc1 = backbone.layer1

    #Input: 256 x H/4 x W/4 ---> Output: 512 x H/8 x W/8
    self.enc2 = backbone.layer2

    #Input: 512 x H/8 x W/8 ---> Output: 1024 x H/16 x W/16
    self.enc3 = backbone.layer3

    # Input: 1024 x H/16 x W/16 ---> Output: 2048 x H/32 x W/32
    self.enc4 = backbone.layer4

    # DECODER
    # Output: 512 x H/16 x W/16
    self.decoder1 = DecoderBlock(filters[3]*2, filters[3], dropout=dropout)

    # Output: 256 x H/8 x W/8
    self.decoder2 = DecoderBlock(filters[3], filters[2], dropout=dropout)

    # Output: 128 x H/4 x W/4
    self.decoder3 = DecoderBlock(filters[2], filters[1], dropout=dropout)

    # Output: 64 x H/2 x W/2
    self.decoder4 = DecoderBlock(filters[1], filters[0], dropout=dropout, in_conv=128)


    # Input: 64 x H/2 x W/2, Output: n_classes x H x W
    self.last = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0),
            nn.Conv2d(64, n_classes, kernel_size=3, padding=1)
    )

  def forward(self, inputs):

    e0 = self.conv1(inputs)
    e0 = self.bn1(e0)
    e0 = self.relu(e0)
    x = self.maxpool(e0)

    # Encoder
    enc1 = self.enc1(x)
    enc2 = self.enc2(enc1)
    enc3 = self.enc3(enc2)
    enc4 = self.enc4(enc3)

    # Decoder
    dec1 = self.decoder1(enc4, enc3)
    dec2 = self.decoder2(dec1, enc2)
    dec3 = self.decoder3(dec2, enc1)
    dec4 = self.decoder4(dec3, e0)
    out = self.last(dec4)

    return out

In [None]:
unet = UNetResNet(NUM_CLASSES).to(device)
# summary(unet, input_size=(3, 512, 512))
input_tensor = torch.randn(1, 3, 512, 512).to(device)
flops, params = profile(unet, inputs=(input_tensor, ))
print(f"FLOPs: {flops/1e9:.2f} GFLOPs")
print(f"Params: {params/1e6:.2f} M")

In [None]:
unet = UNetResNet(NUM_CLASSES, backbone='resnet50').to(device)
# summary(unet, input_size=(3, 512, 512))
input_tensor = torch.randn(1, 3, 512, 512).to(device)
flops, params = profile(unet, inputs=(input_tensor, ))
print(f"FLOPs: {flops/1e9:.2f} GFLOPs")
print(f"Params: {params/1e6:.2f} M")

### DeepLab V3+

#### Dilated Convolution Block

In [None]:
class AtrousConvolution(nn.Module):

  def __init__(self, in_c, out_c, kernel_size, pad, dilation_rate):

    super().__init__()

    self.conv = nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, dilation=dilation_rate, padding=pad)
    self.bn = nn.BatchNorm2d(out_c)
    self.relu = nn.ReLU()

  def forward(self, x):

    x = self.conv(x)
    x = self.bn(x)
    x = self.relu(x)

    return x

#### ASPP module

In [None]:
class ASPP(nn.Module):

  def __init__(self, in_c, out_c):
    super().__init__()

    self.conv1 = AtrousConvolution(in_c, out_c, kernel_size=1, dilation_rate=1, pad=0)

    self.conv6 = AtrousConvolution(in_c, out_c, kernel_size=3, dilation_rate=6, pad=6)

    self.conv12 = AtrousConvolution(in_c, out_c, kernel_size=3, dilation_rate=12, pad=12)

    self.conv18 = AtrousConvolution(in_c, out_c, kernel_size=3, dilation_rate=18, pad=18)

    self.image_pool = nn.Sequential(
        nn.AdaptiveAvgPool2d(1),
        nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, padding=0, dilation=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU()
    )

    self.final_conv = nn.Sequential(
        nn.Conv2d(out_c * 5, out_c, kernel_size=1, stride=1, padding=0, dilation=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(),
        nn.Dropout(0.5)
    )

  def forward(self, x):

    x_1 = self.conv1(x)
    x_6 = self.conv6(x)
    x_12 = self.conv12(x)
    x_18 = self.conv18(x)
    x_pooled = self.image_pool(x)

    x_pooled = F.interpolate(x_pooled, size=x_18.size()[2:], mode='bilinear', align_corners=True)

    merged = torch.cat((x_1, x_6, x_12, x_18, x_pooled), dim=1)
    x_final = self.final_conv(merged)

    return x_final



#### Decoder

In [None]:
class Decoder(nn.Module):

  def __init__(self, num_classes):

    super().__init__()

    self.conv1 = nn.Conv2d(256, 48, kernel_size=1, stride=1, padding=0)
    self.bn1 = nn.BatchNorm2d(48)
    self.relu = nn.ReLU()

    self.last_conv = nn.Sequential(
        nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        nn.Dropout(0.1),
    )
    self.classifier = nn.Conv2d(256, num_classes, kernel_size=1)


  def forward(self, low_level_features, x):

    low_level = self.conv1(low_level_features)
    low_level = self.bn1(low_level)
    low_level = self.relu(low_level)
    x = torch.cat((x, low_level), dim=1)
    x = self.last_conv(x)
    x = F.interpolate(x, scale_factor=(4,4), mode='bilinear', align_corners=True)
    x = self.classifier(x)

    return x

#### Backbone

In [None]:
class Backbone(nn.Module):

     def __init__(self, backbone_name='resnet101', low_level='layer1', high_level='layer4', pretrained=True):
        super().__init__()

        assert low_level in ['conv1', 'layer1', 'layer2', 'layer3', 'layer4'], \
            f"Invalid low_level '{low_level}', choose from ['conv1', 'layer1', 'layer2', 'layer3', 'layer4']"

        # To create a total stride of 16 in the backbone
        dilation = [False, False, True]

        if backbone_name == 'resnet50':
            backbone = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT if pretrained else None,
                                                   replace_stride_with_dilation=dilation)
        elif backbone_name == 'resnet101':
            backbone = torchvision.models.resnet101(weights=torchvision.models.ResNet101_Weights.DEFAULT if pretrained else None,
                                                    replace_stride_with_dilation=dilation)
        else:
            raise ValueError(f"Unsupported backbone: {backbone_name}")

        self.conv1 = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu
        )
        self.maxpool = backbone.maxpool
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4

        self.low_level = low_level
        self.high_level = high_level

     def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        l1 = self.layer1(x)
        l2 = self.layer2(l1)
        l3 = self.layer3(l2)
        l4 = self.layer4(l3)

        level_dict = {
            'conv1': x,
            'layer1': l1,
            'layer2': l2,
            'layer3': l3,
            'layer4': l4,
        }
        low_level_feat = level_dict[self.low_level]

        high_level_feat = level_dict[self.high_level]

        return low_level_feat, high_level_feat

#### Final architecture

In [None]:
class DeepLabV3Plus(nn.Module):

   def __init__(self, n_classes, backbone='resnet101', low_level='layer1', high_level='layer4', pretrained=True):

    super().__init__()

    self.backbone = Backbone(backbone_name=backbone, low_level=low_level, pretrained=pretrained, high_level=high_level)
    self.aspp = ASPP(2048, 256)

    self.decoder = Decoder(n_classes)

   def forward(self, inputs):

    low_level, high_level = self.backbone(inputs)
    aspp = self.aspp(high_level)
    aspp = F.interpolate(aspp, scale_factor=(4,4), mode='bilinear', align_corners=True)
    out = self.decoder(low_level, aspp)

    return out

# unet = DeepLabV3Plus(NUM_CLASSES, 'resnet101', high_level='layer4').to(device)
# summary(unet, input_size=(3, 512, 512))

In [None]:
unet = DeepLabV3Plus(NUM_CLASSES, 'resnet50', high_level='layer4').to(device)
input_tensor = torch.randn(1, 3, 512, 512).to(device)
flops, params = profile(unet, inputs=(input_tensor, ))
print(f"FLOPs: {flops/1e9:.2f} GFLOPs")
print(f"Params: {params/1e6:.2f} M")

In [None]:
unet = DeepLabV3Plus(NUM_CLASSES, 'resnet101', high_level='layer4').to(device)
input_tensor = torch.randn(1, 3, 512, 512).to(device)
flops, params = profile(unet, inputs=(input_tensor, ))
print(f"FLOPs: {flops/1e9:.2f} GFLOPs")
print(f"Params: {params/1e6:.2f} M")

### Pytorch's DeepLab V3

In [None]:
from torchvision.models.segmentation import DeepLabV3_ResNet101_Weights, DeepLabV3_ResNet50_Weights

class DeepLabV3Pytorch(nn.Module):

  def __init__(self, num_classes, backbone='resnet101'):

    super().__init__()

    if backbone == 'resnet50':
      self.model = torchvision.models.segmentation.deeplabv3_resnet50(weights=DeepLabV3_ResNet50_Weights.DEFAULT)
    elif backbone == 'resnet101':
      self.model = torchvision.models.segmentation.deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT)
    else:
      raise ValueError('Backbone not recognized')

    self.model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))

  def forward(self, x):

    x = self.model(x)['out']

    return x

unet = DeepLabV3Pytorch(NUM_CLASSES, 'resnet101').to(device)
input_tensor = torch.randn(1, 3, 512, 512).to(device)
flops, params = profile(unet, inputs=(input_tensor, ))
print(f"FLOPs: {flops/1e9:.2f} GFLOPs")
print(f"Params: {params/1e6:.2f} M")

unet = DeepLabV3Pytorch(NUM_CLASSES, 'resnet50').to(device)
input_tensor = torch.randn(1, 3, 512, 512).to(device)
flops, params = profile(unet, inputs=(input_tensor, ))
print(f"FLOPs: {flops/1e9:.2f} GFLOPs")
print(f"Params: {params/1e6:.2f} M")

## Training procedures

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0, mode='max', verbose=True):

        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.verbose = verbose
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.best_weights = None

    def __call__(self, current_score, model):
        if self.best_score is None:
            self.best_score = current_score
            self.best_weights = deepcopy(model.state_dict())
        else:
            improvement = (
                (self.mode == 'min' and current_score < self.best_score - self.min_delta) or
                (self.mode == 'max' and current_score > self.best_score + self.min_delta)
            )
            if improvement:
                self.best_score = current_score
                self.best_weights = deepcopy(model.state_dict())
                self.counter = 0
            else:
                self.counter += 1
                if self.verbose:
                    print(f"EarlyStopping counter: {self.counter}/{self.patience}")
                if self.counter >= self.patience:
                    if self.verbose:
                        print("Early stopping triggered.")
                    self.early_stop = True

In [None]:
def train_loop(model, num_epochs, optimizer, scheduler, criterion, metric, early_stop, device, train_loader, val_loader=None):

  train_losses = []
  val_losses = []
  val_metric = []

  global_step = 0

  for epoch in tqdm.tqdm(range(num_epochs)):

    model.train()
    train_loss_epoch = []
    for idx, (images, masks) in enumerate(train_loader):
      # Stack the list of tensors into a single tensor
      images = torch.stack(images, dim=0).to(device)
      masks = torch.stack(masks, dim=0).to(device)

      predicted_mask = model(images)
      loss = criterion(predicted_mask, masks.squeeze(1))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      train_loss_epoch.append(loss.item())

      # if idx % 10 == 0:
      #   wandb.log({
      #     "train/loss": loss
      #   }, global_step)
      #   global_step += 10
    train_loss = np.mean(train_loss_epoch)
    train_losses.append(train_loss)

    if scheduler is not None:
      scheduler.step(loss)

    if val_loader is not None:
      model.eval()
      val_loss_epoch = []
      val_metric_epoch = []
      with torch.no_grad():
        for idx, (images, masks) in enumerate(val_loader):
          # Stack the list of tensors into a single tensor
          images = torch.stack(images, dim=0).to(device)
          masks = torch.stack(masks, dim=0).to(device)

          predictions = model(images)
          loss = criterion(predictions, masks.squeeze(1))
          _, metric_value = metric(predictions, masks.squeeze(1))
          val_loss_epoch.append(loss.item())
          val_metric_epoch.append(metric_value.item())

      new_loss_ep = np.mean(val_loss_epoch)
      new_metric_ep = np.mean(val_metric_epoch)

      val_losses.append(new_loss_ep)
      val_metric.append(new_metric_ep)
      # wandb.log({
      #       "val/loss": np.mean(val_loss_epoch),
      #       "val/metric": np.mean(val_metric_epoch)
      #   })
      print(f"Epoch {epoch}: training loss {train_loss} -- validation metric {new_metric_ep} -- validation loss {new_loss_ep}")
      early_stop(new_metric_ep, model)
      if early_stop.early_stop:
        print(f"Stopped early with best mAP: {early_stop.best_score:.2f}")
        break

  return train_losses, val_losses, val_metric, early_stop.best_weights


def train(model, num_epochs, optimizer, criterion, metric, scheduler, early_stop, device, project_name, run_name, train_loader, val_loader=None):

  # wandb.init(project=project_name, name=run_name)

  train_losses, val_losses, val_metric, best_model_state = train_loop(model, num_epochs, optimizer, scheduler, criterion, metric, early_stop, device, train_loader, val_loader)

  # wandb.finish()

  path_ckpts = Path("/content/drive/My Drive/CamVid/ckpts")
  path_ckpts.mkdir(exist_ok=True)
  torch.save(best_model_state, path_ckpts / f"{run_name}.pt")

  return train_losses, val_losses, val_metric

## Train models

In [None]:
wandb_proj_name = "CamVid2"

BATCH_SIZE = 16

In [None]:
train_data = CamVidDataset(NEW_TRAIN_PATH, NEW_TRAIN_LABELS, 512, 512, transforms=get_train_transform())
val_data = CamVidDataset(VALID_PATH, LABEL_VALID_PATH, 512, 512, transforms=get_valid_transform())
test_data = CamVidDataset(NEW_TEST_PATH, NEW_TEST_LABELS, 512, 512, transforms=get_valid_transform())

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=collate_fn)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, collate_fn=collate_fn)

### First comparison of models

In [None]:
NUM_EPOCHS = 60
BATCH_SIZE = 16
WEIGHT_DECAY = 1e-4

#### VGG Unet

In [None]:
unet = UNet(NUM_CLASSES).to(device)
lr=1e-4
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
criterion = nn.CrossEntropyLoss(ignore_index=255)

run_name = f"unet_AdamW_LR_{lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(unet, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
unet = UNet(NUM_CLASSES).to(device)
path_ckpts = Path("/content/drive/My Drive/CamVid/ckpts")
run_name = "unet_AdamW_LR_0.0001_eps_60"
weight_path = path_ckpts / f"{run_name}.pt"
unet.load_state_dict(torch.load(weight_path, map_location=torch.device(device)))
evaluate_model(unet, test_loader, device, num_classes=11)
evaluate_model(unet, val_loader, device, num_classes=11)
evaluate_model(unet, train_loader, device, num_classes=11)

In [None]:
indexes = [0,20,50,80,99]

unet.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = unet(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### From scratch ResNet-34

In [None]:
unet34 = UNetResNet34(NUM_CLASSES, pretrained=False).to(device)

base_lr = 1e-4

optimizer = torch.optim.AdamW(unet34.parameters(), lr=base_lr, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
criterion = nn.CrossEntropyLoss(ignore_index=255)

run_name = f"scratch_unet34_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(unet34, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(unet34, test_loader, device, num_classes=11)
evaluate_model(unet34, val_loader, device, num_classes=11)
evaluate_model(unet34, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

unet34.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = unet34(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### Unet ResNet-34

In [None]:
unet34 = UNetResNet34(NUM_CLASSES).to(device)

base_lr = 1e-4

for layer in [unet34.conv1, unet34.bn1, unet34.relu, unet34.maxpool]:
    for param in layer.parameters():
        param.requires_grad = False

param_groups = [
    {"params": unet34.enc1.parameters(), "lr": base_lr * 0.01},
    {"params": unet34.enc2.parameters(), "lr": base_lr * 0.2},
    {"params": unet34.enc3.parameters(), "lr": base_lr * 0.25},
    {"params": unet34.enc4.parameters(), "lr": base_lr * 0.5},
    {"params": unet34.bottleneck.parameters(), "lr": base_lr},
    {"params": unet34.decoder1.parameters(), "lr": base_lr},
    {"params": unet34.decoder2.parameters(), "lr": base_lr},
    {"params": unet34.decoder3.parameters(), "lr": base_lr},
    {"params": unet34.decoder4.parameters(), "lr": base_lr},
    {"params": unet34.decoder5.parameters(), "lr": base_lr},
    {"params": unet34.last.parameters(), "lr": base_lr},
]

print_trainable_params(unet34)

optimizer = torch.optim.AdamW(param_groups, lr=base_lr, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
criterion = nn.CrossEntropyLoss(ignore_index=255)

run_name = f"unet34_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(unet34, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(unet34, test_loader, device, num_classes=11)
evaluate_model(unet34, val_loader, device, num_classes=11)
evaluate_model(unet34, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

unet34.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = unet34(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### Unet ResNet-50

In [None]:
unet50 = UNetResNet(NUM_CLASSES, backbone='resnet50').to(device)

base_lr = 1e-4

for layer in [unet50.conv1, unet50.bn1, unet50.relu, unet50.maxpool]:
    for param in layer.parameters():
        param.requires_grad = False

param_groups = [
    {"params": unet50.enc1.parameters(), "lr": base_lr * 0.01},
    {"params": unet50.enc2.parameters(), "lr": base_lr * 0.2},
    {"params": unet50.enc3.parameters(), "lr": base_lr * 0.25},
    {"params": unet50.enc4.parameters(), "lr": base_lr * 0.5},
    {"params": unet50.decoder1.parameters(), "lr": base_lr},
    {"params": unet50.decoder2.parameters(), "lr": base_lr},
    {"params": unet50.decoder3.parameters(), "lr": base_lr},
    {"params": unet50.decoder4.parameters(), "lr": base_lr},
    {"params": unet50.last.parameters(), "lr": base_lr},
]

print_trainable_params(unet50)

optimizer = torch.optim.AdamW(param_groups, lr=base_lr, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
criterion = nn.CrossEntropyLoss(ignore_index=255)

run_name = f"unet50_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(unet50, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(unet50, test_loader, device, num_classes=11)
evaluate_model(unet50, val_loader, device, num_classes=11)
evaluate_model(unet50, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

unet50.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = unet50(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### Unet ResNet-101

In [None]:
unet101 = UNetResNet(NUM_CLASSES, backbone='resnet101').to(device)

base_lr = 1e-4

for layer in [unet101.conv1, unet101.bn1, unet101.relu, unet101.maxpool]:
    for param in layer.parameters():
        param.requires_grad = False

param_groups = [
    {"params": unet101.enc1.parameters(), "lr": base_lr * 0.01},
    {"params": unet101.enc2.parameters(), "lr": base_lr * 0.2},
    {"params": unet101.enc3.parameters(), "lr": base_lr * 0.25},
    {"params": unet101.enc4.parameters(), "lr": base_lr * 0.5},
    {"params": unet101.decoder1.parameters(), "lr": base_lr},
    {"params": unet101.decoder2.parameters(), "lr": base_lr},
    {"params": unet101.decoder3.parameters(), "lr": base_lr},
    {"params": unet101.decoder4.parameters(), "lr": base_lr},
    {"params": unet101.last.parameters(), "lr": base_lr},
]

print_trainable_params(unet101)

optimizer = torch.optim.AdamW(param_groups, lr=base_lr, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
critection = nn.CrossEntropyLoss(ignore_index=255)

run_name = f"unet101_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(unet101, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(unet101, test_loader, device, num_classes=11)
evaluate_model(unet101, val_loader, device, num_classes=11)
evaluate_model(unet101, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

unet101.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = unet101(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### DeepLab V3+

In [None]:
deeplab = DeepLabV3Plus(NUM_CLASSES, backbone='resnet50').to(device)

base_lr = 1e-4

for layer in [deeplab.backbone.conv1, deeplab.backbone.maxpool]:
    for param in layer.parameters():
        param.requires_grad = False

param_groups = [
    {"params": deeplab.backbone.layer1.parameters(), "lr": base_lr * 0.01},
    {"params": deeplab.backbone.layer2.parameters(), "lr": base_lr * 0.2},
    {"params": deeplab.backbone.layer3.parameters(), "lr": base_lr * 0.25},
    {"params": deeplab.backbone.layer4.parameters(), "lr": base_lr * 0.5},
    {"params": deeplab.aspp.parameters(), "lr": base_lr},
    {"params": deeplab.decoder.parameters(), "lr": base_lr}
]

print_trainable_params(deeplab)

optimizer = torch.optim.AdamW(param_groups, lr=base_lr, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
critection = nn.CrossEntropyLoss(ignore_index=255)

run_name = f"deeplab_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(deeplab, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(deeplab, test_loader, device, num_classes=11)
evaluate_model(deeplab, val_loader, device, num_classes=11)
evaluate_model(deeplab, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

deeplab.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = deeplab(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### DeepLab V3+ 50 from scratch

In [None]:
deeplab = DeepLabV3Plus(NUM_CLASSES, backbone='resnet50', pretrained=False).to(device)

base_lr = 1e-4

optimizer = torch.optim.AdamW(deeplab.parameters(), lr=base_lr, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
critection = nn.CrossEntropyLoss(ignore_index=255)

run_name = f"scratch_deeplab_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(deeplab, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(deeplab, test_loader, device, num_classes=11)
evaluate_model(deeplab, val_loader, device, num_classes=11)
evaluate_model(deeplab, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

deeplab.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = deeplab(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### DeepLab V3+ 101

In [None]:
deeplab101 = DeepLabV3Plus(NUM_CLASSES, backbone='resnet101').to(device)

base_lr = 1e-4

for layer in [deeplab101.backbone.conv1, deeplab101.backbone.maxpool]:
    for param in layer.parameters():
        param.requires_grad = False

param_groups = [
    {"params": deeplab101.backbone.layer1.parameters(), "lr": base_lr * 0.01},
    {"params": deeplab101.backbone.layer2.parameters(), "lr": base_lr * 0.2},
    {"params": deeplab101.backbone.layer3.parameters(), "lr": base_lr * 0.25},
    {"params": deeplab101.backbone.layer4.parameters(), "lr": base_lr * 0.5},
    {"params": deeplab101.aspp.parameters(), "lr": base_lr},
    {"params": deeplab101.decoder.parameters(), "lr": base_lr}
]

print_trainable_params(deeplab101)

optimizer = torch.optim.AdamW(param_groups, lr=base_lr, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
critection = nn.CrossEntropyLoss(ignore_index=255)

run_name = f"deeplab101_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(deeplab101, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(deeplab101, test_loader, device, num_classes=11)
evaluate_model(deeplab101, val_loader, device, num_classes=11)
evaluate_model(deeplab101, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

deeplab101.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = deeplab101(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### Pretrained DeepLab V3

In [None]:
pytorchdeeplab = DeepLabV3Pytorch(NUM_CLASSES).to(device)

base_lr = 1e-4

for name, param in pytorchdeeplab.model.backbone.named_parameters():
       if not any(f in name for f in ["layer2", "layer3", "layer4"]):
            param.requires_grad = False
classifier_params = []
backbone_params = []
for name, param in pytorchdeeplab.named_parameters():
    if not param.requires_grad:
        continue
    elif name.startswith("classifier"):
        classifier_params.append(param)
    elif name.startswith("aux_classifier"):
        continue
    else:
        backbone_params.append(param)

optimizer = torch.optim.Adam([
    {'params': backbone_params, 'lr': base_lr * 0.1, },   # smaller LR for pretrained backbone
    {'params': classifier_params, 'lr': base_lr}       # main head: larger LR
], weight_decay=WEIGHT_DECAY)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
criterion = nn.CrossEntropyLoss(ignore_index=255)

print_trainable_params(pytorchdeeplab)

run_name = f"pytorch_deeplab101_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(pytorchdeeplab, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(pytorchdeeplab, test_loader, device, num_classes=11)
evaluate_model(pytorchdeeplab, val_loader, device, num_classes=11)
evaluate_model(pytorchdeeplab, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

pytorchdeeplab.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = pytorchdeeplab(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### Pretrained DeepLab V3

Higher lr is used because otherwise the convergence was too slow.

In [None]:
pytorchdeeplab = DeepLabV3Pytorch(NUM_CLASSES).to(device)

base_lr = 1e-3

for name, param in pytorchdeeplab.model.backbone.named_parameters():
       if not any(f in name for f in ["layer2", "layer3", "layer4"]):
            param.requires_grad = False
classifier_params = []
backbone_params = []
for name, param in pytorchdeeplab.named_parameters():
    if not param.requires_grad:
        continue
    elif name.startswith("classifier"):
        classifier_params.append(param)
    elif name.startswith("aux_classifier"):
        continue
    else:
        backbone_params.append(param)

optimizer = torch.optim.Adam([
    {'params': backbone_params, 'lr': base_lr * 0.1, },   # smaller LR for pretrained backbone
    {'params': classifier_params, 'lr': base_lr}       # main head: larger LR
], weight_decay=WEIGHT_DECAY)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
criterion = nn.CrossEntropyLoss(ignore_index=255)

print_trainable_params(pytorchdeeplab)

run_name = f"pytorch_deeplab101_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(pytorchdeeplab, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
# plot_curves(train_losses, val_losses, val_metric)
pytorchdeeplab = DeepLabV3Pytorch(NUM_CLASSES).to(device)
path_ckpts = Path("/content/drive/My Drive/CamVid/ckpts")
run_name = "pytorch_deeplab101_AdamW_LR_0.001_eps_60"
weight_path = path_ckpts / f"{run_name}.pt"
pytorchdeeplab.load_state_dict(torch.load(weight_path, map_location=torch.device(device)))
evaluate_model(pytorchdeeplab, test_loader, device, num_classes=11)
evaluate_model(pytorchdeeplab, val_loader, device, num_classes=11)
evaluate_model(pytorchdeeplab, train_loader, device, num_classes=11)


In [None]:
indexes = [0,20,50,80,99]

pytorchdeeplab.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = pytorchdeeplab(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

### Change loss

In [None]:
NUM_EPOCHS = 60
WEIGHT_DECAY = 1e-4

#### Class weights

In [None]:
num_classes = 11
class_counts = torch.zeros(num_classes)

for _, targets in train_loader:
    targets = torch.stack(targets, dim=0)
    targets = targets.squeeze(1)

    for cls in range(num_classes):
        mask = (targets == cls) & (targets != 255)
        class_counts[cls] += mask.sum().item()

# Normalizing with max
class_weights = 1.0 / (class_counts + 1e-6)
class_weights = class_weights / class_weights.max()

# Normalizing with sum
# class_weights = 1.0 / (class_counts + 1e-6)
# class_weights = class_weights / class_weights.sum()

# Median frequency balancing
# frequencies = class_counts / class_counts.sum()
# median_freq = torch.median(frequencies[frequencies > 0])
# class_weights = median_freq / (frequencies + 1e-6)

CLASS_WEIGHTS = torch.cat([class_weights, torch.tensor([0.0])])
for i in range(len(CLASS_WEIGHTS)):
  print(f"Weights for class {CAMVID_CLASSES[i]} is : {CLASS_WEIGHTS[i]:.3f}.")

#### Unet ResNet-50

In [None]:
unet50 = UNetResNet(NUM_CLASSES, backbone='resnet50').to(device)

base_lr = 1e-4

for layer in [unet50.conv1, unet50.bn1, unet50.relu, unet50.maxpool]:
    for param in layer.parameters():
        param.requires_grad = False

param_groups = [
    {"params": unet50.enc1.parameters(), "lr": base_lr * 0.01},
    {"params": unet50.enc2.parameters(), "lr": base_lr * 0.2},
    {"params": unet50.enc3.parameters(), "lr": base_lr * 0.25},
    {"params": unet50.enc4.parameters(), "lr": base_lr * 0.5},
    {"params": unet50.decoder1.parameters(), "lr": base_lr},
    {"params": unet50.decoder2.parameters(), "lr": base_lr},
    {"params": unet50.decoder3.parameters(), "lr": base_lr},
    {"params": unet50.decoder4.parameters(), "lr": base_lr},
    {"params": unet50.last.parameters(), "lr": base_lr},
]

print_trainable_params(unet50)

optimizer = torch.optim.AdamW(param_groups, lr=base_lr, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')

# Ignore index 11 because of internal problems of FocalLoss - solve it!
criterion = FocalLoss(alpha=CLASS_WEIGHTS, ignore_index=11)

run_name = f"focal_unet50_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(unet50, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(unet50, test_loader, device, num_classes=11)
evaluate_model(unet50, val_loader, device, num_classes=11)
evaluate_model(unet50, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

unet50.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = unet50(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### UNet ResNet-50 median alpha

In [None]:
num_classes = 11
class_counts = torch.zeros(num_classes)

for _, targets in train_loader:
    targets = torch.stack(targets, dim=0)
    targets = targets.squeeze(1)

    for cls in range(num_classes):
        mask = (targets == cls) & (targets != 255)
        class_counts[cls] += mask.sum().item()

# Normalizing with sum
# class_weights = 1.0 / (class_counts + 1e-6)
# class_weights = class_weights / class_weights.sum()

# Median frequency balancing
frequencies = class_counts / class_counts.sum()
median_freq = torch.median(frequencies[frequencies > 0])
class_weights_median = median_freq / (frequencies + 1e-6)

CLASS_WEIGHTS_MEDIAN = torch.cat([class_weights_median, torch.tensor([0.0])])
for i in range(len(CLASS_WEIGHTS_MEDIAN)):
  print(f"Weights for class {CAMVID_CLASSES[i]} is : {CLASS_WEIGHTS_MEDIAN[i]:.3f}.")

In [None]:
unet50 = UNetResNet(NUM_CLASSES, backbone='resnet50').to(device)

base_lr = 1e-4

for layer in [unet50.conv1, unet50.bn1, unet50.relu, unet50.maxpool]:
    for param in layer.parameters():
        param.requires_grad = False

param_groups = [
    {"params": unet50.enc1.parameters(), "lr": base_lr * 0.01},
    {"params": unet50.enc2.parameters(), "lr": base_lr * 0.2},
    {"params": unet50.enc3.parameters(), "lr": base_lr * 0.25},
    {"params": unet50.enc4.parameters(), "lr": base_lr * 0.5},
    {"params": unet50.decoder1.parameters(), "lr": base_lr},
    {"params": unet50.decoder2.parameters(), "lr": base_lr},
    {"params": unet50.decoder3.parameters(), "lr": base_lr},
    {"params": unet50.decoder4.parameters(), "lr": base_lr},
    {"params": unet50.last.parameters(), "lr": base_lr},
]

print_trainable_params(unet50)

optimizer = torch.optim.AdamW(param_groups, lr=base_lr, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')

# Ignore index 11 because of internal problems of FocalLoss - solve it!
criterion = FocalLoss(alpha=CLASS_WEIGHTS_MEDIAN, ignore_index=11)

run_name = f"focal_median_unet50_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(unet50, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(unet50, test_loader, device, num_classes=11)
evaluate_model(unet50, val_loader, device, num_classes=11)
evaluate_model(unet50, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

unet50.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = unet50(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

#### DeepLab V3+ ResNet101

In [None]:
train_data = CamVidDataset(NEW_TRAIN_PATH, NEW_TRAIN_LABELS, 512, 512, transforms=get_train_transform())
val_data = CamVidDataset(VALID_PATH, LABEL_VALID_PATH, 512, 512, transforms=get_valid_transform())
test_data = CamVidDataset(NEW_TEST_PATH, NEW_TEST_LABELS, 512, 512, transforms=get_valid_transform())

train_loader = DataLoader(train_data, batch_size=8, shuffle=True, num_workers=2, collate_fn=collate_fn)
val_loader = DataLoader(val_data, batch_size=8, shuffle=False, num_workers=2, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=8, shuffle=False, num_workers=2, collate_fn=collate_fn)

In [None]:
deeplab101 = DeepLabV3Plus(NUM_CLASSES, backbone='resnet101').to(device)

base_lr = 1e-4

for layer in [deeplab101.backbone.conv1, deeplab101.backbone.maxpool]:
    for param in layer.parameters():
        param.requires_grad = False

param_groups = [
    {"params": deeplab101.backbone.layer1.parameters(), "lr": base_lr * 0.01},
    {"params": deeplab101.backbone.layer2.parameters(), "lr": base_lr * 0.2},
    {"params": deeplab101.backbone.layer3.parameters(), "lr": base_lr * 0.25},
    {"params": deeplab101.backbone.layer4.parameters(), "lr": base_lr * 0.5},
    {"params": deeplab101.aspp.parameters(), "lr": base_lr},
    {"params": deeplab101.decoder.parameters(), "lr": base_lr}
]

print_trainable_params(deeplab101)

optimizer = torch.optim.AdamW(param_groups, lr=base_lr, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
early_stop = EarlyStopping(patience=10, min_delta=0.001, mode='max')
criterion = FocalLoss(alpha=CLASS_WEIGHTS, ignore_index=11)

run_name = f"focal_deeplab101_AdamW_LR_{base_lr}_eps_{NUM_EPOCHS}"

In [None]:
train_losses, val_losses, val_metric = train(deeplab101, NUM_EPOCHS, optimizer, criterion, multiclass_iou, scheduler, early_stop, device, wandb_proj_name, run_name, train_loader, val_loader)

In [None]:
plot_curves(train_losses, val_losses, val_metric)
evaluate_model(deeplab101, test_loader, device, num_classes=11)
evaluate_model(deeplab101, val_loader, device, num_classes=11)
evaluate_model(deeplab101, train_loader, device, num_classes=11)

indexes = [0,20,50,80,99]

deeplab101.eval()
with torch.no_grad():
  for i in indexes:
    img, mask = test_data[i]
    img = img.unsqueeze(0)
    pred = deeplab101(img.to(device))
    pred_classes = pred.argmax(dim=1)
    visualize(img.squeeze(0), mask, predicted=pred_classes)

## Final visualization

In [None]:
@torch.no_grad()
def compare_models_visualization(models: dict,
                                 test_dataset,
                                 device: str = "cuda",
                                 num_samples: int = 3,
                                 void_label: int = 255):
    """
    Visualize predictions from multiple models side-by-side.

    Args:
        models (dict): Dictionary {model_name: model_object}
        test_dataset: Dataset object returning (image, target)
        device (str): 'cuda' or 'cpu'
        num_samples (int): Number of random samples to visualize
        void_label (int): Label ID for void pixels (ignored)
    """

    # Set models to eval mode
    for m in models.values():
        m.eval()

    # Randomly sample images
    indices = random.sample(range(len(test_dataset)), num_samples)
    n_models = len(models)

    for idx in indices:
        image, target = test_dataset[idx]
        image = image.to(device).unsqueeze(0)  # (1, C, H, W)
        target = target.squeeze().cpu()

        # Create figure
        n_cols = 2 + n_models  # image + GT + predictions
        plt.figure(figsize=(5 * n_cols, 4))

        plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0.01, hspace=0.01)

        # Show input image
        img_disp = image.squeeze().cpu()
        denormalize(img_disp)
        plt.subplot(1, n_cols, 1)
        plt.imshow(img_disp.permute(1, 2, 0))
        plt.axis("off")
        plt.title("Input Image", fontsize=14, weight="bold")

        # Show ground truth
        plt.subplot(1, n_cols, 2)
        plt.imshow(colorize_mask(target))
        plt.axis("off")
        plt.title("Ground Truth", fontsize=14, weight="bold")

        # Model predictions
        for i, (name, model) in enumerate(models.items(), start=3):
            preds = model(image)
            if isinstance(preds, dict) and "out" in preds:
                preds = preds["out"]  # handle torchvision models
            pred_mask = preds.argmax(dim=1).squeeze().cpu()

            plt.subplot(1, n_cols, i)
            plt.imshow(colorize_mask(pred_mask))
            plt.axis("off")
            plt.title(f"{name} Prediction", fontsize=14)

        plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0.02, hspace=0.02)
        plt.tight_layout()
        plt.show()


In [None]:
baseline = UNet(NUM_CLASSES).to(device)
path_ckpts = Path("/content/drive/My Drive/CamVid/ckpts")
run_name = "unet_AdamW_LR_0.0001_eps_60"
weight_path = path_ckpts / f"{run_name}.pt"
baseline.load_state_dict(torch.load(weight_path, map_location=torch.device(device)))
resnet34 = UNetResNet34(NUM_CLASSES).to(device)
run_name = "unet34_AdamW_LR_0.0001_eps_60"
weight_path = path_ckpts / f"{run_name}.pt"
resnet34.load_state_dict(torch.load(weight_path, map_location=torch.device(device)))
resnet50 = UNetResNet(NUM_CLASSES, backbone='resnet50').to(device)
run_name = "unet50_AdamW_LR_0.0001_eps_60"
weight_path = path_ckpts / f"{run_name}.pt"
resnet50.load_state_dict(torch.load(weight_path, map_location=torch.device(device)))
deeplab_custom = DeepLabV3Plus(NUM_CLASSES, backbone='resnet101').to(device)
run_name = "deeplab101_AdamW_LR_0.0001_eps_60"
weight_path = path_ckpts / f"{run_name}.pt"
deeplab_custom.load_state_dict(torch.load(weight_path, map_location=torch.device(device)))
deeplab_torch = DeepLabV3Pytorch(NUM_CLASSES).to(device)
run_name = "pytorch_deeplab101_AdamW_LR_0.001_eps_60"
weight_path = path_ckpts / f"{run_name}.pt"
deeplab_torch.load_state_dict(torch.load(weight_path, map_location=torch.device(device)))
focal_unet50 = UNetResNet(NUM_CLASSES, backbone='resnet50').to(device)
run_name = "focal_unet50_AdamW_LR_0.0001_eps_60"
weight_path = path_ckpts / f"{run_name}.pt"
focal_unet50.load_state_dict(torch.load(weight_path, map_location=torch.device(device)))

models = {
    "Baseline": baseline,
    "U-Net (ResNet-34)": resnet34,
    "U-Net (ResNet-50)": resnet50,
    "DeepLabV3+": deeplab_custom,
    "DeepLabV3 (Torch)": deeplab_torch,
    "Focal U-Net (ResNet-50)": focal_unet50
}

compare_models_visualization(models, test_data, device="cuda", num_samples=3)

In [None]:
models = {
    "Baseline": baseline,
    "U-Net (ResNet-34)": resnet34,
    "U-Net (ResNet-50)": resnet50,
    "DeepLabV3+": deeplab_custom,
    "DeepLabV3 (Torch)": deeplab_torch,
    "Focal U-Net (ResNet-50)": focal_unet50
}

compare_models_visualization(models, test_data, device="cuda", num_samples=12)

In [None]:
models = {
    "Baseline": baseline,
    "U-Net (ResNet-34)": resnet34,
    "U-Net (ResNet-50)": resnet50,
    "DeepLabV3+": deeplab_custom,
    "DeepLabV3 (Torch)": deeplab_torch,
    "Focal U-Net (ResNet-50)": focal_unet50
}

compare_models_visualization(models, test_data, device="cuda", num_samples=6)

In [None]:
@torch.no_grad()
def compare_models_visualization_grid(models: dict,
                                 test_dataset,
                                 device: str = "cuda",
                                 num_samples: int = 3,
                                 void_label: int = 255):
    """
    Visualize predictions from multiple models side-by-side (multi-row grid layout).

    Each row = one random sample from the dataset.
    Columns = Input | Ground Truth | Predictions for each model.
    """

    # Set all models to evaluation mode
    for m in models.values():
        m.eval()

    # Randomly select samples
    indices = random.sample(range(len(test_dataset)), num_samples)
    n_models = len(models)
    n_cols = 2 + n_models  # image + GT + predictions

    # Create figure grid
    fig, axes = plt.subplots(num_samples, n_cols,
                             figsize=(3.5 * n_cols, 3.5 * num_samples))

    # Ensure axes is 2D even if num_samples = 1
    if num_samples == 1:
        axes = axes[None, :]

    for row, idx in enumerate(indices):
        image, target = test_dataset[idx]
        image = image.to(device).unsqueeze(0)
        target = target.squeeze().cpu()

        # Input image
        img_disp = image.squeeze().cpu()
        denormalize(img_disp)
        axes[row, 0].imshow(img_disp.permute(1, 2, 0))
        axes[row, 0].axis("off")
        if row == 0:
            axes[row, 0].set_title("Input", fontsize=13, weight="bold")

        # Ground truth
        axes[row, 1].imshow(colorize_mask(target))
        axes[row, 1].axis("off")
        if row == 0:
            axes[row, 1].set_title("Ground Truth", fontsize=13, weight="bold")

        # Model predictions
        for col, (name, model) in enumerate(models.items(), start=2):
            preds = model(image)
            if isinstance(preds, dict) and "out" in preds:
                preds = preds["out"]
            pred_mask = preds.argmax(dim=1).squeeze().cpu()

            axes[row, col].imshow(colorize_mask(pred_mask))
            axes[row, col].axis("off")
            if row == 0:
                axes[row, col].set_title(f"{name} Prediction", fontsize=12, pad=4)

    # Tighten layout (almost no gaps)
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0.02, hspace=0.02)
    plt.show()


In [None]:
models = {
    "Baseline": baseline,
    "U-Net (ResNet-34)": resnet34,
    "U-Net (ResNet-50)": resnet50,
    "DeepLabV3+": deeplab_custom,
    "DeepLabV3 (Torch)": deeplab_torch,
    "Focal U-Net (ResNet-50)": focal_unet50
}

compare_models_visualization_grid(models, test_data, device="cuda", num_samples=6)