In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import pickle
import cv2

In [None]:
from keras.models import Model
from keras.layers import Input
from keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose, Cropping2D
from keras.layers import concatenate
from keras.optimizers import Adam
from keras.layers.normalization import BatchNormalization
from keras import backend as K

In [None]:
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.utils import multi_gpu_model

# Load Dataset

In [None]:
data_path = "./CamVid/"

train_path = "./CamVid/train/"
train_label_path = "./CamVid/trainannot/"

valid_path = "./CamVid/val/"
valid_label_path = "./CamVid/valannot/"

test_path = "./CamVid/test/"
test_label_path = "./CamVid/testannot/"

train_file = data_path + "train.p"
valid_file = data_path + "val.p"
test_file = data_path + "test.p"

In [None]:
with open(train_file, "rb") as f:
    X_train, y_train = pickle.load(f)
with open(valid_file, "rb") as f:
    X_val, y_val = pickle.load(f)
with open(test_file, "rb") as f:
    X_test, y_test = pickle.load(f)

In [None]:
plt.subplot(121),plt.imshow(X_train[0])
plt.subplot(122),plt.imshow(y_train[0]*127)

In [None]:
X_train.shape, y_train.shape

In [None]:
from matplotlib.gridspec import GridSpec
from random import randint
ranidx = randint(0, len(y_train))
gs = GridSpec(4,3)
plt.figure(dpi=200)
for i in range(4*3):
    plt.subplot(gs[i]), plt.imshow(y_train[ranidx][:,:,i]*255, cmap='gray')

# Build Model(U-Net)

In [None]:
num_classes = 12
input_shape = X_train.shape[1:] #(360,480,3)
smooth = 1.
parallel=True

In [None]:
model_checkpoint = ModelCheckpoint('model_UNET.hdf5', monitor='loss', save_best_only=True)
model_earlystopping = EarlyStopping(monitor='loss')

In [None]:
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

In [None]:
def encoder(x, layer_id, filters=64):
    s_id = 'encoder' + str(layer_id)
    
    x = Conv2D(filters, (3,3), activation='relu', padding='same', name='conv1_' + s_id)(x)
    x = Conv2D(filters, (3,3), activation='relu', padding='same', name='conv2_' + s_id)(x)
    x = BatchNormalization(name='BN_' + s_id)(x)
    xp = MaxPooling2D(pool_size=(2,2), name='pool_' + s_id)(x)
    
    return xp, x

In [None]:
def decoder(xp, x, layer_id, filters=32, cropfilters=((0,0),(0,0))):
    s_id = 'decoder' + str(layer_id)
    
    x = Conv2DTranspose(filters, (2,2), strides=(2,2), padding='same', name='dconv_' + s_id)(x)
    xp = Cropping2D(cropping=(cropfilters))(xp)
    x = concatenate([xp, x], axis=-1, name='concat'+s_id)
    x = Conv2D(filters, (3,3), activation='relu', padding='same', name='conv1_' + s_id)(x)
    x = Conv2D(filters, (3,3), activation='relu', padding='same', name='conv2_' + s_id)(x)
    
    return x

In [None]:
def unet(nb_classes=32, input_shape=(480,480,3)):
    inputs = Input((input_shape))
    x, x1 = encoder(inputs, layer_id=1, filters=64)
    x, x2 = encoder(x, layer_id=2, filters=128)
    x, x3 = encoder(x, layer_id=3, filters=256)
    x, x4 = encoder(x, layer_id=4, filters=512)

    x = Conv2D(1024, (3,3), activation='relu', padding='same', name='conv_layer5')(x)

    x = decoder(x4, x, layer_id=4, filters=512)
    x = decoder(x3, x, layer_id=3, filters=256)
    x = decoder(x2, x, layer_id=2, filters=128)
    x = decoder(x1, x, layer_id=1, filters=64)

    x = Conv2D(nb_classes, (1,1), activation='softmax', padding='same', name='output')(x)
    
    model = Model(inputs=inputs, outputs=x)
    model.summary()
    if parallel == True:
        parallel_model = multi_gpu_model(model, gpus=2)
        parallel_model.compile(optimizer='adam', loss=dice_coef_loss, metrics=[dice_coef])

    else:
        model.compile(optimizer='adam', loss=dice_coef_loss, metrics=[dice_coef])
        parallel_model = model
    return parallel_model

In [None]:
model = unet(nb_classes=12, input_shape=(480,480,3))

In [None]:
model.fit(X_train, bin_y_train, batch_size=1*2, epochs=20, verbose=1, shuffle=True,
          callbacks=[model_checkpoint, model_earlystopping])

In [None]:
pred = model.predict(X_test, verbose=1, batch_size=4)

In [None]:
plt.imshow(pred[140])

In [None]:
pred.shape