In [1]:
"""script to evaluate model on test images"""

'script to evaluate model on test images'

In [2]:
# !ipython nbconvert --to=python enet_eval.ipynb

In [3]:
import glob
import os
from tqdm import tqdm
import numpy as np
import functools
import random
import cv2
import tensorflow as tf
# import tensorflow_addons as tfa
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers
from tensorflow.keras import losses
import tensorflow.keras.backend as K
from tensorflow.python.keras.callbacks import LambdaCallback
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, Input, Conv2DTranspose, \
    Add, Activation, BatchNormalization, Concatenate
from tensorflow.python.keras.models import Model

import segmentation_models as sm

sm.set_framework('tf.keras')

from cb.tbi_cb import TensorBoardImage
from cb.snapshot_cb_builder import SnapshotCallbackBuilder
from cb.sgdr_lr_scheduler import SGDRScheduler


Segmentation Models: using `tf.keras` framework.


In [4]:
import matplotlib.image as mpimg
# This is needed to display the images.
%matplotlib inline
import matplotlib.pyplot as plt

def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(64, 20))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [5]:
IMG_DIR = "/home/ubuntu/valdata/road_boundary_train/road_satellite/dataset/testing/images"
LABEL_DIR = "/home/ubuntu/valdata/road_boundary_train/road_satellite/dataset/testing/masks"
OUTPUT_PATH = "/home/ubuntu/valdata/road_boundary_train/road_satellite/enet/test_results"

MODEL_DIR = "/home/ubuntu/valdata/road_boundary_train/road_satellite/enet"

OUTPUT_SHAPE = (512, 512, 1)
INPUT_SHAPE = (512, 512, 3)

img_shape = (512, 512, 3)
n_classes = 1
BACKBONE = 'efficientnetb4'

In [6]:
# LOSS functions
alpha = K.variable(value=0.1)
alpha._trainable = False

def update_alpha_value(epoch):
    if epoch == 0:
        K.set_value(alpha, 0.1)
        print(f"Setting alpha to = {K.get_value(alpha)}")
    if epoch > 10:
        new_alpha = K.get_value(alpha) + 0.2
        if new_alpha < 0.5:
            K.set_value(alpha, new_alpha)
        else:
            K.set_value(alpha, 0.1)
        print(f"Setting alpha to = {K.get_value(alpha)}")


alpha_update_clb = LambdaCallback(on_epoch_begin=lambda epoch, log: update_alpha_value(epoch))


def segmentation_boundary_loss(y_true, y_pred):
    """
    Using Binary Segmentation mask, generates boundary mask on fly and claculates boundary loss.
    :param y_true:
    :param y_pred:
    :return:
    """
    y_pred_bd = layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same', input_shape=OUTPUT_SHAPE)(1 - y_pred)
    y_true_bd = layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same', input_shape=OUTPUT_SHAPE)(1 - y_true)
    y_pred_bd = y_pred_bd - (1 - y_pred)
    y_true_bd = y_true_bd - (1 - y_true)

    y_pred_bd_ext = layers.MaxPooling2D((5, 5), strides=(1, 1), padding='same', input_shape=OUTPUT_SHAPE)(1 - y_pred)
    y_true_bd_ext = layers.MaxPooling2D((5, 5), strides=(1, 1), padding='same', input_shape=OUTPUT_SHAPE)(1 - y_true)
    y_pred_bd_ext = y_pred_bd_ext - (1 - y_pred)
    y_true_bd_ext = y_true_bd_ext - (1 - y_true)

    P = K.sum(y_pred_bd * y_true_bd_ext) / K.sum(y_pred_bd) + 1e-7
    R = K.sum(y_true_bd * y_pred_bd_ext) / K.sum(y_true_bd) + 1e-7
    F1_Score = 2 * P * R / (P + R + 1e-7)
    # print(f'Precission: {P.eval()}, Recall: {R.eval()}, F1: {F1_Score.eval()}')
    loss = K.mean(1 - F1_Score)
    # print(f"Loss:{loss.eval()}")
    return loss


def binary_focal_loss(y_true, y_pred):
    """
    Binary form of focal loss.
      FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
      where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
    References:
        https://arxiv.org/pdf/1708.02002.pdf
    """
    alpha = 0.25
    gamma = 2
    pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
    pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
    epsilon = K.epsilon()
    # clip to prevent NaN's and Inf's
    pt_1 = K.clip(pt_1, epsilon, 1. - epsilon)
    pt_0 = K.clip(pt_0, epsilon, 1. - epsilon)

    return -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) \
           - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))


def combined_loss(y_true, y_pred):
#     loss = (1 - alpha) * (binary_focal_loss(y_true, y_pred) + log_cosh_dce_loss(y_true,
#                                                                                 y_pred)) + alpha * segmentation_boundary_loss(y_true, y_pred)
    loss = (1 - alpha) * (losses.binary_crossentropy(y_true, y_pred) + log_cosh_dce_loss(y_true,
                                                                                y_pred)) + alpha * segmentation_boundary_loss(y_true, y_pred)
    return loss


def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss


def log_cosh_dce_loss(y_true, y_pred):
    """
    Implementation suggested in https://arxiv.org/pdf/2006.14822.pdf
    """
    return tf.math.log(tf.math.cosh(dice_loss(y_true, y_pred)))


def dice_coeff(y_true, y_pred):
    smooth = 1.
    # Flatten
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return score


def bce_dice_loss(y_true, y_pred):
    loss = losses.binary_crossentropy(y_true, y_pred) + log_cosh_dce_loss(y_true, y_pred)
    return loss

In [None]:
# MODEL
model = sm.Unet(BACKBONE, encoder_weights='imagenet', classes=1, activation='sigmoid', input_shape = (img_shape[0], img_shape[1], 3))

# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
#dice_loss = dice_coef_loss()
#focal_loss = binary_focal_loss()
total_loss = dice_loss + (1 * focal_loss)
dice_loss_metrics = total_loss

# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5), sm.metrics.Precision(threshold=0.5),
          sm.metrics.Recall(threshold=0.5), dice_loss_metrics]

model.compile(optimizer='adam', loss=bce_dice_loss, metrics=metrics)

# model.summary()

save_model_path = f"{MODEL_DIR}/enetRoadSegV1.h5"

if os.path.exists(save_model_path):
    print("model initialized from saved weights")
    model.load_weights(save_model_path)


In [None]:
# Utils

def get_filename(file):
    idx1 = file.rfind("/")
    filename = file[idx1+1:]
    return filename

def read_img(filepath):
    image = cv2.imread(filepath, cv2.IMREAD_UNCHANGED)
    return image

def save_image(image, path):
    resized_img = cv2.resize(image,(1500,1500), interpolation=cv2.INTER_AREA)
    cv2.imwrite(path, resized_img)

# def save_mask(filename, image):
#     resized_img = cv2.resize(image,(1500,1500),interpolation=cv2.INTER_AREA)
#     cv2.imwrite(f"{OUTPUT_PATH}/masks/{filename}", resized_img)
    

In [None]:
# IMAGE OPERATIONS

def crop_img_label(img, mask):
    img_arr = []
    label_arr = []
    image = cv2.resize(img, (1536, 1536),interpolation=cv2.INTER_AREA)
    label = cv2.resize(mask, (1536, 1536),interpolation=cv2.INTER_AREA)
    for i in range(3):
        row_start, row_end = i*512, (i+1)*512
        for j in range(3):
            col_start, col_end = j*512, (j+1)*512
            crp_img = image[row_start:row_end, col_start:col_end]
            crp_mask = label[row_start:row_end, col_start:col_end]
            
            img_arr.append(crp_img)
            label_arr.append(crp_mask)
    
    return img_arr, label_arr

def merge_imgs(images):
    img_row1 = np.concatenate((images[0], images[1], images[2]), axis=1)
    img_row2 = np.concatenate((images[3], images[4], images[5]), axis=1)
    img_row3 = np.concatenate((images[6], images[7], images[8]), axis=1)
    final_img = np.concatenate((img_row1, img_row2, img_row3), axis=0)
    return final_img

def convert_images_dtype(images):
    """
    Convert image d-type from unit8 to float 0-1 range
    :param images:
    :return:
    """
    images_shape = images.shape
    if len(images_shape) > 4:
        ValueError("'image' must have either 3 or 4 dimensions, "
                   "received `{}`.".format(images_shape))

    def convert_img(img):
        shape = img.shape
        img = tf.image.convert_image_dtype(img, dtype=tf.float32)
        img = tf.cast(img, dtype=tf.float32)
        img.set_shape((shape[0], shape[1], shape[2]))
        return img

    if len(images_shape) == 4:
        return tf.map_fn(convert_img, images)

    return convert_img(images)


def smoothen_detection(mask, dilate_iter=2):
    """
    :param dilate_iter:
    :type dilate_iter:
    :param mask:
    :type mask:
    :return:
    :rtype:
    """
    kernel = np.ones((3, 3), np.uint8)
    dilated = cv2.dilate(mask, kernel, iterations=dilate_iter)
    eroded = cv2.erode(dilated, kernel, iterations=dilate_iter*2)
    dilated = cv2.dilate(eroded, kernel, iterations=dilate_iter)
    return cv2.medianBlur(dilated, 3)


def remove_small_connected_objects(img, min_size):
    """
    Function to remove small prediction patches from infered image
    :param img:
    :param min_size: minimum numbers of pixels in patch to be valid....
    :return:
    """
    nb_components, output, stats, _ = cv2.connectedComponentsWithStats(img, connectivity=8)
    sizes = stats[1:, -1]
    nb_components = nb_components - 1
    output_shape = output.shape
    converted_image = np.zeros((output_shape[0], output_shape[1]), np.uint8)

    for i in range(0, nb_components):
        if sizes[i] >= min_size:
            converted_image[output == i + 1] = 255

    return converted_image


def expand_dims(image_, axis=0):
    """
    :param image_:
    :type image_:
    :return:
    :rtype:
    """
    new_image = np.expand_dims(image_, axis=axis)
    return new_image




In [None]:
# Evaluation Metrics Utils

def calculate_precision(gt, pred):
#   True positives
    tp = np.logical_and(gt, pred)
#   True positive plus false positive
    tp_fn = pred 
    if np.sum(tp_fn) == 0:
        return None
    recall = np.sum(tp) / np.sum(tp_fn)
    return recall


def calculate_recall(gt, pred):
#   True positives
    tp = np.logical_and(gt, pred)
#   TP + FN
    tp_fn = gt 
    if np.sum(tp_fn) == 0:
        return None
    recall = np.sum(tp) / np.sum(tp_fn)
    return recall


def calculate_iou(gt, pred):
    intersection = np.logical_and(gt, pred)
    union = np.logical_or(gt, pred)
    if np.sum(union) == 0:
        # return none for black input images
        return None
    iou_score = np.sum(intersection) / np.sum(union)
    return iou_score


def calculate_F1(precision, recall):
    product = precision * recall
    summation = precision + recall
    if summation == 0:
        return None
    f1_score = (2 * product)/summation
    return f1_score


def cal_mean(acc_list):
    not_none_values = [val for val in acc_list if val is not None]
    mean_val = np.sum(not_none_values)/len(not_none_values)
    return mean_val

In [None]:
def infer_and_evaluate(img, mask, file_name):
    """
    Run inference on input image and evaluate accuracy parameters
    """
    imgs, labels = crop_img_label(img, mask)
    
    pred_arr = []
    
    for idx, crp in enumerate(imgs):
        label = labels[idx]
        inp_img = np.multiply(crp, 1 / 255.0)
        batched_img = expand_dims(inp_img)
        
        pred = model.predict(batched_img)
        pred = np.reshape(pred, (512, 512))
        
#         pred = np.where(pred > 0.5, 255, 0).astype('uint8')
        final_img = np.zeros((512, 512), np.uint8)
        thresh_indices = pred[:, :] > 0.5
        final_img[thresh_indices] = 255
        final_img = remove_small_connected_objects(final_img, 40)
        pred_arr.append(final_img)
        
    merged_pred = merge_imgs(pred_arr)
    merged_pred = cv2.resize(merged_pred,(1500,1500),interpolation=cv2.INTER_AREA)
    
    ## uncomment below lines to visualise inference data
#     visualize(
#             img = img,
#             pred = merged_pred,
#             mask = mask
#         )
    
    save_image(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), f"{OUTPUT_PATH}/images/{file_name}")
    save_image(merged_pred, f"{OUTPUT_PATH}/pred/{file_name}")
    save_image(mask, f"{OUTPUT_PATH}/labels/{file_name}")
    
    # converting gt, pred into one hot encoding. 
    gt = np.where(mask > 0, 1, 0).astype('uint8')
    pred = np.where(merged_pred > 0, 1, 0).astype('uint8')
    
    iou = calculate_iou(gt, pred)
    recall = calculate_recall(gt, pred)
    precision = calculate_precision(gt, pred)
    f1_score = calculate_F1(precision, recall)
    
    print(f"for {file_name} - iou: {iou}, recall: {recall}, precision: {precision}, f1_score: {f1_score}")
    
    return iou, recall, precision, f1_score

In [None]:
def main():
    images = glob.glob(f"{IMG_DIR}/*.png")
    iou_list, recall_list, precision_list, f1_list = [], [], [], []
    for img_path in tqdm(images):
        file_name = get_filename(img_path)
        label_path = f"{LABEL_DIR}/{file_name}"
        
        bgr_img = read_img(img_path)
        rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
        label = read_img(label_path)
        
        iou, recall, precision, f1_score = infer_and_evaluate(rgb_img, label, file_name)
        iou_list.append(iou)
        recall_list.append(recall)
        precision_list.append(precision)
        f1_list.append(f1_score)
        
    mean_iou = cal_mean(iou_list)
    mean_recall = cal_mean(recall_list)
    mean_precision = cal_mean(precision_list)
    mean_f1_score = cal_mean(f1_list)
    print("*******Results*******")
    print(f"Mean - iou: {mean_iou}, recall: {mean_recall}, precision: {mean_precision}, f1_score: {mean_f1_score}")

In [None]:
if __name__ == '__main__':
    main()