In [None]:
import os
import sys
import cv2
import PIL
import glob
import random
import imageio
import sklearn
import itertools
import numpy as np
from IPython.display import clear_output
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from skimage.transform import resize
from skimage.morphology import label

import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.losses import *
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import optimizers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from tensorflow.keras.mixed_precision import experimental as mixed_precision

from label_utils import get_labels

K.clear_session()
physical_devices = tf.config.experimental.list_physical_devices("GPU")

def enable_amp():
    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_policy(policy)
    
print("Tensorflow version: ", tf.__version__)
print(physical_devices)
enable_amp()

In [None]:
img_height = 256
img_width = 512
n_classes = 34

labels = get_labels()
# id to label object
id2label = { label.id : label for label in labels }

In [None]:
def get_data(path, shuffle=False, subset=None):
    
    ids_temp = next(os.walk(path + "annotations"))[2]
    ids_1 = []
    for i in ids_temp:
        if i.endswith("labelIds.png"):
            id_temp = i.split("\\")
            id_temp = id_temp[-1][:-20]
            ids_1.append(id_temp)
            
    ids = []
    if shuffle:
        random.seed(2019)
        id_order = np.arange(len(ids_1))
        np.random.shuffle(id_order)
        for i in range(len(id_order)):
            ids.append(ids_1[np.int(id_order[i])])
    else:
        ids = ids_1
        
    if (subset is not None):
        X = np.zeros((subset, img_height, img_width, 3), dtype=np.float32)
        y = np.zeros((subset, img_height, img_width, 1), dtype=np.uint8)
        print("Number of images: " + str(subset))
    else:
        X = np.zeros((len(ids), img_height, img_width, 3), dtype=np.float32)
        y = np.zeros((len(ids), img_height, img_width, 1), dtype=np.uint8)
        print("Number of images: " + str(len(ids)))
        
    for n, id_ in enumerate(ids):
        
        print("\r Loading %s \ %s " % (n+1, len(ids)), end='')
        
        # load images
        id_image = id_ + "_leftImg8bit.png"
        img = load_img(path + "images\\" + id_image)
        x_img = img_to_array(img)
        x_img = resize(x_img, (img_height, img_width, 3), mode='constant', preserve_range = True)
        
        # load masks
        id_mask = id_ + "_gtFine_labelIds.png"
        mask = load_img(path + "annotations\\" + id_mask, color_mode = "grayscale")
        mask = img_to_array(mask)
        mask = cv2.resize(mask, (img_width, img_height), interpolation = cv2.INTER_NEAREST)
        mask = np.expand_dims(mask, 2)
        #mask = to_categorical(mask, n_classes)
        
        # save images
        X[n, ...] = x_img.squeeze()
        y[n] = mask.astype(np.uint8)
        
        if (subset is not None) and (n == subset-1):
            break
        
    return np.array(X), np.array(y)
    

In [None]:
X_all, y_all = get_data(path="Cityscapes\\", shuffle=True, subset=1000)

In [None]:
def label_to_rgb(mask):
    
    mask_rgb = np.zeros((img_height, img_width, 3), dtype=np.uint8)
    
    for i in range(0,n_classes):
        mask_rgb[mask[:,:,0]==i] = id2label[i].color
    
    return mask_rgb


def display(display_list, title=False):
    plt.figure(figsize=(15, 5))
    if title:
        title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        if title:
            plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    # plt.tight_layout()
    plt.show()

In [None]:
img_num = 5

sample_image = X_all[img_num]
sample_mask = y_all[img_num]
sample_mask = label_to_rgb(sample_mask)

display([sample_image, sample_mask])

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=0.2, random_state=42)

In [None]:
BATCH_SIZE = 4
TRAIN_LENGTH = len(X_train)
TEST_LENGTH = len(X_test)
BUFFER_SIZE = 800

In [None]:
X_train = X_train.astype(np.float32) / 255
X_test = X_test.astype(np.float32) / 255

In [None]:
@tf.function
def mask_to_categorical(image, mask):
    mask = tf.squeeze(mask)
    mask = tf.one_hot(tf.cast(mask, tf.int32), n_classes)
    mask = tf.cast(mask, tf.float32)
    return image, mask

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train))
valid_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test))

train_ds = train_ds.map(mask_to_categorical, num_parallel_calls=tf.data.experimental.AUTOTUNE)
valid_ds = valid_ds.map(mask_to_categorical)

train_dataset = train_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
valid_dataset = valid_ds.batch(BATCH_SIZE)

In [None]:
for image, mask in valid_ds.take(2):
    sample_image, sample_mask = image, mask
    
sample_mask = tf.argmax(sample_mask, axis=-1)
sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

In [None]:
def unet_model(input_height=img_height,  input_width=img_width, n_classes = 3):
    
    img_input = tf.keras.layers.Input(shape=(input_height, input_width, 3))

    # -------------------------- Encoder --------------------------
    
    c1 = Conv2D(64, 3, padding='same', activation="selu")(img_input)
    c1 = Conv2D(64, 3, padding='same', activation="selu")(c1)
    p1 = MaxPooling2D((2,2))(c1)
    
    c2 = Conv2D(128, 3, padding='same', activation="selu")(p1)
    c2 = Conv2D(128, 3, padding='same', activation="selu")(c2)
    p2 = MaxPooling2D((2,2))(c2)
    p2 = Dropout(0.1)(p2)
    
    c3 = Conv2D(256, 3, padding='same', activation="selu")(p2)
    c3 = Conv2D(256, 3, padding='same', activation="selu")(c3)
    p3 = MaxPooling2D((2,2))(c3)
    p3 = Dropout(0.2)(p3)
    
    c4 = Conv2D(512, 3, padding='same', activation="selu")(p3)
    c4 = Conv2D(512, 3, padding='same', activation="selu")(c4)
    p4 = MaxPooling2D((2,2))(c4)
    p4 = Dropout(0.2)(p4)
    
    # ------------------------ Bottleneck -------------------------
    
    c5 = Conv2D(1024, 3, padding='same', activation="selu")(p4)
    c5 = Conv2D(1024, 3, padding='same', activation="selu")(c5)
    c5 = Dropout(0.5)(c5)
    
    # -------------------------- Decoder --------------------------
    
    u6 = concatenate([UpSampling2D(2)(c5), c4])
    c6 = Conv2D(512, 3, padding='same')(u6)
    c6 = BatchNormalization()(c6)
    c6 = Activation('selu')(c6)
    c6 = Conv2D(256, 3, padding='same')(c6)
    c6 = BatchNormalization()(c6)
    c6 = Activation('selu')(c6)
    c6 = Dropout(0.2)(c6)
    
    u7 = concatenate([UpSampling2D(2)(c6), c3])
    c7 = Conv2D(256, 3, padding='same')(u7)
    c7 = BatchNormalization()(c7)
    c7 = Activation('selu')(c7)
    c7 = Conv2D(128, 3, padding='same')(c7)
    c7 = BatchNormalization()(c7)
    c7 = Activation('selu')(c7)
    c7 = Dropout(0.2)(c7)

    u8 = concatenate([UpSampling2D(2)(c7), c2])
    c8 = Conv2D(128, 3, padding='same')(u8)
    c8 = BatchNormalization()(c8)
    c8 = Activation('selu')(c8)
    c8 = Conv2D(64, 3, padding='same')(c8)
    c8 = BatchNormalization()(c8)
    c8 = Activation('selu')(c8)
    c8 = Dropout(0.1)(c8)

    u9 = concatenate([UpSampling2D(2)(c8), c1]) 
    c9 = Conv2D(64, 3, padding='same')(u9)
    c9 = BatchNormalization()(c9)
    c9 = Activation('selu')(c9)
    c9 = Conv2D(64, 3, padding='same')(u9)
    c9 = BatchNormalization()(c9)
    c9 = Activation('selu')(c9)
    c9 = Conv2D(n_classes, 3, padding='same')(c9)
    
    output = Activation("softmax", dtype='float32')(c9)
    
    return tf.keras.Model(inputs=img_input, outputs=output)

In [None]:
def dice_coef(y_true, y_pred):
    dice = 0.0
    smooth = 1.0
    for i in range(1, n_classes):
        intersection = y_true[:,:,i] * y_pred[:,:,i]
        all_ = y_true[:,:,i] + y_pred[:,:,i]
        intersection = K.sum(intersection, 1)
        all_ = K.sum(all_, 1)
        temp = (2. * intersection + smooth) / (all_ + smooth)
        temp = K.mean(temp)
        dice = dice + temp
    return dice / (n_classes-1)


def cce_dice_loss(y_true, y_pred):
    return (tf.keras.losses.categorical_crossentropy(y_true, y_pred) - dice_coef(y_true, y_pred)) + 1


def dice_loss(y_true, y_pred):
    return (1 - dice_coef(y_true, y_pred))

In [None]:
model = unet_model(input_height=img_height, input_width=img_width, n_classes=34)

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.squeeze(pred_mask)
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    pred_mask = label_to_rgb(pred_mask.numpy())
    return pred_mask


def show_predictions():
    pred_mask = model.predict(sample_image[tf.newaxis, ...])
    display([sample_image, sample_mask, create_mask(pred_mask)])

        
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
        
show_predictions()

In [None]:
model_name = "saved_models\\unet_pets.h5"

model.compile(optimizer = Adam(lr=1e-4),
              loss = cce_dice_loss, 
              metrics = ['accuracy', dice_coef])

callbacks = [
    DisplayCallback(),
    EarlyStopping(monitor='val_loss', mode='min', patience=9, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', mode='min', patience=3, factor=0.1, min_lr=1e-10, verbose=1),
    ModelCheckpoint(model_name, monitor='val_loss', verbose=1, mode='min', save_best_only=True, save_weights_only=True)
]
# model.load_weights("big_unet_model.h5", by_name=True)

In [None]:
EPOCHS = 50
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
VALIDATION_STEPS = TEST_LENGTH//BATCH_SIZE

In [None]:
results = model.fit(train_dataset,
                    steps_per_epoch=STEPS_PER_EPOCH,
                    validation_steps=VALIDATION_STEPS,
                    epochs = EPOCHS,
                    validation_data = valid_dataset,
                    callbacks = callbacks,
                    verbose = 1)