In [None]:
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, UpSampling2D, Activation, Input, Concatenate, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import Sequential
import numpy as np
from matplotlib import pyplot
from tensorflow.keras import mixed_precision
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.models import Model
import tensorflow as tf
import os
import glob
import cv2
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K
from keras.layers.core import Lambda

In [None]:
if tf.config.list_physical_devices('GPU'):
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)
    batch_size = 16
    print('Compute dtype: %s' % policy.compute_dtype)
    print('Variable dtype: %s' % policy.variable_dtype)
else:
    print('The model will run on a CPU')
    batch_size = 8

In [None]:
im_width = 572
im_height = 572

output_mask_width = 388
output_mask_height = 388

image_chanels = 3
n_classes = 1
seed = 42

data_path = 'yourdatapath'

images_path = data_path + '/' + 'train/images/img/'
masks_path = data_path + '/' + 'train/masks/img/'

val_images_path = data_path + '/' + 'val/images/img/'
val_masks_path = data_path + '/' + 'val/masks/img/'

In [None]:
input_img_paths = sorted(
    [
        os.path.join(images_path, fname)
        for fname in os.listdir(images_path)
        if fname.endswith(".png")
    ]
)
target_img_paths = sorted(
    [
        os.path.join(masks_path, fname)
        for fname in os.listdir(masks_path)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

val_input_img_paths = sorted(
    [
        os.path.join(val_images_path, fname)
        for fname in os.listdir(val_images_path)
        if fname.endswith(".png")
    ]
)
val_target_img_paths = sorted(
    [
        os.path.join(val_masks_path, fname)
        for fname in os.listdir(val_masks_path)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

In [None]:
print('Train images: ' + str(len(input_img_paths)))
print('Train masks: ' + str(len(target_img_paths)))

print('Val images: ' + str(len(val_input_img_paths)))
print('Val masks: ' + str(len(val_target_img_paths)))

In [None]:
data_gen_args = dict(rescale=1. / 255,
                    rotation_range=10,
                    shear_range=0.2,
                    width_shift_range=0.2,
                    height_shift_range=0.2,
                    zoom_range=0.2,
                    horizontal_flip=True)

val_data_gen_args = dict(rescale=1. / 255)

image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

val_image_datagen = ImageDataGenerator(**val_data_gen_args)
val_mask_datagen = ImageDataGenerator(**val_data_gen_args)

image_generator = image_datagen.flow_from_directory(
    'yourdatapath/train/images/',
    class_mode=None,
    color_mode='rgb',
    batch_size=batch_size,
    target_size=(im_height, im_width),
    seed=seed)
mask_generator = mask_datagen.flow_from_directory(
    'yourdatapath/train/masks/',
    class_mode=None,
    color_mode='grayscale',
    batch_size=batch_size,
    target_size=(output_mask_height, output_mask_width),
    seed=seed)

val_image_generator = val_image_datagen.flow_from_directory(
    'yourdatapath/val/images/',
    class_mode=None,
    color_mode='rgb',
    batch_size=batch_size,
    target_size=(im_height, im_width),
    seed=seed)
val_mask_generator = val_mask_datagen.flow_from_directory(
    'yourdatapath/val/masks/',
    class_mode=None,
    color_mode='grayscale',
    batch_size=batch_size,
    target_size=(output_mask_height, output_mask_width),
    seed=seed)

train_generator = zip(image_generator, mask_generator)
val_generator = zip(val_image_generator, val_mask_generator)

In [None]:
def unet(input_shape=(im_height, im_width, image_chanels), f=64, steps=4, n_classes=n_classes):
  
    def downstream(x, f):
        x = Conv2D(f, 3, activation='relu')(x)
        d = Conv2D(f, 3, activation='relu')(x)
        x = MaxPooling2D(2, strides=2, padding='same')(d)
        return d, x

    def crop_merge(x, d):
        _, xw, xh, _ = K.int_shape(x)
        _, dw, dh, _ = K.int_shape(d)
        mw, mh = (dw-xw)//2, (dh-xh)//2

        d = Lambda(lambda x: x[:, mw: dw-mw, mh: dh-mh, :])(d)
        x = Concatenate()([d, x])
        return x

    def upstream(x, f, d):
        x = UpSampling2D()(x)
        x = Conv2D(f, 2, padding='same')(x)
        x = crop_merge(x, d)
        x = Conv2D(f, 3, activation='relu')(x)
        x = Conv2D(f, 3, activation='relu')(x)
        return x

    input = Input(input_shape)
    x = input

    downsampled = []
    for i in range(steps+1):
        d, x = downstream(x, f*2**i)
        downsampled.append(d)
    x = downsampled.pop()

    for i in range(steps-1, -1, -1):
        x = upstream(x, f*2**i, downsampled[i])

    output = Conv2D(n_classes, (1, 1), activation='sigmoid', dtype='float32')(x)
    model = Model(input, output)

    return model

In [None]:
model = unet()

In [None]:
def iou_score(y_pred, y_true, smooth=1.):
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    union = K.sum(y_true, -1) + K.sum(y_pred, -1) - intersection
    iou = (intersection + smooth)/(union + smooth)
    return iou

def dice_coef(y_true, y_pred, smooth=1.):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

In [None]:
model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=[dice_coef])

In [None]:
model.summary()

In [None]:
model.load_weights('bestresnetunet.h5')

In [None]:
callbacks = [
    EarlyStopping(patience=10, monitor='val_dice_coef', mode='auto'),
    ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.0001, monitor='val_dice_coef', mode='max', verbose=1),
    ModelCheckpoint('bestresnetunet.h5', save_best_only=True, save_weights_only=True, monitor='val_dice_coef', mode='max', verbose=1)
]

In [None]:
results = model.fit(train_generator,
                    steps_per_epoch=len(input_img_paths) // batch_size,
                    batch_size=batch_size,
                    validation_data=val_generator,
                    validation_steps=len(val_input_img_paths) // batch_size,
                    epochs=500,
                    callbacks=callbacks)

In [None]:
model.load_weights('bestresnetunet.h5')

In [None]:
im = cv2.imread('yourimagepath/images/image.png')
im = cv2.resize(im, (im_height, im_width), interpolation = cv2.INTER_AREA)
img = img_to_array(im)
img /= 255
img = np.expand_dims(img, axis=0)

im_mask = cv2.imread('yourimagepath/masks/truemask.png')
im_mask = cv2.cvtColor(im_mask, cv2.COLOR_BGR2GRAY)
im_mask = cv2.resize(im_mask, (im_height, im_width), interpolation = cv2.INTER_AREA)
img_mask = img_to_array(im_mask)
img_mask /= 255
img_mask = np.expand_dims(img_mask, axis=0)

In [None]:
preds_val = model.predict(img, verbose=1)
preds_val_t = (preds_val > 0.75).astype(np.float32)

In [None]:
def plot_sample(X, y, preds, binary_preds, ix=None):
    if ix is None:
        ix = random.randint(0, len(X))

    has_mask = y[ix].max() > 0

    fig, ax = plt.subplots(1, 4, figsize=(20, 10))
    ax[0].imshow(X[ix, ..., 0], cmap='jet')
    if has_mask:
        ax[0].contour(y[ix].squeeze(), colors='k', levels=[0.5])
    ax[0].set_title('Jet')

    ax[1].imshow(y[ix].squeeze())
    ax[1].set_title('Salt')

    ax[2].imshow(preds[ix], vmin=0, vmax=1)
    if has_mask:
        ax[2].contour(y[ix].squeeze(), colors='k', levels=[0.5])
    ax[2].set_title('Salt Predicted')
    
    ax[3].imshow(binary_preds[ix].squeeze(), vmin=0, vmax=1)
    if has_mask:
        ax[3].contour(y[ix].squeeze(), colors='k', levels=[0.5])
    ax[3].set_title('Salt Predicted binary');

In [None]:
plot_sample(img, img_mask, preds_val, preds_val_t, ix=0)