# Prototype of Segmenter based on tf.data API
---

In [None]:
pwd

In [None]:
# import the necessary packages
import tempfile
import platform
import imageio
import time
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.ticker import LogLocator

In [None]:
import tensorflow as tf
# tensorflow_io can handle TIFF images (not ready for TF 2.0)
# import tensorflow_io.image as image_io
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import models
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.callbacks import LambdaCallback
# eager execution is default for TF 2.0
#tf.enable_eager_execution()
AUTOTUNE = tf.data.experimental.AUTOTUNE
pd.options.display.float_format = '{:.4f}'.format
print("Python version     : " + platform.python_version())
print("Tensorflow version : " + tf.version.VERSION)
print("Keras version      : " + tf.keras.__version__)
print("Numpy version      : " + np.__version__)
print("Pandas version     : " + pd.__version__)
print("Imageio version    : " + imageio.__version__)
print("GPU available      : " + str(tf.test.is_gpu_available()))

## ISPRS Project Constants
---
If one would like to train using the ISPRS dataset, the following cells should be run exclusively (INRIA Building cells not to be run)

In [None]:
# defining tiff image and mask location
train_img_dir  = '../train_dir/potsdam_aerials/'
train_mask_dir = '../train_dir/potsdam_masks/'
val_img_dir    = '../validation_dir/potsdam_aerials/'
val_mask_dir   = '../validation_dir/potsdam_masks/'
# let's build two list containing the path of the aerials and mask files
train_src = Path(train_img_dir)
train_ref = Path(train_mask_dir)
val_src   = Path(val_img_dir)
val_ref   = Path(val_mask_dir)

In [None]:
train_images     = sorted([str(x) for x in train_src.iterdir() if x.is_file() and x.suffix == '.tif'])
train_masks      = sorted([str(x) for x in train_ref.iterdir() if x.is_file() and x.suffix == '.tif'])
val_images       = sorted([str(x) for x in val_src.iterdir()   if x.is_file() and x.suffix == '.tif'])
val_masks        = sorted([str(x) for x in val_ref.iterdir()   if x.is_file() and x.suffix == '.tif'])

# checking the equality of the two list lenghts
assert len(train_images) == len(train_masks), " (Error) The train Aerial image count does not match Mask counts!"
assert len(val_images)   == len(val_masks), " (Error) The Validation Aerial image count does not match Mask counts!"
len_train_images = len(train_images)
len_val_images   = len(val_images)
len_train_images, len_val_images

In [None]:
# ordered lists of colors and it's corresponding class name in [R, G, B] format
color_list  = [ [1, 0, 0], [0, 0, 1], [1, 1, 1], [0, 1, 1], [0, 1, 0], [1, 1, 0]]
color_names = ['Background', 'Building', 'Roads', 'Vegetation', 'Tree', 'Car']
# Patch extraction parameters
# it works jointly on the aerial and the mask image to keep them aligned
TILE_SIZE   = 6000 
PATCH_SIZE  = 256
PATCH_STRIDE= 256
PATCH_RATE  = 1
SIZES       = [1, PATCH_SIZE, PATCH_SIZE, 1] 
STRIDES     = [1, PATCH_STRIDE, PATCH_STRIDE, 1] 
RATES       = [1, PATCH_RATE, PATCH_RATE, 1] 
PADDING     = 'VALID'

def compute_tile_patch_number(rates=[1,2,3]):
    """ we expect that all tile share the same size """
    with tf.device('/CPU:0'):
        aerial = tf.constant(imageio.imread(train_images[0]))
        total  = 0
        for r in rates:
            patches = tf.image.extract_patches(images=tf.expand_dims(aerial, axis=0), sizes=SIZES, strides=STRIDES, rates=[1,r,r,1], padding=PADDING)            
            total += patches.shape[1] * patches.shape[2]
        return total
    
# old formula to compute patch number
# PATCH_NUMBER= ((TILE_SIZE - PATCH_SIZE)//PATCH_STRIDE + 1)**2

# choose this line if you want a dataset with 3 different scales
# TRAIN_PATCH_NUMBER = compute_tile_patch_number(rates=[1,2,3])
# chosse this line if you want only one scale
TRAIN_PATCH_NUMBER = compute_tile_patch_number(rates=[1])
VALID_PATCH_NUMBER = compute_tile_patch_number(rates=[1])
TRAIN_PATCH_NUMBER, VALID_PATCH_NUMBER

In [None]:
MODEL_SHAPE  = (256, 256, 3)
PATCH_RESIZE = (128, 128)
NUM_CLASSES  = 6
BATCH_SIZE   = 50
EPOCHS       = 50
NUM_TRAIN_EXAMPLES = TRAIN_PATCH_NUMBER * len_train_images 
NUM_VAL_EXAMPLES   = VALID_PATCH_NUMBER * len_val_images 
LEARNING_RATE   = 0.0001
TRAIN_STEPS_PER_EPOCH = int(np.ceil(NUM_TRAIN_EXAMPLES / float(BATCH_SIZE)))
VAL_STEPS_PER_EPOCH   = int(np.ceil(NUM_VAL_EXAMPLES / float(BATCH_SIZE)))
TRAIN_STEPS_PER_EPOCH, VAL_STEPS_PER_EPOCH

## INRIA Project Constants
---
(Remainder) chose between ISPRS or INRIA constant to be run

In [None]:
# defining tiff image and mask location
inria_train_img_dir  = '/datasets/InriaAerial/AerialImageDataset/train/images/'
inria_train_mask_dir = '/datasets/InriaAerial/AerialImageDataset/train/gt/'
inria_val_img_dir    = '/datasets/InriaAerial/AerialImageDataset/valid/images/'
inria_val_mask_dir   = '/datasets/InriaAerial/AerialImageDataset/valid/gt/'
# let's build two list containing the path of the aerials and mask files
train_src = Path(inria_train_img_dir)
train_ref = Path(inria_train_mask_dir)
val_src   = Path(inria_val_img_dir)
val_ref   = Path(inria_val_mask_dir)
train_images     = sorted([str(x) for x in train_src.iterdir() if x.is_file() and x.suffix == '.tif'])
train_masks      = sorted([str(x) for x in train_ref.iterdir() if x.is_file() and x.suffix == '.tif'])
val_images       = sorted([str(x) for x in val_src.iterdir()   if x.is_file() and x.suffix == '.tif'])
val_masks        = sorted([str(x) for x in val_ref.iterdir()   if x.is_file() and x.suffix == '.tif'])

# checking the equality of the two list lenghts
assert len(train_images) == len(train_masks), " (Error) The train Aerial image count does not match Mask counts!"
assert len(val_images)   == len(val_masks), " (Error) The Validation Aerial image count does not match Mask counts!"
len_train_images = len(train_images)
len_val_images   = len(val_images)
len_train_images, len_val_images

In [None]:
# ordered lists of colors and it's corresponding class name in [R, G, B] format
color_list  = [ [0, 0, 0], [1, 1, 1]]
color_names = ['Background', 'Building']
# patch extraction parameters
TILE_SIZE   = 5000 
PATCH_SIZE  = 256
PATCH_STRIDE= 256
PATCH_RATE  = 1
SIZES       = [1, PATCH_SIZE, PATCH_SIZE, 1] 
STRIDES     = [1, PATCH_STRIDE, PATCH_STRIDE, 1] 
RATES       = [1, PATCH_RATE, PATCH_RATE, 1] 
PADDING     = 'VALID'

def compute_tile_patch_number(rates=[1,2,3]):
    """ we expect that all tile share the same size """
    with tf.device('/CPU:0'):
        aerial = tf.constant(imageio.imread(train_images[0]))
        total  = 0
        for r in rates:
            patches = tf.image.extract_patches(images=tf.expand_dims(aerial, axis=0), sizes=SIZES, strides=STRIDES, rates=[1,r,r,1], padding=PADDING)            
            total += patches.shape[1] * patches.shape[2]
        return total

# only patch at one scale are generated for INRIA
TRAIN_PATCH_NUMBER = compute_tile_patch_number(rates=[1])
VALID_PATCH_NUMBER = compute_tile_patch_number(rates=[1])
TRAIN_PATCH_NUMBER, VALID_PATCH_NUMBER

In [None]:
MODEL_SHAPE  = (256, 256, 3)
PATCH_RESIZE = (128, 128)
NUM_CLASSES  = 2
BATCH_SIZE   = 50
EPOCHS       = 50
NUM_TRAIN_EXAMPLES = TRAIN_PATCH_NUMBER * len_train_images 
NUM_VAL_EXAMPLES   = VALID_PATCH_NUMBER * len_val_images 
LEARNING_RATE   = 0.0001
TRAIN_STEPS_PER_EPOCH = int(np.ceil(NUM_TRAIN_EXAMPLES / float(BATCH_SIZE)))
VAL_STEPS_PER_EPOCH   = int(np.ceil(NUM_VAL_EXAMPLES / float(BATCH_SIZE)))
TRAIN_STEPS_PER_EPOCH, VAL_STEPS_PER_EPOCH

## Utility Functions
---

In [None]:
# mask to lavel conversion functions
def mask2label(mask, colors=color_list,sparse=False):
    label_list = []
    if sparse:
        for i,color in enumerate(colors):
            # sparse encoding returned in tf.int32 to be able to use it in tf.gather as indices
            label_list.append(i * tf.cast(tf.reduce_all(tf.equal(mask,color),axis=-1), dtype=tf.int32))
            label = tf.add_n(label_list)

    else:
        for i,color in enumerate(colors):
            label_list.append(tf.cast(tf.reduce_all(tf.equal(mask,color),axis=-1), dtype=tf.int8))
            label = tf.stack(label_list, axis=-1)
            # here is another way to return a sparse label from one hot encoded label
            # for a strange reason, can't return tf.int8 as output_type
            # label = mask2label(test_mask)
            # sparse_label = mask2label(test_mask, sparse=True)
            # arg_label = tf.argmax(label, axis=-1, output_type=tf.int32)
            # so we cast it here back to tf.int8 (not needed after all)
            # arg_label = tf.dtypes.cast(arg_label, tf.int8)
            # checking that the two implementation return the same thing
            # tf.reduce_all(tf.equal(arg_label, sparse_label)).numpy()            
    return label                      

def label2mask(sparse_label, colors=color_list):
    # label should be a sparse encoding of the mask, not one-hot encoding
    # TODO, add assertion to check sparse encoding (check last dimension = 1, not = number of classes)
    mask = tf.gather(color_list, sparse_label)
    mask = tf.dtypes.cast(mask, tf.float32)
    return mask
    

In [None]:
def my_show_pair(img_list, title=None, interpolation=None, **kwargs):
    """helper to display an image pairs side by side """
    # get a grid of axes in the figure
    f, ax_list = plt.subplots(1, 2 ,figsize=(20,10))
    for ax, img in zip(ax_list, img_list):
        ax.imshow(img, interpolation=interpolation, **kwargs)
        ax.axis('off')
        if title:
            ax.set_title(title)

# utility function to display image pairs in a grid
def my_show_grid(img_list, nrow=1, ncol=2, title=None, interpolation=None, **kwargs):
    """helper to display an image list in a grid """
    # here img_list is tuple
    flatten_img_list = zip(*img_list)
    # flatten the tuples
    flatten_img_list = [np.squeeze(item) for sublist in flatten_img_list for item in sublist]
    # get a grid of axes in the figure, make sure you have the correct count of axes
    f, ax_list = plt.subplots(nrow, ncol ,figsize=(20,8), gridspec_kw = {'wspace':0.05, 'hspace':0.05})
    ax_list = ax_list.flatten()
    for ax, img in zip(ax_list, flatten_img_list):
        ax.imshow(img, interpolation=interpolation, **kwargs)
        ax.axis('off')
        if title:
            ax.set_title(title)

In [None]:
def timepipeline(ds, batch_size = 1, steps=1000):
    """Computes the throughput of the dataset iterations"""
    start = time.time()
    it = iter(ds)
    for i in tqdm(range(steps)):
        batch = next(it)
    print()
    end = time.time()
    duration = end-start
    print("batche size: {}".format(batch_size))
    print("{} batches: {} s".format(steps, duration))
    print("{:0.5f} Batches/s".format(steps/duration))
    print("{:0.5f} Images/s".format(batch_size*steps/duration))

## Patch Extraction:
---

In [None]:
# now we are ready to make our final dataset
def make_patches(aerial, mask):
    with tf.device('/CPU:0'):
        # first we generate the patches from the pair of aerial and mask images
        patches = tf.image.extract_patches(images=[aerial, mask], sizes=SIZES, strides=STRIDES, rates=[1,1,1,1], padding=PADDING)
        # at this point patches shape = [2, 23, 23, 196608]
        # the image content is flattened, let's bring back the RGB channels
        # patches = tf.reshape(patches, [2,289,512,512,3])
        patch_nb = tf.shape(patches)[1] * tf.shape(patches)[2]
        patches = tf.reshape(patches, [2, patch_nb, PATCH_SIZE, PATCH_SIZE, 3])
        # lastly we want to permute the 0th and the 1st dimension
        # so we end up with a list of pair of patches: aerial + mask, 
        # produced dim = [patch_number, 2, patch_size, patch_size, 3]
        patches = tf.transpose(patches, perm=[1,0,2,3,4])
        return patches

# This variant of the function help create a dataset with patches at 3 scales (zooming effect)
def make_patches_scaled(aerial, mask):
    with tf.device('/CPU:0'):
        # first we generate the patches from the pair of aerial and mask images
        patches_rate1 = tf.image.extract_patches(images=[aerial, mask], sizes=SIZES, strides=STRIDES, rates=[1,1,1,1], padding=PADDING)
        patches_rate2 = tf.image.extract_patches(images=[aerial, mask], sizes=SIZES, strides=STRIDES, rates=[1,2,2,1], padding=PADDING)
        patches_rate3 = tf.image.extract_patches(images=[aerial, mask], sizes=SIZES, strides=STRIDES, rates=[1,3,3,1], padding=PADDING)
        # at this point patches shape = [2, 23, 23, 196608]
        # the image content is flattened, let's bring back the RGB channels
        # patches = tf.reshape(patches, [2,289,512,512,3])

        patch_number_rate1 = tf.shape(patches_rate1)[1] * tf.shape(patches_rate1)[2]
        patch_number_rate2 = tf.shape(patches_rate2)[1] * tf.shape(patches_rate2)[2]
        patch_number_rate3 = tf.shape(patches_rate3)[1] * tf.shape(patches_rate3)[2]
        patches_rate1 = tf.reshape(patches_rate1, [2, patch_number_rate1, PATCH_SIZE, PATCH_SIZE, 3])
        patches_rate2 = tf.reshape(patches_rate2, [2, patch_number_rate2, PATCH_SIZE, PATCH_SIZE, 3])
        patches_rate3 = tf.reshape(patches_rate3, [2, patch_number_rate3, PATCH_SIZE, PATCH_SIZE, 3])
        patches = tf.concat([patches_rate1,patches_rate2,patches_rate3],axis=1)
        # lastly we want to permute the 0th and the 1st dimension
        # so we end up with a list of pair of patches: aerial + mask, 
        # produced dim = [patch_number, 2, patch_size, patch_size, 3]
        patches = tf.transpose(patches, perm=[1,0,2,3,4])
        return patches

In [None]:
# data augmentation methods
def make_augmentation(pair_batch):
    with tf.device('/CPU:0'):
        # Spatial Transformation
        do_flip    = tf.random.uniform([]) > 0.5
        do_rot90   = tf.random.uniform([], maxval=4, dtype=tf.int32)
        pair_batch = tf.cond(do_flip, lambda: tf.image.flip_left_right(pair_batch), lambda: pair_batch)
        pair_batch = tf.image.rot90(pair_batch, do_rot90)
        # rescale image from uint8 [0..255] to float32 [0..1]
        pair_batch = tf.image.convert_image_dtype(pair_batch,tf.float32)
        # temp solution, to be removed, used only if patches bigger than the model input size are generated 
        # Beware, tf.image.resize change dtype from uint8 to float32 while keeping [0..255] range
        # pair_batch = tf.image.resize(pair_batch, PATCH_RESIZE, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        img = pair_batch[0]
        # TODO: does per_image_standardization help get a better model ???
        # img = tf.image.per_image_standardization(img)
        
        # color transformation
        # TODO, remove comment after finding best values for color transformation
#         img = tf.image.random_hue(img, 0.08)
#         img = tf.image.random_saturation(img, 0.6, 1.6)
#         img = tf.image.random_brightness(img, 0.05)
#         img = tf.image.random_contrast(img, 0.7, 1.3)
#         img = tf.clip_by_value(img, 0, 1)

        mask = pair_batch[1]
        mask = mask2label(mask, sparse=True)
        # if only a particular class is of interest, enable following code
        # example for buildings (class #1)
        # mask = mask[:,:,1]
        # mask = tf.expand_dims(mask, axis=-1)    
        return img, mask

# no augmentation for validation set
def no_augmentation(pair_batch):
    with tf.device('/CPU:0'):
        # rescale image from uint8 [0..255] to float32 [0..1]
        pair_batch = tf.image.convert_image_dtype(pair_batch,tf.float32)
        # temp solution, to be removed, used only if patches bigger than the model input size are generated 
        # Beware, tf.image.resize change dtype from uint8 to float32 while keeping [0..255] range
        # pair_batch = tf.image.resize(pair_batch, PATCH_RESIZE, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        img = pair_batch[0]
        # TODO: does per_image_standardization help get a better model ???
        # img = tf.image.per_image_standardization(img)
        mask = pair_batch[1]
        mask = mask2label(mask, sparse=True)
        return img, mask

## Metrics Definition:
---

In [None]:
def plot_keras_history(keras_history):
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.plot(keras_history.history['Accuracy'], label='train')
    plt.plot(keras_history.history['val_Accuracy'],label='validation')
    plt.grid(axis='y')
    plt.legend()
    plt.title('Accuracy per epoch')
    plt.subplot(132)
    plt.plot(keras_history.history['mean_iou'], label='train')
    plt.plot(keras_history.history['val_mean_iou'], label='validation')
    plt.grid(axis='y')
    plt.legend()
    plt.title('Mean IoU per epoch')
    plt.subplot(133)
    plt.plot(keras_history.history['loss'], label='train')
    plt.plot(keras_history.history['val_loss'], label='validation')
    plt.grid(axis='y')
    plt.legend()
    plt.title('Loss per epoch')
    plt.show()

In [None]:
def plot_cm(cm):
    """ show the Normalized Confusion Matrix for all prediction classes """
    fig, ax = plt.subplots(figsize=(16,8))
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
               yticks=np.arange(cm.shape[0]),
               # ... and label them with the respective list entries
               xticklabels=color_names, yticklabels=color_names,
               title='Normalized Confusion matrix',
               ylabel='True label',
               xlabel='Predicted label')
     # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")
    # Loop over data dimensions and create text annotations.
    #fmt = '.2f' if normalize else 'd'
    fmt = '.2f'
    thresh = tf.reduce_max(cm) / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    plt.show()

In [None]:
# computing IoU per class
with tf.device('/CPU:0'):
    def compute_batch_iou(model, patch_pairs):
        pred = model.predict(patch_pairs[0])
        pred = tf.argmax(pred, axis=-1, output_type=tf.int32)
        label = patch_pairs[1]
        total_cm = tf.math.confusion_matrix(tf.reshape(label, [-1]), tf.reshape(pred, [-1]))
        sum_over_row = tf.reduce_sum(total_cm, axis=0)
        sum_over_col = tf.reduce_sum(total_cm, axis=1)
        true_positives = tf.linalg.diag_part(total_cm)
        denominator = sum_over_row + sum_over_col - true_positives
        num_valid_entries = tf.reduce_sum(tf.cast(tf.not_equal(denominator, 0), tf.float32))
        iou = tf.math.divide_no_nan(tf.cast(true_positives, tf.float32), tf.cast(denominator, tf.float32))    
        #return list(zip(iou.numpy().tolist(), color_names))
        return pd.Series(data=iou, index=color_names).to_frame().T


In [None]:
with tf.device('/CPU:0'):
    def compute_dataset_cm(model, ds, steps, normalize=False):
        total_cm = np.zeros((NUM_CLASSES,NUM_CLASSES))
        for i, patch_pairs in tqdm(enumerate(ds.take(steps))):
            pred = model.predict(patch_pairs[0])
            pred = tf.argmax(pred, axis=-1, output_type=tf.int32)
            label = patch_pairs[1]
            total_cm += tf.math.confusion_matrix(tf.reshape(label, [-1]), tf.reshape(pred, [-1]))
        # normalizing the confusing matrix
        if normalize: 
            total_cm = total_cm / tf.reduce_sum(total_cm, axis=1, keepdims=True)
        return total_cm

In [None]:
with tf.device('/CPU:0'):
    def compute_dataset_iou(model, ds, steps):
        total_cm = compute_dataset_cm(model, ds, steps)
        normalized_cm = total_cm / tf.reduce_sum(total_cm, axis=1, keepdims=True)
        sum_over_row = tf.reduce_sum(total_cm, axis=0)
        sum_over_col = tf.reduce_sum(total_cm, axis=1)
        true_positives = tf.linalg.diag_part(total_cm)
        denominator = sum_over_row + sum_over_col - true_positives
        num_valid_entries = tf.reduce_sum(tf.cast(tf.not_equal(denominator, 0), tf.float32))
        iou = tf.math.divide_no_nan(tf.cast(true_positives, tf.float32), tf.cast(denominator, tf.float32))
        plot_cm(normalized_cm)
        return pd.Series(data=iou, index=color_names).to_frame().T

## Tile prediction functions (instant sliding window):
---

In [None]:
with tf.device('/CPU:0'):
    def get_extract_pred_scatter(img,model):
        H,W,C = img.shape
        # patch_number 
        tile_PATCH_NUMBER = ((H - PATCH_SIZE)//PATCH_STRIDE + 1)*((W - PATCH_SIZE)//PATCH_STRIDE + 1)
        # the indices trick to reconstruct the tile
        x = tf.range(W)
        y = tf.range(H)
        x, y = tf.meshgrid(x, y)
        indices = tf.stack([y, x], axis=-1)
        # making patches, TensorShape([2, 17, 17, 786432])
        img_patches = tf.image.extract_patches(images=tf.expand_dims(img, axis=0),     sizes=SIZES, strides=STRIDES, rates=RATES, padding=PADDING)
        ind_patches = tf.image.extract_patches(images=tf.expand_dims(indices, axis=0), sizes=SIZES, strides=STRIDES, rates=RATES, padding=PADDING) 
        # squeezing the shape (removing dimension of size 1)
        img_patches = tf.squeeze(img_patches)
        ind_patches = tf.squeeze(ind_patches)
        # reshaping
        img_patches = tf.reshape(img_patches, [tile_PATCH_NUMBER, PATCH_SIZE, PATCH_SIZE, C])
        ind_patches = tf.reshape(ind_patches, [tile_PATCH_NUMBER, PATCH_SIZE, PATCH_SIZE, 2])
        # Now predict
        pred_patches = model.predict(img_patches, batch_size=BATCH_SIZE)
        # stitch together the patch summing the overlapping patches probabilities
        pred_tile    = tf.scatter_nd(indices=ind_patches, updates=pred_patches, shape=(H,W,NUM_CLASSES))
        return pred_tile

    def get_tile_prediction(tile_path, model, from_disk=True):
        # reading the tile content
        if from_disk:
            img = imageio.imread(tile_path)            
        else:
            img = tile_path
        img = tf.image.convert_image_dtype(img,tf.float32)
        pred_tile = get_extract_pred_scatter(img,model)    
        pred_tile    = tf.argmax(pred_tile, axis=-1, output_type=tf.int32)
        pred_tile    = label2mask(pred_tile)
        return pred_tile

    def get_tile_tta_pred(tile_path, model, from_disk=True):
        """ test time augmentation prediction """
        # reading the tile content
        if from_disk:
            img = imageio.imread(tile_path)
        else:
            img = tile_path
        img = tf.image.convert_image_dtype(img,tf.float32)
        H,W,C = img.shape
        pred_tile = tf.zeros(shape=(H,W,NUM_CLASSES))
        for i in tqdm(tf.range(4)):
            rot_img = tf.image.rot90(img,k=i)
            pred_tmp = get_extract_pred_scatter(rot_img,model)
            pred_tile += tf.image.rot90(pred_tmp,k=-i)
        img = tf.image.flip_left_right(img)
        for i in tqdm(tf.range(4)):
            rot_img = tf.image.rot90(img,k=i)
            pred_tmp = get_extract_pred_scatter(rot_img,model)
            pred_tile += tf.image.flip_left_right(tf.image.rot90(pred_tmp,k=-i))
        pred_tile    = tf.argmax(pred_tile, axis=-1, output_type=tf.int32)
        pred_tile    = label2mask(pred_tile)
        return pred_tile    

    def get_tile_cm(y_true, y_pred, normalize=False):
        y_true = mask2label(y_true, sparse=True)
        y_pred = mask2label(y_pred, sparse=True)
        tile_cm = tf.math.confusion_matrix(tf.reshape(y_true, [-1]), tf.reshape(y_pred, [-1]))
        tile_cm = tf.cast(tile_cm, tf.float32)
        if normalize:
            tile_cm = tile_cm / tf.reduce_sum(tile_cm, axis=1, keepdims=True)
        return tile_cm

    def get_tile_iou(y_true, y_pred):
        tile_cm = get_tile_cm(y_true, y_pred)
        sum_over_row = tf.reduce_sum(tile_cm, axis=0)
        sum_over_col = tf.reduce_sum(tile_cm, axis=1)
        true_positives = tf.linalg.diag_part(tile_cm)
        denominator = sum_over_row + sum_over_col - true_positives
        num_valid_entries = tf.reduce_sum(tf.cast(tf.not_equal(denominator, 0), tf.float32))
        iou = tf.math.divide_no_nan(tf.cast(true_positives, tf.float32), tf.cast(denominator, tf.float32))
        return pd.Series(data=iou, index=color_names).to_frame().T    

## get_dataset wrapper function:
---

In [None]:
with tf.device('/CPU:0'):
    def get_image(srcpath, refpath):
        # black and white color list, only useful for binary segmentation
        color_list_local  = [[0, 0, 0], [255, 255, 255]]
        srcpath = srcpath.numpy().decode("utf-8")
        refpath = refpath.numpy().decode("utf-8")
        src = imageio.imread(srcpath)
        ref = imageio.imread(refpath)
        # for binary segmentation, mask is only grayscale, not RGB
        # convert it here to 3 (RGB) channels to avoid breaking previous code
        if (len(ref.shape) == 2):
            ref = ref // 255
            ref = tf.cast(ref, tf.int32)
            ref = tf.gather(color_list_local, ref)
            ref = tf.cast(ref, tf.uint8)
        return src,ref
    def tf_get_image(srcpath, refpath):
        src,ref = tf.py_function(func=get_image, inp=[srcpath, refpath], Tout=[tf.uint8,tf.uint8])
        return src,ref

In [None]:
with tf.device('/CPU:0'):
    def get_dataset(img_dir, mask_dir, batch_size, shuffle=True, repeat=False, augment=True, scaled=True, cachefile=None):
        images_path  = Path(img_dir)
        masks_path   = Path(mask_dir)
        images_files = sorted([str(x) for x in images_path.iterdir() if x.is_file() and x.suffix == '.tif'])
        masks_files  = sorted([str(x) for x in masks_path.iterdir()  if x.is_file() and x.suffix == '.tif'])

        # checking the equality of the two list lenghts
        assert len(images_files) == len(masks_files), " (Error) The Aerial image count does not match Mask counts!"
        len_images = len(images_files)
        
        if cachefile:
            cache_file_name = '/tmp/' + cachefile
            # start with a fresh cache file
            for f in list(Path('/tmp').glob(cachefile + '*')):
                print('Removing old cache file: {}'.format(f))
                f.unlink()
        else:
            cache_file_name = ''

        # First strategy to get images files dataset
        images_ds = tf.data.Dataset.from_tensor_slices(images_files)
        masks_ds = tf.data.Dataset.from_tensor_slices(masks_files)

        # second strategy
        # maybe this strategy does not garantee the order of the lists
        #images_ds = tf.data.Dataset.list_files(str(images_path/'*.tif'),shuffle=False) # commented out
        #masks_ds  = tf.data.Dataset.list_files(str(masks_path/'*.tif'),shuffle=False)  # commented ou
        
        # zipping the aerial and mask dataset
        patch_pair_ds = tf.data.Dataset.zip((images_ds, masks_ds))
        if shuffle:
            patch_pair_ds = patch_pair_ds.shuffle(len_images) 
        patch_pair_ds = patch_pair_ds.map(tf_get_image, num_parallel_calls= AUTOTUNE).prefetch(AUTOTUNE)

        # start making patches
        if scaled:
            patch_pair_ds = patch_pair_ds.map(make_patches_scaled, num_parallel_calls= AUTOTUNE)
        else:
            patch_pair_ds = patch_pair_ds.map(make_patches, num_parallel_calls= AUTOTUNE)
        patch_pair_ds = patch_pair_ds.apply(tf.data.Dataset.unbatch).cache(filename=cache_file_name)
        if augment:
            patch_pair_ds = patch_pair_ds.map(make_augmentation, num_parallel_calls= AUTOTUNE).prefetch(AUTOTUNE)
        else:
            patch_pair_ds = patch_pair_ds.map(no_augmentation, num_parallel_calls= AUTOTUNE).prefetch(AUTOTUNE)
        # repeat is not necessary as keras fit method will use epochs parameter for repetition
        if repeat:
            patch_pair_ds = patch_pair_ds.repeat()
        if shuffle:
            patch_pair_ds = patch_pair_ds.shuffle(600) # maybe this needs a new parameter?
        patch_pair_ds = patch_pair_ds.batch(batch_size)
        patch_pair_ds = patch_pair_ds.prefetch(AUTOTUNE)
        return patch_pair_ds


## Train, Validation Datasets:
---

In [None]:
train_ds = get_dataset(img_dir=train_img_dir, mask_dir=train_mask_dir, batch_size=BATCH_SIZE, shuffle=True, repeat=True, augment=True, scaled=False, cachefile=None)
val_ds   = get_dataset(img_dir=val_img_dir, mask_dir=val_mask_dir,batch_size=BATCH_SIZE, shuffle=False, repeat=True, augment=False, scaled=False, cachefile=None)
print("train patch number = {}, valid patch number = {}".format(TRAIN_PATCH_NUMBER, VALID_PATCH_NUMBER))
print("train steps = {}, valid steps = {}".format(TRAIN_STEPS_PER_EPOCH, VAL_STEPS_PER_EPOCH))

## Performance of get_dataset wrapper function:
---

In [None]:
train_ds = get_dataset(img_dir=train_img_dir, mask_dir=train_mask_dir, batch_size=BATCH_SIZE, shuffle=True, repeat=True, augment=True, scaled=False, cachefile=None)

In [None]:
timepipeline(train_ds, batch_size=BATCH_SIZE)

In [None]:
val_ds = get_dataset(img_dir=val_img_dir, mask_dir=val_mask_dir, batch_size=BATCH_SIZE, shuffle=False, repeat=True, augment=False, scaled=False, cachefile=None)

In [None]:
timepipeline(val_ds, batch_size=BATCH_SIZE)

In [None]:
train_iterator = iter(train_ds)

In [None]:
# getting next pair of aerial image and its associated mask
patch_pairs = next(train_iterator)
# display images
my_show_grid(patch_pairs, nrow=4, ncol=10)
# plt.tight_layout()

In [None]:
# Experimenting with options
# aggregator = tf.data.experimental.StatsAggregator()

# options = tf.data.Options()
# options.experimental_stats.aggregator = aggregator
# options.experimental_stats.latency_all_edges = True
# pair_ds = pair_ds.with_options(options)

## The Keras Functional API:
---

In [None]:
def conv_block(input_tensor, num_filters, dilation=1, residual=False):
    x = layers.Conv2D(num_filters, (3,3), activation='relu', padding='same', dilation_rate=dilation, kernel_initializer='he_normal')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(num_filters, (3,3), activation='relu', padding='same', dilation_rate=dilation, kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    if residual:
        x  = layers.Concatenate()([x, input_tensor])
    return x

def bottleneck_block(input_tensor, num_filters, mode='parallel'):
    if mode == 'serial':
        dilated1  = layers.Conv2D(num_filters, (3,3), activation='relu', padding='same', dilation_rate=1, kernel_initializer='he_normal')(input_tensor)
        dilated1  = layers.BatchNormalization()(dilated1)
        dilated2  = layers.Conv2D(num_filters, (3,3), activation='relu', padding='same', dilation_rate=2, kernel_initializer='he_normal')(dilated1)
        dilated2  = layers.BatchNormalization()(dilated2)
        dilated4  = layers.Conv2D(num_filters, (3,3), activation='relu', padding='same', dilation_rate=4, kernel_initializer='he_normal')(dilated2)
        dilated4  = layers.BatchNormalization()(dilated4)
        dilated8  = layers.Conv2D(num_filters, (3,3), activation='relu', padding='same', dilation_rate=8, kernel_initializer='he_normal')(dilated4)
        dilated8  = layers.BatchNormalization()(dilated8)
        x  = layers.Concatenate()([dilated1, dilated2, dilated4, dilated8])
    elif mode == 'parallel':
        dilated1 = conv_block(input_tensor, num_filters, dilation=1)
        dilated2 = conv_block(input_tensor, num_filters, dilation=2)
        dilated4 = conv_block(input_tensor, num_filters, dilation=4)
        dilated8 = conv_block(input_tensor, num_filters, dilation=8)
        x  = layers.Concatenate()([dilated1, dilated2, dilated4, dilated8])
    else:
        x = conv_block(input_tensor, num_filters, dilation=1, residual=True)
    return x

def encoder_block(input_tensor, num_filters, residual=True):
    encoder = conv_block(input_tensor, num_filters, residual=residual)
    encoder_pool = layers.MaxPooling2D((2,2), strides=(2,2))(encoder)
    return encoder_pool, encoder

def decoder_block(input_tensor, concat_tensor, num_filters, residual=True):
    x = layers.Conv2DTranspose(num_filters, (2,2), strides=(2,2), padding='same', kernel_initializer='he_normal')(input_tensor)
    x = layers.concatenate([concat_tensor, x], axis=-1)
    shortcut = x
    x = layers.Conv2D(num_filters, (3,3), activation='relu', padding='same', kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(num_filters, (3,3), activation='relu', padding='same', kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    if residual:
        x  = layers.Concatenate()([x, shortcut])
    return x

def encoder_proc(x, filters=32, n_block=4, residual=True):
    skip = []
    for i in range(n_block):
        x, x_skip = encoder_block(x, filters * 2**i, residual=residual)
        skip.append(x_skip)
    return x, skip

def decoder_proc(x, skip, filters=32, n_block=4, residual=True):
    for i in reversed(range(n_block)):
        x = decoder_block(x, skip[i], filters * 2**i, residual=residual)
    return x

def get_unet_model(input_shape, filters=32, n_block=4, n_class=6, mode='parallel', residual=True):
    if n_class == 2:
        final_activation = 'sigmoid'
    elif n_class > 2:
        final_activation = 'softmax'
    final_activation = 'softmax'
    inputs = layers.Input(input_shape)    
    enc, skip = encoder_proc(inputs, filters, n_block, residual=residual)
    center = bottleneck_block(enc, num_filters=filters * 2**n_block, mode=mode)
    dec = decoder_proc(center, skip, filters, n_block, residual=residual)
    classify = layers.Conv2D(n_class, (1, 1), activation=final_activation)(dec)
    model = models.Model(inputs=inputs, outputs=classify)
    return model    

In [None]:
# Learning Rate Finder
class LearningRateFinder:
    def __init__(self, model, stopFactor=10, beta=0.98):
        # store the model, stop factor, and beta value (for computing a smoothed, average loss)
        self.model = model
        self.stopFactor = stopFactor
        self.beta = beta

        # initialize our list of learning rates and losses, respectively
        self.lrs = []
        self.losses = []

        # initialize our learning rate multiplier, average loss, best loss found thus far, current batch number, and weights file
        self.lrMult = 1
        self.avgLoss = 0
        self.bestLoss = 1e9
        self.batchNum = 0
        self.weightsFile = None

    def reset(self):
        # re-initialize all variables from our constructor
        self.lrs = []
        self.losses = []
        self.lrMult = 1
        self.avgLoss = 0
        self.bestLoss = 1e9
        self.batchNum = 0
        self.weightsFile = None   
 
    def on_batch_end(self, batch, logs):
        # grab the current learning rate and add log it to the list of learning rates that we've tried
        lr = K.get_value(self.model.optimizer.lr)
        self.lrs.append(lr)

        # grab the loss at the end of this batch, increment the total number of batches processed, compute the average average
        # loss, smooth it, and update the losses list with the smoothed value
        l = logs["loss"]
        self.batchNum += 1
        self.avgLoss = (self.beta * self.avgLoss) + ((1 - self.beta) * l)
        smooth = self.avgLoss / (1 - (self.beta ** self.batchNum))
        self.losses.append(smooth)

        # compute the maximum loss stopping factor value
        stopLoss = self.stopFactor * self.bestLoss

        # check to see whether the loss has grown too large
        if self.batchNum > 1 and smooth > stopLoss:
            # stop returning and return from the method
            self.model.stop_training = True
            return

        # check to see if the best loss should be updated
        if self.batchNum == 1 or smooth < self.bestLoss:
            self.bestLoss = smooth

        # increase the learning rate
        lr *= self.lrMult
        K.set_value(self.model.optimizer.lr, lr)

    def find(self, trainData, startLR, endLR, epochs=1, stepsPerEpoch=None, verbose=1):
        # reset our class-specific variables
        self.reset()

        # compute the total number of batch updates that will take place while we are attempting to find a good starting
        # learning rate
        numBatchUpdates = epochs * stepsPerEpoch

        # derive the learning rate multiplier based on the ending learning rate, starting learning rate, and total number of
        # batch updates
        self.lrMult = (endLR / startLR) ** (1.0 / numBatchUpdates)

        # create a temporary file path for the model weights and then save the weights (so we can reset the weights when we
        # are done)
        self.weightsFile = tempfile.mkstemp()[1]
        self.model.save_weights(self.weightsFile)

        # grab the *original* learning rate (so we can reset it later), and then set the *starting* learning rate
        origLR = K.get_value(self.model.optimizer.lr)
        K.set_value(self.model.optimizer.lr, startLR)

        # construct a callback that will be called at the end of each batch, enabling us to increase our learning rate as training
        # progresses
        callback = LambdaCallback(on_batch_end=lambda batch, logs: self.on_batch_end(batch, logs))
        self.model.fit(trainData, steps_per_epoch=stepsPerEpoch, epochs=epochs, verbose=verbose, callbacks=[callback])


        # restore the original model weights and learning rate
        self.model.load_weights(self.weightsFile)
        K.set_value(self.model.optimizer.lr, origLR)

    def plot_loss(self, skipBegin=2, skipEnd=1, title=""):
        # grab the learning rate and losses values to plot
        lrs = self.lrs[skipBegin:-skipEnd]
        losses = self.losses[skipBegin:-skipEnd]

        # plot the learning rate vs. loss   
        fig, ax = plt.subplots(figsize=(15,6))
        ax.plot(lrs, losses)
        ax.set_xscale("log")
        ax.xaxis.set_major_locator(LogLocator(base=10., numticks=15, subs=range(10)))
        ax.set_xlabel("Learning Rate (Log Scale)")
        ax.set_ylabel("Loss")
        ax.grid(True) 
        # if the title is not empty, add it to the plot
        if title != "":
            plt.title(title)

In [None]:
# One Cycle Policy training
# inspired by https://github.com/shivam-agarwal-17/keras-one-cycle-policy/blob/master/one_cycle_lr/one_cycle_scheduler.py
class ParamScheduler:
    def __init__(self, start, end, num_iter):
        self.start = start
        self.end = end
        self.num_iter = num_iter
        self.idx = -1
        
    def func(self, start_val, end_val, pct):
        raise NotImplementedError
        
    def step(self):
        self.idx+=1
        return self.func(self.start, self.end, self.idx/self.num_iter)
    
    def reset(self):
        self.idx=-1
        
    def is_complete(self):
        return self.idx >= self.num_iter

class LinearScheduler(ParamScheduler):
    
    def func(self, start_val, end_val, pct):
        return start_val + pct * (end_val - start_val)
    
class CosineScheduler(ParamScheduler):
    
    def func(self, start_val, end_val, pct):
        cos_out = np.cos(np.pi * pct) + 1
        return end_val + (start_val - end_val)/2 * cos_out


class OneCycleScheduler(Callback):
    
    def __init__(self, max_lr, momentums=(0.95,0.80), start_div=25., pct_start=0.3, verbose=True, sched=CosineScheduler, end_div=25e3):
        self.max_lr, self.momentums, self.start_div,self.pct_start, self.verbose, self.sched, self.end_div = max_lr, momentums, start_div, pct_start, verbose, sched, end_div
        self.logs = {}
        
    def on_train_begin(self, logs=None):
        self.num_epochs = self.params['epochs']
        self.steps_per_epoch = self.params['steps']
        start_lr   = self.max_lr / self.start_div
        end_lr     = self.max_lr / self.end_div
        num_iter   = self.num_epochs * self.steps_per_epoch
        num_iter_1 = int(self.pct_start*num_iter)
        num_iter_2 = num_iter - num_iter_1
        self.lr_scheds = (self.sched(start_lr, self.max_lr, num_iter_1), self.sched(self.max_lr, end_lr, num_iter_2))
        self.momentum_scheds = (self.sched(self.momentums[0], self.momentums[1], num_iter_1), self.sched(self.momentums[1], self.momentums[0], num_iter_2))
        self.sched_idx = 0
        self.optimizer_params_step()   
        
    def optimizer_params_step(self):
        next_lr = self.lr_scheds[self.sched_idx].step()
        next_momentum = self.momentum_scheds[self.sched_idx].step()
        
        # add to logs
        self.logs.setdefault('lr', []).append(next_lr)
        self.logs.setdefault('momentum', []).append(next_momentum)
        
        # update optimizer params
        K.set_value(self.model.optimizer.lr, next_lr)
        if hasattr(self.model.optimizer, 'momentum'):
            K.set_value(self.model.optimizer.momentum, next_momentum)
        
    def on_batch_end(self, batch, logs=None):
        if self.sched_idx >= len(self.lr_scheds):
            self.model.stop_training=True
            return
        self.optimizer_params_step()
        if self.lr_scheds[self.sched_idx].is_complete():
            self.sched_idx += 1
            
    def on_epoch_end(self, epoch, logs=None):
        if self.verbose:
            if hasattr(self.model.optimizer, 'momentum'):
                print(" - OneCycleScheduler, lr: {:.7f}, momentum: {:.7f}".format(self.logs['lr'][-1], self.logs['momentum'][-1]))
            else:
                print(" - OneCycleScheduler, lr: {:.7f}".format(self.logs['lr'][-1]))
            
        if epoch >= self.num_epochs:
            self.model.stop_training=True
            return
        
    def plot_lr(self, show_momentums=True):
        plt.figure(figsize=(20,5))
        if hasattr(self.model.optimizer, 'momentum') and show_momentums:
            plt.subplot(131)
            plt.plot(self.logs['lr'])
            plt.ylabel('learning rate')
            plt.xlabel('iteration') 
            plt.grid(True, linestyle="--")

            plt.subplot(132)
            plt.plot(self.logs['lr'])
            plt.yscale("log")
            plt.ylabel('learning rate (log)')
            plt.xlabel('iteration')
            plt.grid(True, linestyle="--")

            plt.subplot(133)
            plt.plot(self.logs['momentum'])
            plt.ylabel('momentum')
            plt.xlabel('iteration')
            plt.grid(True, linestyle="--")
        else:
            plt.subplot(121)
            plt.plot(self.logs['lr'])
            plt.ylabel('learning rate')
            plt.xlabel('iteration')
            plt.grid(True, linestyle="--")

            plt.subplot(122)
            plt.plot(self.logs['lr'])
            plt.yscale("log")
            plt.ylabel('learning rate (Log scale)')
            plt.xlabel('iteration')
            plt.grid(True, linestyle="--")
        

In [None]:
## fixing MeanIoU
class CategoricalMeanIoU(tf.keras.metrics.MeanIoU):

    def __init__(self, name='mean_iou', **kwargs):
        super(CategoricalMeanIoU, self).__init__(name=name, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1)
        return super(CategoricalMeanIoU, self).update_state(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)


In [None]:
def jaccard_distance(y_true, y_pred):
    """Jaccard distance for semantic segmentation.

    Also known as the intersection-over-union loss.

    This loss is useful when you have unbalanced numbers of pixels within an image
    because it gives all classes equal weight. However, it is not the defacto
    standard for image segmentation.

    For example, assume you are trying to predict if
    each pixel is cat, dog, or background.
    You have 80% background pixels, 10% dog, and 10% cat.
    If the model predicts 100% background
    should it be be 80% right (as with categorical cross entropy)
    or 30% (with this loss)?

    The loss has been modified to have a smooth gradient as it converges on zero.
    This has been shifted so it converges on 0 and is smoothed to avoid exploding
    or disappearing gradient.

    Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
            = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))

    # Arguments
        y_true: The ground truth tensor.
        y_pred: The predicted tensor
        smooth: Smoothing factor. Default is 100.

    # Returns
        The Jaccard distance between the two tensors.

    # References
        - [What is a good evaluation measure for semantic segmentation?](
           http://www.bmva.org/bmvc/2013/Papers/paper0032/paper0032.pdf)

    """
    smooth=1
    y_true = tf.one_hot(tf.cast(y_true,tf.int32), NUM_CLASSES)
#     This implementation is strange, does computing IoU with last axis correct ?
#     intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
#     sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
#     jac = (intersection + smooth) / (sum_ - intersection + smooth)
#     return (1 - jac) * smooth

    intersection = K.sum(K.abs(y_true * y_pred), axis=(0,1,2))
    sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=(0,1,2))
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return tf.reduce_mean((1 - jac))

In [None]:
def binary_focal_loss(gamma=2., alpha=.25):
    """
    Binary form of focal loss.
      FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
      where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
    References:
        https://arxiv.org/pdf/1708.02002.pdf
    Usage:
     model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    def binary_focal_loss_fixed(y_true, y_pred):
        """
        :param y_true: A tensor of the same shape as `y_pred`
        :param y_pred:  A tensor resulting from a sigmoid
        :return: Output tensor.
        """
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))

        epsilon = K.epsilon()
        # clip to prevent NaN's and Inf's
        pt_1 = K.clip(pt_1, epsilon, 1. - epsilon)
        pt_0 = K.clip(pt_0, epsilon, 1. - epsilon)

        return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) \
               -K.sum((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))

    return binary_focal_loss_fixed

def categorical_focal_loss(gamma=2., alpha=.25):
    """
    Softmax version of focal loss.
           m
      FL = ∑  -alpha * (1 - p_o,c)^gamma * y_o,c * log(p_o,c)
          c=1
      where m = number of classes, c = class and o = observation
    Parameters:
      alpha -- the same as weighing factor in balanced cross entropy
      gamma -- focusing parameter for modulating factor (1-p)
    Default value:
      gamma -- 2.0 as mentioned in the paper
      alpha -- 0.25 as mentioned in the paper
    References:
        Official paper: https://arxiv.org/pdf/1708.02002.pdf
        https://www.tensorflow.org/api_docs/python/tf/keras/backend/categorical_crossentropy
    Usage:
     model.compile(loss=[categorical_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    def categorical_focal_loss_fixed(y_true, y_pred):
        """
        :param y_true: A tensor of the same shape as `y_pred`
        :param y_pred: A tensor resulting from a softmax
        :return: Output tensor.
        """

        # Scale predictions so that the class probas of each sample sum to 1 (not needed as last layer is softmax)
#         y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
        y_true = tf.one_hot(tf.cast(y_true,tf.int32), NUM_CLASSES)

        # Clip the prediction value to prevent NaN's and Inf's
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

        # Calculate Cross Entropy
        cross_entropy = -y_true * K.log(y_pred)

        # Calculate Focal Loss
        loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy

        # Sum the losses in mini_batch
        return K.sum(loss, axis=1)

    return categorical_focal_loss_fixed

In [None]:
model = get_unet_model(MODEL_SHAPE, filters=32, n_block=4, n_class=6, mode='serial', residual=True)
# if one wants to resets the weights of the model after training, use this lambda method
reset_weights = model.get_weights()
resetModel = lambda m: m.set_weights(reset_weights)
# usage: resetModel(model)

In [None]:
model.summary()
# old way to declare a UNET model
# inputs = layers.Input(shape=PATCH_SHAPE)
# # 512

# encoder0_pool, encoder0 = encoder_block(inputs, 32)
# # 256

# encoder1_pool, encoder1 = encoder_block(encoder0_pool, 64)
# # 128

# encoder2_pool, encoder2 = encoder_block(encoder1_pool, 128)
# # 64

# encoder3_pool, encoder3 = encoder_block(encoder2_pool, 256)
# # 32

# # encoder4_pool, encoder4 = encoder_block(encoder3_pool, 512)
# # 16

# center = bottleneck_block(encoder3_pool, 512, mode='parallel')
# # center

# # decoder4 = decoder_block(center, encoder4, 512)
# # 32

# decoder3 = decoder_block(center, encoder3, 256)
# # 64

# decoder2 = decoder_block(decoder3, encoder2, 128)
# # 128

# decoder1 = decoder_block(decoder2, encoder1, 64)
# # 256

# decoder0 = decoder_block(decoder1, encoder0, 32)
# # 512

# outputs = layers.Conv2D(NUM_CLASSES, (1,1), activation='softmax')(decoder0)
# model = models.Model(inputs=[inputs], outputs=[outputs])

In [None]:
experiment_name = "unet_input256_noScale_skip4_Residual_serialCenter_sgd_cross_entropy_bs50_ocpCB"
sgd_optim = tf.keras.optimizers.SGD(lr=0.1, nesterov=True)
adam_optim = tf.keras.optimizers.Adam(lr=1e-3)
acc_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='Accuracy')
iou_metric = tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES)
fixed_iou  = CategoricalMeanIoU(num_classes=NUM_CLASSES)
model.compile(optimizer=sgd_optim, loss='sparse_categorical_crossentropy', metrics=[acc_metric,fixed_iou])
# model.compile(optimizer=sgd_optim, loss=jaccard_distance, metrics=[acc_metric,fixed_iou])
# model.compile(optimizer=sgd_optim, loss=categorical_focal_loss(gamma=2., alpha=.25), metrics=[acc_metric,fixed_iou])

In [None]:
checkpoint_path = "../checkpoint_dir/" + experiment_name + datetime.now().strftime("-%Y%m%d-%Hh%Mmn") + "/best-cp-" + experiment_name
tb_log_dir      = "../log_dir/" + experiment_name + datetime.now().strftime("-%Y%m%d-%Hh%Mmn")
# callbacks
checkpoint_cb  = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, monitor='val_mean_iou', mode='max', save_weights_only=True, save_best_only=True, verbose=1)
earlystop_cb   = tf.keras.callbacks.EarlyStopping(monitor='val_mean_iou', mode='max', patience=10, verbose=1)
reduce_cb      = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_Accuracy', min_delta=0.0001, factor=0.5, patience=5, mode='max', verbose=1)
tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=tb_log_dir)

## Classic Training Strategy:
---

In [None]:
keras_history = model.fit(train_ds,
                          steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
                          epochs=100,
                          validation_data=val_ds,
                          validation_steps=VAL_STEPS_PER_EPOCH,
#                           callbacks=[checkpoint_cb, reduce_cb, earlystop_cb, tensorboard_cb],
                          callbacks=[checkpoint_cb, reduce_cb, tensorboard_cb],
#                           callbacks=[checkpoint_cb, reduce_cb],
                         verbose=1)

In [None]:
plot_keras_history(keras_history)

In [None]:
plot_keras_history(keras_history)

In [None]:
plot_keras_history(keras_history)

## Saving and loading the model:
---

In [None]:
# if one wants to save the model weights manually (keras checkpoint callback are preferred)
model.save_weights('../saved_models/you_directory_name/model_name')

In [None]:
# if one wants to load the waits from a specific directory (model should be built beforehand)
model.load_weights('../saved_models/you_directory_name/model_name')

In [None]:
# if one wants to save the full model (architecture + weights): NOT WORKING because of the custon Mean IOU metric
model.save('../saved_models/saved_full_models/you_directory_name/model_name')

In [None]:
model_new = get_unet_model(MODEL_SHAPE, filters=32, n_block=4, n_class=6, mode='serial', residual=True)
model_new.compile(optimizer=sgd_optim, loss='sparse_categorical_crossentropy', metrics=[acc_metric,fixed_iou])
model_new.load_weights('../saved_models/isprs_saved_weights/best-cp-unet_input256_noScale_skip4_Residual_serialCenter_sgd_jaccardloss_bs50_ocpCB')

In [None]:
# checking that we obtain the same metric performance after loading the weights
model_new.evaluate(val_ds, steps=VAL_STEPS_PER_EPOCH)

## Metrics Computation:
---

First, load the best model reached during training

In [None]:
# performance before loading the weights
model.evaluate(val_ds, steps=VAL_STEPS_PER_EPOCH)

In [None]:
model.load_weights('../checkpoint_dir_part2/unet_input256_noScale_skip4_Residual_serialCenter_sgd_jaccardloss_bs50_ocpCB-20191129-10h43mn/best-cp-unet_input256_noScale_skip4_Residual_serialCenter_sgd_jaccardloss_bs50_ocpCB')
# model.load_weights('../saved_models/isprs_saved_weights/best-cp-unet_input256_noScale_skip4_Residual_serialCenter_sgd_jaccardloss_bs50_ocpCB')

In [None]:
# performance after loading the best weights
model.evaluate(val_ds, steps=VAL_STEPS_PER_EPOCH)

In [None]:
# we can compute the IoU for a given patch pairs
compute_batch_iou(model, patch_pairs)

In [None]:
# or over the whole dataset: output the confusion matrix and per class IOU
compute_dataset_iou(model, val_ds, VAL_STEPS_PER_EPOCH)

In [None]:
# we can also output only the normalized confusion matrix (accuracy on diagonal)
cm = compute_dataset_cm(model, val_ds, VAL_STEPS_PER_EPOCH, normalize=True)

In [None]:
plot_cm(cm)

In [None]:
# output the confusion matrix with pretty printing
pd.options.display.float_format = '{:3.2e}'.format
data = cm.numpy().astype(float)
df = pd.DataFrame(data, columns=color_names, index= color_names)
df['Total'] = df.sum(axis=1)
df

## Learning Rate Finder:
---

In [None]:
lrf = LearningRateFinder(model)
lrf.find(train_ds, 1e-6, 1e+1, epochs=1, stepsPerEpoch=TRAIN_STEPS_PER_EPOCH)
lrf.plot_loss()

In [None]:
lrf.find(train_ds, 1e-6, 1e+1, epochs=2, stepsPerEpoch=TRAIN_STEPS_PER_EPOCH)
lrf.plot_loss()

In [None]:
lrf.find(train_ds, 1e-6, 1, epochs=3, stepsPerEpoch=TRAIN_STEPS_PER_EPOCH)
lrf.plot_loss()

In [None]:
resetModel(model)

## One Cycle Policy Training:
---

In [None]:
ocp_cb = OneCycleScheduler(max_lr=5e-2)
checkpoint_path = "../checkpoint_dir/" + experiment_name + datetime.now().strftime("-%Y%m%d-%Hh%Mmn") + "/best-cp-" + experiment_name
tb_log_dir      = "../log_dir/" + experiment_name + datetime.now().strftime("-%Y%m%d-%Hh%Mmn")

checkpoint_cb  = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, monitor='val_mean_iou', mode='max', save_weights_only=True, save_best_only=True, verbose=1)
tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=tb_log_dir)

keras_history = model.fit(train_ds, 
                   steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
                   epochs=100,
                   validation_data=val_ds,
                   validation_steps=VAL_STEPS_PER_EPOCH,
                   callbacks=[ocp_cb, checkpoint_cb, tensorboard_cb],
                   verbose=1)

In [None]:
ocp_cb.plot_lr()

In [None]:
plot_keras_history(keras_history)

In [None]:
model.evaluate(val_ds, steps=VAL_STEPS_PER_EPOCH)

In [None]:
current_weight = model.get_weights()

In [None]:
model.load_weights('../checkpoint_dir/skip4_resunet_300_jl_sgd_ocp-20191111-20h47mn/best-cp-skip4_resunet_300_jl_sgd_ocp')
model.evaluate(val_ds, steps=VAL_STEPS_PER_EPOCH)

In [None]:
compute_dataset_iou(model, val_ds, VAL_STEPS_PER_EPOCH)

In [None]:
compute_dataset_iou(model, val_ds, VAL_STEPS_PER_EPOCH)

In [None]:
model.save_weights('../saved_models/saved_weight_models/base_256input_sgd_ocp')

In [None]:
resetModel(model)
model.evaluate(val_ds, steps=VAL_STEPS_PER_EPOCH)

## Trying custom metric for each class:
---

Instead of computing Mean Intersection over Union during training, we can also computer per class IoU (as demonstrated by this code). But this solution is not suitable for two reasons:

- first we have 6 classes, it will clutter the log during keras training
- second, for validation set, a proper implementation should consider a cumulative IoU over the whole dataset. This implementation is called after each batch

In [None]:
from keras import backend as K
@tf.function
def iou(y_true, y_pred, label: int):
    """
    Return the Intersection over Union (IoU) for a given label.
    Args:
        y_true: the expected y values as a one-hot
        y_pred: the predicted y values as a one-hot or softmax output
        label: the label to return the IoU for
    Returns:
        the IoU for the given label
    """
    # extract the label values using the argmax operator then
    # calculate equality of the predictions and truths to the label
    y_true = K.cast(K.equal(y_true, label), K.floatx())
    y_pred = K.cast(K.equal(K.argmax(y_pred, axis=-1), label), K.floatx())
    # calculate the |intersection| (AND) of the labels
    intersection = K.sum(y_true * y_pred)
    # calculate the |union| (OR) of the labels
    union = K.sum(y_true) + K.sum(y_pred) - intersection
    # avoid divide by zero - if the union is zero, return 1
    # otherwise, return the intersection over union
    return K.switch(K.equal(union, 0), 1.0, intersection / union)


def build_iou_for(label: int, name: str=None):
    """
    Build an Intersection over Union (IoU) metric for a label.
    Args:
        label: the label to build the IoU metric for
        name: an optional name for debugging the built method
    Returns:
        a keras metric to evaluate IoU for the given label
        
    Note:
        label and name support list inputs for multiple labels
    """
    # handle recursive inputs (e.g. a list of labels and names)
    if isinstance(label, list):
        if isinstance(name, list):
            return [build_iou_for(l, n) for (l, n) in zip(label, name)]
        return [build_iou_for(l) for l in label]

    # build the method for returning the IoU of the given label
    def label_iou(y_true, y_pred):
        """
        Return the Intersection over Union (IoU) score for {0}.
        Args:
            y_true: the expected y values as a one-hot
            y_pred: the predicted y values as a one-hot or softmax output
        Returns:
            the scalar IoU value for the given label ({0})
        """.format(label)
        return iou(y_true, y_pred, label)

    # if no name is provided, us the label
    if name is None:
        name = label
    # change the name of the method for debugging
    label_iou.__name__ = 'iou_{}'.format(name)

    return label_iou


In [None]:
class_iou = build_iou_for(label=[0,1,2,3,4,5], name=color_names)

In [None]:
model_weight.compile(optimizer=adam_optim, loss='sparse_categorical_crossentropy', metrics=[acc_metric,fixed_iou, *class_iou])

In [None]:
H = model_weight.evaluate(val_ds, steps=VAL_STEPS_PER_EPOCH)

## Futur test

In [None]:
"""An implementation of the Intersection over Union (IoU) metric for Keras."""
from keras import backend as K


def iou(y_true, y_pred, label: int):
    """
    Return the Intersection over Union (IoU) for a given label.
    Args:
        y_true: the expected y values as a one-hot
        y_pred: the predicted y values as a one-hot or softmax output
        label: the label to return the IoU for
    Returns:
        the IoU for the given label
    """
    # extract the label values using the argmax operator then
    # calculate equality of the predictions and truths to the label
    y_true = K.cast(K.equal(K.argmax(y_true), label), K.floatx())
    y_pred = K.cast(K.equal(K.argmax(y_pred), label), K.floatx())
    # calculate the |intersection| (AND) of the labels
    intersection = K.sum(y_true * y_pred)
    # calculate the |union| (OR) of the labels
    union = K.sum(y_true) + K.sum(y_pred) - intersection
    # avoid divide by zero - if the union is zero, return 1
    # otherwise, return the intersection over union
    return K.switch(K.equal(union, 0), 1.0, intersection / union)


def build_iou_for(label: int, name: str=None):
    """
    Build an Intersection over Union (IoU) metric for a label.
    Args:
        label: the label to build the IoU metric for
        name: an optional name for debugging the built method
    Returns:
        a keras metric to evaluate IoU for the given label
        
    Note:
        label and name support list inputs for multiple labels
    """
    # handle recursive inputs (e.g. a list of labels and names)
    if isinstance(label, list):
        if isinstance(name, list):
            return [build_iou_for(l, n) for (l, n) in zip(label, name)]
        return [build_iou_for(l) for l in label]

    # build the method for returning the IoU of the given label
    def label_iou(y_true, y_pred):
        """
        Return the Intersection over Union (IoU) score for {0}.
        Args:
            y_true: the expected y values as a one-hot
            y_pred: the predicted y values as a one-hot or softmax output
        Returns:
            the scalar IoU value for the given label ({0})
        """.format(label)
        return iou(y_true, y_pred, label)

    # if no name is provided, us the label
    if name is None:
        name = label
    # change the name of the method for debugging
    label_iou.__name__ = 'iou_{}'.format(name)

    return label_iou
        

def mean_iou(y_true, y_pred):
    """
    Return the Intersection over Union (IoU) score.
    Args:
        y_true: the expected y values as a one-hot
        y_pred: the predicted y values as a one-hot or softmax output
    Returns:
        the scalar IoU value (mean over all labels)
    """
    # get number of labels to calculate IoU for
    num_labels = K.int_shape(y_pred)[-1]
    # initialize a variable to store total IoU in
    total_iou = K.variable(0)
    # iterate over labels to calculate IoU for
    for label in range(num_labels):
        total_iou = total_iou + iou(y_true, y_pred, label)
    # divide total IoU by number of labels to get mean IoU
    return total_iou / num_labels


# explicitly define the outward facing API of this module
__all__ = [build_iou_for.__name__, mean_iou.__name__]

## (POC) Image Reconstruction from patches w/o looping:
---
In this section we study how to predict over an entire tile: we extract patches, predict over them and then reconstruct the prediction mask using the `tf.scatter_nd` method

In [None]:
img = imageio.imread(val_images[2])
img = tf.image.convert_image_dtype(img,tf.float32)
msk = imageio.imread(val_masks[2])
msk = tf.image.convert_image_dtype(msk,tf.float32)

In [None]:
my_show_pair((img, msk))

In [None]:
one_tensor = tf.ones_like(img)
# patches = tf.image.extract_patches(images=tf.expand_dims(img, axis=0), sizes=SIZES, strides=STRIDES, rates=RATES, padding=PADDING)
patches = tf.image.extract_patches(images=[img,one_tensor], sizes=SIZES, strides=STRIDES, rates=[1, 3, 3, 1], padding=PADDING)
nrow, ncol = patches.shape[1:3]
patches = tf.reshape(patches, [2, nrow*ncol,PATCH_SIZE,PATCH_SIZE,3])

In [None]:
my_show_pair((patches[0,0,:], patches[1,0,:]))

In [None]:
# the index patching trick to reconstruct image from the extracted patches
one_tensor = tf.ones_like(img)
# patches = tf.image.extract_patches(images=tf.expand_dims(img, axis=0), sizes=SIZES, strides=STRIDES, rates=RATES, padding=PADDING)
patches = tf.image.extract_patches(images=[img,one_tensor], sizes=SIZES, strides=STRIDES, rates=RATES, padding=PADDING)
z = patches[0]
z = tf.squeeze(z) # output shape=(17, 17, 786432)
z = tf.reshape(z, [17*17,PATCH_SIZE,PATCH_SIZE,3]) # TensorShape([289, 512, 512, 3])
one_patches = patches[1]
one_patches = tf.squeeze(one_patches) # output shape=(17, 17, 786432)
one_patches = tf.reshape(one_patches, [17*17,PATCH_SIZE,PATCH_SIZE,3]) # TensorShape([289, 512, 512, 3])


In [None]:
patches[0].shape

In [None]:
x = tf.range(6000)
y = tf.range(6000)
x, y = tf.meshgrid(x, y)
indices = tf.stack([y, x], axis=-1)
indices_patches = tf.image.extract_patches(images=tf.expand_dims(indices, axis=0), sizes=SIZES, strides=STRIDES, rates=RATES, padding=PADDING) 
indices_patches =  tf.squeeze(indices_patches) # shape=(17, 17, 524288)
indices_patches = tf.reshape(indices_patches, [17*17,PATCH_SIZE,PATCH_SIZE,2]) # TensorShape([289, 512, 512, 2])
indices_patches
reconstructed = tf.scatter_nd(indices=indices_patches, updates=z, shape=(6000,6000,3))
overlap =  tf.scatter_nd(indices=indices_patches, updates=one_patches, shape=(6000,6000,3))
final = tf.math.truediv(reconstructed, overlap)

In [None]:
tf.reduce_all(final == img)

In [None]:
preds = model.predict(z, batch_size=BATCH_SIZE)

In [None]:
reconstructed_mask = tf.scatter_nd(indices=indices_patches, updates=preds, shape=(6000,6000,6))
reconstructed_mask = tf.argmax(reconstructed_mask, axis=-1, output_type=tf.int32)
reconstructed_mask = label2mask(reconstructed_mask)

In [None]:
fig, ax = plt.subplots(figsize=(20,20))
plt.imshow(reconstructed_mask)

## Predicting on new Tiles:
---

In [None]:
# img_path = '/datasets/obliquetest/cal004image0000152.jpg'
# img_path = 'RI-39-GEOTIFF-195-20150316174501569000c47.tif'
idx = 2
img_path = val_images[idx]
msk_path = val_masks[idx]
img = imageio.imread(img_path)
img = tf.image.convert_image_dtype(img,tf.float32)
msk = imageio.imread(msk_path)
msk = tf.image.convert_image_dtype(msk,tf.float32)

In [None]:
# plt.subplots(figsize=(20,20))
# plt.imshow(img)
my_show_pair((img,msk),interpolation='gaussian')

In [None]:
y_pred = get_tile_prediction(img_path, model)

In [None]:
my_show_pair((img,y_pred),interpolation='gaussian')

In [None]:
my_show_pair((msk,y_pred),interpolation='gaussian')

In [None]:
get_tile_iou(msk, y_pred)

In [None]:
cm = get_tile_cm(msk, y_pred, normalize=True)
plot_cm(cm)

In [None]:
# Test Time Augmentation in order to cover the right and bottom edges of the tile (red bands)
y_pred2 = get_tile_tta_pred(img_path, model)

In [None]:
my_show_pair((msk,y_pred2), interpolation='gaussian')

In [None]:
get_tile_iou(msk, y_pred2)

In [None]:
cm2 = get_tile_cm(msk, y_pred2, normalize=True)
plot_cm(cm2)

In [None]:
# we can even overlap the aerial image with the prediction mask (too much clutter)
plt.subplots(figsize=(20,20))
plt.imshow(img)
plt.imshow(y_pred2, alpha=0.25)

In [None]:
idx = 3
img_path = val_images[idx]
msk_path = val_masks[idx]
img = imageio.imread(img_path)
img = tf.image.convert_image_dtype(img,tf.float32)
msk = imageio.imread(msk_path)
msk = tf.image.convert_image_dtype(msk,tf.float32)
y_pred2 = get_tile_tta_pred(img_path, model)
my_show_pair((msk,y_pred2), interpolation='gaussian')
get_tile_iou(msk, y_pred2)

In [None]:
idx = 4
img_path = val_images[idx]
msk_path = val_masks[idx]
img = imageio.imread(img_path)
img = tf.image.convert_image_dtype(img,tf.float32)
msk = imageio.imread(msk_path)
msk = tf.image.convert_image_dtype(msk,tf.float32)
y_pred2 = get_tile_tta_pred(img_path, model)
my_show_pair((msk,y_pred2), interpolation='gaussian')
get_tile_iou(msk, y_pred2)

## INRIA Aerial Dataset
---

In [None]:
# defining tiff image and mask location
inria_train_img_dir  = '/datasets/InriaAerial/AerialImageDataset/train/images/'
inria_train_mask_dir = '/datasets/InriaAerial/AerialImageDataset/train/gt/'
inria_val_img_dir    = '/datasets/InriaAerial/AerialImageDataset/valid/images/'
inria_val_mask_dir   = '/datasets/InriaAerial/AerialImageDataset/valid/gt/'
# let's build two list containing the path of the aerials and mask files
train_src = Path(inria_train_img_dir)
train_ref = Path(inria_train_mask_dir)
val_src   = Path(inria_val_img_dir)
val_ref   = Path(inria_val_mask_dir)
train_images     = sorted([str(x) for x in train_src.iterdir() if x.is_file() and x.suffix == '.tif'])
train_masks      = sorted([str(x) for x in train_ref.iterdir() if x.is_file() and x.suffix == '.tif'])
val_images       = sorted([str(x) for x in val_src.iterdir()   if x.is_file() and x.suffix == '.tif'])
val_masks        = sorted([str(x) for x in val_ref.iterdir()   if x.is_file() and x.suffix == '.tif'])

# checking the equality of the two list lenghts
assert len(train_images) == len(train_masks), " (Error) The train Aerial image count does not match Mask counts!"
assert len(val_images)   == len(val_masks), " (Error) The Validation Aerial image count does not match Mask counts!"
len_train_images = len(train_images)
len_val_images   = len(val_images)
len_train_images, len_val_images

In [None]:
# ordered lists of colors and it's corresponding class name in [R, G, B] format
color_list  = [ [0, 0, 0], [1, 1, 1]]
color_names = ['Background', 'Building']
# patch extraction parameters
TILE_SIZE   = 5000 
PATCH_SIZE  = 256
PATCH_STRIDE= 256
PATCH_RATE  = 1
SIZES       = [1, PATCH_SIZE, PATCH_SIZE, 1] 
STRIDES     = [1, PATCH_STRIDE, PATCH_STRIDE, 1] 
RATES       = [1, PATCH_RATE, PATCH_RATE, 1] 
PADDING     = 'VALID'

def compute_tile_patch_number(rates=[1,2,3]):
    """ we expect that all tile share the same size """
    with tf.device('/CPU:0'):
        aerial = tf.constant(imageio.imread(train_images[0]))
        total  = 0
        for r in rates:
            patches = tf.image.extract_patches(images=tf.expand_dims(aerial, axis=0), sizes=SIZES, strides=STRIDES, rates=[1,r,r,1], padding=PADDING)            
            total += patches.shape[1] * patches.shape[2]
        return total

TRAIN_PATCH_NUMBER = compute_tile_patch_number(rates=[1])
VALID_PATCH_NUMBER = compute_tile_patch_number(rates=[1])
TRAIN_PATCH_NUMBER, VALID_PATCH_NUMBER

In [None]:
MODEL_SHAPE  = (256, 256, 3)
PATCH_RESIZE = (128, 128)
NUM_CLASSES  = 2
BATCH_SIZE   = 50
EPOCHS       = 50
NUM_TRAIN_EXAMPLES = TRAIN_PATCH_NUMBER * len_train_images 
NUM_VAL_EXAMPLES   = VALID_PATCH_NUMBER * len_val_images 
LEARNING_RATE   = 0.0001
TRAIN_STEPS_PER_EPOCH = int(np.ceil(NUM_TRAIN_EXAMPLES / float(BATCH_SIZE)))
VAL_STEPS_PER_EPOCH   = int(np.ceil(NUM_VAL_EXAMPLES / float(BATCH_SIZE)))
TRAIN_STEPS_PER_EPOCH, VAL_STEPS_PER_EPOCH

In [None]:
train_ds = get_dataset(img_dir=inria_train_img_dir, mask_dir=inria_train_mask_dir, batch_size=BATCH_SIZE, shuffle=True, repeat=True, augment=True, cachefile='inria_train_cache')
val_ds   = get_dataset(img_dir=inria_val_img_dir, mask_dir=inria_val_mask_dir,batch_size=BATCH_SIZE, shuffle=False, repeat=True, augment=False, cachefile=None)
print("train patch number = {}, valid patch number = {}".format(TRAIN_PATCH_NUMBER, VALID_PATCH_NUMBER))
print("train steps = {}, valid steps = {}".format(TRAIN_STEPS_PER_EPOCH, VAL_STEPS_PER_EPOCH))

In [None]:
timepipeline(train_ds, batch_size=BATCH_SIZE, steps=4000)

In [None]:
timepipeline(val_ds, batch_size=BATCH_SIZE, steps=1000)

In [None]:
train_iterator = iter(train_ds)

In [None]:
# getting next pair of aerial image and its associated mask
patch_pairs = next(train_iterator)
# display images
my_show_grid(patch_pairs, nrow=4, ncol=10)

In [None]:
inria_model = get_unet_model(MODEL_SHAPE, filters=32, n_block=4, n_class=2, mode=None, residual=True)
reset_weights = inria_model.get_weights()
resetModel = lambda m: m.set_weights(reset_weights)
# usage: resetModel(model)

In [None]:
inria_model.summary()

In [None]:
experiment_name = "inria_unet_input256_noScale_skip4_Residual_residualCenter_sgd_cross_entropy_bs50_ocpCB"
sgd_optim = tf.keras.optimizers.SGD(lr=0.1, nesterov=True)
adam_optim = tf.keras.optimizers.Adam(lr=1e-3)
acc_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='Accuracy')
iou_metric = tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES)
fixed_iou  = CategoricalMeanIoU(num_classes=NUM_CLASSES)
inria_model.compile(optimizer=sgd_optim, loss='sparse_categorical_crossentropy', metrics=[acc_metric,fixed_iou])
# inria_model.compile(optimizer=sgd_optim, loss=jaccard_distance, metrics=[acc_metric,fixed_iou])
# inria_model.compile(optimizer=sgd_optim, loss=categorical_focal_loss(gamma=2., alpha=.25), metrics=[acc_metric,fixed_iou])

In [None]:
lrf = LearningRateFinder(inria_model)
lrf.find(train_ds, 1e-7, 1, epochs=1, stepsPerEpoch=TRAIN_STEPS_PER_EPOCH)
lrf.plot_loss()

In [None]:
lrf = LearningRateFinder(inria_model)
lrf.find(train_ds, 1e-6, 1e+1, epochs=2, stepsPerEpoch=TRAIN_STEPS_PER_EPOCH)
lrf.plot_loss()

In [None]:
# classical training
checkpoint_path = "../inria_checkpoint_dir/" + experiment_name + datetime.now().strftime("-%Y%m%d-%Hh%Mmn") + "/best-cp-" + experiment_name
tb_log_dir      = "../inria_log_dir/" + experiment_name + datetime.now().strftime("-%Y%m%d-%Hh%Mmn")
# callbacks
checkpoint_cb  = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, monitor='val_mean_iou', mode='max', save_weights_only=True, save_best_only=True, verbose=1)
earlystop_cb   = tf.keras.callbacks.EarlyStopping(monitor='val_mean_iou', mode='max', patience=10, verbose=1)
reduce_cb      = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_Accuracy', min_delta=0.0001, factor=0.5, patience=5, mode='max', verbose=1)
tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=tb_log_dir)

keras_history = inria_model.fit(train_ds,
                          steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
                          epochs=100,
                          validation_data=val_ds,
                          validation_steps=VAL_STEPS_PER_EPOCH,
#                           callbacks=[checkpoint_cb, reduce_cb, earlystop_cb, tensorboard_cb],
                          callbacks=[checkpoint_cb, reduce_cb, tensorboard_cb],
#                           callbacks=[checkpoint_cb, reduce_cb],
                         verbose=1)

In [None]:
# One cycle policy training
checkpoint_path = "../inria_checkpoint_dir/" + experiment_name + datetime.now().strftime("-%Y%m%d-%Hh%Mmn") + "/best-cp-" + experiment_name
tb_log_dir      = "../inria_log_dir/" + experiment_name + datetime.now().strftime("-%Y%m%d-%Hh%Mmn")
# callbacks
ocp_cb = OneCycleScheduler(max_lr=1e-1)
checkpoint_cb  = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, monitor='val_mean_iou', mode='max', save_weights_only=True, save_best_only=True, verbose=1)
tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=tb_log_dir)

keras_history = inria_model.fit(train_ds, 
                   steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
                   epochs=70,
                   validation_data=val_ds,
                   validation_steps=VAL_STEPS_PER_EPOCH,
                   callbacks=[ocp_cb, checkpoint_cb, tensorboard_cb],
                   verbose=1)

## Model Evaluation:
---

In [None]:
inria_model.evaluate(val_ds, steps=VAL_STEPS_PER_EPOCH)

In [None]:
inria_model.load_weights('../inria_checkpoint_dir/inria_unet_input256_noScale_skip4_Residual_residualCenter_sgd_cross_entropy_bs50_ocpCB-20191204-21h36mn/best-cp-inria_unet_input256_noScale_skip4_Residual_residualCenter_sgd_cross_entropy_bs50_ocpCB')

In [None]:
inria_model.evaluate(val_ds, steps=VAL_STEPS_PER_EPOCH)

In [None]:
compute_dataset_iou(inria_model, val_ds, VAL_STEPS_PER_EPOCH)

## Predicting on new Tiles:
---

In [None]:
# img_path = '/datasets/obliquetest/cal004image0000152.jpg'
# img_path = 'RI-39-GEOTIFF-195-20150316174501569000c47.tif'
index = 7
img_path = val_images[index]
msk_path = val_masks[index]
img = imageio.imread(img_path)
# img = tf.image.convert_image_dtype(img,tf.float32)
msk = imageio.imread(msk_path)
# msk = tf.image.convert_image_dtype(msk,tf.float32)
msk = msk // 255
msk = tf.cast(msk, tf.int32)
msk = tf.gather([[0, 0, 0], [255, 255, 255]], msk)
msk = tf.cast(msk, tf.uint8)
msk = tf.image.convert_image_dtype(msk,tf.float32)

In [None]:
plt.subplots(figsize=(20,20))
plt.imshow(img)

In [None]:
# plt.subplots(figsize=(20,20))
# plt.imshow(img)
my_show_pair((img,msk))

In [None]:
y_pred = get_tile_prediction(img_path, inria_model)

In [None]:
my_show_pair((msk,y_pred))

In [None]:
get_tile_iou(msk, y_pred)

In [None]:
cm = get_tile_cm(msk, y_pred, normalize=True)
plot_cm(cm)

In [None]:
y_pred2 = get_tile_tta_pred(img_path, inria_model)

In [None]:
my_show_pair((msk,y_pred2))

In [None]:
get_tile_iou(msk, y_pred2)

In [None]:
cm2 = get_tile_cm(msk, y_pred2, normalize=True)
plot_cm(cm2)

## Marseille Tile:
---


In [None]:
tile = imageio.imread('/datasets/msebai_projects/marseille_15cm.tif')
tile.shape

In [None]:
plt.subplots(figsize=(20,20))
plt.imshow(tile, interpolation='gaussian')

In [None]:
img = tile[20000:25000, 15000:20000,:]
plt.subplots(figsize=(20,20))
plt.imshow(img)

In [None]:
y_pred_tile = get_tile_tta_pred(img, model, from_disk=False)
plt.subplots(figsize=(20,20))
plt.imshow(y_pred_tile)

In [None]:
my_show_pair((img,y_pred_tile),interpolation='gaussian')

## Toulouse Tile:
---


In [None]:
tile = imageio.imread('~/datasets_link/MARSEILLE/cal000image0000185.jpg')
tile.shape

In [None]:
plt.subplots(figsize=(20,20))
plt.imshow(tile, interpolation='gaussian')

In [None]:
# img = tile[10000:16000, 10000:16000,:]
img = tile
plt.subplots(figsize=(20,20))
plt.imshow(img,interpolation='gaussian')

In [None]:
y_pred_tile = get_tile_tta_pred(img, model, from_disk=False)
plt.subplots(figsize=(20,20))
plt.imshow(y_pred_tile)

In [None]:
my_show_pair((img,y_pred_tile),interpolation='gaussian')

In [None]:
y_pred_tile = get_tile_tta_pred(img, inria_model, from_disk=False)
my_show_pair((img,y_pred_tile),interpolation='gaussian')