In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision.io import decode_image
import cv2 as cv
import matplotlib.pyplot as plt
import pandas as pd

from dataclasses import dataclass
from clearml import Task
import os

# Config

In [25]:
@dataclass
class Config:
    DATA_PATH: str = "/home/hxastur/vscode-projects/cityscapes-segmentation/dataset"
    ANNOTATIONS_DATA_PATH: str = os.path.join(DATA_PATH, "gtFine/gtFine")
    ANNOTATIONS_TRAIN_PATH: str = os.path.join(ANNOTATIONS_DATA_PATH, "train")
    ANNOTATIONS_TEST_PATH: str = os.path.join(ANNOTATIONS_DATA_PATH, "test")
    ANNOTATIONS_VAL_PATH: str = os.path.join(ANNOTATIONS_DATA_PATH, "val")
    LEFT_DATA_PATH: str = os.path.join(DATA_PATH, "left/leftImg8bit")
    LEFT_TRAIN_PATH: str = os.path.join(LEFT_DATA_PATH, "train")
    LEFT_TEST_PATH: str = os.path.join(LEFT_DATA_PATH, "test")
    LEFT_VAL_PATH: str = os.path.join(LEFT_DATA_PATH, "val")
    LEFT_TYPE: str = "leftImg8bit"
    ANNOTATIONS_PREFIX: str = "gtFine"
    batch_size: int = 8


config = Config()

# Processor

[Github Dataset Link](https://github.com/mcordts/cityscapesScripts)

The folder structure of the Cityscapes dataset is as follows:

**{root}/{type}{video}/{split}/{city}/{city}_{seq:0>6}_{frame:0>6}_{type}{ext}**

The meaning of the individual elements is:

**root** the root folder of the Cityscapes dataset. Many of our scripts check if an environment variable CITYSCAPES_DATASET pointing to this folder exists and use this as the default choice.

**type** the type/modality of data, e.g. gtFine for fine ground truth, or leftImg8bit for left 8-bit images.

**split** the split, i.e. train/val/test/train_extra/demoVideo. Note that not all kinds of data exist for all splits. Thus, do not be surprised to occasionally find empty folders.

**city** the city in which this part of the dataset was recorded.

**seq** the sequence number using 6 digits.

**frame** the frame number using 6 digits. Note that in some cities very few, albeit very long sequences were recorded, while in some cities many short sequences were recorded, of which only the 19th frame is annotated.

**ext** the extension of the file and optionally a suffix, e.g. _polygons.json for ground truth files

In [41]:
ANNOTATION_TYPES = ["color.png", "instanceIds.png", "labelIds.png", "polygons.json"]


class Processor:
    def __init__(
        self, LEFT_DATA_PATH, ANNOTATIONS_DATA_PATH, ANNOTATIONS_PREFIX="_gtFine"
    ):
        self.LEFT_DATA_PATH = LEFT_DATA_PATH
        self.ANNOTATIONS_DATA_PATH = ANNOTATIONS_DATA_PATH
        self.ANNOTATIONS_PREFIX = ANNOTATIONS_PREFIX

    def get_images(self):
        images = {}
        cities = os.listdir(self.LEFT_DATA_PATH)
        for city in cities:
            city_left_path = os.path.join(self.LEFT_DATA_PATH, city)
            files_left = os.listdir(city_left_path)
            for file in files_left:
                full_left_path = os.path.join(city_left_path, file)
                splitted = file.split("_")
                if len(splitted) != 4:
                    raise ValueError("Len of splitted != 4")
                image_type = "left"
                image_city = splitted[0]
                sequence_number = splitted[1]
                frame_number = splitted[2]
                image_name = f"{image_city}_{sequence_number}_{frame_number}"

                train_paths = [
                    f"{image_name}{self.ANNOTATIONS_PREFIX}_{ANNOTATION_TYPE}"
                    for ANNOTATION_TYPE in ANNOTATION_TYPES
                ]
                image_arr = images.get(image_name, {})
                image_arr.update({"left": full_left_path})
                for ANNOTATION_TYPE in ANNOTATION_TYPES:
                    annot_type = ANNOTATION_TYPE.split(".")[0]
                    image_arr.update(
                        {
                            annot_type: f"{image_name}{self.ANNOTATIONS_PREFIX}_{ANNOTATION_TYPE}"
                        }
                    )
                images.update({image_name: image_arr})

        for imgid in images.keys():
            if len(images[imgid]) != 5:
                raise ValueError("Len of arr %5 != 0")

        return images

In [43]:
processor = Processor(
    config.LEFT_TRAIN_PATH,
    ANNOTATIONS_DATA_PATH=config.ANNOTATIONS_TRAIN_PATH,
    ANNOTATIONS_PREFIX="gtFine",
)

# Dataset

In [4]:
class CityscapesDataset(Dataset):
    def __init__(self, processor: Processor):
        self.images = processor.get_images()
        self.keys = list(self.images.keys())

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        image_index = self.keys[idx]
        image = self.images[image_index]

        color_path = image["color"]
        instanceIds_path = image["instanceIds"]
        labelIds_path = image["labelIds"]
        polygons_path = image["polygons"]

        color = decode_image(color_path)
        return color

# Net

In [5]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNNBlock, self).__init__()
        self.cnn = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.batchnorm(self.cnn))


class Block(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_cnn,
        pool=False,
        upsample=False,
        softmax=False,
    ):
        super(Block, self).__init__()
        self.softmax = softmax
        self.pool = pool
        self.upsample = upsample
        if num_cnn == 2:
            self.block = nn.ModuleList(
                [
                    CNNBlock(in_channels=in_channels, out_channels=out_channels),
                    CNNBlock(in_channels=out_channels, out_channels=out_channels),
                ]
            )
        if num_cnn == 3:
            self.block = nn.ModuleList(
                [
                    CNNBlock(in_channels=in_channels, out_channels=out_channels),
                    CNNBlock(in_channels=out_channels, out_channels=out_channels),
                    CNNBlock(in_channels=out_channels, out_channels=out_channels),
                ]
            )

        self.mp = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.mup = nn.MaxUnpool2d(2, stride=2)
        self.sm = nn.Softmax(dim=1)

    def forward(self, x, ind=None):
        for module in self.block:
            x = module(x)

        if self.pool:
            x, ind = self.mp(x)
            return x, ind

        if self.upsample:
            x = self.mup(x, ind)

        if self.softmax:
            x = self.sm(x)

        return x


class SegNet(nn.Module):
    def __init__(
        self, in_channels=3, out_channels=32, num_two=2, num_blocks=5, channel_step=64
    ):
        super(SegNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.poolblock = nn.ModuleList(
            [
                (
                    Block(
                        in_channels=(
                            channel_step * 2 ** (i - 1) if i != 0 else in_channels
                        ),
                        out_channels=channel_step * 2**i,
                        num_cnn=2 if i < num_two else 3,
                        pool=True,
                    )
                )
                for i in range(num_blocks)
            ]
        )
        self.unpoolblock = nn.ModuleList(
            [
                Block(
                    in_channels=(channel_step * 2 ** (4 - i)),
                    out_channels=(
                        channel_step * 2 ** (3 - i)
                        if i != num_blocks - 1
                        else out_channels
                    ),
                    num_cnn=2 if i < num_two else 3,
                    upsample=True,
                    softmax=True if i == num_blocks - 1 else False,
                )
                for i in range(num_blocks)
            ]
        )

    def forward(self, x):
        index_list = []
        for module in self.poolblock:
            x, ind = module(x)
            index_list.append(ind)
        for i, module in enumerate(self.unpoolblock):
            ind = index_list[4 - i]
            x = module(x, ind)
        return x

In [31]:
processor = Processor(config.TRAIN_PATH)
dataset = CityscapesDataset(processor)
dataloader = DataLoader(dataset=dataset, batch_size=config.batch_size)
block = SegNet()

In [None]:
for i in enumerate(dataloader)