In [1]:
from __future__ import print_function, absolute_import, division
from collections import namedtuple

import os
import sys
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.losses import *
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, LambdaCallback
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array, load_img
from tensorflow.keras import mixed_precision

from models.u2net import U2NET, U2NET_lite
from models.unet import unet, unet_small
from models.deeplabv3_xception import deeplabv3, deeplabv3_lite

physical_devices = tf.config.experimental.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)
K.clear_session()

def enable_amp():
    mixed_precision.set_global_policy("mixed_float16")
    
print("Tensorflow version: ", tf.__version__)
print(physical_devices,"\n")
enable_amp() 

In [None]:
%config InlineBackend.figure_format = 'retina'
plt.style.use('ggplot')
plt.rc('xtick',labelsize=16)
plt.rc('ytick',labelsize=16)

In [None]:
def get_labels():

    # a label and all meta information
    Label = namedtuple( 'Label' , [

        'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                        # We use them to uniquely name a class

        'id'          , # An integer ID that is associated with this label.
                        # The IDs are used to represent the label in ground truth images
                        # An ID of -1 means that this label does not have an ID and thus
                        # is ignored when creating ground truth images (e.g. license plate).
                        # Do not modify these IDs, since exactly these IDs are expected by the
                        # evaluation server.

        'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                        # ground truth images with train IDs, using the tools provided in the
                        # 'preparation' folder. However, make sure to validate or submit results
                        # to our evaluation server using the regular IDs above!
                        # For trainIds, multiple labels might have the same ID. Then, these labels
                        # are mapped to the same class in the ground truth images. For the inverse
                        # mapping, we use the label that is defined first in the list below.
                        # For example, mapping all void-type classes to the same ID in training,
                        # might make sense for some approaches.
                        # Max value is 255!

        'category'    , # The name of the category that this label belongs to

        'categoryId'  , # The ID of this category. Used to create ground truth images
                        # on category level.

        'hasInstances', # Whether this label distinguishes between single instances or not

        'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                        # during evaluations or not

        'color'       , # The color of this label
        ] )


    #--------------------------------------------------------------------------------
    # A list of all labels
    #--------------------------------------------------------------------------------

    # Please adapt the train IDs as appropriate for your approach.
    # Note that you might want to ignore labels with ID 255 during training.
    # Further note that the current train IDs are only a suggestion. You can use whatever you like.
    # Make sure to provide your results using the original IDs and not the training IDs.
    # Note that many IDs are ignored in evaluation and thus you never need to predict these!
    
    # Want to evaluate model on ids [7,8,11,12,13,17,19,20,21,22,23,24,25,26,27,28,31,32,33]

    labels = [
        #     name                    id  trainId  category         catId  hasInstances  ignoreInEval  color
        Label('unlabeled'            , 0 ,    255 ,'void'          ,0      ,False       ,True         ,(  0,  0,  0)),
        Label('ego vehicle'          , 1 ,    255 ,'void'          ,0      ,False       ,True         ,(  0,  0,  0)),
        Label('rectification border' , 2 ,    255 ,'void'          ,0      ,False       ,True         ,(  0,  0,  0)),
        Label('out of roi'           , 3 ,    255 ,'void'          ,0      ,False       ,True         ,(  0,  0,  0)),
        Label('static'               , 4 ,    255 ,'void'          ,0      ,False       ,True         ,(  0,  0,  0)),
        Label('dynamic'              , 5 ,    255 ,'void'          ,0      ,False       ,True         ,(111, 74,  0)),
        Label('ground'               , 6 ,    255 ,'void'          ,0      ,False       ,True         ,( 81,  0, 81)),
        Label('road'                 , 7 ,      0 ,'flat'          ,1      ,False       ,False        ,(128, 64,128)),
        Label('sidewalk'             , 8 ,      1 ,'flat'          ,1      ,False       ,False        ,(244, 35,232)),
        Label('parking'              , 9 ,    255 ,'flat'          ,0      ,False       ,True         ,(250,170,160)),
        Label('rail track'           ,10 ,    255 ,'flat'          ,0      ,False       ,True         ,(230,150,140)),
        Label('building'             ,11 ,      2 ,'construction'  ,2      ,False       ,False        ,( 70, 70, 70)),
        Label('wall'                 ,12 ,      3 ,'construction'  ,2      ,False       ,False        ,(102,102,156)),
        Label('fence'                ,13 ,      4 ,'construction'  ,2      ,False       ,False        ,(190,153,153)),
        Label('guard rail'           ,14 ,    255 ,'construction'  ,0      ,False       ,True         ,(180,165,180)),
        Label('bridge'               ,15 ,    255 ,'construction'  ,0      ,False       ,True         ,(150,100,100)),
        Label('tunnel'               ,16 ,    255 ,'construction'  ,0      ,False       ,True         ,(150,120, 90)),
        Label('pole'                 ,17 ,      5 ,'object'        ,3      ,False       ,False        ,(153,153,153)),
        Label('polegroup'            ,18 ,    255 ,'object'        ,0      ,False       ,True         ,(153,153,153)),
        Label('traffic light'        ,19 ,      6 ,'object'        ,3      ,False       ,False        ,(250,170, 30)),
        Label('traffic sign'         ,20 ,      7 ,'object'        ,3      ,False       ,False        ,(220,220,  0)),
        Label('vegetation'           ,21 ,      8 ,'nature'        ,4      ,False       ,False        ,(107,142, 35)),
        Label('terrain'              ,22 ,      9 ,'nature'        ,4      ,False       ,False        ,(152,251,152)),
        Label('sky'                  ,23 ,     10 ,'sky'           ,5      ,False       ,False        ,( 70,130,180)),
        Label('person'               ,24 ,     11 ,'human'         ,6      ,True        ,False        ,(220, 20, 60)),
        Label('rider'                ,25 ,     12 ,'human'         ,6      ,True        ,False        ,(255,  0,  0)),
        Label('car'                  ,26 ,     13 ,'vehicle'       ,7      ,True        ,False        ,(  0,  0,142)),
        Label('truck'                ,27 ,     14 ,'vehicle'       ,7      ,True        ,False        ,(  0,  0, 70)),
        Label('bus'                  ,28 ,     15 ,'vehicle'       ,7      ,True        ,False        ,(  0, 60,100)),
        Label('caravan'              ,29 ,    255 ,'vehicle'       ,0      ,True        ,True         ,(  0,  0, 90)),
        Label('trailer'              ,30 ,    255 ,'vehicle'       ,0      ,True        ,True         ,(  0,  0,110)),
        Label('train'                ,31 ,     16 ,'vehicle'       ,7      ,True        ,False        ,(  0, 80,100)),
        Label('motorcycle'           ,32 ,     17 ,'vehicle'       ,7      ,True        ,False        ,(  0,  0,230)),
        Label('bicycle'              ,33 ,     18 ,'vehicle'       ,7      ,True        ,False        ,(119, 11, 32)),
        Label('license plate'        ,-1 ,     -1 ,'vehicle'       ,7      ,False       ,True         ,(  0,  0,142)),
    ]
    
    return labels

In [None]:
def read_tfrecord(serialized_example):
    feature_description = {
        'image': tf.io.FixedLenFeature((), tf.string),
        'segmentation': tf.io.FixedLenFeature((), tf.string),
        'height': tf.io.FixedLenFeature((), tf.int64),
        'width': tf.io.FixedLenFeature((), tf.int64),
        'image_depth': tf.io.FixedLenFeature((), tf.int64),
        'mask_depth': tf.io.FixedLenFeature((), tf.int64),
    }
    example = tf.io.parse_single_example(serialized_example, feature_description)
    image = tf.io.parse_tensor(example['image'], out_type = tf.uint8)
    image_shape = [example['height'], example['width'], 3]
    image = tf.reshape(image, image_shape)
    mask = tf.io.parse_tensor(example['segmentation'], out_type = tf.uint8)
    mask_shape = [example['height'], example['width'], 1]
    mask = tf.reshape(mask, mask_shape)
    return image, mask


def get_dataset_from_tfrecord(tfrecord_dir):
    tfrecord_dataset = tf.data.TFRecordDataset(tfrecord_dir)
    parsed_dataset = tfrecord_dataset.map(read_tfrecord)
    return parsed_dataset

In [None]:
fine = True

if fine:
    train_tfrecord_dir = 'records/fine_train_cat.tfrecords'
    test_tfrecord_dir = 'records/fine_test_cat.tfrecords'
else:
    train_tfrecord_dir = 'records/coarse_train_cat.tfrecords'
    test_tfrecord_dir = 'records/coarse_test_cat.tfrecords'


img_height = 512
img_width = 1024
n_classes = 8

labels = get_labels()
trainId2label = { label.trainId : label for label in labels }
catId2label = { label.categoryId : label for label in labels }

Adding the `@tf.function` decorator supposedly makes things faster.

In [None]:
@tf.function
def mask_to_categorical(image, mask):
    mask = tf.squeeze(mask) 
    mask = tf.one_hot(tf.cast(mask, tf.int32), n_classes)
    mask = tf.cast(mask, tf.float32)
    return image, mask


@tf.function
def load_image_train(input_image, input_mask):
    input_image = tf.image.resize(input_image, (img_height, img_width))
    input_mask = tf.image.resize(input_mask, (img_height, img_width))

    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)

    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_image, input_mask = mask_to_categorical(input_image, input_mask)
    input_mask = tf.squeeze(input_mask)

    return input_image, input_mask


def load_image_test(input_image, input_mask):
    input_image = tf.image.resize(input_image, (img_height, img_width))
    input_mask = tf.image.resize(input_mask, (img_height, img_width))
    
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_image, input_mask = mask_to_categorical(input_image, input_mask)
    input_mask = tf.squeeze(input_mask)

    return input_image, input_mask

In [None]:
if fine:
    TRAIN_LENGTH = 2780
    TEST_LENGTH = 695
else:
    TRAIN_LENGTH = 16000
    TEST_LENGTH = 3998

BATCH_SIZE = 8
BUFFER_SIZE = 1000

In [None]:
train_tfrecords_dataset = get_dataset_from_tfrecord(train_tfrecord_dir)
test_tfrecords_dataset = get_dataset_from_tfrecord(test_tfrecord_dir)

In [None]:
# Preprocessing: resize the images and masks, flip them, 
train = train_tfrecords_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = test_tfrecords_dataset.map(load_image_test)

In [None]:
train_dataset = train.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

In [None]:
def label_to_rgb(mask):
    mask_rgb = np.zeros((img_height, img_width, 3), dtype=np.uint8)
    for i in range(0, n_classes):
        #mask_rgb[mask[:,:,0]==i] = trainId2label[i].color 
        mask_rgb[mask[:,:,0]==i] = catId2label[i].color
        #mask_rgb[mask[:,:,0]==i] = id2label[i].color 
    return mask_rgb


def display(display_list, title=True):
    plt.figure(figsize=(15, 5))
    if title:
        title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        if title:
            plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
for image, mask in train.take(1):
    sample_image, sample_mask = image, mask
    
sample_mask = tf.argmax(sample_mask, axis=-1)
sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

A standard way to train an image segmentation model is using categorical crossentropy loss. This works okay, but it has the drawback of being biased towards classes that are generally larger in the image (ex: roads, trees). So here we define a custom loss function that combines the categorical crossentropy and the average IoU coefficient. 

In [None]:
def iou_coef(y_true, y_pred):
    """When smooth=1, metrics where absent classes contribute to the class mean as 1.0."""
    smooth = 1.0
    intersection = K.sum(y_true * y_pred, axis=(1,2))
    mask_sum = K.sum(y_true, axis=(1,2)) + K.sum(y_pred, axis=(1,2))
    union = mask_sum - intersection
    iou = K.mean((intersection + smooth) / (union + smooth))
    return iou


def cce_iou_loss(y_true, y_pred):
    return (tf.keras.losses.categorical_crossentropy(y_true, y_pred) - iou_coef(y_true, y_pred)) + 1

In [None]:
K.clear_session()
# model = unet(input_height=img_height, input_width=img_width, n_classes=n_classes, act="relu")
# model = unet_small(input_height=img_height, input_width=img_width, n_classes=n_classes, act="relu")
# model = U2NET(input_height=img_height, input_width=img_width, n_classes=n_classes)
# model = U2NET_lite(input_height=img_height, input_width=img_width, n_classes=n_classes)
# model = deeplabv3(input_height=img_height, input_width=img_width, n_classes=n_classes)
model = deeplabv3_lite(input_height=img_height, input_width=img_width, n_classes=n_classes)

To note: the U2Net model has six different outputs (for training), so the `model.predict()` function outputs a list of six tensors. 

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.squeeze(pred_mask)
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    pred_mask = label_to_rgb(pred_mask.numpy())
    return pred_mask


def show_predictions():
    pred_mask = model.predict(sample_image[tf.newaxis, ...])
    if model.name == "u2net":
        pred_mask = pred_mask[0]
    display([sample_image, sample_mask, create_mask(pred_mask)])

        
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
show_predictions()

In [None]:
plot_model(model, show_shapes=True)

In [None]:
if fine:
    MODEL_PATH = "weights/"+model.name+".h5"
else:
    MODEL_PATH = "weights/"+model.name+"_coarse.h5"

model.load_weights("weights/"+model.name+"_coarse.h5")

In [None]:
model.compile(
    optimizer = Adam(lr=1e-4),
    loss = cce_iou_loss, 
    metrics = ['accuracy', iou_coef]
)

In [None]:
if fine:
    ES_patience = 10
    RLR_patience = 5
    EPOCHS = 20
else:
    ES_patience = 4
    RLR_patience = 2
    EPOCHS = 20

In [None]:
callbacks = [
    DisplayCallback(),
    ModelCheckpoint(MODEL_PATH, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='min'),
    EarlyStopping(monitor='val_loss', mode='min', patience=ES_patience, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', mode='min', patience=RLR_patience, factor=0.1, min_lr=1e-10, verbose=1),
]

In [None]:
SUBSPLITS = 4
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE // SUBSPLITS
VALIDATION_STEPS = TEST_LENGTH // BATCH_SIZE // SUBSPLITS

In [None]:
results = model.fit(
    train_dataset,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
    epochs = (EPOCHS*SUBSPLITS),
    validation_data = test_dataset,
    callbacks = callbacks,
    verbose = 1
)

In [None]:
if fine:
    model.save_weights(PATH_fine)
else:
    model.save_weights(PATH_coarse)

In [None]:
# results = model.history

In [None]:
def plot_history(results, model):
        
    plt.figure(figsize=(15,7))
    plt.subplot(1,3,1)  
    if model.name == "u2net":
        plt.plot(results.history['d0_loss'], 'r', label='Training loss')
        plt.plot(results.history['val_d0_loss'], 'b', label='Validation loss')
    else: 
        plt.plot(results.history['loss'], 'r', label='Training loss')
        plt.plot(results.history['val_loss'], 'b', label='Validation loss')
    plt.title("Loss: "+model.name, fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(prop={'size': 14})

    plt.subplot(1,3,2)
    if model.name == "u2net":
        plt.plot(results.history['d0_accuracy'], 'r', label='Training accuracy')
        plt.plot(results.history['val_d0_accuracy'], 'b', label='Validation accuracy')
    else:
        plt.plot(results.history['accuracy'], 'r', label='Training accuracy')
        plt.plot(results.history['val_accuracy'], 'b', label='Validation accuracy')
    plt.title('Accuracy: '+model.name, fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(prop={'size': 14})

    plt.subplot(1,3,3)
    if model.name == "u2net":
        plt.plot(results.history['d0_iou_coef'], 'r', label='IoU coefficient')
        plt.plot(results.history['val_d0_iou_coef'], 'b', label='Validation IoU coefficient')
    else:
        plt.plot(results.history['iou_coef'], 'r', label='IoU coefficient')
        plt.plot(results.history['val_iou_coef'], 'b', label='Validation IoU coefficient')
    plt.title('IoU Coefficient: '+model.name, fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(prop={'size': 14})
    if fine:
        plt.savefig(DATA_ROOT+"plots/"+model.name+"_learning_curves.png")
    else:
        plt.savefig(DATA_ROOT+"plots/"+model.name+"_learning_curves_coarse.png")
    plt.show()

In [None]:
plot_history(results, model)

In [None]:
def get_mean_iou(y_true, y_pred):
    """ Absent classes contribute to the class mean as 1.0 """
    iou = 0.0
    smooth = 1.0
    iou_class = []
    for i in range(1, n_classes):
        intersection = K.sum(y_true[:,:,:,i] * y_pred[:,:,:,i], axis=(1,2))
        union = K.sum(y_true[:,:,:,i] + y_pred[:,:,:,i], axis=(1,2)) - intersection
        iou_temp = K.mean((intersection + smooth) / (union + smooth))
        iou_class.append(iou_temp.numpy())
        iou = iou + iou_temp
    iou_mean = iou / n_classes
    return iou_class, iou_mean.numpy()


def evaluate_iou(model, dataset, n_samples):
    
    iou_class_scores = np.zeros((n_samples, n_classes-1))
    iou_mean_scores = np.zeros((n_samples,))
    
    for idx, (image, mask) in enumerate(dataset):
        print("\r Predicting {} \ {} ".format(idx+1, n_samples), end='')
        
        X = image.numpy()
        y_true = np.expand_dims(mask.numpy(), axis=0)
        y_pred = model.predict(np.expand_dims(X, axis=0))
        if model.name == "u2net":
            y_pred = y_pred[0]
        
        iou_class, iou_mean = get_mean_iou(y_true, y_pred)
        iou_class_scores[idx] = iou_class
        iou_mean_scores[idx] = iou_mean
        
        if idx == (n_samples-1):
            break
            
    return np.mean(iou_class_scores, axis=0), np.mean(iou_mean_scores)

In [None]:
iou_class, iou_mean = evaluate_iou(model=model, dataset=test, n_samples=TEST_LENGTH)

In [None]:
print("IoU Score: {:.4f}".format(iou_mean))

In [None]:
def plot_iou_trainId(Id2label, n_classes, iou_class, model, iou_mean):
    categories = [Id2label[i].category for i in range(n_classes)]
    cmap = [color['color'] for color in plt.rcParams['axes.prop_cycle']]
    cat_colors = {
        'void': 'black',
        'flat': cmap[0],
        'construction': cmap[1],
        'object': cmap[2],
        'nature': cmap[3],
        'sky': cmap[4],
        'human': cmap[5],
        'vehicle': cmap[6]
    }
    colors = [cat_colors[category] for category in categories]

    names = [Id2label[i].name for i in range(n_classes)]

    plt.figure(figsize=(14,10), dpi=200)
    plt.barh(names, iou_class, color=colors)
    plt.xlabel("IoU Coefficient: ", fontsize=18)
    plt.ylabel("Class Name", fontsize=18)
    plt.title("Class IoU Scores for {} - Average: {:.3f}".format(model.name, iou_mean), fontsize=22)
    plt.xlim([0, 1])
    if fine:
        plt.savefig(DATA_ROOT+"plots/"+model.name+"_iou_scores.png")
    else:
        plt.savefig(DATA_ROOT+"plots/"+model.name+"_iou_scores_coarse.png")
    plt.show()
    
    
def plot_iou_catId(Id2label, n_classes, iou_class, model, iou_mean):
    categories = [Id2label[i+1].category for i in range(n_classes-1)]
    cat_colors = {
        'void': colors.to_hex(list(np.array(catId2label[0].color)/255)),
        'flat': colors.to_hex(list(np.array(catId2label[1].color)/255)),
        'construction': colors.to_hex(list(np.array(catId2label[2].color)/255)),
        'object': colors.to_hex(list(np.array(catId2label[3].color)/255)),
        'nature': colors.to_hex(list(np.array(catId2label[4].color)/255)),
        'sky': colors.to_hex(list(np.array(catId2label[5].color)/255)),
        'human': colors.to_hex(list(np.array(catId2label[6].color)/255)),
        'vehicle': colors.to_hex(list(np.array(catId2label[7].color)/255))
    }
    _colors = [cat_colors[category] for category in categories]

    plt.figure(figsize=(14,10))
    plt.barh(categories, iou_class, color=_colors)
    plt.xlabel("IoU Coefficient", fontsize=18)
    plt.ylabel("Category Name", fontsize=18)
    plt.title("Category IoU Scores for {} - Average: {:.3f}".format(model.name, iou_mean), fontsize=22)
    plt.xlim([0, 1])
    if fine:
        plt.savefig(DATA_ROOT+"plots/"+model.name+"_iou_scores.png")
    else:
        plt.savefig(DATA_ROOT+"plots/"+model.name+"_iou_scores_coarse.png")
    plt.show()

In [None]:
# plot_iou_trainId(Id2label=trainId2label, n_classes=n_classes, iou_class=iou_class, model)

In [None]:
plot_iou_catId(Id2label=catId2label, n_classes=n_classes, iou_class=iou_class, model=model, iou_mean=iou_mean)