# About 
This notebook demonstrates one way to setup a UNet using the pretrained ResNet50 encoder from Keras.  The input images are scaled first to `224 x 224` by doubling the size and padding the rest.  In this notebook I have kept simple all other details of training, augmentation, and prediction.  Feel free to extend this notebook by adding features. 

In [None]:
import cv2
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import glob
import gc
gc.enable() 

import matplotlib.pyplot as plt 
import os 
import time 

# This stops pandas from spitting 
# out warnings at us. 
import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import train_test_split

from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import map_coordinates
from skimage.io import imread
from skimage.transform import resize

import tensorflow as tf 

from keras.applications.resnet50 import ResNet50
from keras.preprocessing.image import load_img, ImageDataGenerator
from keras import Model
from keras.callbacks import Callback, EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.models import load_model
from keras.optimizers import Adam, SGD
from keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, SpatialDropout2D 
from keras.layers import UpSampling2D, Concatenate, Dropout, Lambda, BatchNormalization, Add, ZeroPadding2D
from keras.layers import concatenate
from keras.losses import binary_crossentropy
from keras import backend as K 

from tqdm import tqdm_notebook

%matplotlib inline

In [None]:
BASE_DIRECTORY = '../input/'

# Image sizes for loading. 
IMAGE_WIDTH, IMAGE_HEIGHT = 101, 101
IMAGE_CHANNELS = 3

# U-net will be applied later, and this is the 
# input size that we will use. 
RESIZED_WIDTH, RESIZED_HEIGHT = 224, 224

In [None]:
def upscale(input_image, resized_shape=(224,224,3)):
    new_image = np.zeros(shape=resized_shape)
    
    if len(input_image.shape) == 3:
        height, width, channels = input_image.shape
    else:
        height, width = input_image.shape
        channels = 1
    
    if channels > 1:
        new_image[11:213, 11:213, :] = resize(input_image, (202, 202, 3))
    else:
        new_image[11:213, 11:213, :] = resize(input_image, (202, 202, 1))

    return new_image

def downscale(input_image):
    height, width, channels = input_image.shape
    return resize(input_image[11:213, 11:213, :], (101, 101, channels))

In [None]:
def prepare_training_sample(path_to_train='../input/train/', bad_ids=None, sample_size=100):
    
    # Get list of images and masks. 
    image_files = glob.glob(path_to_train + 'images/*.png')
    extract_id = lambda x: x.split('.png')[0].split('/')[-1]
    image_ids = [extract_id(file) for file in image_files]

    if bad_ids is not None:
        images_ids = [id for id in image_ids if id not in bad_ids]
    
    sample_size = (sample_size if sample_size < len(image_ids) else len(image_ids))

    x = np.zeros(shape=(sample_size, 224, 224, 3))
    y = np.zeros(shape=(sample_size, 224, 224, 1))
    
    for index, id in enumerate(image_ids[:sample_size]):
        x[index,:,:,:] = upscale(np.array(load_img(path_to_train + 'images/' + id + '.png')) / 255)
        y[index,:,:,:] = upscale(np.array(load_img(path_to_train + 'masks/'  + id + '.png', grayscale=True)) / 255, resized_shape=(224,224,1))
    
    return x, y

In [None]:
x, y = prepare_training_sample(sample_size=4001) # Load everything (there are 4000 images)
salt_fraction = np.sum(np.sum(y, axis=1), axis=1)
salt_fraction = np.digitize(salt_fraction, np.linspace(0,1,10))

In [None]:
x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=0.1, stratify=salt_fraction)
del x, y
gc.collect()

In [None]:
# The metric function from: https://www.kaggle.com/shaojiaxin/u-net-with-simple-resnet-blocks-v2-new-loss
def get_iou_vector(A, B):
    batch_size = A.shape[0]
    metric = []
    for batch in range(batch_size):
        t, p = A[batch] > 0, B[batch] > 0
        
        intersection = np.logical_and(t, p)
        union = np.logical_or(t, p)
        iou = (np.sum(intersection > 0) + 1e-10 ) / (np.sum(union > 0) + 1e-10)
        thresholds = np.arange(0.5, 1, 0.05)
        
        s = []
        for thresh in thresholds:
            s.append(iou > thresh)
        
        metric.append(np.mean(s))
    return np.mean(metric)

def my_iou_metric(label, pred):
    return tf.py_func(get_iou_vector, [label, pred > 0.5], tf.float64)

def my_iou_metric_2(label, pred):
    return tf.py_func(get_iou_vector, [label, pred > 0], tf.float64)

In [None]:
def get_unet_resnet(input_shape, trainable=False):
    
    input_layer = Input(input_shape)
    resnet_base = ResNet50(input_shape=input_shape, include_top=False, 
                           input_tensor=input_layer, weights='imagenet')
    
    for l in resnet_base.layers:
        l.trainable = trainable

    conv1 = resnet_base.get_layer("activation_1").output
    conv2 = resnet_base.get_layer("activation_10").output
    conv3 = resnet_base.get_layer("activation_22").output
    conv4 = resnet_base.get_layer("activation_40").output
    conv5 = resnet_base.get_layer("activation_49").output

    middle_layer = Conv2D(256, (3,3), strides=(1,1), padding='same', activation='relu')(conv5)
    middle_layer = Conv2D(256, (3,3), strides=(1,1), padding='same', activation='relu')(middle_layer)    
    
    upconv_5 = UpSampling2D()(middle_layer)
    upconv_5 = Concatenate()([conv4, upconv_5])
    upconv_5 = Conv2D(256, (3,3), strides=(1,1), padding='same', activation='relu')(upconv_5)
    upconv_5 = BatchNormalization()(upconv_5)
    upconv_5 = Conv2D(256, (3,3), strides=(1,1), padding='same', activation='relu')(upconv_5)     
    upconv_5 = BatchNormalization()(upconv_5)

    upconv_4 = UpSampling2D()(upconv_5)
    upconv_4 = Concatenate()([conv3, upconv_4])
    upconv_4 = Conv2D(128, (3,3), strides=(1,1), padding='same', activation='relu')(upconv_4) 
    upconv_4 = BatchNormalization()(upconv_4)
    upconv_4 = Conv2D(128, (3,3), strides=(1,1), padding='same', activation='relu')(upconv_4) 
    upconv_4 = BatchNormalization()(upconv_4)
        
    upconv_3 = UpSampling2D()(upconv_4)
    upconv_3 = Concatenate()([ZeroPadding2D(((1,0),(0,1)))(conv2), upconv_3])
    upconv_3 = Conv2D(64, (3,3), strides=(1,1), padding='same', activation='relu')(upconv_3) 
    upconv_3 = BatchNormalization()(upconv_3)
    upconv_3 = Conv2D(64, (3,3), strides=(1,1), padding='same', activation='relu')(upconv_3) 
    upconv_3 = BatchNormalization()(upconv_3)
        
    upconv_2 = UpSampling2D()(upconv_3)
    upconv_2 = Concatenate()([conv1, upconv_2])
    upconv_2 = Conv2D(32, (3,3), strides=(1,1), padding='same', activation='relu')(upconv_2)     
    upconv_2 = BatchNormalization()(upconv_2)
    upconv_2 = Conv2D(32, (3,3), strides=(1,1), padding='same', activation='relu')(upconv_2) 
    upconv_2 = BatchNormalization()(upconv_2)
    
    upconv_1 = UpSampling2D()(upconv_2)
    upconv_1 = Concatenate()([input_layer, upconv_1])
    upconv_1 = Conv2D(16, (3,3), strides=(1,1), padding='same', activation='relu')(upconv_1) 
    upconv_1 = BatchNormalization()(upconv_1)
    upconv_1 = Conv2D(16, (3,3), strides=(1,1), padding='same', activation='relu')(upconv_1) 
    upconv_1 = BatchNormalization()(upconv_1)
    output_layer = Conv2D(1, (3,3), padding='same', activation='sigmoid')(upconv_1)

    model = Model(input_layer, output_layer)
    return model

In [None]:
# Credit: Github user kylemcdonald
# https://github.com/keras-team/keras/issues/1625
class TimedStopping(Callback):
    '''Stop training when enough time has passed.
    # Arguments
        seconds: maximum time before stopping.
        verbose: verbosity mode.
    '''
    def __init__(self, seconds=None, verbose=0):
        super(Callback, self).__init__()

        self.start_time = 0
        self.seconds = seconds
        self.verbose = verbose

    def on_train_begin(self, logs={}):
        self.start_time = time.time()

    def on_epoch_end(self, epoch, logs={}):
        if time.time() - self.start_time > self.seconds:
            self.model.stop_training = True
            if self.verbose:
                print('Stopping after %s seconds.' % self.seconds)

In [None]:
early_stopping = EarlyStopping(patience = 15, verbose=1)
model_checkpoint = ModelCheckpoint('keras.h5', save_best_only=True, verbose=1)
reduce_lr = ReduceLROnPlateau(factor=1e-1, patience=5, min_lr=1e-6, verbose=1)
timed_stopping = TimedStopping(seconds = 1 * 3600) 

epochs = 100
batch_size = 32

model = get_unet_resnet((224, 224, 3), trainable = False)
optimizer = Adam(lr=1e-2)
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=[my_iou_metric])

history = model.fit(
    x_train, y_train, 
    validation_data=[x_valid, y_valid], 
    epochs=epochs,
    batch_size=batch_size,
    callbacks=[early_stopping, model_checkpoint, reduce_lr, timed_stopping], 
    shuffle=True)

In [None]:
epochs = 100
batch_size = 32
timed_stopping = TimedStopping(seconds = 4 * 3600) 

for layer in model.layers:
    layer.trainable = True

model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=[my_iou_metric])
history = model.fit(
    x_train, y_train, 
    validation_data=[x_valid, y_valid], 
    epochs=epochs,
    batch_size=batch_size,
    callbacks=[early_stopping, model_checkpoint, reduce_lr, timed_stopping], 
    shuffle=True)

In [None]:
fig = plt.figure( figsize=(8,6) )
ax = fig.add_subplot(1,1,1)
ax.plot(history.history['loss'], label='Train')
ax.plot(history.history['val_loss'], label='Validation')
ax.set_xlabel('Epoch')
ax.legend(frameon=False)

### Free Memory Up
If you're planning to do anything with the validation dataset, that should be added above this cell.

In [None]:
del x_train, y_train, x_valid, y_valid
gc.collect()

In [None]:
model = load_model('keras.h5', custom_objects={'mean_iou':mean_iou})

In [None]:
def rle_encode(im):
    pixels = im.flatten(order = 'F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
def batch_predict(testing_ids, testing_path, batch_size=1000, threshold=0.5, model=None):
    ''' Predict batches of images to save memory. '''
        
    total_batches = len(testing_ids) // batch_size
    
    data_store = []
    for batch_index in tqdm_notebook(range(total_batches), total=total_batches):
        if batch_index < (total_batches - 1):
            ids = testing_ids[batch_index * batch_size:(batch_index + 1) * batch_size]
        else:            
            ids = testing_ids[batch_index * batch_size:]

        test_df = pd.DataFrame({'id':ids})
        test_df['image'] = [np.array(load_img("{}test/images/{}.png".format(testing_path, id))) / 255 for id in ids]
        test_df['image'] = test_df['image'].apply(upscale)
        x_batch = np.array(test_df['image'].values.tolist()).reshape(len(ids), 224, 224, 3)
        y_pred = 0.5 * (model.predict(x_batch) + model.predict(x_batch[:,:,::-1,:])[:,:,::-1,:])
        y_pred = np.round(y_pred > threshold)
        y_pred = np.array([downscale(y) for y in y_pred], dtype=np.int8)
        test_df['rle_mask'] = [rle_encode(y) for y in y_pred]
        data_store.append(test_df.drop(columns=['image']))
        
    return pd.concat(data_store)

In [None]:
path_to_test = '../input/'
extract_id = lambda x: x.split('.png')[0].split('/')[-1]
testing_ids = [extract_id(f) for f in glob.glob(path_to_test+'test/images/*.png')]

In [None]:
pred = batch_predict(testing_ids, path_to_test, batch_size=100, model=model, threshold=0.5)
pred.to_csv('submission.csv', index=False)