In [None]:
"""
A UNet model (U-shape) consists of an encoder (downsampler) and decoder (upsampler) with a bottleneck in between. 
It uses the skip connections that concatenate encoder block outputs to each stage of the decoder. 
"""

In [5]:
import os 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from keras.layers import Input, ZeroPadding2D, Add, concatenate, \
Conv2D, MaxPooling2D, AveragePooling2D, Conv2DTranspose, BatchNormalization, \
Dropout, Flatten, Dense, Activation
        


In [1]:
# Encoder

def conv2d_block(input, n_filters, kernel_size = 3):

  x = input

  for i in range(2):
    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
            kernel_initializer = 'he_normal', padding = 'same')(x)
    x = Activation('relu')(x)

  return x


def encoder_block(inputs, n_filters=64, pool_size=(2,2), dropout=0.3):
  f = conv2d_block(inputs, n_filters=n_filters)
  p = MaxPooling2D(pool_size=pool_size)(f)
  p = Dropout(0.3)(p)

  return f, p


def encoder(inputs):
  f1, p1 = encoder_block(inputs, n_filters=64, pool_size=(2,2), dropout=0.3)
  f2, p2 = encoder_block(p1, n_filters=128, pool_size=(2,2), dropout=0.3)
  f3, p3 = encoder_block(p2, n_filters=256, pool_size=(2,2), dropout=0.3)
  f4, p4 = encoder_block(p3, n_filters=512, pool_size=(2,2), dropout=0.3)

  return p4, (f1, f2, f3, f4)

In [8]:
# bottleneck
def bottleneck(inputs):
  
  bottle_neck = conv2d_block(inputs, n_filters=1024)

  return bottle_neck

In [3]:
# Decoder

def decoder_block(inputs, conv_output, n_filters=64, kernel_size=3, strides=3, dropout=0.3):
  
  u = Conv2DTranspose(n_filters, kernel_size, strides = strides, padding = 'same')(inputs)
  c = concatenate([u, conv_output])
  c = Dropout(dropout)(c)
  c = conv2d_block(c, n_filters, kernel_size=3)

  return c


def decoder(inputs, convs, output_channels):

  f1, f2, f3, f4 = convs
  c6 = decoder_block(inputs, f4, n_filters=512, kernel_size=(3,3), strides=(2,2), dropout=0.3)
  c7 = decoder_block(c6, f3, n_filters=256, kernel_size=(3,3), strides=(2,2), dropout=0.3)
  c8 = decoder_block(c7, f2, n_filters=128, kernel_size=(3,3), strides=(2,2), dropout=0.3)
  c9 = decoder_block(c8, f1, n_filters=64, kernel_size=(3,3), strides=(2,2), dropout=0.3)

  outputs = Conv2D(output_channels, (1, 1), activation='softmax')(c9)

  return outputs

In [6]:
OUTPUT_CHANNELS = 3

def unet():
  inputs = Input(shape=(128, 128,3,))

  encoder_output, convs = encoder(inputs)
  bottle_neck = bottleneck(encoder_output)
  outputs = decoder(bottle_neck, convs, output_channels=OUTPUT_CHANNELS)

  model = Model(inputs=inputs, outputs=outputs)
  return model

In [9]:
model = unet()
model.summary()