In [1]:
from PIL import Image
from PIL import ImageFilter
import numpy as np


from matplotlib import pyplot as plt
%matplotlib inline

In [2]:
img_path = "/ssd_scratch/cvit/dksingh/cityscapes/leftImg8bit/train/aachen/aachen_000173_000019_leftImg8bit.png"
color_path = "/ssd_scratch/cvit/dksingh/cityscapes/gtFine/train/aachen/aachen_000173_000019_gtFine_color.png"
labelids_path = "/ssd_scratch/cvit/dksingh/cityscapes/gtFine/train/aachen/aachen_000173_000019_gtFine_labelIds.png" 

In [3]:
from collections import namedtuple
Label = namedtuple( 'Label' , [

    'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                    # We use them to uniquely name a class

    'id'          , # An integer ID that is associated with this label.
                    # The IDs are used to represent the label in ground truth images
                    # An ID of -1 means that this label does not have an ID and thus
                    # is ignored when creating ground truth images (e.g. license plate).
                    # Do not modify these IDs, since exactly these IDs are expected by the
                    # evaluation server.

    'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                    # ground truth images with train IDs, using the tools provided in the
                    # 'preparation' folder. However, make sure to validate or submit results
                    # to our evaluation server using the regular IDs above!
                    # For trainIds, multiple labels might have the same ID. Then, these labels
                    # are mapped to the same class in the ground truth images. For the inverse
                    # mapping, we use the label that is defined first in the list below.
                    # For example, mapping all void-type classes to the same ID in training,
                    # might make sense for some approaches.
                    # Max value is 255!

    'category'    , # The name of the category that this label belongs to

    'categoryId'  , # The ID of this category. Used to create ground truth images
                    # on category level.

    'hasInstances', # Whether this label distinguishes between single instances or not

    'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                    # during evaluations or not

    'color'       , # The color of this label
    ] )
labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]

In [4]:
id2name = {label.id:label.name for label in labels}
name2id = {label.name:label.id for label in labels}

In [5]:
# for label_val in uniq_lid_labels:

#     temp = lid.point(lambda i: i == label_val)
#     temp.save(id2name[label_val]+".png")
    

In [6]:
# person_mask = lid_data == name2id["person"]

# print(f"img_data.shape:{img_data.shape}")
# print(f"person_mask.shape:{person_mask.shape}")

# person_mask_stack = np.stack((person_mask, person_mask, person_mask), axis=-1)
# print(f"person_mask_stack.shape: {person_mask_stack.shape}")

# # img_person = np.ma.array(img_data, mask = person_mask_stack)

# # plt.imshow(img_data*person_mask_stack)
# # plt.show()
# # Image.paste(im = im, mask=person_mask_stack)

In [7]:
# for uniq_lbl in uniq_lid_labels_set:
#     chosen_lbl_lid = lid_data == uniq_lbl
    
#     chosen_mask_stack = np.stack((chosen_lbl_lid,chosen_lbl_lid, chosen_lbl_lid), axis=-1)
    
#     chosen_lbl_img = Image.fromarray(img_data*chosen_mask_stack)
#     chosen_lbl_img.save("results/"+id2name[uniq_lbl]+".png")
    
    

In [8]:
# test = Image.open("./results/person.png")
# test

In [57]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class CityscapesDataset(Dataset):
    
    def __init__(self,img_path_list, transform=None):
        
        self.len = len(img_path_list)
        self.img_path_list = img_path_list
        
        self.blacklisted_label_ids = {3, -1}
            
        
    def get_paths(self,img_path):
       
        left_img_path = img_path
        label_img_path = img_path.replace("leftImg8bit","gtFine",1)[0:-15]+"gtFine_labelIds.png"
        color_img_path = img_path.replace("leftImg8bit","gtFine",1)[0:-15]+"gtFine_color.png"
        
        paths_dict = {}
        paths_dict["img_path"] = left_img_path
        paths_dict["label_img_path"] = label_img_path
        paths_dict["color_img_path"] = color_img_path
        
        return paths_dict
        

    def __getitem__(self, index):
        try:
            
            img_path = self.img_path_list[index]
            
            #get appropriate image paths
            self.paths = self.get_paths(img_path)
#             print(self.paths)
            self.img = Image.open(self.paths["img_path"])
            self.label_img = Image.open(self.paths["label_img_path"])
            self.color_img = Image.open(self.paths["color_img_path"])


            #get unique labels from label img and remove unnecessary labelids
            uniq_lbls_set = set(np.unique(self.label_img)) - self.blacklisted_label_ids
            
            for uniq_lbl in uniq_lbls_set:
                chosen_label_mask = self.label_img == uniq_lbl
                chosen_label_mask_stack = np.stack((chosen_label_mask,chosen_label_mask, chosen_label_mask), axis=-1)
                chosen_label_mask_img = self.img * chosen_label_mask_stack

            return_obj = {}
            return_obj["leftImg8bit"] = torch.from_numpy(np.array(self.img))
            return_obj["labelIds"] = torch.from_numpy(np.array(self.label_img))
            return_obj["color"] = torch.from_numpy(np.array(self.color_img))
            return_obj["unique_label_ids"] = torch.tensor(list((uniq_lbls_set)))
            
            return return_obj
            
        except IOError as e:
            print(e)
            return None
        
    
    def __len__(self):
        return self.len
    


In [58]:
file_path_list = ["/ssd_scratch/cvit/dksingh/cityscapes/leftImg8bit/train/aachen/aachen_000173_000019_leftImg8bit.png",
"/ssd_scratch/cvit/dksingh/cityscapes/leftImg8bit/train/aachen/aachen_000173_000019_leftImg8bit.png",
"/ssd_scratch/cvit/dksingh/cityscapes/leftImg8bit/train/aachen/aachen_000173_000019_leftImg8bit.png",
"/ssd_scratch/cvit/dksingh/cityscapes/leftImg8bit/train/aachen/aachen_000173_000019_leftImg8bit.png",
"/ssd_scratch/cvit/dksingh/cityscapes/leftImg8bit/train/aachen/aachen_000173_000019_leftImg8bit.png",
"/ssd_scratch/cvit/dksingh/cityscapes/leftImg8bit/train/aachen/aachen_000173_000019_leftImg8bit.png",
"/ssd_scratch/cvit/dksingh/cityscapes/leftImg8bit/train/aachen/aachen_000173_000019_leftImg8bit.png",
"/ssd_scratch/cvit/dksingh/cityscapes/leftImg8bit/train/aachen/aachen_000173_000019_leftImg8bit.png"]
cd = CityscapesDataset(file_path_list)

In [61]:
cd_loader = DataLoader(dataset=cd, batch_size=8, shuffle=True)  

In [62]:
for data in cd_loader:
    print(data)

{'leftImg8bit': tensor([[[[ 41,  64,  62],
          [ 33,  59,  53],
          [ 46,  63,  43],
          ...,
          [ 50,  63,  54],
          [ 50,  63,  53],
          [ 50,  63,  52]],

         [[ 49,  66,  72],
          [ 45,  75,  78],
          [ 83, 107, 100],
          ...,
          [ 49,  64,  54],
          [ 47,  63,  53],
          [ 47,  62,  51]],

         [[ 46,  59,  67],
          [ 47,  73,  77],
          [107, 131, 127],
          ...,
          [ 48,  64,  55],
          [ 47,  64,  55],
          [ 46,  63,  53]],

         ...,

         [[ 46,  60,  50],
          [ 46,  60,  49],
          [ 45,  59,  49],
          ...,
          [ 33,  47,  38],
          [ 35,  48,  39],
          [ 38,  47,  39]],

         [[ 46,  60,  50],
          [ 46,  60,  49],
          [ 45,  59,  49],
          ...,
          [ 37,  49,  40],
          [ 35,  47,  38],
          [ 36,  47,  37]],

         [[ 46,  60,  50],
          [ 46,  60,  49],
          [ 45,  59,

In [28]:
test_img = Image.open("./results/person.png")
np.array(test_img).shap

(1024, 2048, 3)