In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision.io import decode_image
import cv2 as cv
import matplotlib.pyplot as plt
import pandas as pd

from dataclasses import dataclass
from clearml import Task
import os

In [3]:
@dataclass
class Config:
    DATA_PATH: str = "/home/hxastur/vscode_projects/cityscapes-segmentation/dataset"
    TRAIN_PATH: str = os.path.join(DATA_PATH, "gtFine/train")
    TEST_PATH: str = os.path.join(DATA_PATH, "gtFine/test")
    VAL_PATH: str = os.path.join(DATA_PATH, "gtFine/val")


config = Config()

In [4]:
class Processor:
    def __init__(self, DATA_PATH):
        self.DATA_PATH = DATA_PATH

    def get_images(self):
        images = {}
        cities = os.listdir(self.DATA_PATH)
        for city in cities:
            city_path = os.path.join(self.DATA_PATH, city)
            files = os.listdir(city_path)
            for file in files:
                full_path = os.path.join(city_path, file)
                splitted = file.split("_")
                if len(splitted) != 5:
                    raise ValueError("Len of splitted != 5")
                image_type = splitted[-1].replace(".png", "").replace(".json", "")
                image_city = splitted[0]
                image_id = splitted[1] + "_" + splitted[2]

                image_name = f"{image_city}_{image_id}"
                image_arr = images.get(image_name, {})
                image_arr.update({image_type: full_path})
                images.update({image_name: image_arr})

        for imgid in images.keys():
            if len(images[imgid]) != 4:
                raise ValueError("Len of arr %4 != 0")

        return images

In [5]:
class CityscapesDataset(Dataset):
    def __init__(self, processor: Processor):
        self.images = processor.get_images()
        self.keys = list(self.images.keys())

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        image_index = self.keys[idx]
        image = self.images[image_index]

        color_path = image["color"]
        instanceIds_path = image["instanceIds"]
        labelIds_path = image["labelIds"]
        polygons_path = image["polygons"]

        color = decode_image(color_path)
        return color

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNNBlock, self).__init__()
        self.cnn = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.batchnorm(self.cnn))


class Block(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_cnn,
        pool=False,
        upsample=False,
        softmax=False,
    ):
        super(Block, self).__init__()
        self.softmax = softmax
        if num_cnn == 2:
            self.block = nn.ModuleList(
                [
                    CNNBlock(in_channels=in_channels, out_channels=out_channels),
                    CNNBlock(in_channels=out_channels, out_channels=out_channels),
                ]
            )
        if num_cnn == 3:
            self.block = nn.ModuleList(
                [
                    CNNBlock(in_channels=in_channels, out_channels=out_channels),
                    CNNBlock(in_channels=out_channels, out_channels=out_channels),
                    CNNBlock(in_channels=out_channels, out_channels=out_channels),
                ]
            )
        if pool:
            self.block.append(nn.MaxPool2d(2, stride=2, return_indices=True))
        if upsample:
            self.block.append(nn.MaxUnpool2d(2, stride=2))
        if softmax:
            self.block.append(nn.Softmax(dim=1))

    def forward(self, x):
        for module in self.block:
            x = module(x)
        return x


class SegNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=32):
        super(SegNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

In [21]:
block = Block(1, 2, 3, pool=True)

ModuleList(
  (0): CNNBlock(
    (cnn): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (1-2): 2 x CNNBlock(
    (cnn): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)


In [None]:
processor = Processor(config.TRAIN_PATH)
dataset = CityscapesDataset(processor)
dataloader = DataLoader(dataset=dataset)