In [None]:
#default_exp segmentation_model_training

In [None]:
#export
from IPython.display import Image, display
import glob
import os
import pandas as pd
import json
import numpy as np
import tensorflow as tf
import tensorflow
import matplotlib.pyplot as plt
from toolz import compose
from tensorflow.keras import losses, metrics, layers, models
from deeplearning_image_pixelwise import data, config
import attr


gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)
    

if config.float_dtype == 'float16':
    tf.keras.backend.set_floatx('float16')
    tf.keras.backend.set_epsilon(1e-4)
    policy = tensorflow.keras.mixed_precision.experimental.Policy('mixed_float16')
    tensorflow.keras.mixed_precision.experimental.set_policy(policy)

In [None]:
#export
DATA_DIR = config.DATA_DIR
TRAIN_DIR = config.TRAIN_DIR
TRAIN_MASK_DIR = config.TRAIN_MASK_DIR
VAL_DIR = config.VAL_DIR
VAL_MASK_DIR = config.VAL_MASK_DIR
N_CLASSES = config.N_CLASSES
BATCH_SIZE = config.BATCH_SIZE
IMG_HEIGHT, IMG_WIDTH = config.IMG_WIDTH, config.IMG_WIDTH
EPOCHS = 50 

buffer_size = 128 

In [None]:
#export
BASE_N_FILTERS = 8
DROPOUT_RATE = 0.5
ACTIVATION = 'relu'
INITIALIZER = 'glorot_normal'
LEARNING_RATE = 1e-4
BATCH_SIZE = 8
USE_DEV_SUBSET=False

In [None]:
%matplotlib inline
%cd ..

In [None]:
#export
val_dataset = data.load_dataset(VAL_DIR, VAL_MASK_DIR)
train_dataset = data.load_dataset(TRAIN_MASK_DIR, TRAIN_MASK_DIR)

In [None]:
#export

def unet_forward_block(input_, n_filters, dropout_rate, activation, initializer):
    conv_out = layers.Conv2D(n_filters, (3, 3), activation=activation, kernel_initializer=initializer, padding='same')(input_)
    conv_out = layers.BatchNormalization()(conv_out)
    conv_out = layers.Dropout(dropout_rate) (conv_out)
    conv_out = layers.Conv2D(n_filters, (3, 3), activation=activation, kernel_initializer=initializer, padding='same') (conv_out)
    conv_out = layers.BatchNormalization()(conv_out)
    pool_out = layers.MaxPooling2D((2, 2)) (conv_out)
    return conv_out, pool_out 
    
    
def unet_skip_connect_block(current, skip_connected, n_filters, dropout_rate, activation, initializer):
    conv_current = layers.Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same') (current)
    skip_connected_concat = layers.concatenate([conv_current, skip_connected])
    skip_connected_concat = layers.Conv2D(2 * n_filters, (3, 3), activation=activation, kernel_initializer=initializer, padding='same') (skip_connected_concat)
    skip_connected_concat = layers.BatchNormalization()(skip_connected_concat)
    skip_connected_concat = layers.Dropout(dropout_rate) (skip_connected_concat)
    skip_connected_concat = layers.Conv2D(2 * n_filters, (3, 3), activation=activation, kernel_initializer=initializer, padding='same') (skip_connected_concat)
    return layers.BatchNormalization()(skip_connected_concat)

# Evaluation metric - IoU

Mean Intersection over Union is commonly used for evaluating segmentation models - it calculates mean IoU score over classes (like in scikit-learn 'macro' averaging scheme).
This makes this metric care about each class equally, and not be overpowered by classes with many pixels, what happens to accuracy.

MeanIOU from tf.keras.metrics can't handle logits (it operates on labels) so there was a need to write this function.

In [None]:
#export

def tf_casted_sum(tensor, dtype=tf.uint8):
    return tf.math.reduce_sum(tf.cast(tensor, dtype))


def iou(masks, masks_logits_pred, category):
    masks_pred = tf.cast(tf.math.argmax(masks_logits_pred, axis=-1), tf.int32)
    positive = masks == category
    negative = masks != category
    positive_pred = masks_pred == category
    negative_pred = masks_pred != category
    intersection = tf_casted_sum(
        tf.math.logical_and(positive_pred, positive[:,:,:,0])
    )
    union = tf_casted_sum(positive) + tf_casted_sum(positive_pred) - intersection
    return tf.cond(union > 0, lambda: intersection / union, lambda: tf.ones((), dtype=tf.float32))


def mean_iou(y_true, y_pred, ignored_indices=[N_CLASSES-1]):
    """
    Return the Intersection over Union (IoU) score.
    Args:
        y_true: the expected y values as a one-hot
        y_pred: the predicted y values as a one-hot or softmax output
    Returns:
        the scalar IoU value (mean over all labels)
    """
    # get number of labels to calculate IoU for
    num_labels = y_pred.shape[-1]
    # initialize a variable to store total IoU in
    total_iou = 0 #tf.zeros((),)
    # iterate over labels to calculate IoU for
    for label in range(num_labels):
        if label not in ignored_indices:
            total_iou = total_iou + iou(y_true, y_pred, label)
    # divide total IoU by number of labels to get mean IoU
    return total_iou / (num_labels - len(ignored_indices))

Tensorflow had some weird issue with dtypes, so in case of using float16 a custom accuracy was needed.

In [None]:
#export

def accuracy(y_true, y_pred_logits):
    y_true = tf.cast(y_true, tf.uint8)[:,:,0]
    y_pred = tf.cast(tf.math.argmax(y_pred_logits, axis=-1), tf.uint8)
    return tf.math.reduce_mean(tf.cast(y_pred == y_true, config.float_dtype))

# Callbacks

In [None]:
#export

def get_default_callbacks():
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath='weights.{epoch:02d}-{mean_iou:.3f}.hdf5',
        monitor='mean_iou'
    )
    # this ridiculous profile_batch parameter is passed because if it's not set then Tensorboard complaints about seting trace
    tensorboard_training_callback = tf.keras.callbacks.TensorBoard(
        log_dir='.logs', histogram_freq=0, write_images=False,
        update_freq=100,
        profile_batch=100000000
    )
    tensorboard_epoch_callback = tf.keras.callbacks.TensorBoard(
        log_dir='.logs', histogram_freq=0, write_images=False,
        update_freq='epoch', profile_batch=100000000
    )
    return [model_checkpoint_callback, tensorboard_training_callback, tensorboard_epoch_callback]

In [None]:
#export

def build_segmentation_model(
        input_shape,
        n_classes,
        base_n_filters=BASE_N_FILTERS,
        dropout_rate=DROPOUT_RATE,
        activation=ACTIVATION,
        initializer=INITIALIZER
    ):
    # Build U-Net segmentation_model
    inputs = layers.Input(input_shape)
    
    c1, p1 = unet_forward_block(inputs, 2 * base_n_filters, dropout_rate, activation, initializer)
    c2, p2 = unet_forward_block(p1, 4 * base_n_filters, dropout_rate, activation, initializer)
    c3, p3 = unet_forward_block(p2, 8 * base_n_filters, dropout_rate, activation, initializer)
    c4, p4 = unet_forward_block(p3, 8 * base_n_filters, dropout_rate, activation, initializer)
    c5, __ = unet_forward_block(p4, 16 * base_n_filters, dropout_rate, activation, initializer)

    #concating starts
    u6 = unet_skip_connect_block(c5, c4, 8 * base_n_filters, dropout_rate, activation, initializer)
    u7 = unet_skip_connect_block(u6, c3, 4 * base_n_filters, dropout_rate, activation, initializer)
    u8 = unet_skip_connect_block(u7, c2, 2 * base_n_filters, dropout_rate, activation, initializer)
    u9 = unet_skip_connect_block(u8, c1, base_n_filters, dropout_rate, activation, initializer)

    out = layers.Conv2D(n_classes, (1, 1)) (u9)
    # for some reason SparseCategoricalCrossEntropy fails if the output is fp16
    out = tf.cast(out, tf.float32)
    return models.Model(inputs=[inputs], outputs=[out])


def setup_segmentation_model(
        input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
        n_classes=N_CLASSES,
        loss=losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=tensorflow.keras.optimizers.Adam(LEARNING_RATE),
        metrics=[accuracy, mean_iou]
    ):
    segmentation_model = build_segmentation_model(input_shape, n_classes)
    segmentation_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    return segmentation_model

In [None]:
#export

segmentation_model = setup_segmentation_model()

In [None]:
segmentation_model.summary()

In [None]:
tf.keras.utils.plot_model(segmentation_model)

In [None]:
#export

if USE_DEV_SUBSET:
    train_dataset = train_dataset.take(2 ** 14)
    val_dataset = val_dataset.take(1024)

# Model training


In [None]:
#export

if __name__ == '__main__':
    segmentation_model.fit(
        train_dataset.batch(BATCH_SIZE).shuffle(buffer_size).repeat(), 
        validation_data=val_dataset.batch(BATCH_SIZE),
        epochs=EPOCHS,
        callbacks=get_default_callbacks(),
        steps_per_epoch=1000
    )