In [None]:
"""script for enet training"""

In [None]:
# !ipython nbconvert --to=python enet.ipynb

In [None]:
import glob
import os
import numpy as np
import functools
import random
import tensorflow as tf
# import tensorflow_addons as tfa
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers
from tensorflow.keras import losses
import tensorflow.keras.backend as K
from tensorflow.python.keras.callbacks import LambdaCallback
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, Input, Conv2DTranspose, \
    Add, Activation, BatchNormalization, Concatenate
from tensorflow.python.keras.models import Model

import segmentation_models as sm

sm.set_framework('tf.keras')

from cb.tbi_cb import TensorBoardImage
from cb.snapshot_cb_builder import SnapshotCallbackBuilder
from cb.sgdr_lr_scheduler import SGDRScheduler


In [None]:
# GPU check
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)

In [None]:
# img_dir = "/home/ubuntu/valdata/road_boundary_train/17_data/images"
# label_dir = "/home/ubuntu/valdata/road_boundary_train/17_data/labels"

img_dir = "/home/ubuntu/valdata/road_boundary_train/road_satellite/dataset/training/images"
label_dir = "/home/ubuntu/valdata/road_boundary_train/road_satellite/dataset/training/masks"

MODEL_DIR = "/home/ubuntu/valdata/road_boundary_train/road_satellite/enet"

alpha = K.variable(value=0.0)
alpha._trainable = False

OUTPUT_SHAPE = (512, 512, 1)
INPUT_SHAPE = (512, 512, 3)

In [None]:
# Dataset

train_files = glob.glob(img_dir + "/*.png")
train_label_files = []
for x in train_files:
    train_label_files.append(x.replace("images", "masks"))

    
x_train_filenames, x_val_filenames, y_train_filenames, y_val_filenames = train_test_split(train_files,
                                                                                          train_label_files,
                                                                                          test_size=0.15,
                                                                                          random_state=42)

num_train_examples = len(x_train_filenames)
num_val_examples = len(x_val_filenames)

print("Number of training examples: {}".format(num_train_examples))
print("Number of validation examples: {}".format(num_val_examples))

img_shape = (512, 512, 3)
batch_size = 4
n_classes = 1
epochs = 30
BACKBONE = 'efficientnetb4'

preprocess_input = sm.get_preprocessing(BACKBONE)


def _process_pathnames(fname, label_path):
    img_str = tf.io.read_file(fname)
    img = tf.image.decode_png(img_str, channels=3)

    label_img_str = tf.io.read_file(label_path)
    label_img = tf.image.decode_png(label_img_str)

    label_img = label_img[:, :, 0]
    label_img = tf.expand_dims(label_img, axis=-1)
    return img, label_img

def flip_img(horizontal_flip, tr_img, label_img):
    if horizontal_flip:
        flip_prob = tf.random.uniform([], 0.0, 1.0)
        tr_img, label_img = tf.cond(tf.math.less(flip_prob, 0.5),
                                    lambda: (tf.image.flip_left_right(tr_img), tf.image.flip_left_right(label_img)),
                                    lambda: (tr_img, label_img))
    return tr_img, label_img


def flip_img_vertically(vertical_flip, tr_img, label_img):
    if vertical_flip:
        flip_prob = tf.random.uniform([], 0.0, 1.0)
        tr_img, label_img = tf.cond(tf.math.less(flip_prob, 0.5),
                                    lambda: (tf.image.flip_up_down(tr_img), tf.image.flip_up_down(label_img)),
                                    lambda: (tr_img, label_img))
    return tr_img, label_img


def _augment(img,
             label_img,
             resize=None,  # Resize the image to some size e.g. [256, 256]
             scale=1,  # Scale image e.g. 1 / 255.
             hue_delta=0,  # Adjust the hue of an RGB image by random factor
             horizontal_flip=False,  # Random left right flip,
             width_shift_range=0,  # Randomly translate the image horizontally
             height_shift_range=0):  # Randomly translate the image vertically
    if resize is not None:
        # Resize both images
        label_img = tf.image.resize(label_img, resize)
        img = tf.image.resize(img, resize)

    brightness_prob = tf.random.uniform([], 0.0, 1.0)
    if tf.math.less(brightness_prob, 0.5):
        img = tf.image.adjust_brightness(img, 0.2)
        img = tf.image.random_contrast(img, lower=0.05, upper=0.5)

    if hue_delta:
        img = tf.image.random_hue(img, hue_delta)

    img, label_img = flip_img(horizontal_flip, img, label_img)
    img, label_img = flip_img_vertically(horizontal_flip, img, label_img)
    
    label_img = tf.cast(label_img, dtype=tf.float32) * scale
    img = tf.cast(img, dtype=tf.float32) * scale
    return img, label_img


def _tb_augment(img, label_img):
    label_img = tf.image.resize(label_img, [img_shape[0], img_shape[1]])
    img = tf.image.resize(img, [img_shape[0], img_shape[1]])

    label_img = tf.cast(label_img, dtype=tf.float32) * (1 / 255.)
    img = tf.cast(img, dtype=tf.float32) * (1 / 255.)

    return img, label_img


def get_baseline_dataset(filenames,
                         labels,
                         preproc_fn=functools.partial(_augment),
                         threads=6,
                         batch_size=batch_size,
                         shuffle=False):
    num_x = len(filenames)
    # Create a dataset from the filenames and labels
    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

    # Map our preprocessing function to every element in our dataset, taking
    # advantage of multithreading
    dataset = dataset.map(_process_pathnames, num_parallel_calls=threads)
    # print(dataset)
    if preproc_fn.keywords is not None and 'resize' not in preproc_fn.keywords:
        assert batch_size == 1, "Batching images must be of the same size"

    dataset = dataset.map(preproc_fn, num_parallel_calls=threads)

    if shuffle:
        dataset = dataset.shuffle(num_x)

    # It's necessary to repeat our data for all epochs
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat()
#     dataset = dataset.prefetch(2)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    # dataset = dataset.repeat().batch(batch_size)
    return dataset


def get_tb_dataset(fnames,
                   lnames,
                   preproc_fn=functools.partial(_tb_augment),
                   threads=6,
                   batch_size=1,
                   shuffle=True):
    filenames, labels = zip(*random.sample(list(zip(fnames, lnames)), 300))
    filenames = list(filenames)
    labels = list(labels)
    num_x = len(filenames)
    # Create a dataset from the filenames and labels
    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

    dataset = dataset.map(_process_pathnames, num_parallel_calls=threads)
    # print(dataset)
    if preproc_fn.keywords is not None and 'resize' not in preproc_fn.keywords:
        assert batch_size == 1, "Batching images must be of the same size"

    dataset = dataset.map(preproc_fn, num_parallel_calls=threads)

    if shuffle:
        dataset = dataset.shuffle(num_x)

    # It's necessary to repeat our data for all epochs
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat()
#     dataset = dataset.prefetch(2)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # dataset = dataset.repeat().batch(batch_size)

    return dataset


tr_cfg = {
    'resize': [img_shape[0], img_shape[1]],
    'scale': 1 / 255.,
    'hue_delta': 0.2,
    'horizontal_flip': True,
    'width_shift_range': 0.1,
    'height_shift_range': 0.1
}
tr_preprocessing_fn = functools.partial(_augment, **tr_cfg)

val_cfg = {
    'resize': [img_shape[0], img_shape[1]],
    'scale': 1 / 255.,
}
val_preprocessing_fn = functools.partial(_augment, **val_cfg)

train_ds = get_baseline_dataset(x_train_filenames,
                                y_train_filenames,
                                preproc_fn=tr_preprocessing_fn,
                                batch_size=batch_size)
val_ds = get_baseline_dataset(x_val_filenames,
                              y_val_filenames,
                              preproc_fn=val_preprocessing_fn,
                              batch_size=batch_size)
tb_ds = get_tb_dataset(x_val_filenames, y_val_filenames)


In [None]:
# LOSS functions
alpha = K.variable(value=0.1)
alpha._trainable = False

def update_alpha_value(epoch):
    if epoch == 0:
        K.set_value(alpha, 0.1)
        print(f"Setting alpha to = {K.get_value(alpha)}")
    if epoch > 5:
        new_alpha = K.get_value(alpha) + 0.2
        if new_alpha < 0.5:
            K.set_value(alpha, new_alpha)
        else:
            K.set_value(alpha, 0.1)
        print(f"Setting alpha to = {K.get_value(alpha)}")


alpha_update_clb = LambdaCallback(on_epoch_begin=lambda epoch, log: update_alpha_value(epoch))


def segmentation_boundary_loss(y_true, y_pred):
    """
    Using Binary Segmentation mask, generates boundary mask on fly and claculates boundary loss.
    :param y_true:
    :param y_pred:
    :return:
    """
    y_pred_bd = layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same', input_shape=OUTPUT_SHAPE)(1 - y_pred)
    y_true_bd = layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same', input_shape=OUTPUT_SHAPE)(1 - y_true)
    y_pred_bd = y_pred_bd - (1 - y_pred)
    y_true_bd = y_true_bd - (1 - y_true)

    y_pred_bd_ext = layers.MaxPooling2D((5, 5), strides=(1, 1), padding='same', input_shape=OUTPUT_SHAPE)(1 - y_pred)
    y_true_bd_ext = layers.MaxPooling2D((5, 5), strides=(1, 1), padding='same', input_shape=OUTPUT_SHAPE)(1 - y_true)
    y_pred_bd_ext = y_pred_bd_ext - (1 - y_pred)
    y_true_bd_ext = y_true_bd_ext - (1 - y_true)

    P = K.sum(y_pred_bd * y_true_bd_ext) / K.sum(y_pred_bd) + 1e-7
    R = K.sum(y_true_bd * y_pred_bd_ext) / K.sum(y_true_bd) + 1e-7
    F1_Score = 2 * P * R / (P + R + 1e-7)
    # print(f'Precission: {P.eval()}, Recall: {R.eval()}, F1: {F1_Score.eval()}')
    loss = K.mean(1 - F1_Score)
    # print(f"Loss:{loss.eval()}")
    return loss


def binary_focal_loss(y_true, y_pred):
    """
    Binary form of focal loss.
      FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
      where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
    References:
        https://arxiv.org/pdf/1708.02002.pdf
    """
    alpha = 0.25
    gamma = 2
    pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
    pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
    epsilon = K.epsilon()
    # clip to prevent NaN's and Inf's
    pt_1 = K.clip(pt_1, epsilon, 1. - epsilon)
    pt_0 = K.clip(pt_0, epsilon, 1. - epsilon)

    return -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) \
           - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))


def combined_loss(y_true, y_pred):
#     loss = (1 - alpha) * (binary_focal_loss(y_true, y_pred) + log_cosh_dce_loss(y_true,
#                                                                                 y_pred)) + alpha * segmentation_boundary_loss(y_true, y_pred)
    loss = (1 - alpha) * (losses.binary_crossentropy(y_true, y_pred) + log_cosh_dce_loss(y_true,
                                                                                y_pred)) + alpha * segmentation_boundary_loss(y_true, y_pred)
    return loss


def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss


def log_cosh_dce_loss(y_true, y_pred):
    """
    Implementation suggested in https://arxiv.org/pdf/2006.14822.pdf
    """
    return tf.math.log(tf.math.cosh(dice_loss(y_true, y_pred)))


def dice_coeff(y_true, y_pred):
    smooth = 1.
    # Flatten
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return score


def bce_dice_loss(y_true, y_pred):
    loss = losses.binary_crossentropy(y_true, y_pred) + log_cosh_dce_loss(y_true, y_pred)
    return loss

In [None]:
# callbacks

steps_per_epoch = int(np.ceil(num_train_examples / float(batch_size)))
validation_steps = int(np.ceil(num_val_examples / float(batch_size)))

# Sets up a timestamped log directory.
logdir = f"{MODEL_DIR}/logs/train_data/"

mcp_save = tf.keras.callbacks.ModelCheckpoint(f'{MODEL_DIR}/enetRoadSegV1.h5',
                                              save_best_only=True,
                                              save_weights_only=True, monitor='val_loss', mode='min')

tbCallBack = tf.keras.callbacks.TensorBoard(log_dir=logdir, histogram_freq=0, write_graph=True, write_images=True)
tbi_callback = TensorBoardImage(logdir, data_set=tb_ds)

earlystopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=12)

sgdr_lr = SGDRScheduler(min_lr=1e-5,
                        max_lr=1e-2,
                        steps_per_epoch=steps_per_epoch,
                        swa_path=MODEL_DIR + "/swa_weights/model_swa_{}.hdf5",
                        tb_log_dir=logdir,
                        lr_decay=0.8,
                        cycle_length=5,
                        mult_factor=1.5)

snapshot = SnapshotCallbackBuilder(sgdr_lr=sgdr_lr, nb_epochs=epochs, nb_snapshots=1, init_lr=1e-3)
snapshot_callbacks = snapshot.get_callbacks()

In [None]:
# MODEL
model = sm.Unet(BACKBONE, encoder_weights='imagenet', classes=1, activation='sigmoid', input_shape = (img_shape[0], img_shape[1], 3))

# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
#dice_loss = dice_coef_loss()
#focal_loss = binary_focal_loss()
total_loss = dice_loss + (1 * focal_loss)
dice_loss_metrics = total_loss

# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5), sm.metrics.Precision(threshold=0.5),
          sm.metrics.Recall(threshold=0.5), dice_loss_metrics]

model.compile(optimizer='adam', loss=bce_dice_loss, metrics=metrics)

model.summary()

save_model_path = f"{MODEL_DIR}/enetRoadSegF1.hdf5"

if os.path.exists(save_model_path):
    model.load_weights(save_model_path)


In [None]:
# training

history = model.fit_generator(tf.compat.v1.data.make_one_shot_iterator(train_ds), validation_data=tf.compat.v1.data.make_one_shot_iterator(val_ds), validation_steps=validation_steps, epochs=epochs,
                    steps_per_epoch=steps_per_epoch,
                    callbacks=[mcp_save, tbCallBack, tbi_callback, earlystopping, alpha_update_clb] + snapshot_callbacks)

final_model_path = f"{MODEL_DIR}/enetRoadSegF1.hdf5"
model.save(final_model_path)
print(history)