In [None]:
import os
import sys
import sklearn
import numpy as np
from IPython.display import clear_output
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.mixed_precision import experimental as mixed_precision

from label_utils import get_labels, get_train_labels
from custom_models import unet, segnet, unet_xception
from deeplabV3_xception import deeplabv3
from plot_utils import plot_history, plot_dice_and_iou

K.clear_session()
physical_devices = tf.config.experimental.list_physical_devices("GPU")

def enable_amp():
    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_policy(policy)
    
print("Tensorflow version: ", tf.__version__)
print(physical_devices)
# enable_amp()

In [None]:
def read_tfrecord(serialized_example):
    feature_description = {
        'image': tf.io.FixedLenFeature((), tf.string),
        'segmentation': tf.io.FixedLenFeature((), tf.string),
        'height': tf.io.FixedLenFeature((), tf.int64),
        'width': tf.io.FixedLenFeature((), tf.int64),
        'image_depth': tf.io.FixedLenFeature((), tf.int64),
        'mask_depth': tf.io.FixedLenFeature((), tf.int64),
    }
    example = tf.io.parse_single_example(serialized_example, feature_description)
    
    image = tf.io.parse_tensor(example['image'], out_type = tf.uint8)
    image_shape = [example['height'], example['width'], 3]
    image = tf.reshape(image, image_shape)
    
    mask = tf.io.parse_tensor(example['segmentation'], out_type = tf.uint8)
    mask_shape = [example['height'], example['width'], 1]
    mask = tf.reshape(mask, mask_shape)
    
    return image, mask


def get_dataset_from_tfrecord(tfrecord_dir):
    tfrecord_dataset = tf.data.TFRecordDataset(tfrecord_dir)
    parsed_dataset = tfrecord_dataset.map(read_tfrecord)
    return parsed_dataset

In [None]:
train_tfrecord_dir = 'Cityscapes\\fine_train.tfrecords'
test_tfrecord_dir = 'Cityscapes\\fine_test.tfrecords'

img_height = 256
img_width = 512
n_classes = 19

labels = get_labels()
id2label = { label.id : label for label in labels }
trainId2label = { label.trainId : label for label in labels }

In [None]:
@tf.function
def mask_to_categorical(image, mask):
    mask = tf.squeeze(mask)
    mask = tf.one_hot(tf.cast(mask, tf.int32), n_classes)
    mask = tf.cast(mask, tf.float32)
    return image, mask


@tf.function
def load_image_train(input_image, input_mask):
    input_image = tf.image.resize(input_image, (img_height, img_width))
    input_mask = tf.image.resize(input_mask, (img_height, img_width))

    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)

    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_image, input_mask = mask_to_categorical(input_image, input_mask)
    input_mask = tf.squeeze(input_mask)

    return input_image, input_mask


def load_image_test(input_image, input_mask):
    input_image = tf.image.resize(input_image, (img_height, img_width))
    input_mask = tf.image.resize(input_mask, (img_height, img_width))
    
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_image, input_mask = mask_to_categorical(input_image, input_mask)
    input_mask = tf.squeeze(input_mask)

    return input_image, input_mask

In [None]:
train_tfrecords_dataset = get_dataset_from_tfrecord(train_tfrecord_dir)
test_tfrecords_dataset = get_dataset_from_tfrecord(test_tfrecord_dir)

In [None]:
# Preprocessing: resize the images and masks, flip them, 
train = train_tfrecords_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = test_tfrecords_dataset.map(load_image_test)

In [None]:
def id_to_trainid(mask):
    mask_train = np.zeros((mask.shape[0], mask.shape[1], mask.shape[2]), dtype=np.uint8)
    for i in range(0,34):
        mask_train[mask[:,:,0]==i] = id2label[i].trainId
    return mask_train


def label_to_rgb(mask):
    mask_rgb = np.zeros((img_height, img_width, 3), dtype=np.uint8)
    for i in range(0,n_classes):
        mask_rgb[mask[:,:,0]==i] = trainId2label[i].color
    #mask_rgb[mask[:,:,0]==255] = trainId2label[255].color
    return mask_rgb


def display(display_list, title=False):
    plt.figure(figsize=(15, 7))
    if title:
        title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        if title:
            plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
for image, mask in test.take(3):
    sample_image, sample_mask = image, mask
    
sample_mask = tf.argmax(sample_mask, axis=-1)
sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())

display([sample_image, sample_mask])

In [None]:
# model = unet(input_height=img_height, input_width=img_width, n_classes=n_classes)
# model = segnet(input_height=img_height, input_width=img_width, n_classes=34)
# model = unet_xception(input_height=img_height, input_width=img_width, n_classes=34)
model = deeplabv3(input_height=img_height, input_width=img_width, n_classes=n_classes, load_weights=False)
plot_model(model, show_shapes=True, dpi=64)

In [None]:
model_name = "saved_models\\deeplab_xception_cityscapes.h5"
model.load_weights(model_name, by_name=True)

In [None]:
def arrays_from_dataset(dataset, n_samples):
    X_samples = np.zeros((n_samples, img_height, img_width, 3))
    y_samples = np.zeros((n_samples, img_height, img_width, n_classes))

    for idx, (image, mask) in enumerate(dataset):
        X_samples[idx] = image.numpy()
        y_samples[idx] = mask.numpy()
        if idx == (n_samples-1):
            break
            
    return X_samples, y_samples

In [None]:
n_samples = 100
X_test, y_test = arrays_from_dataset(dataset=test, n_samples=n_samples)
print("X_test.shape: {} , y_test.shape: {}".format(X_test.shape, y_test.shape))

In [None]:
img_num = 3
sample_image = X_test[img_num]
sample_mask= np.expand_dims(np.argmax(y_test[img_num], axis=-1), axis=-1)
sample_mask = label_to_rgb(sample_mask)

def create_mask(pred_mask):
    pred_mask = tf.squeeze(pred_mask, axis=0)
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    pred_mask = label_to_rgb(pred_mask.numpy())
    return pred_mask


def show_predictions():
    pred_mask = model.predict(sample_image[tf.newaxis, ...])
    display([sample_image, sample_mask, create_mask(pred_mask)])
    
show_predictions()

In [None]:
def mean_dice(y_true, y_pred):
    dice = 0.0
    smooth = 1.0
    class_dice = []
    for i in range(0, n_classes):
        intersection = K.sum(y_true[:,:,:,i] * y_pred[:,:,:,i], axis=(1,2))
        union = K.sum(y_true[:,:,:,i] + y_pred[:,:,:,i], axis=(1,2))
        dice_temp = K.mean((2. * intersection + smooth) / (union + smooth))
        class_dice.append(dice_temp.numpy())
        dice = dice + dice_temp
    mean_dice = dice / (n_classes)
    return class_dice, round(mean_dice.numpy(), 4)


def mean_iou(y_true, y_pred):
    iou = 0.0
    smooth = 1.0
    class_iou = []
    for i in range(0, n_classes):
        intersection = K.sum(y_true[:,:,:,i] * y_pred[:,:,:,i], axis=(1,2))
        union = K.sum(y_true[:,:,:,i] + y_pred[:,:,:,i], axis=(1,2)) - intersection
        iou_temp = K.mean((intersection + smooth) / (union + smooth))
        class_iou.append(iou_temp.numpy())
        iou = iou + iou_temp
    mean_iou = iou / (n_classes)
    return class_iou, round(mean_iou.numpy(), 4)

In [None]:
y_pred = model.predict(X_test[0:n_samples])

In [None]:
class_dice, mean_dice = mean_dice(y_test[0:n_samples], y_pred)
print("MEAN DICE")
print("Best: {} \nWorst: {}\nAverage: {}".format(max(class_dice), min(class_dice), mean_dice))
class_iou, mean_iou = mean_iou(y_test[0:n_samples], y_pred)
print("MEAN IOU")
print("Best: {} \nWorst: {}\nAverage: {}".format(max(class_iou), min(class_iou), mean_iou))

In [None]:
plot_dice_and_iou(trainId2label=trainId2label, n_classes=n_classes, class_dice=class_dice, class_iou=class_iou)