In [2]:
from keras import backend as K

import warnings
warnings.filterwarnings("ignore")

In [3]:
# Dice metrics

def dice_coef(y_true, y_pred, smooth=1):
  intersection = K.sum(y_true * y_pred, axis=[1,2,3])
  union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
  dice = K.mean((2. * intersection + smooth)/(union + smooth), axis=0)
  return dice

def dice_loss(y_true, y_pred):
  numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=(1,2,3))
  denominator = tf.reduce_sum(y_true + y_pred, axis=(1,2,3))

  return 1 - numerator / denominator

In [4]:
# Creating the model

def apply_unet(pretrained_W = None, input_size = (256,256,1)):
    Inputs = Input(input_size)

    Conv1 = Conv2D(64,3, activation='relu',padding='same')(Inputs)
    Conv1 = Conv2D(64,3, activation='relu',padding='same')(Conv1)
    Pool1 = MaxPool2D(pool_size=(2,2))(Conv1)

    Conv2 = Conv2D(128,3, activation='relu',padding='same')(Pool1)
    Conv2 = Conv2D(128,3, activation='relu',padding='same')(Conv2)
    Pool2 = MaxPool2D(pool_size=(2,2))(Conv2)

    Conv3 = Conv2D(256,3, activation='relu',padding='same')(Pool2)
    Conv3 = Conv2D(256,3, activation='relu',padding='same')(Conv3)
    Pool3 = MaxPool2D(pool_size=(2,2))(Conv3)

    Conv4 = Conv2D(512,3, activation='relu',padding='same')(Pool3)
    Conv4 = Conv2D(512,3, activation='relu',padding='same')(Conv4)
    Pool4 = MaxPool2D(pool_size=(2,2))(Conv4)

    Conv5 = Conv2D(1024,3, activation='relu',padding='same')(Pool4)
    Conv5 = Conv2D(1024,3, activation='relu',padding='same')(Conv5)
    Up6 = UpSampling2D(size=(2,2))(Conv5)

    Conv6 = Conv2D(512,2, activation='relu',padding='same')(Up6)
    merge7 = concatenate([Conv4,Conv6], axis=3)

    Conv7 = Conv2D(512,2, activation='relu',padding='same')(merge7)
    Conv7 = Conv2D(512,2, activation='relu',padding='same')(Conv7)
    Up8 = UpSampling2D(size=(2,2))(Conv7)

    Conv8 = Conv2D(256,3, activation='relu', padding='same')(Up8)
    merge9 = concatenate([Conv3,Conv8], axis=3)

    Conv10 = Conv2D(256,3, activation='relu', padding='same')(merge9)
    Conv10 = Conv2D(256,3, activation='relu', padding='same')(Conv10)
    Up11 = UpSampling2D(size=(2,2))(Conv10)

    Conv12 = Conv2D(128,3, activation='relu', padding='same')(Up11)
    merge13 = concatenate([Conv2,Conv12], axis=3)

    Conv14 = Conv2D(128,3, activation='relu', padding='same')(merge13)
    Conv14 = Conv2D(128,3, activation='relu', padding='same')(Conv14)
    Up15 = UpSampling2D(size=(2,2))(Conv14)

    Conv16 = Conv2D(64,3, activation='relu', padding='same')(Up15)
    merge17 = concatenate([Conv1,Conv16], axis=3)

    Conv18 = Conv2D(64,3, activation='relu', padding='same')(merge17)
    Conv18 = Conv2D(64,3, activation='relu', padding='same')(Conv18)

    Conv19 = Conv2D(2,3,activation='relu', padding='same')(Conv18)
    Conv20 = Conv2D(1,1, activation='sigmoid')(Conv19)

    model = Model(input = Inputs, output = Conv20)

    model.compile(optimizer = Adam(1e-4), loss = [dice_loss], metrics= [dice_coef,'accuracy'])


    model.summary()

    if pretrained_W :
        model.load_weights(pretrained_W)

    return model
