In [None]:
# coding: utf-8
import os

import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator
from keras_unet.models import custom_unet
import keras.backend as K
import matplotlib.pyplot as plt

SEED = 909
BATCH_SIZE_TRAIN = 4
BATCH_SIZE_VAL = 4
BATCH_SIZE_TEST = 4
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
IMG_SIZE = (IMAGE_HEIGHT,IMAGE_WIDTH)
NUM_TRAIN = 6651
NUM_VAL = 932
NUM_TEST = 1950
NUM_EPOCHS = 20

In [None]:
model = custom_unet(
    input_shape=(512, 512, 1),
    use_batch_norm=True,
    num_classes=1,
    filters=64,
    dropout=0.25,
    output_activation='sigmoid')

In [None]:
def dice_loss(targets, inputs, smooth=1e-6):
    inputs = K.flatten(inputs)
    targets = K.flatten(targets)

    intersection = K.sum(targets * inputs)
    dice = (2 * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    return 1 - dice

In [None]:
EPOCH_STEP_TRAIN = NUM_TRAIN // BATCH_SIZE_TRAIN
EPOCH_STEP_VAL = NUM_VAL // BATCH_SIZE_VAL
EPOCH_STEP_TEST = NUM_TEST // BATCH_SIZE_TEST
model.compile(optimizer='adam', loss=dice_loss, metrics=[dice_loss, tf.keras.metrics.Precision(),
    tf.keras.metrics.Recall()], run_eagerly=False)

In [None]:
def create_img_generator(img_path, mask_path, batch_size):
    data_gen_args = dict(rescale=1./255)
    img_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    img_generator = img_datagen.flow_from_directory(img_path, target_size=IMG_SIZE, class_mode=None,
       color_mode='grayscale', batch_size=batch_size, seed=SEED)

    mask_generator = mask_datagen.flow_from_directory(mask_path, target_size=IMG_SIZE,
       class_mode=None, color_mode='grayscale', batch_size=batch_size, seed=SEED)
    return zip(img_generator, mask_generator)

In [None]:
train_img_path = os.path.join('train-output', 'images')
train_mask_path = os.path.join('train-output', 'masks')

val_img_path = os.path.join('val-output', 'images')
val_mask_path = os.path.join('val-output', 'masks')

train_generator = create_img_generator(train_img_path, train_mask_path, BATCH_SIZE_TRAIN)
val_generator = create_img_generator(val_img_path, val_mask_path, BATCH_SIZE_TRAIN)

In [None]:
model.fit(train_generator,
    steps_per_epoch=EPOCH_STEP_TRAIN,
    validation_data=val_generator,
    validation_steps=EPOCH_STEP_VAL,
    epochs=NUM_EPOCHS)

In [None]:
model.save('ctg-segmentation-model.h5')

In [None]:
def display(display_list):
    plt.figure(figsize=(15,15))
    title = ['Input', 'True Mask', 'Predicted Mask']
    for i, _ in enumerate(display_list):
        plt.subplot(1, len(display_list), i + 1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]), cmap='gray')
    plt.show()

def show_prediction(datagen, num=1):
    for _ in range(0, num):
        image,mask = next(datagen)
        pred_mask = model.predict(image)[0] > 0.5
        display([image[0], mask[0], pred_mask])

In [None]:
test_img_path = os.path.join('test-output', 'images')
test_mask_path = os.path.join('test-output', 'masks')

test_generator = create_img_generator(test_img_path, test_mask_path, BATCH_SIZE_TRAIN)

show_prediction(test_generator, 3)