In [None]:
# Display kaggle GPU info
! nvidia-smi

import numpy as np
import tensorflow as tf
import keras

print("TensorFlow version: " + tf.__version__)
print("Keras version: " + keras.__version__)
keras.backend.set_image_data_format("channels_last")

In [None]:
# Define global variables
# BATCH_SIZE: 一次训练所抓取的数据样本数量
BATCH_SIZE = 32

# DATASET_NAME: 数据集名称 (BickleyDiary, DIBCO, PLM)
DATASET_NAME = "DIBCO"

# NETWORK_MODEL: 深度网络模型名称 (UNet, UNet1Plus_w_DeepSupv, UNet1Plus_wo_DeepSupv, UNet2Plus_w_DeepSupv, UNet2Plus_wo_DeepSupv, UNet3Plus_w_DeepSupv, UNet3Plus_wo_DeepSupv, UNet4Plus_w_DeepSupv, UNet4Plus_wo_DeepSupv)
NETWORK_MODEL = "UNet4Plus_wo_DeepSupv"

# LOSS_FUNCTION: 深度网络损失函数 (BCE_Dice, BCE_Dice_mIoU)
LOSS_FUNCTION = "BCE_Dice_mIoU"

# NUM_EPOCHS: 最大训练迭代次数
NUM_EPOCHS = 500

# NUM_FILTERS: 网络第一层通道滤波器数量
NUM_FILTERS = 32

# TILE_SIZE: 图像子块大小
TILE_SIZE = 128

In [None]:
# Brian Beck提供了一个类switch来实现switch功能
class switch(object):
  def __init__(self, value):
    self.value = value
    self.fall = False
  def __iter__(self):
    """Return the match method once, then stop"""
    yield self.match
    raise StopIteration
  def match(self, *args):
    """Indicate whether or not to enter a case suite"""
    if self.fall or not args:
      return True
    elif self.value in args: # changed for v1.5, see below
      self.fall = True
      return True
    else:
      return False

In [None]:
# Create network models
from keras.callbacks import *
from keras.layers import *
from keras.losses import *
from keras.models import *
from keras.optimizers import *
from keras.regularizers import *
from keras import backend as K


def conv_block(input_tensor, num_filters):
    out = Conv2D(num_filters, kernel_size=3, padding="same", activation="relu", kernel_initializer="he_normal")(input_tensor)
    out = Conv2D(num_filters, kernel_size=3, padding="same", activation="relu", kernel_initializer="he_normal")(out)
    return out


def conv(input_tensor, num_filters):
    out = Conv2D(num_filters, kernel_size=3, padding="same", activation="relu", kernel_initializer="he_normal")(input_tensor)
    return out


def up_conv(input_tensor, num_filters, up_size):
    out = UpSampling2D(up_size, interpolation="bilinear")(input_tensor)
    out = Conv2D(num_filters, kernel_size=3, padding="same", activation="relu", kernel_initializer="he_normal")(out)
    return out


def down_conv(input_tensor, num_filters, down_size):
    out = MaxPooling2D(down_size)(input_tensor)
    out = Conv2D(num_filters, kernel_size=3, padding="same", activation="relu", kernel_initializer="he_normal")(out)
    return out


def UNet(num_classes, input_height, input_width, num_filters):
    """
    U-Net
    Paper : https://arxiv.org/abs/1505.04597
    """
    inputs = Input(shape=(input_height, input_width, 1))

    filters = [num_filters, num_filters * 2, num_filters * 4, num_filters * 8, num_filters * 16]

    e1 = conv_block(inputs, filters[0])

    e2 = MaxPooling2D()(e1)
    e2 = conv_block(e2, filters[1])

    e3 = MaxPooling2D()(e2)
    e3 = conv_block(e3, filters[2])

    e4 = MaxPooling2D()(e3)
    e4 = conv_block(e4, filters[3])

    e5 = MaxPooling2D()(e4)
    e5 = conv_block(e5, filters[4])

    d4 = up_conv(e5, filters[3], 2)
    d4 = Concatenate()([e4, d4])
    d4 = conv_block(d4, filters[3])

    d3 = up_conv(d4, filters[2], 2)
    d3 = Concatenate()([e3, d3])
    d3 = conv_block(d3, filters[2])

    d2 = up_conv(d3, filters[1], 2)
    d2 = Concatenate()([e2, d2])
    d2 = conv_block(d2, filters[1])

    d1 = up_conv(d2, filters[0], 2)
    d1 = Concatenate()([e1, d1])
    d1 = conv_block(d1, filters[0])

    outputs = Conv2D(num_classes, kernel_size=3, padding="same", activation="sigmoid", kernel_initializer="he_normal")(d1)

    model = Model(inputs=inputs, outputs=outputs, name="UNet")

    return model


def UNet1Plus_w_DeepSupv(num_classes, input_height, input_width, num_filters):
    """
    U-Net+ (Deep Layer Aggregation) with deep supervision
    Paper : http://openaccess.thecvf.com/content_cvpr_2018/papers/Yu_Deep_Layer_Aggregation_CVPR_2018_paper.pdf
    """
    inputs = Input(shape=(input_height, input_width, 1))

    filters = [num_filters, num_filters * 2, num_filters * 4, num_filters * 8, num_filters * 16]

    x0_0 = conv_block(inputs, filters[0])

    x1_0 = MaxPooling2D()(x0_0)
    x1_0 = conv_block(x1_0, filters[1])

    x0_1 = up_conv(x1_0, filters[0], 2)
    x0_1 = Concatenate()([x0_0, x0_1])
    x0_1 = conv_block(x0_1, filters[0])

    x2_0 = MaxPooling2D()(x1_0)
    x2_0 = conv_block(x2_0, filters[2])

    x1_1 = up_conv(x2_0, filters[1], 2)
    x1_1 = Concatenate()([x1_0, x1_1])
    x1_1 = conv_block(x1_1, filters[1])

    x0_2 = up_conv(x1_1, filters[0], 2)
    x0_2 = Concatenate()([x0_1, x0_2])
    x0_2 = conv_block(x0_2, filters[0])

    x3_0 = MaxPooling2D()(x2_0)
    x3_0 = conv_block(x3_0, filters[3])

    x2_1 = up_conv(x3_0, filters[2], 2)
    x2_1 = Concatenate()([x2_0, x2_1])
    x2_1 = conv_block(x2_1, filters[2])

    x1_2 = up_conv(x2_1, filters[1], 2)
    x1_2 = Concatenate()([x1_1, x1_2])
    x1_2 = conv_block(x1_2, filters[1])

    x0_3 = up_conv(x1_2, filters[0], 2)
    x0_3 = Concatenate()([x0_2, x0_3])
    x0_3 = conv_block(x0_3, filters[0])

    x4_0 = MaxPooling2D()(x3_0)
    x4_0 = conv_block(x4_0, filters[4])

    x3_1 = up_conv(x4_0, filters[3], 2)
    x3_1 = Concatenate()([x3_0, x3_1])
    x3_1 = conv_block(x3_1, filters[3])

    x2_2 = up_conv(x3_1, filters[2], 2)
    x2_2 = Concatenate()([x2_1, x2_2])
    x2_2 = conv_block(x2_2, filters[2])

    x1_3 = up_conv(x2_2, filters[1], 2)
    x1_3 = Concatenate()([x1_2, x1_3])
    x1_3 = conv_block(x1_3, filters[1])

    x0_4 = up_conv(x1_3, filters[0], 2)
    x0_4 = Concatenate()([x0_3, x0_4])
    x0_4 = conv_block(x0_4, filters[0])

    outputs = Average()([x0_1, x0_2, x0_3, x0_4])
    outputs = Conv2D(num_classes, kernel_size=3, padding="same", activation="sigmoid", kernel_initializer="he_normal")(outputs)

    model = Model(inputs=inputs, outputs=outputs, name="UNet1Plus_w_DeepSupv")

    return model


def UNet1Plus_wo_DeepSupv(num_classes, input_height, input_width, num_filters):
    """
    U-Net+ (Deep Layer Aggregation) without deep supervision
    Paper : http://openaccess.thecvf.com/content_cvpr_2018/papers/Yu_Deep_Layer_Aggregation_CVPR_2018_paper.pdf
    """
    inputs = Input(shape=(input_height, input_width, 1))

    filters = [num_filters, num_filters * 2, num_filters * 4, num_filters * 8, num_filters * 16]

    x0_0 = conv_block(inputs, filters[0])

    x1_0 = MaxPooling2D()(x0_0)
    x1_0 = conv_block(x1_0, filters[1])

    x0_1 = up_conv(x1_0, filters[0], 2)
    x0_1 = Concatenate()([x0_0, x0_1])
    x0_1 = conv_block(x0_1, filters[0])

    x2_0 = MaxPooling2D()(x1_0)
    x2_0 = conv_block(x2_0, filters[2])

    x1_1 = up_conv(x2_0, filters[1], 2)
    x1_1 = Concatenate()([x1_0, x1_1])
    x1_1 = conv_block(x1_1, filters[1])

    x0_2 = up_conv(x1_1, filters[0], 2)
    x0_2 = Concatenate()([x0_1, x0_2])
    x0_2 = conv_block(x0_2, filters[0])

    x3_0 = MaxPooling2D()(x2_0)
    x3_0 = conv_block(x3_0, filters[3])

    x2_1 = up_conv(x3_0, filters[2], 2)
    x2_1 = Concatenate()([x2_0, x2_1])
    x2_1 = conv_block(x2_1, filters[2])

    x1_2 = up_conv(x2_1, filters[1], 2)
    x1_2 = Concatenate()([x1_1, x1_2])
    x1_2 = conv_block(x1_2, filters[1])

    x0_3 = up_conv(x1_2, filters[0], 2)
    x0_3 = Concatenate()([x0_2, x0_3])
    x0_3 = conv_block(x0_3, filters[0])

    x4_0 = MaxPooling2D()(x3_0)
    x4_0 = conv_block(x4_0, filters[4])

    x3_1 = up_conv(x4_0, filters[3], 2)
    x3_1 = Concatenate()([x3_0, x3_1])
    x3_1 = conv_block(x3_1, filters[3])

    x2_2 = up_conv(x3_1, filters[2], 2)
    x2_2 = Concatenate()([x2_1, x2_2])
    x2_2 = conv_block(x2_2, filters[2])

    x1_3 = up_conv(x2_2, filters[1], 2)
    x1_3 = Concatenate()([x1_2, x1_3])
    x1_3 = conv_block(x1_3, filters[1])

    x0_4 = up_conv(x1_3, filters[0], 2)
    x0_4 = Concatenate()([x0_3, x0_4])
    x0_4 = conv_block(x0_4, filters[0])

    outputs = Conv2D(num_classes, kernel_size=3, padding="same", activation="sigmoid", kernel_initializer="he_normal")(x0_4)

    model = Model(inputs=inputs, outputs=outputs, name="UNet1Plus_wo_DeepSupv")

    return model


def UNet2Plus_w_DeepSupv(num_classes, input_height, input_width, num_filters):
    """
    U-Net++ with deep supervision
    Paper : https://arxiv.org/abs/1807.10165
    """
    inputs = Input(shape=(input_height, input_width, 1))

    filters = [num_filters, num_filters * 2, num_filters * 4, num_filters * 8, num_filters * 16]

    x0_0 = conv_block(inputs, filters[0])

    x1_0 = MaxPooling2D()(x0_0)
    x1_0 = conv_block(x1_0, filters[1])

    x0_1 = up_conv(x1_0, filters[0], 2)
    x0_1 = Concatenate()([x0_0, x0_1])
    x0_1 = conv_block(x0_1, filters[0])

    x2_0 = MaxPooling2D()(x1_0)
    x2_0 = conv_block(x2_0, filters[2])

    x1_1 = up_conv(x2_0, filters[1], 2)
    x1_1 = Concatenate()([x1_0, x1_1])
    x1_1 = conv_block(x1_1, filters[1])

    x0_2 = up_conv(x1_1, filters[0], 2)
    x0_2 = Concatenate()([x0_0, x0_1, x0_2])
    x0_2 = conv_block(x0_2, filters[0])

    x3_0 = MaxPooling2D()(x2_0)
    x3_0 = conv_block(x3_0, filters[3])

    x2_1 = up_conv(x3_0, filters[2], 2)
    x2_1 = Concatenate()([x2_0, x2_1])
    x2_1 = conv_block(x2_1, filters[2])

    x1_2 = up_conv(x2_1, filters[1], 2)
    x1_2 = Concatenate()([x1_0, x1_1, x1_2])
    x1_2 = conv_block(x1_2, filters[1])

    x0_3 = up_conv(x1_2, filters[0], 2)
    x0_3 = Concatenate()([x0_0, x0_1, x0_2, x0_3])
    x0_3 = conv_block(x0_3, filters[0])

    x4_0 = MaxPooling2D()(x3_0)
    x4_0 = conv_block(x4_0, filters[4])

    x3_1 = up_conv(x4_0, filters[3], 2)
    x3_1 = Concatenate()([x3_0, x3_1])
    x3_1 = conv_block(x3_1, filters[3])

    x2_2 = up_conv(x3_1, filters[2], 2)
    x2_2 = Concatenate()([x2_0, x2_1, x2_2])
    x2_2 = conv_block(x2_2, filters[2])

    x1_3 = up_conv(x2_2, filters[1], 2)
    x1_3 = Concatenate()([x1_0, x1_1, x1_2, x1_3])
    x1_3 = conv_block(x1_3, filters[1])

    x0_4 = up_conv(x1_3, filters[0], 2)
    x0_4 = Concatenate()([x0_0, x0_1, x0_2, x0_3, x0_4])
    x0_4 = conv_block(x0_4, filters[0])

    outputs = Average()([x0_1, x0_2, x0_3, x0_4])
    outputs = Conv2D(num_classes, kernel_size=3, padding="same", activation="sigmoid", kernel_initializer="he_normal")(outputs)

    model = Model(inputs=inputs, outputs=outputs, name="UNet2Plus_w_DeepSupv")

    return model


def UNet2Plus_wo_DeepSupv(num_classes, input_height, input_width, num_filters):
    """
    U-Net++ without deep supervision
    Paper : https://arxiv.org/abs/1807.10165
    """
    inputs = Input(shape=(input_height, input_width, 1))

    filters = [num_filters, num_filters * 2, num_filters * 4, num_filters * 8, num_filters * 16]

    x0_0 = conv_block(inputs, filters[0])

    x1_0 = MaxPooling2D()(x0_0)
    x1_0 = conv_block(x1_0, filters[1])

    x0_1 = up_conv(x1_0, filters[0], 2)
    x0_1 = Concatenate()([x0_0, x0_1])
    x0_1 = conv_block(x0_1, filters[0])

    x2_0 = MaxPooling2D()(x1_0)
    x2_0 = conv_block(x2_0, filters[2])

    x1_1 = up_conv(x2_0, filters[1], 2)
    x1_1 = Concatenate()([x1_0, x1_1])
    x1_1 = conv_block(x1_1, filters[1])

    x0_2 = up_conv(x1_1, filters[0], 2)
    x0_2 = Concatenate()([x0_0, x0_1, x0_2])
    x0_2 = conv_block(x0_2, filters[0])

    x3_0 = MaxPooling2D()(x2_0)
    x3_0 = conv_block(x3_0, filters[3])

    x2_1 = up_conv(x3_0, filters[2], 2)
    x2_1 = Concatenate()([x2_0, x2_1])
    x2_1 = conv_block(x2_1, filters[2])

    x1_2 = up_conv(x2_1, filters[1], 2)
    x1_2 = Concatenate()([x1_0, x1_1, x1_2])
    x1_2 = conv_block(x1_2, filters[1])

    x0_3 = up_conv(x1_2, filters[0], 2)
    x0_3 = Concatenate()([x0_0, x0_1, x0_2, x0_3])
    x0_3 = conv_block(x0_3, filters[0])

    x4_0 = MaxPooling2D()(x3_0)
    x4_0 = conv_block(x4_0, filters[4])

    x3_1 = up_conv(x4_0, filters[3], 2)
    x3_1 = Concatenate()([x3_0, x3_1])
    x3_1 = conv_block(x3_1, filters[3])

    x2_2 = up_conv(x3_1, filters[2], 2)
    x2_2 = Concatenate()([x2_0, x2_1, x2_2])
    x2_2 = conv_block(x2_2, filters[2])

    x1_3 = up_conv(x2_2, filters[1], 2)
    x1_3 = Concatenate()([x1_0, x1_1, x1_2, x1_3])
    x1_3 = conv_block(x1_3, filters[1])

    x0_4 = up_conv(x1_3, filters[0], 2)
    x0_4 = Concatenate()([x0_0, x0_1, x0_2, x0_3, x0_4])
    x0_4 = conv_block(x0_4, filters[0])

    outputs = Conv2D(num_classes, kernel_size=3, padding="same", activation="sigmoid", kernel_initializer="he_normal")(x0_4)

    model = Model(inputs=inputs, outputs=outputs, name="UNet2Plus_wo_DeepSupv")

    return model


def UNet3Plus_w_DeepSupv(num_classes, input_height, input_width, num_filters):
    """
    U-Net3+ with deep supervision
    Paper : https://arxiv.org/abs/2004.08790
    """
    inputs = Input(shape=(input_height, input_width, 1))

    filters = [num_filters, num_filters * 2, num_filters * 4, num_filters * 8, num_filters * 16]

    e1 = conv_block(inputs, filters[0])

    e2 = MaxPooling2D()(e1)
    e2 = conv_block(e2, filters[1])

    e3 = MaxPooling2D()(e2)
    e3 = conv_block(e3, filters[2])

    e4 = MaxPooling2D()(e3)
    e4 = conv_block(e4, filters[3])

    e5 = MaxPooling2D()(e4)
    e5 = conv_block(e5, filters[4])

    e1_d_d4 = down_conv(e1, filters[0], 8)
    e2_d_d4 = down_conv(e2, filters[0], 4)
    e3_d_d4 = down_conv(e3, filters[0], 2)
    e4_d4 = conv(e4, filters[0])
    e5_u_d4 = up_conv(e5, filters[0], 2)
    d4 = Concatenate()([e1_d_d4, e2_d_d4, e3_d_d4, e4_d4, e5_u_d4])
    d4 = conv_block(d4, filters[0] * 5)

    e1_d_d3 = down_conv(e1, filters[0], 4)
    e2_d_d3 = down_conv(e2, filters[0], 2)
    e3_d3 = conv(e3, filters[0])
    e5_u_d3 = up_conv(e5, filters[0], 4)
    d4_u_d3 = up_conv(d4, filters[0], 2)
    d3 = Concatenate()([e1_d_d3, e2_d_d3, e3_d3, e5_u_d3, d4_u_d3])
    d3 = conv_block(d3, filters[0] * 5)

    e1_d_d2 = down_conv(e1, filters[0], 2)
    e2_d2 = conv(e2, filters[0])
    e5_u_d2 = up_conv(e5, filters[0], 8)
    d4_u_d2 = up_conv(d4, filters[0], 4)
    d3_u_d2 = up_conv(d3, filters[0], 2)
    d2 = Concatenate()([e1_d_d2, e2_d2, e5_u_d2, d4_u_d2, d3_u_d2])
    d2 = conv_block(d2, filters[0] * 5)

    e1_d1 = conv(e1, filters[0])
    e5_u_d1 = up_conv(e5, filters[0], 16)
    d4_u_d1 = up_conv(d4, filters[0], 8)
    d3_u_d1 = up_conv(d3, filters[0], 4)
    d2_u_d1 = up_conv(d2, filters[0], 2)
    d1 = Concatenate()([e1_d1, e5_u_d1, d4_u_d1, d3_u_d1, d2_u_d1])
    d1 = conv_block(d1, filters[0] * 5)

    outputs = Average()([d1, up_conv(d2, filters[0] * 5, 2), up_conv(d3, filters[0] * 5, 4), up_conv(d4, filters[0] * 5, 8)])
    outputs = Conv2D(num_classes, kernel_size=3, padding="same", activation="sigmoid", kernel_initializer="he_normal")(outputs)

    model = Model(inputs=inputs, outputs=outputs, name="UNet3Plus_w_DeepSupv")

    return model


def UNet3Plus_wo_DeepSupv(num_classes, input_height, input_width, num_filters):
    """
    U-Net3+ without deep supervision
    Paper : https://arxiv.org/abs/2004.08790
    """
    inputs = Input(shape=(input_height, input_width, 1))

    filters = [num_filters, num_filters * 2, num_filters * 4, num_filters * 8, num_filters * 16]

    e1 = conv_block(inputs, filters[0])

    e2 = MaxPooling2D()(e1)
    e2 = conv_block(e2, filters[1])

    e3 = MaxPooling2D()(e2)
    e3 = conv_block(e3, filters[2])

    e4 = MaxPooling2D()(e3)
    e4 = conv_block(e4, filters[3])

    e5 = MaxPooling2D()(e4)
    e5 = conv_block(e5, filters[4])

    e1_d_d4 = down_conv(e1, filters[0], 8)
    e2_d_d4 = down_conv(e2, filters[0], 4)
    e3_d_d4 = down_conv(e3, filters[0], 2)
    e4_d4 = conv(e4, filters[0])
    e5_u_d4 = up_conv(e5, filters[0], 2)
    d4 = Concatenate()([e1_d_d4, e2_d_d4, e3_d_d4, e4_d4, e5_u_d4])
    d4 = conv_block(d4, filters[0] * 5)

    e1_d_d3 = down_conv(e1, filters[0], 4)
    e2_d_d3 = down_conv(e2, filters[0], 2)
    e3_d3 = conv(e3, filters[0])
    e5_u_d3 = up_conv(e5, filters[0], 4)
    d4_u_d3 = up_conv(d4, filters[0], 2)
    d3 = Concatenate()([e1_d_d3, e2_d_d3, e3_d3, e5_u_d3, d4_u_d3])
    d3 = conv_block(d3, filters[0] * 5)

    e1_d_d2 = down_conv(e1, filters[0], 2)
    e2_d2 = conv(e2, filters[0])
    e5_u_d2 = up_conv(e5, filters[0], 8)
    d4_u_d2 = up_conv(d4, filters[0], 4)
    d3_u_d2 = up_conv(d3, filters[0], 2)
    d2 = Concatenate()([e1_d_d2, e2_d2, e5_u_d2, d4_u_d2, d3_u_d2])
    d2 = conv_block(d2, filters[0] * 5)

    e1_d1 = conv(e1, filters[0])
    e5_u_d1 = up_conv(e5, filters[0], 16)
    d4_u_d1 = up_conv(d4, filters[0], 8)
    d3_u_d1 = up_conv(d3, filters[0], 4)
    d2_u_d1 = up_conv(d2, filters[0], 2)
    d1 = Concatenate()([e1_d1, e5_u_d1, d4_u_d1, d3_u_d1, d2_u_d1])
    d1 = conv_block(d1, filters[0] * 5)

    outputs = Conv2D(num_classes, kernel_size=3, padding="same", activation="sigmoid", kernel_initializer="he_normal")(d1)

    model = Model(inputs=inputs, outputs=outputs, name="UNet3Plus_wo_DeepSupv")

    return model


def UNet4Plus_w_DeepSupv(num_classes, input_height, input_width, num_filters):
    """
    U-Net4+ with deep supervision
    """
    inputs = Input(shape=(input_height, input_width, 1))

    e1 = conv_block(inputs, num_filters)

    e2 = down_conv(e1, num_filters, 2)
    e2 = conv_block(e2, num_filters)

    e1_d_e3 = down_conv(e1, num_filters, 4)
    e2_d_e3 = down_conv(e2, num_filters, 2)
    e3 = Concatenate()([e1_d_e3, e2_d_e3])
    e3 = conv_block(e3, num_filters * 2)

    e1_d_e4 = down_conv(e1, num_filters, 8)
    e2_d_e4 = down_conv(e2, num_filters, 4)
    e3_d_e4 = down_conv(e3, num_filters, 2)
    e4 = Concatenate()([e1_d_e4, e2_d_e4, e3_d_e4])
    e4 = conv_block(e4, num_filters * 3)

    e1_d_e5 = down_conv(e1, num_filters, 16)
    e2_d_e5 = down_conv(e2, num_filters, 8)
    e3_d_e5 = down_conv(e3, num_filters, 4)
    e4_d_e5 = down_conv(e4, num_filters, 2)
    e5 = Concatenate()([e1_d_e5, e2_d_e5, e3_d_e5, e4_d_e5])
    e5 = conv_block(e5, num_filters * 4)

    e1_d_d4 = down_conv(e1, num_filters, 8)
    e2_d_d4 = down_conv(e2, num_filters, 4)
    e3_d_d4 = down_conv(e3, num_filters, 2)
    e4_d4 = conv(e4, num_filters)
    e5_u_d4 = up_conv(e5, num_filters, 2)
    d4 = Concatenate()([e1_d_d4, e2_d_d4, e3_d_d4, e4_d4, e5_u_d4])
    d4 = conv_block(d4, num_filters * 5)

    e1_d_d3 = down_conv(e1, num_filters, 4)
    e2_d_d3 = down_conv(e2, num_filters, 2)
    e3_d3 = conv(e3, num_filters)
    e4_u_d3 = up_conv(e4, num_filters, 2)
    e5_u_d3 = up_conv(e5, num_filters, 4)
    d4_u_d3 = up_conv(d4, num_filters, 2)
    d3 = Concatenate()([e1_d_d3, e2_d_d3, e3_d3, e4_u_d3, e5_u_d3, d4_u_d3])
    d3 = conv_block(d3, num_filters * 6)

    e1_d_d2 = down_conv(e1, num_filters, 2)
    e2_d2 = conv(e2, num_filters)
    e3_u_d2 = up_conv(e3, num_filters, 2)
    e4_u_d2 = up_conv(e4, num_filters, 4)
    e5_u_d2 = up_conv(e5, num_filters, 8)
    d4_u_d2 = up_conv(d4, num_filters, 4)
    d3_u_d2 = up_conv(d3, num_filters, 2)
    d2 = Concatenate()([e1_d_d2, e2_d2, e3_u_d2, e4_u_d2, e5_u_d2, d4_u_d2, d3_u_d2])
    d2 = conv_block(d2, num_filters * 7)

    e1_d1 = conv(e1, num_filters)
    e2_u_d1 = up_conv(e2, num_filters, 2)
    e3_u_d1 = up_conv(e3, num_filters, 4)
    e4_u_d1 = up_conv(e4, num_filters, 8)
    e5_u_d1 = up_conv(e5, num_filters, 16)
    d4_u_d1 = up_conv(d4, num_filters, 8)
    d3_u_d1 = up_conv(d3, num_filters, 4)
    d2_u_d1 = up_conv(d2, num_filters, 2)
    d1 = Concatenate()([e1_d1, e2_u_d1, e3_u_d1, e4_u_d1, e5_u_d1, d4_u_d1, d3_u_d1, d2_u_d1])
    d1 = conv_block(d1, num_filters * 8)

    outputs = Concatenate()([d1,
                             UpSampling2D(size=2, interpolation="bilinear")(d2),
                             UpSampling2D(size=4, interpolation="bilinear")(d3),
                             UpSampling2D(size=8, interpolation="bilinear")(d4)])
    outputs = Conv2D(num_classes, kernel_size=3, padding="same", activation="sigmoid", kernel_initializer="he_normal")(outputs)

    model = Model(inputs=inputs, outputs=outputs, name="UNet4Plus_w_DeepSupv")

    return model


def UNet4Plus_wo_DeepSupv(num_classes, input_height, input_width, num_filters):
    """
    U-Net4+ without deep supervision
    """
    inputs = Input(shape=(input_height, input_width, 1))

    e1 = conv_block(inputs, num_filters)

    e2 = down_conv(e1, num_filters, 2)
    e2 = conv_block(e2, num_filters)

    e1_d_e3 = down_conv(e1, num_filters, 4)
    e2_d_e3 = down_conv(e2, num_filters, 2)
    e3 = Concatenate()([e1_d_e3, e2_d_e3])
    e3 = conv_block(e3, num_filters * 2)

    e1_d_e4 = down_conv(e1, num_filters, 8)
    e2_d_e4 = down_conv(e2, num_filters, 4)
    e3_d_e4 = down_conv(e3, num_filters, 2)
    e4 = Concatenate()([e1_d_e4, e2_d_e4, e3_d_e4])
    e4 = conv_block(e4, num_filters * 3)

    e1_d_e5 = down_conv(e1, num_filters, 16)
    e2_d_e5 = down_conv(e2, num_filters, 8)
    e3_d_e5 = down_conv(e3, num_filters, 4)
    e4_d_e5 = down_conv(e4, num_filters, 2)
    e5 = Concatenate()([e1_d_e5, e2_d_e5, e3_d_e5, e4_d_e5])
    e5 = conv_block(e5, num_filters * 4)

    e1_d_d4 = down_conv(e1, num_filters, 8)
    e2_d_d4 = down_conv(e2, num_filters, 4)
    e3_d_d4 = down_conv(e3, num_filters, 2)
    e4_d4 = conv(e4, num_filters)
    e5_u_d4 = up_conv(e5, num_filters, 2)
    d4 = Concatenate()([e1_d_d4, e2_d_d4, e3_d_d4, e4_d4, e5_u_d4])
    d4 = conv_block(d4, num_filters * 5)

    e1_d_d3 = down_conv(e1, num_filters, 4)
    e2_d_d3 = down_conv(e2, num_filters, 2)
    e3_d3 = conv(e3, num_filters)
    e4_u_d3 = up_conv(e4, num_filters, 2)
    e5_u_d3 = up_conv(e5, num_filters, 4)
    d4_u_d3 = up_conv(d4, num_filters, 2)
    d3 = Concatenate()([e1_d_d3, e2_d_d3, e3_d3, e4_u_d3, e5_u_d3, d4_u_d3])
    d3 = conv_block(d3, num_filters * 6)

    e1_d_d2 = down_conv(e1, num_filters, 2)
    e2_d2 = conv(e2, num_filters)
    e3_u_d2 = up_conv(e3, num_filters, 2)
    e4_u_d2 = up_conv(e4, num_filters, 4)
    e5_u_d2 = up_conv(e5, num_filters, 8)
    d4_u_d2 = up_conv(d4, num_filters, 4)
    d3_u_d2 = up_conv(d3, num_filters, 2)
    d2 = Concatenate()([e1_d_d2, e2_d2, e3_u_d2, e4_u_d2, e5_u_d2, d4_u_d2, d3_u_d2])
    d2 = conv_block(d2, num_filters * 7)

    e1_d1 = conv(e1, num_filters)
    e2_u_d1 = up_conv(e2, num_filters, 2)
    e3_u_d1 = up_conv(e3, num_filters, 4)
    e4_u_d1 = up_conv(e4, num_filters, 8)
    e5_u_d1 = up_conv(e5, num_filters, 16)
    d4_u_d1 = up_conv(d4, num_filters, 8)
    d3_u_d1 = up_conv(d3, num_filters, 4)
    d2_u_d1 = up_conv(d2, num_filters, 2)
    d1 = Concatenate()([e1_d1, e2_u_d1, e3_u_d1, e4_u_d1, e5_u_d1, d4_u_d1, d3_u_d1, d2_u_d1])
    d1 = conv_block(d1, num_filters * 8)

    outputs = Conv2D(num_classes, kernel_size=3, padding="same", activation="sigmoid", kernel_initializer="he_normal")(d1)

    model = Model(inputs=inputs, outputs=outputs, name="UNet4Plus_wo_DeepSupv")

    return model

In [None]:
# Define loss functions
from skimage.morphology import label


def iou_metric(y_true_in, y_pred_in, print_table=False):
    y_true = label(y_true_in > 0.5)
    y_pred = label(y_pred_in > 0.5)

    true_objects = len(np.unique(y_true))
    pred_objects = len(np.unique(y_pred))

    intersection = np.histogram2d(y_true.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))[0]

    # Compute areas (needed for finding the union between all objects)
    area_true = np.histogram(y_true, bins=true_objects)[0]
    area_pred = np.histogram(y_pred, bins=pred_objects)[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)

    # Compute union
    union = area_true + area_pred - intersection

    # Exclude background from the analysis
    intersection = intersection[1:, 1:]
    union = union[1:, 1:]
    union[union == 0] = 1e-9

    # Compute the intersection over union
    iou = intersection / union

    # print('IOU {}'.format(iou))
    # Precision helper function
    def precision_at(threshold, iou):
        matches = iou > threshold
        true_positives = np.sum(matches, axis=1) == 1  # Correct objects
        false_positives = np.sum(matches, axis=0) == 0  # Missed objects
        false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
        tp, fp, fn = np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)
        return tp, fp, fn

    # Loop over IoU thresholds
    prec = []
    if print_table:
        print("Thresh\tTP\tFP\tFN\tPrec.")

    for t in np.arange(0.5, 1.0, 0.05):
        tp, fp, fn = precision_at(t, iou)
        if (tp + fp + fn) > 0:
            p = tp / (tp + fp + fn)
        else:
            p = 0
        if print_table:
            print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tp, fp, fn, p))
        prec.append(p)

    if print_table:
        print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))
    return np.mean(prec)


def iou_metric_batch(y_true, y_pred):
    batch_size = y_true.shape[0]
    metric = []
    for batch in range(batch_size):
        value = iou_metric(y_true[batch], y_pred[batch])
        metric.append(value)
    return np.array(np.mean(metric), dtype=np.float32)


def mean_iou_metric(y_true, y_pred):
    metric_value = tf.compat.v1.py_func(iou_metric_batch, [y_true, y_pred], tf.float32)
    return metric_value


def mean_iou_metric_loss(y_true, y_pred):
    loss = 1 - mean_iou_metric(y_true, y_pred)
    loss.set_shape((None,))
    return loss


def dice_coef(y_true, y_pred, smooth=1e-6):
    """
    Dice = (2*|X & Y|)/ (|X|+ |Y|)
         =  2*sum(|A*B|)/(sum(A^2)+sum(B^2))
    ref: https://arxiv.org/pdf/1606.04797v1.pdf
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2.0 * intersection + smooth) / (K.sum(K.square(y_true), -1) + K.sum(K.square(y_pred), -1) + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)


def _BCE_Dice_mIoU_Loss(y_true, y_pred):
    """
    The segmentation loss is optimized as the weighted average of binary crossentropy,
    dice coefficient and mean intersection over union (IoU) which is evaluated with
    pixel accuracy, loss value and IoU. The IoU score calculation/implementation is
    as per the Kaggle Data Science Bowl Challenge 2018 (KDSB18), which is the more
    precise and accurate approach for computing IoU.

    :param y_true: the ground truth
    :param y_pred: the predicted
    :return: the weighted average loss value
    """
    loss = 0.4 * binary_crossentropy(y_true, y_pred) + \
           0.2 * dice_coef_loss(y_true, y_pred) + \
           0.4 * mean_iou_metric_loss(y_true, y_pred)
    return loss


def BCE_Dice_mIoU_Loss():
    return _BCE_Dice_mIoU_Loss

In [None]:
# Train deep neural network model
import os
import pickle
import random

from datetime import datetime
from keras.preprocessing.image import *

import cv2
import numpy as np

for _ in range(1):
    K.clear_session()
    
    data_root = os.path.join("../input/dibco-training-set-128x128/dibco_train_0.8_val_0.2_seed_8281")
    print("Training {} dataset, using {} model with {} loss ...".format(DATASET_NAME, NETWORK_MODEL, LOSS_FUNCTION))
    
    train_val_split_str = data_root[data_root.index("_") + 1:data_root.index("_") + 18]  # extract substring: train_0.8_val_0.2
    
    data_train_dir = os.path.join(data_root, "train")
    data_val_dir = os.path.join(data_root, "val")

    random.seed()
    seed1 = round(random.random() * 10000)
    seed2 = round(random.random() * 10000)
    print("Random seed1: {}, and random seed2: {}".format(seed1, seed2))

    train_img_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
    train_msk_datagen = ImageDataGenerator(rescale=1.0 / 255.0)

    train_img_generator = train_img_datagen.flow_from_directory(
        data_train_dir,
        target_size=(TILE_SIZE, TILE_SIZE),
        color_mode="grayscale",
        classes=["images"],
        class_mode=None,
        batch_size=BATCH_SIZE,
        shuffle=True,
        seed=seed1,
    )

    train_msk_generator = train_msk_datagen.flow_from_directory(
        data_train_dir,
        target_size=(TILE_SIZE, TILE_SIZE),
        color_mode="grayscale",
        classes=["labels"],
        class_mode=None,
        batch_size=BATCH_SIZE,
        shuffle=True,
        seed=seed1,
    )

    val_img_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
    val_msk_datagen = ImageDataGenerator(rescale=1.0 / 255.0)

    val_img_generator = val_img_datagen.flow_from_directory(
        data_val_dir,
        target_size=(TILE_SIZE, TILE_SIZE),
        color_mode="grayscale",
        classes=["images"],
        class_mode=None,
        batch_size=BATCH_SIZE,
        shuffle=True,
        seed=seed2,
    )

    val_msk_generator = val_msk_datagen.flow_from_directory(
        data_val_dir,
        target_size=(TILE_SIZE, TILE_SIZE),
        color_mode="grayscale",
        classes=["labels"],
        class_mode=None,
        batch_size=BATCH_SIZE,
        shuffle=True,
        seed=seed2,
    )

    train_generator = zip(train_img_generator, train_msk_generator)
    val_generator = zip(val_img_generator, val_msk_generator)
    
    for case in switch(NETWORK_MODEL):
      if case("UNet"):
        model = UNet(num_classes=1, input_height=TILE_SIZE, input_width=TILE_SIZE, num_filters=NUM_FILTERS)
        break
      if case("UNet1Plus_w_DeepSupv"):
        model = UNet1Plus_w_DeepSupv(num_classes=1, input_height=TILE_SIZE, input_width=TILE_SIZE, num_filters=NUM_FILTERS)
        break
      if case("UNet1Plus_wo_DeepSupv"):
        model = UNet1Plus_wo_DeepSupv(num_classes=1, input_height=TILE_SIZE, input_width=TILE_SIZE, num_filters=NUM_FILTERS)
        break
      if case("UNet2Plus_w_DeepSupv"):
        model = UNet2Plus_w_DeepSupv(num_classes=1, input_height=TILE_SIZE, input_width=TILE_SIZE, num_filters=NUM_FILTERS)
        break
      if case("UNet2Plus_wo_DeepSupv"):
        model = UNet2Plus_wo_DeepSupv(num_classes=1, input_height=TILE_SIZE, input_width=TILE_SIZE, num_filters=NUM_FILTERS)
        break
      if case("UNet3Plus_w_DeepSupv"):
        model = UNet3Plus_w_DeepSupv(num_classes=1, input_height=TILE_SIZE, input_width=TILE_SIZE, num_filters=NUM_FILTERS)
        break
      if case("UNet3Plus_wo_DeepSupv"):
        model = UNet3Plus_wo_DeepSupv(num_classes=1, input_height=TILE_SIZE, input_width=TILE_SIZE, num_filters=NUM_FILTERS)
        break
      if case("UNet4Plus_w_DeepSupv"):
        model = UNet4Plus_w_DeepSupv(num_classes=1, input_height=TILE_SIZE, input_width=TILE_SIZE, num_filters=NUM_FILTERS)
        break
      if case("UNet4Plus_wo_DeepSupv"):
        model = UNet4Plus_wo_DeepSupv(num_classes=1, input_height=TILE_SIZE, input_width=TILE_SIZE, num_filters=NUM_FILTERS)
        break
      if case(): # default, could also just omit condition or 'if True'
        print("Oops! Network Model Should Be Something Else!")
        # No need to break here, it'll stop anyway
    
    for case in switch(LOSS_FUNCTION):
      if case("BCE_Dice_mIoU"):
        model.compile(optimizer="Adam", loss=BCE_Dice_mIoU_Loss(), metrics=["accuracy"])
        break
      if case(): # default, could also just omit condition or 'if True'
        print("Oops! Loss Function Should Be Something Else!")
        # No need to break here, it'll stop anyway
    
    # model.load_weights("...")
    # model.summary()
    
    model_weights_root = "./weights-" + DATASET_NAME.lower() + "-" + NETWORK_MODEL.lower() + "-" + LOSS_FUNCTION.lower() + "-" + train_val_split_str + "-" + str(datetime.timestamp(datetime.now()))
    if not os.path.exists(model_weights_root):
        os.makedirs(model_weights_root)

    check_point = ModelCheckpoint(os.path.join(model_weights_root,
                                               DATASET_NAME.lower() + "-" + NETWORK_MODEL.lower() + "-" + LOSS_FUNCTION.lower() + "-" + train_val_split_str +
                                               "-ps_" + str(TILE_SIZE) + "x" + str(TILE_SIZE) +
                                               "-ch_" + str(NUM_FILTERS) +
                                               "-bs_" + str(BATCH_SIZE) +
                                               "-val_loss_{val_loss}-val_accuracy_{val_accuracy}.hdf5"),
                                  monitor="val_loss",
                                  verbose=1,
                                  save_best_only=True,
                                  save_weights_only=True,
                                  mode="auto")

    reduce_lr = ReduceLROnPlateau(monitor="val_loss",
                                  factor=0.5,
                                  patience=10,
                                  verbose=1,
                                  mode="auto")

    early_stop = EarlyStopping(monitor="val_loss",
                               patience=15,
                               verbose=1,
                               mode="auto")

    print("Now start training the network model...")
    history = model.fit_generator(train_generator,
                                  epochs=NUM_EPOCHS,
                                  verbose=1,
                                  steps_per_epoch=len(train_img_generator),
                                  validation_data=val_generator,
                                  validation_steps=len(val_img_generator),
                                  callbacks=[check_point, reduce_lr, early_stop])

    with open(os.path.join(model_weights_root,
                           DATASET_NAME.lower() + "-" + NETWORK_MODEL.lower() + "-" + LOSS_FUNCTION.lower() + "-" + train_val_split_str +
                           "-ps_" + str(TILE_SIZE) + "x" + str(TILE_SIZE) +
                           "-ch_" + str(NUM_FILTERS) +
                           "-bs_" + str(BATCH_SIZE) +
                           "-" + str(datetime.timestamp(datetime.now())) + ".history"), "wb") as df:
        pickle.dump(history.history, df)

    df.close()

print("Finished!")