# Carvana U-Net

## Imports

In [None]:
from keras.optimizers import Adam
from keras.losses import binary_crossentropy
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard
from keras.layers import AveragePooling2D
import tensorflow as tf
import keras.backend as K
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from models.u_net import UNet
import utils
%load_ext autoreload
%autoreload 1
%aimport utils

## Preparing Data

In [None]:
input_size = 128
train_path = "input/train/{}.jpg" 
train_mask_path = "input/train_masks/{}_mask.gif"
df_train = pd.read_csv('input/train_masks.csv')
ids_train = df_train['img'].map(lambda s: s.split('.')[0])#[:3000]
ids_train_split, ids_valid_split = train_test_split(ids_train, test_size=0.2, random_state=42)

print('Training on {} samples'.format(len(ids_train_split)))
print('Validating on {} samples'.format(len(ids_valid_split)))

def train_generator(batch_size):
    return utils.train_generator(train_path, train_mask_path, ids_train_split, input_size, batch_size)

def valid_generator(batch_size):
    return utils.valid_generator(train_path, train_mask_path, ids_valid_split, input_size, batch_size)

## Loss Functions

In [None]:
def dice_value(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

#def bce_dice_loss(y_true, y_pred):
#    return binary_crossentropy(y_true, y_pred) + (1 - dice_value(y_true, y_pred))

def weighted_bce_loss(y_true, y_pred, weights):
    return K.mean(tf.nn.weighted_cross_entropy_with_logits(y_true, y_pred, weights), axis=-1)

def weighted_dice_value(y_true, y_pred, weights):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = y_true_f * y_pred_f
    weights = K.flatten(weights)
    w2 = weights * weights
    return (2. * K.sum(w2 * intersection) + smooth) / (K.sum(w2 * y_true_f) + K.sum(w2 * y_pred_f) + smooth)

def weighted_bce_dice_loss(y_true, y_pred):
    a = AveragePooling2D(pool_size=(11, 11), strides=1, padding='same')(y_true)
    ind = K.cast(K.greater_equal(a, 0.01), 'float32') * K.cast(K.less_equal(a, 0.99), 'float32')
    ind = K.cast(ind, 'float32')
    weights = K.ones_like(a)
    w0 = K.sum(weights)
    weights = weights + ind*2
    w1 = K.sum(weights)
    weights = weights/w1*w0
    return  weighted_bce_loss(y_true, y_pred, weights) + (1 - weighted_dice_value(y_true, y_pred, weights))

## Create Model

In [None]:
model = UNet((input_size, input_size, 3))
model.compile(optimizer=Adam(1e-3), loss=weighted_bce_dice_loss, metrics=[dice_value])

## Fit Model

In [None]:
epochs = 10
batch_size = 16
run_number = 2
weight_path = "weights/UNet-" + str(run_number) + "-{epoch:02d}-{dice_value:.4f}-{val_dice_value:.4f}.hdf5"

callbacks = [EarlyStopping(monitor='val_dice_value',
                           patience=8,
                           verbose=1,
                           min_delta=1e-4,
                           mode='max'),
             ReduceLROnPlateau(monitor='val_dice_value',
                               factor=0.1,
                               patience=4,
                               verbose=1,
                               epsilon=1e-4,
                               mode='max'),
             ModelCheckpoint(monitor='val_dice_value',
                             filepath=weight_path,
                             save_best_only=True,
                             save_weights_only=True,
                             mode='max'),
             TensorBoard(log_dir='logs/U{:d}'.format(run_number), batch_size=batch_size)]

#model.load_weights('weights/best_weights.hdf5')
#K.set_value(model.optimizer.lr, 0.01)

model.fit_generator(generator=train_generator(batch_size),
                    steps_per_epoch=np.ceil(float(len(ids_train_split)) / float(batch_size)),
                    epochs=epochs,
                    verbose=2,
                    callbacks=callbacks,
                    validation_data=valid_generator(batch_size),
                    validation_steps=np.ceil(float(len(ids_valid_split)) / float(batch_size)))

## Validation

In [None]:
def np_dice_value(y_true, y_pred):
    smooth = 1.
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)

### Prediction

In [None]:
model.load_weights("weights/UNet-2-09-0.9880-0.9923.hdf5")

val_imgs, val_masks = next(valid_generator(len(ids_valid_split)))
val_imgs = np.array(val_imgs)
val_masks = np.array(val_masks)
val_pred_masks = model.predict(val_imgs, batch_size=32)
masks_val_dices = [np_dice_value(mask, pred_mask) for (mask, pred_mask) in zip(val_masks, val_pred_masks)]

### Histogram

In [None]:
hist, bins = np.histogram(masks_val_dices, bins=50)
width = 0.7 * (bins[1] - bins[0])
center = (bins[:-1] + bins[1:]) / 2
plt.bar(center, hist, align='center', width=width)
plt.show()

### Visualization

In [None]:
%matplotlib inline
index = 10
img_path = train_path.format(ids_valid_split.values[index])
utils.show_mask(input_size, 
                val_masks[index].squeeze(axis=-1),
                val_pred_masks[index].squeeze(axis=-1),
                img_path)