# Cityscapes

In [1]:
import os
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms as transforms

In [2]:
import utils

In [3]:
class CityScapes(torch.utils.data.Dataset):
    """
    Cityscapes dataset from https://www.cityscapes-dataset.com/
    """
    
    # training dataset root directories
    train_dir = 'leftImg8bit_trainvaltest/leftImg8bit/train'
    train_label_dir = 'gtFine_trainvaltest/gtFine/train'
    
    # validation dataset root directories
    valid_dir = 'leftImg8bit_trainvaltest/leftImg8bit/val'
    valid_label_dir = 'gtFine_trainvaltest/gtFine/val'
    
    # test dataset root directories
    test_dir = 'leftImg8bit_trainvaltest/leftImg8bit/test'
    test_label_dir = 'gtFine_trainvaltest/gtFine/test'
    
    # images extension
    img_extension = '.png'
    label_name_filter = 'labelIds'
    
    # the values associated with the 35 classes
    full_classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
                    17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
                    32, 33, -1)

    # the values above are remapped to the following
    new_classes = (0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 5, 0, 0, 0, 6, 0, 7,
                   8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 17, 18, 19, 0)
    
    # default encoding for pixel value, class name and class color
    from collections import OrderedDict
    
    color_encoding = OrderedDict([
        ('unlabeled', (0, 0, 0)),
        ('road', (128, 64, 128)),
        ('sidewalk', (244, 35, 232)),
        ('building', (70, 70, 70)),
        ('wall', (102, 102, 156)),
        ('fence', (190, 153, 153)),
        ('pole', (153, 153, 153)),
        ('traffic_light', (250, 170, 30)),
        ('traffic_sign', (220, 220, 0)),
        ('vegetation', (107, 142, 35)),
        ('terrain', (152, 251, 152)),
        ('sky', (70, 130, 180)),
        ('person', (220, 20, 60)),
        ('rider', (255, 0, 0)),
        ('car', (0, 0, 142)),
        ('truck', (0, 0, 70)),
        ('bus', (0, 60, 100)),
        ('train', (0, 80, 100)),
        ('motorcycle', (0, 0, 230)),
        ('bicycle', (119, 11, 32))
    ])
    
    def __init__(self, 
                 root_dir, 
                 mode='train', 
                 data_transform=None, 
                 label_transform=None, 
                 loader=utils.pil_loader):
        pass

---