In [1]:
import os
import numpy as np
from tensorflow.keras.models import Model

from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout,Conv2DTranspose,concatenate,Cropping2D, ReLU, BatchNormalization
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras import backend as keras
import tensorflow as tf

import matplotlib.pyplot as plt

In [2]:
size = 640

In [3]:
def load_image(img_path, size=(size, size)):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_png(img)[:, :, :1]
    if size:
        img = tf.image.resize(img, size)
    mean_ = tf.math.reduce_mean(img)
    std_ = tf.math.reduce_std(img)
    return (tf.cast(img, 'float32') - mean_) / std_

def load_label(img_path, size=(size, size)):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_png(img)[:, :, :1]
    if size:
        img = tf.image.resize(img, size)
    return tf.cast(img > 10, 'int64')


def load_val_image(img_path, size=(1280, 1280)):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_png(img)[:, :, :1]
    if size:
        img = tf.image.resize(img, size)
    return (tf.cast(img, 'float32') - 128.) / 27.
    

def load_val_label(img_path, size=None):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_png(img)
    if size:
        img = tf.image.resize(img, size)
    return tf.cast(img > 10, 'int64')


def _get_dataset(file_pattern, load_file):
    dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)
    dataset = dataset.map(load_file)
    return dataset

def sample_division(image, label, threshold=10):
    return tf.cast(tf.reduce_sum(label) > threshold, 'int32')

In [4]:
def get_test_ds(file_pattern, load_file):
    dataset_fp = tf.data.Dataset.list_files(file_pattern, shuffle=False)
    dataset_img = dataset_fp.map(load_file)
    return dataset_fp, dataset_img

In [5]:
def get_dataset(image_dir, label_dir, batch_size, shuffle_size=10):
    image_dataset = _get_dataset(image_dir+'*.jpg', load_image)
    label_dataset = _get_dataset(label_dir+'*.png', load_label)
    pair_dataset = tf.data.Dataset.zip((image_dataset, label_dataset))
    return pair_dataset.shuffle(buffer_size=shuffle_size).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)


def unet(pretrained_weights=None, input_size=(size, size, 1), padding='same', base=24):
    inputs = Input(input_size)
    conv1 = Conv2D(base*1, 5, padding=padding, dilation_rate=1)(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = ReLU()(conv1)
    conv1 = Conv2D(base*1, 3, padding=padding, dilation_rate=1)(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = ReLU()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(base*2, 3, padding=padding, dilation_rate=3)(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = ReLU()(conv2)
    conv2 = Conv2D(base*2, 3, padding=padding, dilation_rate=3)(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = ReLU()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    pool2 = Dropout(0.5)(pool2)
    
    conv3 = Conv2D(base*4, 3, padding=padding, dilation_rate=2)(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = ReLU()(conv3)
    conv3 = Conv2D(base*4, 3, padding=padding, dilation_rate=2)(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = ReLU()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    pool3 = Dropout(0.5)(pool3)
    
    conv4 = Conv2D(base*8, 3, padding=padding, dilation_rate=2)(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = ReLU()(conv4)
    conv4 = Conv2D(base*8, 3, padding=padding, dilation_rate=2)(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = ReLU()(conv4)
    drop4 = Dropout(0.5)(conv4)

    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(base*16, 3, padding=padding, dilation_rate=2)(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = ReLU()(conv5)
    conv5 = Conv2D(base*16, 3, padding=padding, dilation_rate=1)(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = ReLU()(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2DTranspose(base*16, 3, activation='relu', padding='same', strides=(2, 2),)(drop5)
    merge6 = concatenate([drop4, up6], axis=3)
    conv6 = Conv2D(base*8, 3, activation='relu', padding=padding, dilation_rate=1)(merge6)
    conv6 = BatchNormalization()(merge6)
    conv6 = Conv2D(base*8, 3, activation='relu', padding=padding, dilation_rate=2)(conv6)
    conv6 = BatchNormalization()(conv6)
    drop6 = Dropout(0.5)(conv6)

    up7 = Conv2DTranspose(base*4, 3, activation='relu', padding='same', strides=(2, 2),)(conv6)
    merge7 = concatenate([conv3, up7], axis=3)
    conv7 = Conv2D(base*4, 3, activation='relu', padding=padding, dilation_rate=1)(merge7)
    conv7 = BatchNormalization()(merge7)
    conv7 = Conv2D(base*4, 3, activation='relu', padding=padding, dilation_rate=2)(conv7)
    conv7 = BatchNormalization()(conv7)
    drop7 = Dropout(0.5)(conv7)

    up8 = Conv2DTranspose(base*2, 3, activation='relu', padding='same', strides=(2, 2),)(conv7)
    merge8 = concatenate([conv2, up8], axis=3)
    conv8 = Conv2D(base*2, 3, activation='relu', padding=padding, dilation_rate=1)(merge8)
    conv8 = BatchNormalization()(merge8)
    conv8 = Conv2D(base*2, 3, activation='relu', padding=padding, dilation_rate=1)(conv8)
    conv8 = BatchNormalization()(conv8)
    drop8 = Dropout(0.5)(conv8)

    up9 = Conv2DTranspose(base*1, 3, activation='relu', padding='same', strides=(2, 2),)(conv8)
    merge9 = concatenate([conv1, up9], axis=3)
    conv9 = Conv2D(base*1, 3, activation='relu', padding=padding, dilation_rate=1)(merge9)
    conv9 = BatchNormalization()(merge9)
    conv9 = Conv2D(base*1, 3, activation='relu', padding=padding, dilation_rate=1)(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(base*1, 3, activation='relu', padding=padding, dilation_rate=1)(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(1, 1, activation='sigmoid', padding=padding,)(conv9)
    
    model = Model(inputs=inputs, outputs=conv9)

    return model

In [6]:
batch_size = 4

label_dir = 'my_data/train_mask/'
image_dir = 'my_data/train_img/'
val_image_dir = 'my_data/valid_img/'
val_label_dir = 'my_data/valid_mask/'

In [7]:
train_dataset = get_dataset(image_dir, label_dir, batch_size)
valid_dataset = get_dataset(val_image_dir, val_label_dir, batch_size)

In [8]:
model = unet()

In [9]:
ilearning_rate = 1e-2

def scheduler(epoch, lr):
    if epoch < 2:
        return ilearning_rate*1e-2
    elif epoch < 10:
        return ilearning_rate
    elif epoch < 50:
        return ilearning_rate * 0.1
    else:
        return lr * tf.math.exp(-0.1)

callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

In [10]:
model.compile(optimizer=SGD(lr=ilearning_rate, momentum=0.95), loss=tf.keras.losses.mean_absolute_error, metrics=['accuracy'],)

In [2]:
history = model.fit(train_dataset, validation_data= valid_dataset, verbose=1,  epochs=100,callbacks=[callback])
model.evaluate(valid_dataset, verbose=1)

In [None]:
# model.save('pth/seg_unet_09801.h5')