UAVid dataset: https://uavid.nl/

Steps:
 - Install Pytorch via Pip/Conda https://pytorch.org/get-started/locally/
 - Download dataset via link: https://disk.yandex.ru/d/Mw39uGBey44azw
 - Extract dataset zip archive
 - Change `DATASET_DIR` accordingly
 - Implement `UAVIDDataset`.
 - Run the whole notebook to assess SegNet of `UAVIDSemanticSegmentationModel`
 - Modify code of `UAVIDSemanticSegmentationModel` to boost semantic segmentation quality quality of test images. You can choose UNet, PSPnet, DeepLab or something else.

In [1]:
import os
import cv2
import numpy as np
import ipywidgets
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from typing import Optional

import matplotlib.pyplot as plt
%matplotlib inline

import albumentations as A


ModuleNotFoundError: ignored

# Dataset

In [None]:
class UAVIDDataset(Dataset):
    def __init__(self, dataset_dir: str, image_height: int, image_width: int, train: bool = False):
        super(UAVIDDataset, self).__init__()
        self.dataset_dir = dataset_dir
        self.image_height = image_height
        self.image_width = image_width
        self.train = train
        
        dir_suffix = 'train/seq1' if self.train else 'val/seq16'
        
        self.image_dir = f'{self.dataset_dir}/uavid_{dir_suffix}/Images/'
        self.label_path = f'{self.dataset_dir}/uavid_{dir_suffix}/Labels/'

        self.index = list(sorted(os.listdir(self.image_dir)))
        
        # color table.
        self.clr_tab = self.createColorTable()

        # id table.
        id_tab = {}
        for k, v in self.clr_tab.items():
            id_tab[k] = self.clr2id(v)
        self.id_tab = id_tab        
        
        self.transform = transforms.Compose(
            [ 
                transforms.ToTensor(),
                transforms.Normalize(IMAGE_MEAN, IMAGE_STD),
                transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2))
             
                # transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.2, hue=0.2),
                # transforms.RandomHorizontalFlip(),
            ]
        )

        
    def __getitem__(self, index):
        image = self.get_np_image(index)
        image = np.array(image / 255, dtype=np.float32)
        image = self.transform(image)
        
        label = self.get_np_label(index)
        label = self.label_transform(label).astype(np.compat.long)
        label = torch.from_numpy(label)
        
        return image, label
        
    def get_np_image(self, index): #+
        image = cv2.imread(f'{self.image_dir}/{self.index[index]}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (self.image_width, self.image_height), interpolation=cv2.INTER_CUBIC)
        return image
    
    def get_np_label(self, index): #+
        image = cv2.imread(f'{self.label_path}/{self.index[index]}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (self.image_width, self.image_height), interpolation=cv2.INTER_CUBIC)
        
        return image
    
    def __len__(self):
        return len(self.index)
    
    def createColorTable(self):
        clr_tab = {}
        clr_tab['Clutter'] = [0, 0, 0]
        clr_tab['Building'] = [128, 0, 0]
        clr_tab['Road'] = [128, 64, 128]
        clr_tab['Static_Car'] = [192, 0, 192]
        clr_tab['Tree'] = [0, 128, 0]
        clr_tab['Vegetation'] = [128, 128, 0]
        clr_tab['Human'] = [64, 64, 0]
        clr_tab['Moving_Car'] = [64, 0, 128]
        
        return clr_tab

    def colorTable(self):
        return self.clr_tab

    def clr2id(self, clr):
        return clr[0]+clr[1]*255+clr[2]*255*255
    
    #transform to uint8 integer label
    def label_transform(self,label, dtype=np.int32):
        height,width = label.shape[:2]
        # default value is index of clutter.
        newLabel = np.zeros((height, width), dtype=dtype)
        id_label = label.astype(np.int64)
        id_label = id_label[:,:,0]+id_label[:,:,1]*255+id_label[:,:,2]*255*255
        for tid,key in enumerate(self.clr_tab.keys()):
            val = self.id_tab[key]
            mask = (id_label == val)
            newLabel[mask] = tid
            
        return newLabel

    #transform back to 3 channels uint8 label
    def inverse_transform(self, label):
        label_img = np.zeros(shape=(label.shape[0], label.shape[1],3),dtype=np.uint8)
        values = list(self.clr_tab.values())
        for tid,val in enumerate(values):
            mask = (label==tid)
            label_img[mask] = val
            
        return label_img

def create_dataloader(dataset: Dataset, batch_size: int, num_workers: int, train: bool = False):
    return DataLoader(
        dataset=dataset, batch_size=batch_size, 
        shuffle=train, num_workers=num_workers,
        pin_memory=True, drop_last=False)

# Metric

In [None]:
class IOU(nn.Module):
    def __init__(self, eps=1.0, activation=None):
        super().__init__()
        self.eps = eps
        if activation is None:
            self.activation = torch.sigmoid
        else:
            self.activation = activation

    def forward(self, y_pr, y_gt):
        if self.activation is not None:
            y_pr = self.activation(y_pr)
        y_pr = y_pr > 0.5
        y_gt = gt_to_stacked(y_pr, y_gt)
        return get_iou(y_pr, y_gt, eps=self.eps)

    
def gt_to_stacked(pred, gt):
    stacked = torch.zeros_like(pred)
    for i in range(pred.shape[1]):
        stacked[:, i, ...] = gt == i
    return stacked
    
    
def get_iou(pr, gt, eps=1e-6):    
    ious = list()
    
    for i in range(pr.shape[1]):
        gt_i = gt[:, i, ...]
        pr_i = pr[:, i, ...]
        intersection = torch.sum(gt_i * pr_i)
        union = torch.sum(gt_i) + torch.sum(pr_i) - intersection + eps
        iou_i = (intersection + eps) / union
        ious.append(iou_i)
    
    return sum(ious) / len(ious)

# Loss function

In [None]:
class FocalLoss(nn.Module):
    '''
    Multi-class Focal loss implementation
    '''
    def __init__(self, gamma=2, weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, input, target):
        """
        input: [N, C]
        target: [N, ]
        """
        logpt = F.log_softmax(input, dim=1)
        pt = torch.exp(logpt)
        logpt = (1-pt)**self.gamma * logpt
        loss = F.nll_loss(logpt, target, self.weight)
        return loss


class DiceLoss(nn.Layer):
    """
    The implements of the dice loss.
    Args:
        weight (list[float], optional): The weight for each class. Default: None.
        ignore_index (int64): ignore_index (int64, optional): Specifies a target value that
            is ignored and does not contribute to the input gradient. Default ``255``.
        smooth (float32): Laplace smoothing to smooth dice loss and accelerate convergence.
            Default: 1.0
    """

    def __init__(self, weight=None, ignore_index=255, smooth=1.0):
        super().__init__()
        self.weight = weight
        self.ignore_index = ignore_index
        self.smooth = smooth
        self.eps = 1e-8

    def forward(self, logits, labels):
        num_class = logits.shape[1]
        if self.weight is not None:
            assert num_class == len(self.weight), \
                "The lenght of weight should be euqal to the num class"

        logits = F.softmax(logits, axis=1)
        labels_one_hot = F.one_hot(labels, num_class)
        labels_one_hot = paddle.transpose(labels_one_hot, [0, 3, 1, 2])

        mask = labels != self.ignore_index
        mask = paddle.cast(paddle.unsqueeze(mask, 1), 'float32')

        dice_loss = 0.0
        for i in range(num_class):
            dice_loss_i = dice_loss_helper(logits[:, i], labels_one_hot[:, i],
                                           mask, self.smooth, self.eps)
            if self.weight is not None:
                dice_loss_i *= self.weight[i]
            dice_loss += dice_loss_i
        dice_loss = dice_loss / num_class

        return dice_loss


def dice_loss_helper(logit, label, mask, smooth, eps):
    assert logit.shape == label.shape, \
        "The shape of logit and label should be the same"
    logit = paddle.reshape(logit, [0, -1])
    label = paddle.reshape(label, [0, -1])
    mask = paddle.reshape(mask, [0, -1])
    logit *= mask
    label *= mask
    intersection = paddle.sum(logit * label, axis=1)
    cardinality = paddle.sum(logit + label, axis=1)
    dice_loss = 1 - (2 * intersection + smooth) / (cardinality + smooth + eps)
    dice_loss = dice_loss.mean()
    return dice_loss

# Loss

In [None]:
# class Loss(nn.Module):
#     def __init__(self):
#         super().__init__()
#         # self.criterion = FocalLoss()
#         self.criterion = nn.CrossEntropyLoss()

#     def forward(self, outputs, labels):
#         return self.criterion(outputs, labels)

# Utils

In [None]:
def evaluate(model: nn.Module, dataloader: DataLoader, train: bool = False):
    model.eval()
    
    iou_list = list()
    iou_metric = IOU()
    
    with torch.no_grad():
        for data in dataloader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            current_iou = iou_metric(outputs, labels).item()
            iou_list.append(current_iou)

    images = len(dataloader) * dataloader.batch_size
    prefix = "train" if train else "test"
    print(f'mIOU of the network on the {images} {prefix} images: {(sum(iou_list) / len(iou_list)):.3f} %')
    

def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    fig, axs = plt.subplots(
        nrows=num_rows, ncols=num_cols, squeeze=False, figsize=(num_rows * 16, num_cols * 16))

    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

# Model Segnet

In [None]:
class Conv2DBatchNormRelu(nn.Module):
    def __init__(
        self,
        in_channels,
        n_filters,
        k_size,
        stride,
        padding,
        bias=True,
        dilation=1,
        is_batchnorm=True,
    ):
        super(Conv2DBatchNormRelu, self).__init__()

        conv_mod = nn.Conv2d(
            int(in_channels),
            int(n_filters),
            kernel_size=k_size,
            padding=padding,
            stride=stride,
            bias=bias,
            dilation=(dilation, dilation)
        )

        if is_batchnorm:
            self.cbr_unit = nn.Sequential(
                conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True)
            )
        else:
            self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True))

    def forward(self, inputs):
        outputs = self.cbr_unit(inputs)
        return outputs


class SegNetDown2(nn.Module):
    def __init__(self, in_size, out_size):
        super(SegNetDown2, self).__init__()
        self.conv1 = Conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
        self.conv2 = Conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
        self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        unpooled_shape = outputs.size()
        outputs, indices = self.maxpool_with_argmax(outputs)
        return outputs, indices, unpooled_shape


class SegNetUp2(nn.Module):
    def __init__(self, in_size, out_size):
        super(SegNetUp2, self).__init__()
        self.unpool = nn.MaxUnpool2d(2, 2)
        self.conv1 = Conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
        self.conv2 = Conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)

    def forward(self, inputs, indices, output_shape):
        outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape)
        outputs = self.conv1(outputs)
        outputs = self.conv2(outputs)
        return outputs


class SegNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=1, is_unpooling=True):
        super(SegNet, self).__init__()
        self.in_channels = in_channels
        self.is_unpooling = is_unpooling

        self.down1 = SegNetDown2(self.in_channels, 128)
        self.down2 = SegNetDown2(128, 256)

        self.up2 = SegNetUp2(256, 128)
        self.up1 = SegNetUp2(128, 64)

        self.classification = nn.Conv2d(64, num_classes, kernel_size=(3, 3), padding=1)

    def forward(self, inputs):
        down1, indices_1, unpool_shape1 = self.down1(inputs)
        down2, indices_2, unpool_shape2 = self.down2(down1)
        up2 = self.up2(down2, indices_2, unpool_shape2)
        up1 = self.up1(up2, indices_1, unpool_shape1)
        semantic = self.classification(up1)
        return semantic
    
    
class UAVIDSemanticSegmentationModel(nn.Module):
    def __init__(self, num_classes: int):
        super(UAVIDSemanticSegmentationModel, self).__init__()
        self.model = SegNet(num_classes=num_classes)

    def forward(self, x):
        y = self.model(x)
        return y

# Model Unet

In [None]:
def convrelu(in_channels, out_channels, kernel, padding):
  return nn.Sequential(
    nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
    nn.ReLU(inplace=True),
  )


class ResNetUNet(nn.Module):
  def __init__(self, n_class):
    super().__init__()

    self.base_model = torchvision.models.resnet18(pretrained=True)
    self.base_layers = list(self.base_model.children())

    self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
    self.layer0_1x1 = convrelu(64, 64, 1, 0)
    self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
    self.layer1_1x1 = convrelu(64, 64, 1, 0)
    self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
    self.layer2_1x1 = convrelu(128, 128, 1, 0)
    self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
    self.layer3_1x1 = convrelu(256, 256, 1, 0)
    self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
    self.layer4_1x1 = convrelu(512, 512, 1, 0)

    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
    self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
    self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
    self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

    self.conv_original_size0 = convrelu(3, 64, 3, 1)
    self.conv_original_size1 = convrelu(64, 64, 3, 1)
    self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

    self.conv_last = nn.Conv2d(64, n_class, 1)

  def forward(self, input):
    x_original = self.conv_original_size0(input)
    x_original = self.conv_original_size1(x_original)

    layer0 = self.layer0(input)
    layer1 = self.layer1(layer0)
    layer2 = self.layer2(layer1)
    layer3 = self.layer3(layer2)
    layer4 = self.layer4(layer3)

    layer4 = self.layer4_1x1(layer4)
    x = self.upsample(layer4)
    layer3 = self.layer3_1x1(layer3)
    x = torch.cat([x, layer3], dim=1)
    x = self.conv_up3(x)

    x = self.upsample(x)
    layer2 = self.layer2_1x1(layer2)
    x = torch.cat([x, layer2], dim=1)
    x = self.conv_up2(x)

    x = self.upsample(x)
    layer1 = self.layer1_1x1(layer1)
    x = torch.cat([x, layer1], dim=1)
    x = self.conv_up1(x)

    x = self.upsample(x)
    layer0 = self.layer0_1x1(layer0)
    x = torch.cat([x, layer0], dim=1)
    x = self.conv_up0(x)

    x = self.upsample(x)
    x = torch.cat([x, x_original], dim=1)
    x = self.conv_original_size2(x)

    out = self.conv_last(x)

    return out

# ResNet18

In [None]:
class DilatedCNN(nn.Module):
  def __init__(self):
    super(DilatedCNN,self).__init__()
    self.convlayers = nn.Sequential(
      nn.Conv2d(in_channels = 3, out_channels = 6, kernel_size = 9, stride = 1, padding = 0, dilation=2),
      nn.ReLU(),
      nn.Conv2d(in_channels=6, out_channels=16, kernel_size = 3, stride = 1, padding= 0, dilation = 2),
      nn.ReLU(),
    )
    self.fclayers = nn.Sequential(
      nn.Linear(2304,120),
      nn.ReLU(),
      nn.Linear(120,84),
      nn.ReLU(),
      nn.Linear(84,8)
    )
  def forward(self,x):
    x = self.convlayers(x)
    x = x.view(-1,2304)
    x = self.fclayers(x)
    return x

# Parameters

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DATASET_DIR = '/content/drive/MyDrive/Colab Notebooks/uavid'

NUM_CLASSES = 8

IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256

IMAGE_MEAN = [0.485, 0.456, 0.406]
IMAGE_STD = [0.229, 0.224, 0.225]

BATCH_SIZE = 8
NUM_WORKERS = 2
EPOCHS = 8

CHECKPOINT_PATH = f'./uavid_semantic_segmentation_model_{IMAGE_HEIGHT}_{IMAGE_WIDTH}.pth'

# Datasets

In [None]:
train_dataset = UAVIDDataset(
    dataset_dir=DATASET_DIR, image_height=IMAGE_HEIGHT, image_width=IMAGE_WIDTH, train=True)

val_dataset = UAVIDDataset(
    dataset_dir=DATASET_DIR, image_height=IMAGE_HEIGHT, image_width=IMAGE_WIDTH)

train_dataloader = create_dataloader(
    train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, train=True)

val_dataloader = create_dataloader(
    val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

# Data examples

In [None]:
indices = [0, 100, 200]
plot([train_dataset.get_np_image(index) for index in indices]);

In [None]:
indices = [0, 100, 200]
plot([train_dataset.get_np_label(index) for index in indices]);

In [None]:
indices = [0, 100, 200]
plot([train_dataset[index][1].squeeze(0).numpy() for index in indices]);

In [None]:
images, _ = next(iter(val_dataloader))
plt.figure(figsize=(16, 8))
plt.imshow(torchvision.utils.make_grid(images).permute(1, 2, 0));

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# model = UAVIDSemanticSegmentationModel(num_classes=NUM_CLASSES).to(device)
# model = DilatedCNN().to(device)
model = ResNetUNet(NUM_CLASSES).to(device)

# criterion = nn.CrossEntropyLoss()
# criterion = FocalLoss()
criterion = DiceLoss()

# optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training

In [None]:
model.train()

for epoch in range(EPOCHS):
    running_loss = list()

    for i, data in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss.append(loss.item())
    
    print(f'epoch: {epoch:2d}, iteration: {i + 1:4d}, loss: {sum(running_loss) / len(running_loss):.4f}')
            
model.eval();

# Saving model

In [None]:
torch.save(model.state_dict(), CHECKPOINT_PATH)

# Loading model

In [None]:
model.load_state_dict(torch.load(CHECKPOINT_PATH))
model.eval();

# Evaluation

In [None]:
evaluate(model, train_dataloader, train=True)

In [None]:
evaluate(model, val_dataloader, train=False)

# Visualization

In [None]:
def plot_model_preds(index):
    color_transformer = UAVIDDataset(dataset_dir=DATASET_DIR, image_height=IMAGE_HEIGHT, image_width=IMAGE_WIDTH)

    with torch.no_grad():
        image, label = val_dataset[index]
        image = image.to(device)
        label = label.to(device)
        outputs = model(image.unsqueeze(0))
        _, outputs = torch.max(outputs, 1)
        outputs = outputs.squeeze(0).cpu().numpy()
        outputs_colored = color_transformer.inverse_transform(outputs)
        image = image.cpu().numpy()
        label = label.cpu().numpy()
        label_colored = color_transformer.inverse_transform(label)

    plot([val_dataset.get_np_image(index), outputs_colored, label_colored]);

In [None]:
ipywidgets.interact(
    plot_model_preds,    
    index=ipywidgets.IntText(min=0, max=len(val_dataset), step=1, value=0));

# What's next?

In order to boost segmentation accuracy you can:
 - Use State-Of-The-Art model architecture
 - Change Loss function
 - Play with learning_rate of optimizer
 - Increase image resolution
 - Add image augmenations

In [None]:
print (model.__class__.__name__)
print (criterion.__class__.__name__)
print (optimizer.__class__.__name__)

base: image size 256x256 | 8 epoch | 

| Model | Loss  | Optimizer | train | test | additional | conclusion |
| :-----: | :-: | :-: | :-: | :-: | :-: | :-: |
| Segnet | CrossEntropyLoss | SGD | 0.182 | 0.177 | - | 
| Segnet | Focal | SGD | 0.183 | 0.175 | - | 
| Segnet | Focal | Adam | 0.235 | 0.219 | - | 
| Segnet | CrossEntropyLoss | Adam | 0.261 (0.283) | 0.251 (0.267)| - | Adam > SGD|
| Unet | CrossEntropyLoss | Adam | 0.173 | 0.164 | - | 
| Unet | Focal | Adam | 0.233 (0.199) | 0.216 (0.188) | - | focal loss for Unet
| Unet | Focal | Adam | 0.162 | 0.172 | aug |
| Unet | Focal | Adam | 0.216 | 0.201 | 512 image size | increase image size not giving better results
| Unet | Focal | Adam | 0.21 | 0. | 512 image size + 20 epoch |
| Unet | Focal | Adam | 0.17 | 0.167 | 256 image size + 16 epoch |
| Unet | Focal | Adam | 0.17 | 0.162 | 256 image size + 20 epoch + lr=0.0001|