# Import

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.optim import Adam
from torchvision.io import decode_image
import cv2 as cv
import matplotlib.pyplot as plt
import pandas as pd
import torchvision.transforms.functional as TF
from torchvision import transforms
from PIL import Image
from tqdm.notebook import tqdm
import numpy as np
from sklearn.model_selection import train_test_split

from dataclasses import dataclass
from clearml import Task
import os

# Config

In [None]:
@dataclass
class Config:
    DATA_PATH: str = "/home/hxastur/vscode-projects/cityscapes-segmentation/dataset"
    ANNOTATIONS_DATA_PATH: str = os.path.join(DATA_PATH, "gtFine/gtFine")
    ANNOTATIONS_TRAIN_PATH: str = os.path.join(ANNOTATIONS_DATA_PATH, "train")
    ANNOTATIONS_TEST_PATH: str = os.path.join(ANNOTATIONS_DATA_PATH, "test")
    ANNOTATIONS_VAL_PATH: str = os.path.join(ANNOTATIONS_DATA_PATH, "val")
    ANNOTATION_TYPES = ["color.png", "instanceIds.png", "labelIds.png", "polygons.json"]
    IMAGE_DATA_PATH: str = os.path.join(DATA_PATH, "left/leftImg8bit")
    IMAGE_TRAIN_PATH: str = os.path.join(IMAGE_DATA_PATH, "train")
    IMAGE_TEST_PATH: str = os.path.join(IMAGE_DATA_PATH, "test")
    IMAGE_VAL_PATH: str = os.path.join(IMAGE_DATA_PATH, "val")
    IMAGE_TYPE: str = "leftImg8bit"
    ANNOTATIONS_PREFIX: str = "gtFine"
    SAVE_PATH: str = (
        "/home/hxastur/vscode-projects/cityscapes-segmentation/saved_models"
    )
    batch_size: int = 1
    learning_rate: float = 3e-4
    epochs: int = 5
    IMAGE_SIZE = (64, 128)
    evalInterval = 1


config = Config()
device = "cuda" if torch.cuda.is_available else "spu"
print(device)

In [None]:
NUM_CLASSES = 34

CITYSCAPES_MASK_CLASSES = {
    0: "unlabeled",
    1: "ego vehicle",
    2: "rectification border",
    3: "out of roi",
    4: "static",
    5: "dynamic",
    6: "ground",
    7: "road",
    8: "sidewalk",
    9: "parking",
    10: "rail track",
    11: "building",
    12: "wall",
    13: "fence",
    14: "guard rail",
    15: "bridge",
    16: "tunnel",
    17: "pole",
    18: "polegroup",
    19: "traffic light",
    20: "traffic sign",
    21: "vegetation",
    22: "terrain",
    23: "sky",
    24: "person",
    25: "rider",
    26: "car",
    27: "truck",
    28: "bus",
    29: "caravan",
    30: "trailer",
    31: "train",
    32: "motorcycle",
    33: "bicycle",
}

CITYSCAPES_MASK_COLORS = {
    0: (0, 0, 0),
    1: (0, 0, 0),
    2: (0, 0, 0),
    3: (0, 0, 0),
    4: (0, 0, 0),
    5: (111, 74, 0),
    6: (81, 0, 81),
    7: (128, 64, 128),
    8: (244, 35, 232),
    9: (250, 170, 160),
    10: (230, 150, 140),
    11: (70, 70, 70),
    12: (102, 102, 156),
    13: (190, 153, 153),
    14: (180, 165, 180),
    15: (150, 100, 100),
    16: (150, 120, 90),
    17: (153, 153, 153),
    18: (153, 153, 153),
    19: (250, 170, 30),
    20: (220, 220, 0),
    21: (107, 142, 35),
    22: (152, 251, 152),
    23: (70, 130, 180),
    24: (220, 20, 60),
    25: (255, 0, 0),
    26: (0, 0, 142),
    27: (0, 0, 70),
    28: (0, 60, 100),
    29: (0, 0, 90),
    30: (0, 0, 110),
    31: (0, 80, 100),
    32: (0, 0, 230),
    33: (119, 11, 32),
}

# Processor

[Github Dataset Link](https://github.com/mcordts/cityscapesScripts)

The folder structure of the Cityscapes dataset is as follows:

**{root}/{type}{video}/{split}/{city}/{city}_{seq:0>6}_{frame:0>6}_{type}{ext}**

The meaning of the individual elements is:

**root** the root folder of the Cityscapes dataset. Many of our scripts check if an environment variable CITYSCAPES_DATASET pointing to this folder exists and use this as the default choice.

**type** the type/modality of data, e.g. gtFine for fine ground truth, or leftImg8bit for left 8-bit images.

**split** the split, i.e. train/val/test/train_extra/demoVideo. Note that not all kinds of data exist for all splits. Thus, do not be surprised to occasionally find empty folders.

**city** the city in which this part of the dataset was recorded.

**seq** the sequence number using 6 digits.

**frame** the frame number using 6 digits. Note that in some cities very few, albeit very long sequences were recorded, while in some cities many short sequences were recorded, of which only the 19th frame is annotated.

**ext** the extension of the file and optionally a suffix, e.g. _polygons.json for ground truth files

In [None]:
class Processor:
    def __init__(
        self,
        IMAGE_DATA_PATH,
        ANNOTATIONS_DATA_PATH,
        ANNOTATIONS_TYPES,
        ANNOTATIONS_PREFIX="_gtFine",
    ):
        self.IMAGE_DATA_PATH = IMAGE_DATA_PATH
        self.ANNOTATIONS_DATA_PATH = ANNOTATIONS_DATA_PATH
        self.ANNOTATIONS_PREFIX = ANNOTATIONS_PREFIX
        self.ANNOTATIONS_TYPES = ANNOTATIONS_TYPES

    def get_images(self):
        images = {}
        cities = os.listdir(self.IMAGE_DATA_PATH)
        for city in cities:
            city_image_path = os.path.join(self.IMAGE_DATA_PATH, city)
            files_image = os.listdir(city_image_path)
            for file in files_image:
                full_image_path = os.path.join(city_image_path, file)
                splitted = file.split("_")
                if len(splitted) != 4:
                    raise ValueError("Len of splitted != 4")
                image_type = "left"
                image_city = splitted[0]
                sequence_number = splitted[1]
                frame_number = splitted[2]
                image_name = f"{image_city}_{sequence_number}_{frame_number}"

                image_arr = images.get(image_name, {})
                image_arr.update({"left": full_image_path})
                for ANNOTATION_TYPE in self.ANNOTATIONS_TYPES:
                    annot_type = ANNOTATION_TYPE.split(".")[0]
                    image_arr.update(
                        {
                            annot_type: os.path.join(
                                self.ANNOTATIONS_DATA_PATH,
                                f"{image_city}/{image_name}{self.ANNOTATIONS_PREFIX}_{ANNOTATION_TYPE}",
                            )
                        }
                    )
                images.update({image_name: image_arr})

        for imgid in images.keys():
            if len(images[imgid]) != 5:
                raise ValueError("Len of arr %5 != 0")

        return images

# Dataset

In [None]:
class CityscapesDataset(Dataset):
    def __init__(self, images: dict, keys=None, size=(256, 512)):
        self.images = images
        if not keys:
            self.keys = list(self.images.keys())
        else:
            self.keys = keys
        self.size = size

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        image_index = self.keys[idx]
        image_dict = self.images[image_index]

        image_path = image_dict["left"]
        labelIds_path = image_dict["labelIds"]

        # mask_path = image_dict["color"]
        # instanceIds_path = image_dict["instanceIds"]
        # polygons_path = image_dict["polygons"]

        image = Image.open(image_path)
        mask = Image.open(labelIds_path)

        transform = transforms.Compose(
            [
                transforms.Resize(size=self.size, interpolation=Image.NEAREST),
            ]
        )
        image = transform(image)
        mask = transform(mask)

        mask_array = np.array(mask)
        mask_tensor = torch.from_numpy(mask_array).long()
        image_tensor = TF.to_tensor(image)  # C,H,W

        return image_tensor, mask_tensor

In [None]:
processor = Processor(
    IMAGE_DATA_PATH=config.IMAGE_TRAIN_PATH,
    ANNOTATIONS_DATA_PATH=config.ANNOTATIONS_TRAIN_PATH,
    ANNOTATIONS_TYPES=config.ANNOTATION_TYPES,
)
images = processor.get_images()

## Split

In [None]:
def get_index_splits(images: dict):
    """
    Возвращает индексы train и test, которые передаются а датасет при создании
    """
    keys_list = list(images.keys())
    train_images, test_images = train_test_split(
        keys_list, test_size=0.2, random_state=42
    )
    return train_images, test_images


train_images_idx, test_images_idx = get_index_splits(images)

In [None]:
len(train_images_idx), len(test_images_idx)

## Visualize

In [None]:
def visualise(CITYSCAPES_MASK_COLORS, images, val_index):
    dataiter = iter(CityscapesDataset(images, keys=val_index))
    image, mask = next(dataiter)
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))

    def decode_segmap(mask, colormap=CITYSCAPES_MASK_COLORS):
        h, w = mask.shape
        color_mask = np.zeros((h, w, 3), dtype=np.uint8)
        for label in range(len(colormap)):
            num_true = (mask == label).sum().item()
            color_mask[mask == label] = colormap[label]
        return color_mask

    color_mask = decode_segmap(mask.numpy(), CITYSCAPES_MASK_COLORS)
    blended = (0.5 * image.permute(1, 2, 0).numpy() + 0.5 * (color_mask / 255.0)).clip(
        0, 1
    )

    axes[0].imshow(image.permute(1, 2, 0))
    axes[1].imshow(blended)


visualise(CITYSCAPES_MASK_COLORS, images, val_index=test_images_idx)

# Net

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNNBlock, self).__init__()
        self.cnn = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.batchnorm(self.cnn(x)))


class Block(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_cnn,
        pool=False,
        upsample=False,
        softmax=False,
    ):
        super(Block, self).__init__()
        self.softmax = softmax
        self.pool = pool
        self.upsample = upsample
        if num_cnn == 2:
            self.block = nn.ModuleList(
                [
                    CNNBlock(in_channels=in_channels, out_channels=out_channels),
                    CNNBlock(in_channels=out_channels, out_channels=out_channels),
                ]
            )
        if num_cnn == 3:
            self.block = nn.ModuleList(
                [
                    CNNBlock(in_channels=in_channels, out_channels=out_channels),
                    CNNBlock(in_channels=out_channels, out_channels=out_channels),
                    CNNBlock(in_channels=out_channels, out_channels=out_channels),
                ]
            )

        self.mp = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.mup = nn.MaxUnpool2d(2, stride=2)
        self.sm = nn.Softmax(dim=1)

    def forward(self, x, ind=None):

        if self.upsample:
            x = self.mup(x, ind)
            # print(f"UPSAMPLE {x.shape}")

        for module in self.block:
            x = module(x)

        if self.pool:
            x, ind = self.mp(x)
            # print(f"POOL {x.shape}")
            return x, ind

        if self.softmax:
            x = self.sm(x)
        # print(f"ELSE {x.shape}")
        return x


class SegNet(nn.Module):
    def __init__(
        self, in_channels=3, out_channels=32, num_two=2, num_blocks=5, channel_step=64
    ):
        super(SegNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.poolblock = nn.ModuleList(
            [
                (
                    Block(
                        in_channels=(
                            channel_step * 2 ** (i - 1) if i != 0 else in_channels
                        ),
                        out_channels=channel_step * 2**i,
                        num_cnn=2 if i < num_two else 3,
                        pool=True,
                    )
                )
                for i in range(num_blocks)
            ]
        )
        self.unpoolblock = nn.ModuleList(
            [
                Block(
                    in_channels=(channel_step * 2 ** (4 - i)),
                    out_channels=(
                        channel_step * 2 ** (3 - i)
                        if i != num_blocks - 1
                        else out_channels
                    ),
                    num_cnn=2 if i < num_two else 3,
                    upsample=True,
                    softmax=True if i == num_blocks - 1 else False,
                    # softmax=False,
                )
                for i in range(num_blocks)
            ]
        )

    def forward(self, x):
        index_list = []
        # print(self.unpoolblock)
        for module in self.poolblock:
            x, ind = module(x)
            index_list.append(ind)
        for i, module in enumerate(self.unpoolblock):
            ind = index_list[4 - i]
            # print(x.shape, ind.shape)
            x = module(x, ind)
        return x

# Loss

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, num_classes):
        super(DiceLoss, self).__init__()
        self.num_classes = num_classes

    def forward(self, pred, target):
        smooth = 1

        target_one_hot = nn.functional.one_hot(
            target.long(), num_classes=self.num_classes
        )
        target_one_hot = target_one_hot.permute(0, 3, 1, 2).float()  # [B, C, H, W]

        intersection = torch.sum(pred * target_one_hot, dim=(0, 2, 3))
        pred_sum = torch.sum(pred, dim=(0, 2, 3))
        target_sum = torch.sum(target_one_hot, dim=(0, 2, 3))

        dice = 1 - ((2.0 * intersection + smooth) / (pred_sum + target_sum + smooth))
        return dice.mean()

# Train

In [None]:
class Trainer:
    def __init__(self, trainDataloader, testDataloader, evalInterval, savePath):
        self.trainDataloader = trainDataloader
        self.testDataloader = testDataloader
        self.evalInterval = evalInterval
        self.savePath = savePath

    def train(self, net, optimizer, epochs, criterion):
        for epoch in range(epochs):
            epoch_loss = 0.0
            for i, (batch_image, batch_mask) in tqdm(
                enumerate(self.trainDataloader), total=len(self.trainDataloader)
            ):
                batch_image, batch_mask = batch_image.to(device), batch_mask.to(device)
                optimizer.zero_grad()

                output = net(batch_image)

                loss = criterion(output, batch_mask)
                print(f"loss: {loss}")
                loss.backward()
                epoch_loss += loss
                optimizer.step()
            epoch_loss /= len(self.trainDataloader)
            print(f"train loss: {epoch_loss}")

            if (epoch + 1) % self.evalInterval == 0:
                self.test()

    def test(self, net):
        with torch.no_grad():
            for images, masks in self.testLoader:
                images, masks = images.to(device), masks.to(device)
                outputs = net(images)

    def save_model(self, model):
        filename = "model"
        filepath = os.path.join(self.savePath, filename)
        torch.save(model, filepath)
        print(f"Saved model with name: {filename}")

# Execute

In [None]:
processor = Processor(
    IMAGE_DATA_PATH=config.IMAGE_TRAIN_PATH,
    ANNOTATIONS_DATA_PATH=config.ANNOTATIONS_TRAIN_PATH,
    ANNOTATIONS_TYPES=config.ANNOTATION_TYPES,
)
images = processor.get_images()
train_images_idx, test_images_idx = get_index_splits(images)

trainDataset = CityscapesDataset(images, train_images_idx)
testDataset = CityscapesDataset(images, test_images_idx)

trainDataloader = DataLoader(dataset=trainDataset, batch_size=config.batch_size)
testDataloader = DataLoader(dataset=testDataset, batch_size=1)

trainer = Trainer(
    trainDataloader=trainDataloader,
    testDataloader=testDataloader,
    evalInterval=config.evalInterval,
    savePath=config.SAVE_PATH,
)

net = SegNet(in_channels=3, out_channels=NUM_CLASSES).to(device)
optimizer = Adam(net.parameters(), lr=config.learning_rate)
criterion = DiceLoss(NUM_CLASSES)

trainer.train(net=net, optimizer=optimizer, epochs=config.epochs, criterion=criterion)

In [None]:
# processor = Processor(
#     IMAGE_DATA_PATH=config.IMAGE_TRAIN_PATH,
#     ANNOTATIONS_DATA_PATH=config.ANNOTATIONS_TRAIN_PATH,
#     ANNOTATIONS_TYPES=config.ANNOTATION_TYPES,
# )
# dataset = CityscapesDataset(processor)
# dataloader = DataLoader(dataset=dataset, batch_size=config.batch_size)
# data = next(iter(dataloader))