In [1]:
import tensorflow as tf 
from tensorflow import keras
from keras import layers
import tensorflow_datasets as tfds 

In [7]:
def conv2d_block(inputs, n_filters, filter_size=3):
  x = inputs 
  for i in range(2):
    x = layers.Conv2D(filters = n_filters, kernel_size= filter_size, activation="relu", padding="same", kernel_initializer="he_normal")(x)

  return x


def encoder_block(inputs, n_filters, filter_size, pool_size=(2,2), dropout_rate = 0.3): 
  f = conv2d_block(inputs, n_filters, filter_size)
  p = layers.MaxPooling2D(pool_size = pool_size)(f)
  p = layers.Dropout(dropout_rate)(p)
  
  return f, p


def encoder(inputs):
  f1, p1 = encoder_block(inputs, n_filters=64, filter_size=(3,3), pool_size=(2,2), dropout_rate=0.3)
  f2, p2 = encoder_block(p1, n_filters = 128, filter_size=(3,3), pool_size=(2,2), dropout_rate=0.3)
  f3, p3 = encoder_block(p2, n_filters = 256, filter_size=(3,3), pool_size=(2,2), dropout_rate=0.3)
  f4, p4 = encoder_block(p3, n_filters = 512, filter_size=(3,3), pool_size=(2,2), dropout_rate=0.3)

  #output of last pooling, output of all convolutions.
  return p4, (f1, f2, f3, f4) 

#Bottleneck

In [8]:
def bottleneck(inputs):
  x = conv2d_block(inputs, 1024, filter_size=(3,3))
  return x 

#Decoder

In [34]:
def decoder_block(inputs, conv_output, n_filters, kernel_size, strides, dropout=0.3):
  '''
  Args:
    inputs (tensor) -- batch of input features
    conv_output (tensor) -- features from an encoder block
    n_filters (int) -- number of filters
    kernel_size (int) -- kernel size
    strides (int) -- strides for the deconvolution/upsampling
    padding (string) -- "same" or "valid", tells if shape will be preserved by zero padding

  Returns:
    c (tensor) -- output features of the decoder block
  '''
  u = layers.Conv2DTranspose(filters=n_filters, kernel_size=kernel_size, strides=strides, padding="same")(inputs)
  c = layers.concatenate([u, conv_output])
  c = layers.Dropout(dropout)(c)
  c = conv2d_block(c, n_filters, kernel_size)
  return c


def decoder(inputs, convs, output_channels):
  f1, f2, f3, f4 = convs 

  c6 = decoder_block(inputs, f4, n_filters=512, kernel_size=(3,3), strides=(2,2), dropout=0.3)
  c7 = decoder_block(c6, f3, n_filters=256, kernel_size= (3,3), strides = (2,2), dropout=0.3)
  c8 = decoder_block(c7, f2, n_filters=128, kernel_size=(3,3), strides=(2,2), dropout=0.3)
  c9 = decoder_block(c8, f1, n_filters=64, kernel_size=(3,3), strides=(2,2), dropout=0.3)

  outputs = layers.Conv2D(output_channels, (1,1), activation="softmax")(c9) #output channels are the number of classes.
  return outputs

#Final Model

In [35]:
OUTPUT_CHANNELS = 3

def unet():
  '''
  Defines the UNet by connecting the encoder, bottleneck and decoder.
  '''

  #specify the input shape 
  inputs = layers.Input(shape=(128,128,3,))

  #feed the inputs to the encoder. 
  encoder_output, convs = encoder(inputs)

  #feed the encoder output to the bottleneck
  bottle_neck = bottleneck(encoder_output)

  #feed the conv and bottleneck outputs to the decoder
  outputs = decoder(bottle_neck, convs, output_channels=OUTPUT_CHANNELS)

  #create the model. 
  model = tf.keras.Model(inputs = inputs, outputs=outputs)

  return model

In [36]:
#instantiate model 
model = unet()

#see the resulting architecture
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_12 (InputLayer)           [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d_1213 (Conv2D)            (None, 128, 128, 64) 1792        input_12[0][0]                   
__________________________________________________________________________________________________
conv2d_1214 (Conv2D)            (None, 128, 128, 64) 36928       conv2d_1213[0][0]                
__________________________________________________________________________________________________
max_pooling2d_44 (MaxPooling2D) (None, 64, 64, 64)   0           conv2d_1214[0][0]                
____________________________________________________________________________________________