In [None]:
import sys
import numpy as np
from time import time
from tqdm import tqdm
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
plt.style.use('ggplot')
plt.rc('xtick',labelsize=16)
plt.rc('ytick',labelsize=16)

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.optimizers.schedules import PolynomialDecay
from tensorflow.keras.layers import *
from tensorflow.keras.activations import gelu
from tensorflow.keras.utils import plot_model
from utils.plot_utils import plot_iou_trainId, plot_iou_catId, label_to_rgb, display, create_mask
from utils.data_utils import get_labels, parse_record, get_dataset_from_tfrecord
from utils.train_utils import (weighted_cross_entropy_loss, TrainAccumilator, SETRMTrainAccumilator,
                              SETRLTrainAccumilator)
from utils.custom_callbacks import ReduceLROnPlateau
from tensorflow.keras import mixed_precision
# from data_loaders import CityscapesLoader
# from setr import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock
# from tensorflow_addons.layers import GroupNormalization
from models.seg.setr import SETR_PUP

K.clear_session()
physical_devices = tf.config.experimental.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

def enable_amp():
    mixed_precision.set_global_policy("mixed_float16")
    
print("Tensorflow version: ", tf.__version__)
print(physical_devices,"\n")
enable_amp() 

In [None]:
# echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | 
#    sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list

# apt-get install apt-transport-https ca-certificates gnupg

# curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | 
# sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -

# sudo apt-get update && sudo apt-get install google-cloud-sdk

# gcloud init --console-only

# gsutil cp gs://cl_datasets_01/cityscapes/records/trainIds_train.record /mnt/vol_b/records/trainIds_train.record 
# gsutil cp gs://cl_datasets_01/cityscapes/records/trainIds_val.record /mnt/vol_b/records/trainIds_val.record 
# gsutil cp gs://cl_datasets_01/cityscapes/records/ /mnt/vol_b/weights/

In [None]:
class CityscapesLoader():
    
    def __init__(self, img_height, img_width, n_classes):

        self.n_classes = n_classes
        self.img_height = img_height
        self.img_width = img_width
        self.MEAN = np.array([0.485, 0.456, 0.406])
        self.STD = np.array([0.229, 0.224, 0.225])
        self.id2label = tf.constant([0,  0,  0,  0,  0,  0,  0,  1,  2,  0,  0,  3,  
                                     4,  5,  0,  0,  0,  6,  0,  7,  8,  9,  10, 11, 
                                    12, 13, 14, 15, 16,  0,  0, 17, 18, 19,  0], tf.int32)
        

    @tf.function
    def random_crop(self, img, seg):
        """
        Inputs: full resolution image and mask
        A scale between 0.5 and 1.0 is randomly chosen. 
        Then, we multiply original height and width by the scale, 
        and randomly crop to the scaled height and width.
        """
        scales = tf.convert_to_tensor(np.array(
        [0.25, 0.3125, 0.375, 0.4375, 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, 1.0], 
            dtype=np.float32))
        scale = scales[tf.random.uniform(shape=[], minval=0, maxval=13, dtype=tf.int32)]

        shape = tf.cast(tf.shape(img), tf.float32)
        h = tf.cast(shape[0] * scale, tf.int32)
        w = tf.cast(shape[1] * scale, tf.int32)
        combined_tensor = tf.concat([img, seg], axis=2)
        combined_tensor = tf.image.random_crop(combined_tensor, size=[h, w, 4])
        return combined_tensor[:,:,0:3], combined_tensor[:,:,-1]


    @tf.function
    def normalize(self, img):
        img = img / 255.0
        img = img - self.MEAN
        img = img / self.STD
        return img

    
    @tf.function
    def load_image_train(self, datapoint):
        img = datapoint['image_left']
        seg = datapoint['segmentation_label']
        
        if tf.random.uniform(()) > 0.5:
            img = tf.image.flip_left_right(img)
            seg = tf.image.flip_left_right(seg)
        if tf.random.uniform(()) > 0.5:
            img = tf.image.random_brightness(img, 0.1)
        if tf.random.uniform(()) > 0.5:
            img = tf.image.random_saturation(img, 0.7, 1.3)
        if tf.random.uniform(()) > 0.5:
            img = tf.image.random_contrast(img, 0.7, 1.3)
        if tf.random.uniform(()) > 0.5:
            img = tf.image.random_hue(img, 0.05)
            
        img, seg = self.random_crop(img, seg)
        seg = tf.expand_dims(seg, axis=-1)
        
        img = tf.image.resize(img, (self.img_height, self.img_width), method='bilinear')
        seg = tf.image.resize(seg, (self.img_height, self.img_width), method='nearest')
        img = self.normalize(tf.cast(img, tf.float32))
        
        seg = tf.squeeze(seg)
        seg = tf.gather(self.id2label, tf.cast(seg, tf.int32))
        
        return img, seg
    
    
    @tf.function
    def load_image_test(self, datapoint):
        img = datapoint['image_left']
        seg = datapoint['segmentation_label']
        
        img = tf.image.resize(img, (self.img_height, self.img_width), method='bilinear')
        seg = tf.image.resize(seg, (self.img_height, self.img_width), method='nearest')
        img = self.normalize(tf.cast(img, tf.float32))
        
        seg = tf.squeeze(seg, axis=-1)
        seg = tf.gather(self.id2label, tf.cast(seg, tf.int32))
        
        return img, seg
    
    
    @tf.function
    def load_image_eval(self, datapoint):
        img = datapoint['image_left']
        seg = datapoint['segmentation_label']
        seg = tf.expand_dims(seg, axis=-1)
        img = tf.image.resize(img, (self.img_height, self.img_width), method='bilinear')
        img = self.normalize(tf.cast(img, tf.float32))
        seg = tf.squeeze(seg)
        seg = tf.gather(self.id2label, tf.cast(seg, tf.int32))
        return img, seg

In [None]:
n_classes = 20
img_size = 768
patch_size = 16

BATCH_SIZE = 2
ACCUM_STEPS = 4
ADJ_BATCH_SIZE = BATCH_SIZE * ACCUM_STEPS
BUFFER_SIZE = 256

labels = get_labels()
trainid2label = { label.trainId : label for label in labels }
catid2label = { label.categoryId : label for label in labels }

pipeline = CityscapesLoader(
    img_height=img_size, 
    img_width=img_size, 
    n_classes=n_classes
)

In [None]:
dataset, info = tfds.load(
    name = 'cityscapes/semantic_segmentation', 
    data_dir = '/workspace/tensorflow_datasets/', 
    with_info = True,
    shuffle_files=True
)

In [None]:
TRAIN_LENGTH = info.splits['train'].num_examples
VALID_LENGTH = info.splits['validation'].num_examples


EPOCHS = 120
STEPS_PER_EPOCH = TRAIN_LENGTH // ADJ_BATCH_SIZE
VALIDATION_STEPS = VALID_LENGTH // BATCH_SIZE 
DECAY_STEPS = (STEPS_PER_EPOCH * EPOCHS) # // ACCUM_STEPS
print("Decay steps: {}".format(DECAY_STEPS))

In [None]:
train = dataset['train'].map(pipeline.load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
valid = dataset['validation'].map(pipeline.load_image_test, num_parallel_calls=tf.data.AUTOTUNE)
eval_ds = dataset['validation'].map(pipeline.load_image_eval, num_parallel_calls=tf.data.AUTOTUNE)

train_dataset = train.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
valid_dataset = valid.batch(BATCH_SIZE, drop_remainder=True)

In [None]:
for image, mask in train.take(1): 
    sample_image, sample_mask = image, mask

sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

In [None]:
for image, mask in train.take(1): 
    sample_image, sample_mask = image, mask

sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

In [None]:
for image, mask in train.take(1): 
    sample_image, sample_mask = image, mask

sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

In [None]:
for image, mask in train.take(1): 
    sample_image, sample_mask = image, mask

sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

In [None]:
for image, mask in valid.take(2): 
    sample_image, sample_mask = image, mask

sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

In [None]:
for image, mask in valid.take(2): 
    sample_image, sample_mask = image, mask

sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

In [None]:
for image, mask in valid.take(2): 
    sample_image, sample_mask = image, mask

sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

In [None]:
for image, mask in valid.take(2): 
    sample_image, sample_mask = image, mask

sample_mask = sample_mask[..., tf.newaxis]
sample_mask = label_to_rgb(sample_mask.numpy())
display([sample_image, sample_mask])

In [None]:
CONFIG_B = {
    "dropout": 0.1,
    "mlp_dim": 3072,
    "num_heads": 12,
    "num_layers": 12,
    "hidden_size": 768,
    "aux_layers": None,
    "name": "SETR-B_16",
    "pretrained": "weights/vit_b16_imagenet21k_imagenet2012.h5"
}

CONFIG_M = {
    "dropout": 0.1,
    "mlp_dim": 4096,
    "num_heads": 16,
    "num_layers": 18,
    "hidden_size": 1024,
    "aux_layers": [9, 14],
    "name": "SETR-M_16",
    "pretrained": "weights/vit_l16_imagenet21k_imagenet2012.h5"
}


config = CONFIG_B

In [None]:
K.clear_session()

In [None]:
def get_model():

    # with strategy.scope():

    learning_rate_fn = PolynomialDecay(
        initial_learning_rate = 5e-3,
        decay_steps = DECAY_STEPS,
        end_learning_rate = 5e-6,
        power = 0.9
    )

    model = SETR_PUP(
        image_size = img_size,
        patch_size = patch_size,
        num_classes = n_classes,
        num_layers = config["num_layers"],
        hidden_size = config["hidden_size"],
        aux_layers = config["aux_layers"],
        num_heads = config["num_heads"],
        name = config["name"],
        mlp_dim = config["mlp_dim"],
        dropout = 0.1,
    )
    model.load_weights(config["pretrained"], by_name=True)

    model.compile(
        optimizer = SGD(learning_rate=learning_rate_fn, momentum=0.9),
        loss = weighted_cross_entropy_loss,
        metrics = ['accuracy', iou_coef]
        )
        
    return model


def show_predictions():        
    pred_mask = model.predict(sample_image[tf.newaxis, ...])
    if (model.name == "SETR-L_16") or (model.name == "SETR-M_16"):
        pred_mask = pred_mask[0]
    display([sample_image, sample_mask, create_mask(pred_mask)])

        
def iou_coef(y_true, y_pred):
    y_true = tf.one_hot(tf.cast(y_true, tf.int32), depth=n_classes)
    y_pred = tf.math.argmax(y_pred, axis=-1)
    y_pred = tf.one_hot(tf.cast(y_pred, tf.int32), depth=n_classes)
    smooth = 1
    iou_total = 0
    for i in range(1, n_classes):
        intersection = tf.math.reduce_sum(y_true[:,:,:,i] * y_pred[:,:,:,i], axis=(1,2))
        union = tf.math.reduce_sum(y_true[:,:,:,i] + y_pred[:,:,:,i], axis=(1,2)) 
        iou = tf.math.reduce_mean(tf.math.divide_no_nan(2.*intersection+smooth, union+smooth), axis=0)
        iou_total += iou
    return iou_total/(n_classes-1)

In [None]:
learning_rate_fn = PolynomialDecay(
        initial_learning_rate = 1e-2,
        decay_steps = DECAY_STEPS,
        end_learning_rate = 1e-5,
        power = 0.9
    )
    
opt = SGD(learning_rate=learning_rate_fn, momentum=0.9)

model = SETR_PUP(
    image_size = img_size,
    patch_size = patch_size,
    num_classes = n_classes,
    num_layers = config["num_layers"],
    hidden_size = config["hidden_size"],
    aux_layers = config["aux_layers"],
    num_heads = config["num_heads"],
    name = config["name"],
    mlp_dim = config["mlp_dim"],
    dropout = 0.1,
)
model.load_weights(config["pretrained"], by_name=True)

# trainer = SETRMTrainAccumilator(
#     model = model,
#     optimizer = mixed_precision.LossScaleOptimizer(opt),
#     loss_fn = weighted_cross_entropy_loss,
#     n_classes = n_classes,
#     reduce_lr_on_plateau = None,
#     accum_steps = ACCUM_STEPS,
# )

trainer = TrainAccumilator(
    model = model,
    optimizer = mixed_precision.LossScaleOptimizer(opt),
    loss_fn = weighted_cross_entropy_loss,
    n_classes = n_classes,
    reduce_lr_on_plateau = None,
    accum_steps = ACCUM_STEPS,
)

In [None]:
# model = get_model()
# model, trainer = get_ga_model()

In [None]:
# plot_model(model, show_shapes=True, dpi=64, expand_nested=True)

In [None]:
MODEL_PATH = "weights/"+model.name+".h5"

In [None]:
# model.load_weights(MODEL_PATH, by_name=True)

In [None]:
model.summary()

In [None]:
show_predictions()

In [None]:
results = trainer.fit(
    epochs = EPOCHS,
    train_dataset = train_dataset,
    test_dataset = valid_dataset, 
    weights_path = MODEL_PATH,
)

In [None]:
# callbacks = [
#     # EarlyStopping(monitor='val_iou_coef', mode='max', patience=40, verbose=2),
#     # ReduceLROnPlateau(monitor='val_iou_coef', mode='max', patience=10, factor=0.5, min_lr=1e-5, verbose=2),
#     tf.keras.callbacks.ModelCheckpoint(MODEL_PATH, monitor='val_iou_coef', mode='max', 
#                     verbose=2, save_best_only=True, save_weights_only=True)    
# ]

In [None]:
# results = model.fit(
#     train_dataset,
#     steps_per_epoch=STEPS_PER_EPOCH,
#     validation_steps=VALIDATION_STEPS,
#     epochs=EPOCHS,
#     validation_data=valid_dataset,
#     callbacks=callbacks,
#     verbose=1
# )

In [None]:
results = model.history

In [None]:
def plot_history(results, model):
         
    plt.figure(figsize=(15,7))
    plt.subplot(1,3,1)  

    plt.plot(results.history['loss'], 'r', label='Training loss')
    plt.plot(results.history['val_loss'], 'b', label='Validation loss')
    plt.title("Loss: "+model.name, fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(prop={'size': 14})

    plt.subplot(1,3,2)
    plt.plot(results.history['accuracy'], 'r', label='Training accuracy')
    plt.plot(results.history['val_accuracy'], 'b', label='Validation accuracy')
    plt.title('Accuracy: '+model.name, fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(prop={'size': 14})

    plt.subplot(1,3,3)
    plt.plot(results.history['iou_coef'], 'r', label='IoU coefficient')
    plt.plot(results.history['val_iou_coef'], 'b', label='Validation IoU coefficient')
    plt.title('IoU Coefficient: '+model.name, fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(prop={'size': 14})
    
    if fine:
        plt.savefig("plots/"+model.name+"_learning_curves.png")
    else:
        plt.savefig("plots/"+model.name+"_learning_curves_coarse.png")
    plt.show()

In [None]:
plot_history(results, model)