# Cityscapes

In [None]:
import os
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms as transforms

In [None]:
import utils

In [None]:
class CityScapes(torch.utils.data.Dataset):
    """
    Cityscapes dataset from https://www.cityscapes-dataset.com/
    """
    
    # training dataset root directories
    train_dir = 'leftImg8bit_trainvaltest/leftImg8bit/train'
    train_label_dir = 'gtFine_trainvaltest/gtFine/train'
    
    # validation dataset root directories
    valid_dir = 'leftImg8bit_trainvaltest/leftImg8bit/val'
    valid_label_dir = 'gtFine_trainvaltest/gtFine/val'
    
    # test dataset root directories
    test_dir = 'leftImg8bit_trainvaltest/leftImg8bit/test'
    test_label_dir = 'gtFine_trainvaltest/gtFine/test'
    
    # images extension
    img_extension = '.png'
    label_name_filter = 'labelIds'
    
    # the values associated with the 35 classes
    full_classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
                    17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
                    32, 33, -1)

    # the values above are remapped to the following
    new_classes = (0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 5, 0, 0, 0, 6, 0, 7,
                   8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 17, 18, 19, 0)
    
    # default encoding for pixel value, class name and class color
    from collections import OrderedDict
    
    color_encoding = OrderedDict([
        ('road', (128, 64, 128)),  # RGB format
        ('sidewalk', (244, 35, 232)),
        ('building', (70, 70, 70)),
        ('wall', (102, 102, 156)),
        ('fence', (190, 153, 153)),
        ('pole', (153, 153, 153)),
        ('traffic_light', (250, 170, 30)),
        ('traffic_sign', (220, 220, 0)),
        ('vegetation', (107, 142, 35)),
        ('terrain', (152, 251, 152)),
        ('sky', (70, 130, 180)),
        ('person', (220, 20, 60)),
        ('rider', (255, 0, 0)),
        ('car', (0, 0, 142)),
        ('truck', (0, 0, 70)),
        ('bus', (0, 60, 100)),
        ('train', (0, 80, 100)),
        ('motorcycle', (0, 0, 230)),
        ('bicycle', (119, 11, 32)),
        # ('unlabeled', (0, 0, 0))
    ])
    
    def __init__(self, 
                 root_dir, 
                 mode='train', 
                 data_transform=None, 
                 label_transform=None):
        
        self.root_dir = root_dir
        self.mode = mode
        self.data_transform = data_transform
        self.label_transform = label_transform
        
        # get the training data and labels filepaths
        if self.mode.lower() == 'train':
            self.train_data = utils.get_files(os.path.join(root_dir, self.train_dir), 
                                                           extension_filter=self.img_extension)
            
            self.train_labels = utils.get_files(os.path.join(root_dir, self.train_label_dir), 
                                                             extension_filter=self.img_extension)
            
        # get the validation data and labels filepaths
        elif self.mode.lower() == 'valid':
            self.valid_data = utils.get_files(os.path.join(root_dir, self.valid_dir), 
                                                           extension_filter=self.img_extension)
            
            self.valid_labels = utils.get_files(os.path.join(root_dir, self.valid_label_dir), 
                                                             extension_filter=self.img_extension)
            
        # get the test data and labels filepaths
        elif self.mode.lower() == 'test':
            self.test_data = utils.get_files(os.path.join(root_dir, self.test_dir), 
                                                          extension_filter=self.img_extension)
            
            self.test_labels = utils.get_files(os.path.join(root_dir, self.test_label_dir), 
                                                            extension_filter=self.img_extension)
        
        else:
            raise RuntimeError('Unexpected dataset mode. Supported modes are: train, valid and test')
    
    def __getitem__(self, index):
        
        if self.mode.lower() == 'train':
            data_path, label_path = self.train_data[index], self.train_labels[index]
            
        elif self.mode.lower() == 'valid':
            data_path, label_path = self.valid_data[index], self.valid_labels[index]
        
        elif self.mode.lower() == 'test':
            data_path, label_path = self.test_data[index], self.test_labels[index]
        
        else:
            raise RuntimeError('Unexpected dataset mode. Supported modes are: train, valid and test')
        
        image, label = self.loader(data_path, label_path)
        
        # remap class labels
        label = utils.remap(label, self.full_classes, self.new_classes)
        
        if self.data_transform is not None:
            image = self.data_transform(image)
            
        if self.label_transform is not None:
            label = self.label_transform(label)
            
        # perform one-hot-encoding
        target = utils.one_hot_encode(label)
        target = torch.FloatTensor(target)
            
        return image, label, target
    
    def __len__(self):
        
        if self.mode.lower() == 'train':
            return len(self.train_data)
        
        elif self.mode.lower() == 'valid':
            return len(self.valid_data)
        
        elif self.mode.lower() == 'test':
            return len(self.test_data)
        else:
            raise RuntimeError('Unexpected dataset mode. Supported modes are: train, valid and test')

## Sanity Check

In [None]:
DATASET_DIR = './data'
HEIGHT, WIDTH = 360, 480
BATCH_SIZE = 10
WORKERS = 4

In [None]:
data_transform = transforms.Compose([transforms.Resize((HEIGHT, WIDTH)), 
                                     transforms.ToTensor()])

label_transform = transforms.Compose([transforms.Resize((HEIGHT, WIDTH), Image.NEAREST),
                                      transforms.ToTensor()])

In [None]:
cityscapes_set = CityScapes(DATASET_DIR, mode='train', 
                            data_transform=data_transform, label_transform=label_transform)

class_encoding = cityscapes_set.color_encoding
num_classes = len(class_encoding)

In [None]:
cityscapes_loader = torch.utils.data.DataLoader(cityscapes_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

In [None]:
images, labels, targets = iter(cityscapes_loader).next()
print("Images size:", images.size())
print("Labels size:", labels.size())
print("Targets size:", targets.size())
print("Number of class:", num_classes)

In [None]:
for i in range(5):

    plt.figure(figsize=(20, 20))
    plt.subplot(1,3,1)
    plt.title(f'Data ({i})')
    plt.axis('off')
    plt.imshow(images[i,0], cmap='gray')

    plt.subplot(1,3,2)
    plt.imshow(labels[i,0])
    plt.title(f'Label ({i})')
    plt.axis('off')
    
    num_classes, targets = utils.one_hot_encode_for_sanity_check(num_classes, labels)
    targets = torch.tensor(targets)
    
    plt.subplot(1,3,3)
    c = random.randint(0, num_classes-1)
    plt.imshow(targets[i, c])
    plt.title(f'One-Hot-Label ({c})')
    plt.axis('off')
    plt.show()

---