In [1]:
from keras.models import Model
from keras.layers import Dense, Dropout, Activation, Flatten, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose,Cropping2D,ZeroPadding2D
from keras.layers import Input,  UpSampling2D,BatchNormalization
from keras.callbacks import ModelCheckpoint,CSVLogger
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
import tensorflow as tf
from numba import cuda

def unet():
    
    inputs = Input((128 , 128, 2))
    
    conv1 = Conv2D(64, (3, 3), padding='same', activation="relu", data_format="channels_last") (inputs)
    batch1 = BatchNormalization(axis=1)(conv1)
    conv1 = Conv2D(64, (3, 3), padding='same', activation='relu', data_format="channels_last") (batch1)
    batch1 = BatchNormalization(axis=1)(conv1)
    pool1 = MaxPooling2D((2, 2))(batch1)

    conv2 = Conv2D(128, (3, 3), padding='same', activation="relu", data_format="channels_last") (pool1)
    batch2 = BatchNormalization(axis=1)(conv2)
    conv2 = Conv2D(128, (3, 3), padding='same', activation="relu", data_format="channels_last") (batch2)
    batch2 = BatchNormalization(axis=1)(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2),padding='same')(batch2)
    
    conv3 = Conv2D(256, (3, 3), padding='same', activation="relu", data_format="channels_last") (pool2)
    batch3 = BatchNormalization(axis=1)(conv3)
    conv3 = Conv2D(256, (3, 3), padding='same', activation="relu", data_format="channels_last") (batch3)
    batch3 = BatchNormalization(axis=1)(conv3)
    pool3 = MaxPooling2D((2, 2),padding='same')(batch3)
    
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', data_format="channels_last") (pool3)
    batch4 = BatchNormalization(axis=1)(conv4)
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', data_format="channels_last") (batch4)
    batch4 = BatchNormalization(axis=1)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2),padding='same')(batch4)
    
    conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same', data_format="channels_last") (pool4)
    batch5 = BatchNormalization(axis=1)(conv5)
    conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same', data_format="channels_last") (batch5)
    batch5 = BatchNormalization(axis=1)(conv5)
    
    up6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same', data_format="channels_last") (batch5)
    up6 = concatenate([up6, conv4], axis=1)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', data_format="channels_last") (up6)
    batch6 = BatchNormalization(axis=1)(conv6)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', data_format="channels_last") (batch6)
    batch6 = BatchNormalization(axis=1)(conv6)
    
    up7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same', data_format="channels_last") (batch6)
    up7 = concatenate([up7, conv3], axis=1)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', data_format="channels_last") (up7)
    batch7 = BatchNormalization(axis=1)(conv7)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', data_format="channels_last") (batch7)
    batch7 = BatchNormalization(axis=1)(conv7)
    
    up8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same', data_format="channels_last") (batch7)
    up8 = concatenate([up8, conv2], axis=1)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same', data_format="channels_last") (up8)
    batch8 = BatchNormalization(axis=1)(conv8)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same', data_format="channels_last") (batch8)
    batch8 = BatchNormalization(axis=1)(conv8)
    
    up9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same', data_format="channels_last") (batch8)
    up9 = concatenate([up9, conv1], axis=1)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same', data_format="channels_last") (up9)
    batch9 = BatchNormalization(axis=1)(conv9)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same', data_format="channels_last") (batch9)
    batch9 = BatchNormalization(axis=1)(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid',data_format="channels_last")(conv1)

    model = Model(inputs=[inputs], outputs=[conv10])

    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

    return model