# Checking TF Dataset Pipeline

In [2]:
import os
import sys
import cv2
import json
import numpy as np
from IPython import display
import matplotlib.pyplot as plt
# from skimage.transform import resize

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.utils import plot_model, to_categorical
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array, load_img

from collections import namedtuple
print("Tensorflow version: ", tf.__version__)

Tensorflow version:  2.0.0


In [3]:
path_labelId = "cityscapes/gtFine/train/bochum/bochum_000000_027057_gtFine_labelIds.png"
path_trainId = "cityscapes/gtFine/train/bochum/bochum_000000_027057_gtFine_labelTrainIds.png"
path_color =   "cityscapes/gtFine/train/bochum/bochum_000000_027057_gtFine_color.png"

img_labelId = cv2.imread(path_labelId,0)
img_trainId = cv2.imread(path_trainId,0)
img_color = cv2.imread(path_color)

In [5]:
# img_labelId.shape

In [None]:
# img_labelId[:,1200].tolist()

In [None]:
# img_trainId[:,1200].tolist()

In [None]:
plt.figure(figsize=(16,8), dpi=150)
plt.imshow(img_color)
plt.show()

In [None]:
n_classes_total = 34
n_train_classes = 20

In [None]:
#--------------------------------------------------------------------------------
# Definitions
#--------------------------------------------------------------------------------

def get_labels():

    # a label and all meta information
    Label = namedtuple( 'Label' , [

        'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                        # We use them to uniquely name a class

        'id'          , # An integer ID that is associated with this label.
                        # The IDs are used to represent the label in ground truth images
                        # An ID of -1 means that this label does not have an ID and thus
                        # is ignored when creating ground truth images (e.g. license plate).
                        # Do not modify these IDs, since exactly these IDs are expected by the
                        # evaluation server.

        'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                        # ground truth images with train IDs, using the tools provided in the
                        # 'preparation' folder. However, make sure to validate or submit results
                        # to our evaluation server using the regular IDs above!
                        # For trainIds, multiple labels might have the same ID. Then, these labels
                        # are mapped to the same class in the ground truth images. For the inverse
                        # mapping, we use the label that is defined first in the list below.
                        # For example, mapping all void-type classes to the same ID in training,
                        # might make sense for some approaches.
                        # Max value is 255!

        'category'    , # The name of the category that this label belongs to

        'categoryId'  , # The ID of this category. Used to create ground truth images
                        # on category level.

        'hasInstances', # Whether this label distinguishes between single instances or not

        'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                        # during evaluations or not

        'color'       , # The color of this label
        ] )


    #--------------------------------------------------------------------------------
    # A list of all labels
    #--------------------------------------------------------------------------------

    # Please adapt the train IDs as appropriate for your approach.
    # Note that you might want to ignore labels with ID 255 during training.
    # Further note that the current train IDs are only a suggestion. You can use whatever you like.
    # Make sure to provide your results using the original IDs and not the training IDs.
    # Note that many IDs are ignored in evaluation and thus you never need to predict these!
    

    labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,        0 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,        0 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,        0 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,        0 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,        0 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,        0 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,        0 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        1 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        2 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,        0 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,        0 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        3 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        4 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        5 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,        0 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,        0 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,        0 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        6 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,        0 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        7 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        8 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        9 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,       10 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       11 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       12 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       13 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       15 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,        0 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,        0 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       17 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       18 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       19 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
    ]

    return labels

In [None]:
labels = get_labels() # a list of named tuples
id2label = { label.id : label for label in labels }
catid2label = { label.categoryId : label for label in labels }
trainId2label = { label.trainId : label for label in labels }

In [None]:
def label_to_rgb(mask):
    mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for i in range(0,n_train_classes):
        idx = mask[:,:,0]==i
        mask_rgb[idx] = trainId2label[i].color
        # mask_rgb[idx] = catid2label[i].color
        # mask_rgb[idx] = id2label[i].color
    return mask_rgb


def display(display_list):
    plt.figure(figsize=(15, 5), dpi=200)
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        #plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
def parse_record(raw_record):
    keys_to_features = {
      'image/encoded': tf.io.FixedLenFeature((), tf.string),
      'image/format': tf.io.FixedLenFeature((), tf.string),
      'image/height': tf.io.FixedLenFeature((), tf.int64),
      'image/width': tf.io.FixedLenFeature((), tf.int64),
      'image/channels': tf.io.FixedLenFeature((), tf.int64),
      'label/encoded': tf.io.FixedLenFeature((), tf.string),
      'label/format': tf.io.FixedLenFeature((), tf.string),
    }

    parsed = tf.io.parse_single_example(raw_record, keys_to_features)

    image = tf.image.decode_image(tf.reshape(parsed['image/encoded'], shape=[]), 3)
    image = tf.cast(image, tf.float32)
    image.set_shape([None, None, 3])

    label = tf.image.decode_image(tf.reshape(parsed['label/encoded'], shape=[]), 1)
    label = tf.cast(label, tf.int32)
    label.set_shape([None, None, 1])

    return image, label


def read_tfrecord(serialized_example):
    feature_description = {
        'image': tf.io.FixedLenFeature((), tf.string),
        'segmentation': tf.io.FixedLenFeature((), tf.string),
        'height': tf.io.FixedLenFeature((), tf.int64),
        'width': tf.io.FixedLenFeature((), tf.int64),
        'image_depth': tf.io.FixedLenFeature((), tf.int64),
        'mask_depth': tf.io.FixedLenFeature((), tf.int64),
    }
    example = tf.io.parse_single_example(serialized_example, feature_description)
    

    #image = tf.io.parse_tensor(example['image'], out_type = tf.float32)
    image = tf.io.parse_tensor(example['image'], out_type = tf.uint8)
    image_shape = [example['height'], example['width'], 3]
    image = tf.reshape(image, image_shape)
    
    mask = tf.io.parse_tensor(example['segmentation'], out_type = tf.uint8)
    mask_shape = [example['height'], example['width'], 1]
    mask = tf.reshape(mask, mask_shape)
    
    return image, mask


def get_dataset_from_tfrecord(tfrecord_dir):
    tfrecord_dataset = tf.data.TFRecordDataset(tfrecord_dir)
    # parsed_dataset = tfrecord_dataset.map(read_tfrecord)
    parsed_dataset = tfrecord_dataset.map(parse_record)
    return parsed_dataset

## Check dataset and input pipeline

In [None]:
img_height = 512
img_width = 1024

In [None]:
@tf.function
def random_crop(input_image, input_mask):
    stacked_image = tf.concat([input_image, input_mask], axis=2)
    cropped_image = tf.image.random_crop(stacked_image, size=[img_height, img_width, 4])
    return cropped_image[:,:,0:3], cropped_image[:,:,-1]


@tf.function
def mask_to_categorical(image, mask):
    mask = tf.squeeze(mask)
    mask = tf.one_hot(tf.cast(mask, tf.int32), n_train_classes)
    mask = tf.cast(mask, tf.float32)
    return image, mask


@tf.function
def load_image_train(input_image, input_mask):
    
    input_image = tf.cast(input_image, tf.uint8)
    input_mask = tf.cast(input_mask, tf.uint8)
    
    input_image = tf.image.resize(input_image, (768, 1536))
    input_mask = tf.image.resize(input_mask, (768, 1536))
    
    #if tf.random.uniform(()) > 0.5:
    #    input_image = tf.image.flip_left_right(input_image)
    #    input_mask = tf.image.flip_left_right(input_mask)
    
    input_image = tf.squeeze(input_image)    
    input_image, input_mask = random_crop(input_image, input_mask)
        
    input_image = tf.cast(input_image, tf.float32) / 255.0    
    input_image, input_mask = mask_to_categorical(input_image, input_mask)
    input_mask = tf.squeeze(input_mask)

    return input_image, input_mask


def load_image_test(input_image, input_mask):
    input_image = tf.image.resize(input_image, (img_height, img_width))
    input_mask = tf.image.resize(input_mask, (img_height, img_width))
    
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_image, input_mask = mask_to_categorical(input_image, input_mask)
    input_mask = tf.squeeze(input_mask)
    return input_image, input_mask

In [None]:
train_tfrecord_dir = "records/trainIds_train.record"
valid_tfrecord_dir = "records/trainIds_val.record"

In [None]:
train_dataset = get_dataset_from_tfrecord(train_tfrecord_dir)
valid_dataset = get_dataset_from_tfrecord(valid_tfrecord_dir)

In [None]:
for i, (image, mask) in enumerate(train_dataset.take(5)):
    sample_image, sample_mask = image.numpy(), mask.numpy()
    
sample_mask = label_to_rgb(sample_mask)
display([sample_image, sample_mask])

In [None]:
train = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
valid = valid_dataset.map(load_image_test)

In [None]:
for i, (image, mask) in enumerate(train.take(5)):
    preprocessed_image, preprocessed_mask = image, mask
    
print(preprocessed_mask.numpy().shape)

In [None]:
preprocessed_mask = tf.argmax(preprocessed_mask, axis=-1)
preprocessed_mask = preprocessed_mask[..., tf.newaxis]
preprocessed_mask = label_to_rgb(preprocessed_mask.numpy())
display([preprocessed_image, preprocessed_mask])