# Building TFRecord Dataset


### Notes

By default, the pixel values of the masks are integers representing IDs [0,1,2,3,4 ... , 32, 33] (see `utils.label_utils.py`). For training, we want the integer values to be "trainId" or "catId". So we have some functions to map the pixel values before saving the mask.


*Note on `utils.label_utils.py`: The way "trainId" is defined inside of the `get_labels()` function is a bit confusing. The trainId numbers go from -1 to 18, and then some of them are 255. I decided to create a second function `get_train_labels()` that replaces 255 with 19. Also, I made the trainId value of -1 for "licence plate" 0.*

In [6]:
from __future__ import print_function, absolute_import, division

import os
import re
import sys
import cv2
import PIL
import json
import math
import time
import random
import sklearn
import numpy as np
from IPython import display
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from skimage.transform import resize

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.utils import plot_model, to_categorical
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array, load_img
from utils.cityscapes_utils import dump_rgb_data, get_images

from collections import namedtuple
print("Tensorflow version: ", tf.__version__)

Tensorflow version:  2.0.0


In [2]:
n_classes_total = 34
# n_train_classes = 20
n_train_classes = 8

In [7]:
#--------------------------------------------------------------------------------
# Definitions
#--------------------------------------------------------------------------------

def get_labels():

    # a label and all meta information
    Label = namedtuple( 'Label' , [

        'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                        # We use them to uniquely name a class

        'id'          , # An integer ID that is associated with this label.
                        # The IDs are used to represent the label in ground truth images
                        # An ID of -1 means that this label does not have an ID and thus
                        # is ignored when creating ground truth images (e.g. license plate).
                        # Do not modify these IDs, since exactly these IDs are expected by the
                        # evaluation server.

        'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                        # ground truth images with train IDs, using the tools provided in the
                        # 'preparation' folder. However, make sure to validate or submit results
                        # to our evaluation server using the regular IDs above!
                        # For trainIds, multiple labels might have the same ID. Then, these labels
                        # are mapped to the same class in the ground truth images. For the inverse
                        # mapping, we use the label that is defined first in the list below.
                        # For example, mapping all void-type classes to the same ID in training,
                        # might make sense for some approaches.
                        # Max value is 255!

        'category'    , # The name of the category that this label belongs to

        'categoryId'  , # The ID of this category. Used to create ground truth images
                        # on category level.

        'hasInstances', # Whether this label distinguishes between single instances or not

        'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                        # during evaluations or not

        'color'       , # The color of this label
        ] )


    #--------------------------------------------------------------------------------
    # A list of all labels
    #--------------------------------------------------------------------------------

    # Please adapt the train IDs as appropriate for your approach.
    # Note that you might want to ignore labels with ID 255 during training.
    # Further note that the current train IDs are only a suggestion. You can use whatever you like.
    # Make sure to provide your results using the original IDs and not the training IDs.
    # Note that many IDs are ignored in evaluation and thus you never need to predict these!
    
    # Want to evaluate model on ids [7,8,11,12,13,17,19,20,21,22,23,24,25,26,27,28,31,32,33]

    labels = [
        #     name                    id  trainId  category         catId  hasInstances  ignoreInEval  color
        Label('unlabeled'            , 0 ,    255 ,'void'          ,0      ,False       ,True         ,(  0,  0,  0)),
        Label('ego vehicle'          , 1 ,    255 ,'void'          ,0      ,False       ,True         ,(  0,  0,  0)),
        Label('rectification border' , 2 ,    255 ,'void'          ,0      ,False       ,True         ,(  0,  0,  0)),
        Label('out of roi'           , 3 ,    255 ,'void'          ,0      ,False       ,True         ,(  0,  0,  0)),
        Label('static'               , 4 ,    255 ,'void'          ,0      ,False       ,True         ,(  0,  0,  0)),
        Label('dynamic'              , 5 ,    255 ,'void'          ,0      ,False       ,True         ,(111, 74,  0)),
        Label('ground'               , 6 ,    255 ,'void'          ,0      ,False       ,True         ,( 81,  0, 81)),
        Label('road'                 , 7 ,      0 ,'flat'          ,1      ,False       ,False        ,(128, 64,128)),
        Label('sidewalk'             , 8 ,      1 ,'flat'          ,1      ,False       ,False        ,(244, 35,232)),
        Label('parking'              , 9 ,    255 ,'flat'          ,0      ,False       ,True         ,(250,170,160)),
        Label('rail track'           ,10 ,    255 ,'flat'          ,0      ,False       ,True         ,(230,150,140)),
        Label('building'             ,11 ,      2 ,'construction'  ,2      ,False       ,False        ,( 70, 70, 70)),
        Label('wall'                 ,12 ,      3 ,'construction'  ,2      ,False       ,False        ,(102,102,156)),
        Label('fence'                ,13 ,      4 ,'construction'  ,2      ,False       ,False        ,(190,153,153)),
        Label('guard rail'           ,14 ,    255 ,'construction'  ,0      ,False       ,True         ,(180,165,180)),
        Label('bridge'               ,15 ,    255 ,'construction'  ,0      ,False       ,True         ,(150,100,100)),
        Label('tunnel'               ,16 ,    255 ,'construction'  ,0      ,False       ,True         ,(150,120, 90)),
        Label('pole'                 ,17 ,      5 ,'object'        ,3      ,False       ,False        ,(153,153,153)),
        Label('polegroup'            ,18 ,    255 ,'object'        ,0      ,False       ,True         ,(153,153,153)),
        Label('traffic light'        ,19 ,      6 ,'object'        ,3      ,False       ,False        ,(250,170, 30)),
        Label('traffic sign'         ,20 ,      7 ,'object'        ,3      ,False       ,False        ,(220,220,  0)),
        Label('vegetation'           ,21 ,      8 ,'nature'        ,4      ,False       ,False        ,(107,142, 35)),
        Label('terrain'              ,22 ,      9 ,'nature'        ,4      ,False       ,False        ,(152,251,152)),
        Label('sky'                  ,23 ,     10 ,'sky'           ,5      ,False       ,False        ,( 70,130,180)),
        Label('person'               ,24 ,     11 ,'human'         ,6      ,True        ,False        ,(220, 20, 60)),
        Label('rider'                ,25 ,     12 ,'human'         ,6      ,True        ,False        ,(255,  0,  0)),
        Label('car'                  ,26 ,     13 ,'vehicle'       ,7      ,True        ,False        ,(  0,  0,142)),
        Label('truck'                ,27 ,     14 ,'vehicle'       ,7      ,True        ,False        ,(  0,  0, 70)),
        Label('bus'                  ,28 ,     15 ,'vehicle'       ,7      ,True        ,False        ,(  0, 60,100)),
        Label('caravan'              ,29 ,    255 ,'vehicle'       ,0      ,True        ,True         ,(  0,  0, 90)),
        Label('trailer'              ,30 ,    255 ,'vehicle'       ,0      ,True        ,True         ,(  0,  0,110)),
        Label('train'                ,31 ,     16 ,'vehicle'       ,7      ,True        ,False        ,(  0, 80,100)),
        Label('motorcycle'           ,32 ,     17 ,'vehicle'       ,7      ,True        ,False        ,(  0,  0,230)),
        Label('bicycle'              ,33 ,     18 ,'vehicle'       ,7      ,True        ,False        ,(119, 11, 32)),
        Label('license plate'        ,-1 ,     -1 ,'vehicle'       ,7      ,False       ,True         ,(  0,  0,142)),
    ]
    
    return labels

In [13]:
labels = get_labels() # a list of named tuples
# Create a mapping dictionary (integer --> named label tuple)
id2label = { label.id : label for label in labels }
# If we want to plot it in a way that shows each category id as a unique color, 
# we need to create another mapping dictionary.
catid2label = { label.categoryId : label for label in labels }
# trainId2label = { label.trainId : label for label in labels }

In [14]:
def id_to_categoryid(mask):
    """For each pixel, map the current value corresponding to the integer ID, 
    to the value corresponding to train ID"""
    mask_train = np.zeros((mask.shape[0], mask.shape[1], mask.shape[2]), dtype=np.uint8)
    for i in range(0,n_classes_total):
        mask_train[mask[:,:,0]==i] = id2label[i].categoryId
    return mask_train.astype(np.uint8)

In [15]:
def label_to_rgb(mask):
    mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for i in range(0,n_train_classes):
        idx = mask[:,:,0]==i
        # mask_rgb[idx] = trainId2label[i].color
        mask_rgb[idx] = catid2label[i].color
        # mask_rgb[idx] = id2label[i].color
    return mask_rgb

In [16]:
def id_to_trainid(mask):
    """For each pixel, map the current value corresponding to the integer ID, 
    to the value corresponding to train ID"""
    mask_train = np.zeros((mask.shape[0], mask.shape[1], mask.shape[2]), dtype=np.uint8)
    for i in range(0,n_classes_total):
        mask_train[mask[:,:,0]==i] = id2label[i].trainId
    return mask_train.astype(np.uint8)

In [17]:
def display(display_list):
    plt.figure(figsize=(15, 5))
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.tight_layout()
    plt.show()
    
    
def load_image_rgb_data(fp):
    # Opening JSON file 
    with open(fp, 'r') as openfile: 
        # Reading from json file 
        image_info = json.load(openfile) 
    info_dict = {
        "R_MEAN": float(image_info["R_MEAN"]),
        "G_MEAN": float(image_info["G_MEAN"]),
        "B_MEAN": float(image_info["B_MEAN"]),
        "R_STD": float(image_info["R_STD"]),
        "G_STD": float(image_info["B_STD"]),
        "B_STD": float(image_info["G_STD"]),
    }
    return info_dict


def normalize_image_channels(x_img, rgb_data):
    x_img[:,:,0] -= rgb_data['R_MEAN']
    x_img[:,: 1] -= rgb_data['G_MEAN']
    x_img[:,: 2] -= rgb_data['B_MEAN']

    x_img[:,:,0] /= rgb_data['R_STD']
    x_img[:,: 1] /= rgb_data['G_STD']
    x_img[:,: 2] /= rgb_data['B_STD']
    
    return x_img
    
    
def extract_cityscape_data_info(path, img_height=None, img_width=None, subset=None, coarse=True):
    
    ids_temp = next(os.walk(path + "annotations"))[2]
    ids_1 = []
    for i in ids_temp:
        if i.endswith("labelIds.png"):
            id_temp = i.split("/")
            if coarse:
                id_temp = id_temp[-1][:-22]
            else:
                id_temp = id_temp[-1][:-20]
            ids_1.append(id_temp)
            
    random.seed(2019)
    id_order = np.arange(len(ids_1))
    np.random.shuffle(id_order)
    
    ids = []
    for i in range(len(id_order)):
        ids.append(ids_1[np.int(id_order[i])])
        
    print("Number of images: " + str(len(ids)))
    
    image_data = []
            
    for n, id_ in enumerate(ids):
        print("\r Processing %s \ %s " % (n+1, len(ids)), end='')
        
        id_image = id_ + "_leftImg8bit.png"
        image_filename = path + "images/" + id_image
        if coarse:
            id_mask = id_ + "_gtCoarse_labelIds.png"
        else:
            id_mask = id_ + "_gtFine_labelIds.png"
        mask_filename = path + "annotations/" + id_mask
        
        if img_height == None:
            # load image
            img = load_img(image_filename)
            x_img = img_to_array(img)
            x_img = x_img.squeeze()
            x_img = resize(x_img, (img_height, img_width, 3), 
            mode='constant', preserve_range = True)

            # load mask
            mask = img_to_array(load_img(mask_filename, color_mode = "grayscale"))
            mask = cv2.resize(mask, (img_width, img_height), 
            interpolation = cv2.INTER_NEAREST)
            mask = np.expand_dims(mask, 2)
            mask = mask.astype(np.uint8)

            # get size info
            img_height = x_img.shape[0]
            img_width = x_img.shape[1]
            img_depth = x_img.shape[2]
            mask_depth = mask.shape[2]
        else:
            img_depth = 3
            mask_depth = 1
            
        
        # add to list of dicts
        image_dict = {
            "image_filename": image_filename,
            "mask_filename": mask_filename,
            "height": img_height,
            "width": img_width,
            "image_depth": img_depth,
            "mask_depth": mask_depth,
        }

        image_data.append(image_dict)
        
        if (subset is not None) and (n == subset-1):
            break
    
    return image_data

In [18]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    # If the value is an eager tensor BytesList won't unpack a string from an EagerTensor.
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() 
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def serialize_example(image, mask, image_shape, mask_shape):
    feature = {
        'image': _bytes_feature(image),
        'segmentation':  _bytes_feature(mask),
        'height': _int64_feature(image_shape[0]),
        'width': _int64_feature(image_shape[1]),
        'image_depth': _int64_feature(image_shape[2]),
        'mask_depth': _int64_feature(mask_shape[2]),
    }
    #  Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()


def write_tfrecord(
    tfrecord_dir, 
    image_data, 
    img_height=512, 
    img_width=1024, 
    id_type = "trainId", # train 
    normalize=False, 
    rgb_data=None
):
    
    with tf.io.TFRecordWriter(tfrecord_dir) as writer:
        for n, datapoint in enumerate(image_data):
            print("\r Writing %s \ %s " % (n+1, len(image_data)), end='')

            # get image
            img = load_img(datapoint["image_filename"])
            img_array = img_to_array(img)
            img_array = resize(img_array, (img_height, img_width, 3), 
                               mode='constant', preserve_range = True)
            if normalize:
                img_array = normalize_image_channels(img_array, rgb_data)
            else:
                img_array = img_array.astype(np.uint8)
            
            img_bytes = tf.io.serialize_tensor(img_array)
            image_shape = img_array.shape

            # get mask
            mask = load_img(datapoint["mask_filename"], color_mode="grayscale")
            mask_array = img_to_array(mask)
            mask_array = cv2.resize(mask_array, (img_width, img_height), 
                                    interpolation = cv2.INTER_NEAREST)
            mask_array = np.expand_dims(mask_array, 2)
            mask_array = mask_array.astype(np.uint8)
            
            if id_type == "trainId":
                mask_array = id_to_trainid(mask_array)
            elif id_type == "catId":
                mask_array = id_to_categoryid(mask_array)
            
            mask_bytes = tf.io.serialize_tensor(mask_array)
            mask_shape = mask_array.shape

            example = serialize_example(img_bytes, mask_bytes, image_shape, mask_shape)
            writer.write(example)

        
def read_tfrecord(serialized_example):
    feature_description = {
        'image': tf.io.FixedLenFeature((), tf.string),
        'segmentation': tf.io.FixedLenFeature((), tf.string),
        'height': tf.io.FixedLenFeature((), tf.int64),
        'width': tf.io.FixedLenFeature((), tf.int64),
        'image_depth': tf.io.FixedLenFeature((), tf.int64),
        'mask_depth': tf.io.FixedLenFeature((), tf.int64),
    }
    example = tf.io.parse_single_example(serialized_example, feature_description)
    

    #image = tf.io.parse_tensor(example['image'], out_type = tf.float32)
    image = tf.io.parse_tensor(example['image'], out_type = tf.uint8)
    image_shape = [example['height'], example['width'], 3]
    image = tf.reshape(image, image_shape)
    
    mask = tf.io.parse_tensor(example['segmentation'], out_type = tf.uint8)
    mask_shape = [example['height'], example['width'], 1]
    mask = tf.reshape(mask, mask_shape)
    
    return image, mask


def get_dataset_from_tfrecord(tfrecord_dir):
    tfrecord_dataset = tf.data.TFRecordDataset(tfrecord_dir)
    parsed_dataset = tfrecord_dataset.map(read_tfrecord)
    return parsed_dataset

In [19]:
def dump_dict_to_json(data, fp):
    with open(fp, 'w') as outfile: 
        json.dump(data, outfile)
        
def load_dict_from_json(fp):
    with open(fp, 'r') as openfile: 
        data_dict = json.load(openfile) 
    return data_dict

In [20]:
path = "raw_data/fine/"

In [21]:
image_info = extract_cityscape_data_info(path=path, img_height=1024, img_width=2048, coarse=False)

Number of images: 3475


 Processing 1 \ 3475  Processing 2 \ 3475  Processing 3 \ 3475  Processing 4 \ 3475  Processing 5 \ 3475  Processing 6 \ 3475  Processing 7 \ 3475  Processing 8 \ 3475  Processing 9 \ 3475  Processing 10 \ 3475  Processing 11 \ 3475  Processing 12 \ 3475  Processing 13 \ 3475  Processing 14 \ 3475  Processing 15 \ 3475  Processing 16 \ 3475  Processing 17 \ 3475  Processing 18 \ 3475  Processing 19 \ 3475  Processing 20 \ 3475  Processing 21 \ 3475  Processing 22 \ 3475  Processing 23 \ 3475  Processing 24 \ 3475  Processing 25 \ 3475  Processing 26 \ 3475  Processing 27 \ 3475  Processing 28 \ 3475  Processing 29 \ 3475  Processing 30 \ 3475  Processing 31 \ 3475  Processing 32 \ 3475  Processing 33 \ 3475  Processing 34 \ 3475  Processing 35 \ 3475  Processing 36 \ 3475  Processing 37 \ 3475  Processing 38 \ 3475  Processing 39 \ 3475  Processing 40 \ 3475  Processing 41 \ 3475  Processing 42 \ 3475  Processing 43 \ 3475  Processing 44 \ 34

 Processing 3143 \ 3475  Processing 3144 \ 3475  Processing 3145 \ 3475  Processing 3146 \ 3475  Processing 3147 \ 3475  Processing 3148 \ 3475  Processing 3149 \ 3475  Processing 3150 \ 3475  Processing 3151 \ 3475  Processing 3152 \ 3475  Processing 3153 \ 3475  Processing 3154 \ 3475  Processing 3155 \ 3475  Processing 3156 \ 3475  Processing 3157 \ 3475  Processing 3158 \ 3475  Processing 3159 \ 3475  Processing 3160 \ 3475  Processing 3161 \ 3475  Processing 3162 \ 3475  Processing 3163 \ 3475  Processing 3164 \ 3475  Processing 3165 \ 3475  Processing 3166 \ 3475  Processing 3167 \ 3475  Processing 3168 \ 3475  Processing 3169 \ 3475  Processing 3170 \ 3475  Processing 3171 \ 3475  Processing 3172 \ 3475  Processing 3173 \ 3475  Processing 3174 \ 3475  Processing 3175 \ 3475  Processing 3176 \ 3475  Processing 3177 \ 3475  Processing 3178 \ 3475  Processing 3179 \ 3475  Processing 3180 \ 3475  Processing 3181 \ 3475  Processing 3182 \ 3475 

In [22]:
# For "fine" dataset
TRAIN_LENGTH = 2780
VALID_LENGTH = 695

# For "coarse" dataset
# TRAIN_LENGTH = 16000
# valid_LENGTH = 3998

In [23]:
train_info = image_info[0:TRAIN_LENGTH]
valid_info = image_info[TRAIN_LENGTH:]
dump_dict_to_json(data=train_info, fp='raw_data/fine_train_info.json')
dump_dict_to_json(data=valid_info, fp='raw_data/fine_valid_info.json')

In [24]:
train_info = load_dict_from_json(fp='raw_data/fine_train_info.json')
valid_info = load_dict_from_json(fp='raw_data/fine_valid_info.json')

In [25]:
train_tfrecord_dir = 'records/fine_train_cat.tfrecords'
valid_tfrecord_dir = 'records/fine_valid_cat.tfrecords'

In [None]:
write_tfrecord(
    tfrecord_dir = train_tfrecord_dir,
    image_data = train_info,
    img_height=1024, 
    img_width=2048,
    normalize = False,
    rgb_data = None,
    id_type = "catId",
)

 Writing 1703 \ 2780 

In [None]:
write_tfrecord(
    tfrecord_dir = valid_tfrecord_dir,
    image_data = valid_info,
    img_height=1024, 
    img_width=2048,
    normalize = False,
    rgb_data = None,
    id_type = "catId",
)

## Check dataset and input pipeline

In [None]:
train_dataset = get_dataset_from_tfrecord(train_tfrecord_dir)
valid_dataset = get_dataset_from_tfrecord(valid_tfrecord_dir)

In [None]:
for i, (image, mask) in enumerate(train_dataset.take(1)):
    sample_image, sample_mask = image.numpy(), mask.numpy()

In [None]:
print(sample_mask.shape)

In [None]:
img_height = 512
img_width = 1024

In [None]:
def label_to_rgb(mask):
    mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for i in range(0,n_train_classes):
        idx = mask[:,:,0]==i
        # mask_rgb[idx] = trainId2label[i].color
        mask_rgb[idx] = catid2label[i].color
        # mask_rgb[idx] = id2label[i].color
    return mask_rgb


@tf.function
def mask_to_categorical(image, mask):
    mask = tf.squeeze(mask)
    mask = tf.one_hot(tf.cast(mask, tf.int32), n_train_classes)
    # mask = tf.one_hot(tf.cast(mask, tf.int32), n_classes_total)
    mask = tf.cast(mask, tf.float32)
    return image, mask

@tf.function
def load_image_train(input_image, input_mask):
    input_image = tf.image.resize(input_image, (img_height, img_width))
    input_mask = tf.image.resize(input_mask, (img_height, img_width))

    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)

    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_image, input_mask = mask_to_categorical(input_image, input_mask)
    input_mask = tf.squeeze(input_mask)

    return input_image, input_mask


def load_image_test(input_image, input_mask):
    input_image = tf.image.resize(input_image, (img_height, img_width))
    input_mask = tf.image.resize(input_mask, (img_height, img_width))
    
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_image, input_mask = mask_to_categorical(input_image, input_mask)
    input_mask = tf.squeeze(input_mask)

    return input_image, input_mask

In [None]:
sample_mask = label_to_rgb(sample_mask)
display([sample_image, sample_mask])

In [None]:
train = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
valid = valid_dataset.map(load_image_test)

In [None]:
for i, (image, mask) in enumerate(train.take(1)):
    sample_image, sample_mask = image, mask

In [None]:
sample_mask = tf.argmax(sample_mask, axis=-1)
sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])