In [None]:
import sys
import time
import numpy as np
import matplotlib.pyplot as plt
# import tensorflow_addons as tfa
import tensorflow_datasets as tfds

import tensorflow as tf 
from tensorflow.keras import backend as K
from tensorflow.keras import mixed_precision
from tensorflow.keras.utils import plot_model
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

from data_loaders import CityscapesLoader
from utils.plot_utils import plot_iou_trainId, plot_iou_catId
from utils.data_utils import get_labels
from models.seg.setr_pup import setr_pup

%config InlineBackend.figure_format = 'retina'
plt.style.use('seaborn-white')
plt.rc('xtick',labelsize=16)
plt.rc('ytick',labelsize=16)

K.clear_session()
physical_devices = tf.config.experimental.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

def enable_amp():
    mixed_precision.set_global_policy("mixed_float16")
    
print("Tensorflow version: ", tf.__version__)
print(physical_devices,"\n")
enable_amp() 

### Data Prep

In [None]:
n_classes = 20
img_height = 768
img_width = 768

BATCH_SIZE = 1
ACCUM_STEPS = 8
BUFFER_SIZE = 512

labels = get_labels()
trainid2label = { label.trainId : label for label in labels }
catid2label = { label.categoryId : label for label in labels }

In [None]:
pipeline = CityscapesLoader(
    img_height=img_height, 
    img_width=img_width, 
    n_classes=n_classes
)

dataset, info = tfds.load(
    name = 'cityscapes/semantic_segmentation', 
    data_dir = '/workspace/tensorflow_datasets',
    with_info = True,
    shuffle_files = True
)

In [None]:
TRAIN_LENGTH = info.splits['train'].num_examples
VALID_LENGTH = info.splits['validation'].num_examples

train = dataset['train'].map(pipeline.load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
valid = dataset['validation'].map(pipeline.load_image_test, num_parallel_calls=tf.data.AUTOTUNE)

train_dataset = train.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
valid_dataset = valid.batch(BATCH_SIZE)

In [None]:
def label_to_rgb(mask):
    h = mask.shape[0]
    w = mask.shape[1]
    mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for val, key in trainid2label.items():
        indices = mask == val
        mask_rgb[indices.squeeze()] = key.color 
    return mask_rgb


def display(display_list, title=True):
    plt.figure(figsize=(15, 5), dpi=150) # dpi=200
    if title:
        title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        if title:
            plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
for image, mask in train.take(3): 
    sample_image, sample_mask = image, mask

sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

###  Hyperparameter configuration

In [None]:
patch_size = 16  
image_size = img_width
num_patches = (image_size // patch_size) ** 2
projection_dim = 1024
num_heads = 16
transformer_layers = 24

In [None]:
model = setr_pup(
    img_size=image_size,
    n_classes=n_classes,
    num_patches=num_patches,
    patch_size=patch_size,
    transformer_layers=transformer_layers,
    dim=projection_dim,
    heads=num_heads,
    mlp_dim=projection_dim,
    ACCUM_STEPS=ACCUM_STEPS,
    dropout_rate=0.1,
    attn_dropout_rate=0.1,
)

In [None]:
inp_test = tf.random.normal(shape=(1, img_height, img_width, 3))
out_test = model(inp_test)

In [None]:
model.summary()

In [None]:
MODEL_PATH = "weights/"+model.name+".h5"

In [None]:
model.save_weights(MODEL_PATH)

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.squeeze(pred_mask)
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    pred_mask = label_to_rgb(pred_mask.numpy())
    return pred_mask


def show_predictions():        
    pred_mask = model.predict(sample_image[tf.newaxis, ...])
    if "U2Net" in model.name:
        pred_mask = pred_mask[0]
    display([sample_image, sample_mask, create_mask(pred_mask)])

        
def iou_coef(y_true, y_pred):
    y_true = tf.one_hot(tf.cast(y_true, tf.int32), depth=n_classes)
    y_pred = tf.math.argmax(y_pred, axis=-1)
    y_pred = tf.one_hot(tf.cast(y_pred, tf.int32), depth=n_classes)
    smooth = 1
    iou_total = 0
    for i in range(1, n_classes):
        intersection = tf.math.reduce_sum(y_true[:,:,:,i] * y_pred[:,:,:,i], axis=(1,2))
        union = tf.math.reduce_sum(y_true[:,:,:,i] + y_pred[:,:,:,i], axis=(1,2)) 
        iou = tf.math.reduce_mean(tf.math.divide_no_nan(2.*intersection+smooth, union+smooth), axis=0)
        iou_total += iou
    return iou_total/(n_classes-1)

In [None]:
show_predictions()

In [None]:
EPOCHS = 200
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE 
VALIDATION_STEPS = VALID_LENGTH // BATCH_SIZE 
DECAY_STEPS = (STEPS_PER_EPOCH * EPOCHS) // ACCUM_STEPS
print("Decay steps: {}".format(DECAY_STEPS))

In [None]:
def weighted_cross_entropy_loss(y_true_labels, y_pred_logits):
    c_weights = [0.0,    2.602,  6.707,  3.522,  9.877, 9.685,  9.398,  10.288, 9.969,  4.336, 
                 9.454,  7.617,  9.405,  10.359, 6.373, 10.231, 10.262, 10.264, 10.394, 10.094] 
    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true_labels, logits=y_pred_logits)  
    weights = tf.gather(c_weights, y_true_labels)  
    losses = tf.multiply(losses, weights)
    return tf.math.reduce_mean(losses)

In [None]:
model.compile(
    optimizer = Adam(learning_rate=1e-3), #SGD(learning_rate=learning_rate_fn, momentum=0.9, decay=0.0005),
    loss = weighted_cross_entropy_loss,
    metrics = ['accuracy', iou_coef]
)

In [None]:
callbacks = [
    # EarlyStopping(monitor='val_iou_coef', mode='max', patience=40, verbose=2),
    ReduceLROnPlateau(monitor='val_iou_coef', mode='max', patience=10, factor=0.1, min_lr=1e-5, verbose=2),
    ModelCheckpoint(MODEL_PATH, monitor='val_iou_coef', mode='max', 
                    verbose=2, save_best_only=True, save_weights_only=True)    
]

In [None]:
results = model.fit(
    train_dataset,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
    epochs=EPOCHS,
    validation_data=valid_dataset,
    callbacks=callbacks,
    verbose=1
)

In [None]:
plot_history(results, model)