# Using U-net style architecture for semantic segmentation
I will train a u-net-style architecture on the labeled city dataset from 

In [1]:
# get colab status
try:
  import google.colab
  IN_COLAB = True
  %tensorflow_version 2.x
except:
  IN_COLAB = False

TensorFlow 2.x selected.


In [0]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models
from tensorflow.keras.layers import (Dense, Conv2D, Flatten, Dropout, 
MaxPooling2D, BatchNormalization, Conv2DTranspose, concatenate, Input)
from tensorflow.keras.optimizers import RMSprop
from tensorflow.python.keras.utils.data_utils import Sequence # to fix 'imagedatagenerator has no shape' error
import os

#data load


Modeling

In [0]:
img_shape = (256, 256, 3)

In [0]:
#use function for the blocks of the network
def encoder_builder(input_, filters,
                         activ='relu', kernel=(3,3), 
                         drop=.5, pad='same', kern_init='he_uniform'):
  kwargs = {'filters': filters, 'activation': activ, 'kernel_size': kernel, 
       'padding': pad, 'kernel_initializer': kern_init}
  x = Conv2D(**kwargs)(input_)
  x = BatchNormalization()(x)
  x = Dropout(drop)(x)
  x = Conv2D(**kwargs)(x)
  encoder = Dropout(drop)(x)
  pooled = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(encoder)
  return encoder, pooled

def decoder_builder(input_, skip, filters, 
                    activ='relu', kernel=(2,2),
                    drop=.5, pad='same', kern_init='he_uniform',
                    ):
  kwargs = {'filters': filters, 'activation': activ, 'kernel_size': kernel, 
       'padding': pad, 'kernel_initializer': kern_init} 
  x = Conv2DTranspose(**kwargs, strides=(2,2))(input_)
  x = concatenate([skip, x], axis=1)
  x = BatchNormalization()(x)
  x = Dropout(drop)(x)
  x = Conv2D(**kwargs)(x)
  x = BatchNormalization()(x)
  x = Dropout(drop)(x)
  x = Conv2D(**kwargs)(x)
  return x

In [0]:
input_layer = Input(shape=img_shape)
encoder1, pooled1 = encoder_builder(input_layer, filters=32)  # return (128x128x32)
encoder2, pooled2 = encoder_builder(pooled1, filters=64)  # return (64x64x364)
encoder3, pooled3 = encoder_builder(pooled2, filters=128)  # return (32x32x128)
encoder4, pooled4 = encoder_builder(pooled3, filters=256)  # return (16x16x256)
encoder5, pooled5 = encoder_builder(pooled4, filters=512)  # return (8x8x512)
middle, _ = encoder_builder(pooled5, filters=1024)  # return (4x4x1024)
decoder512 = decoder_builder(middle, skip=encoder5, filters=512)
decoder256 = decoder_builder(decoder512, skip=encoder4, filters=256)
decoder128 = decoder_builder(decoder256, skip=encoder3, filters=128)
decoder64 = decoder_builder(decoder128, skip=encoder2, filters=64)
decoder32 = decoder_builder(decoder64, skip=encoder1, filters=32)
out_layer = Conv2D(1, (1, 1), activation='sigmoid')(decoder32)

In [0]:
unet = models.Model(inputs=[input_layer], outputs=[out_layer])

In [43]:
unet.compile(optimizer='adam', loss = 'binary_crossentropy', metrics = ['accuracy'])

unet.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_12 (InputLayer)           [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d_140 (Conv2D)             (None, 256, 256, 32) 896         input_12[0][0]                   
__________________________________________________________________________________________________
batch_normalization_88 (BatchNo (None, 256, 256, 32) 128         conv2d_140[0][0]                 
__________________________________________________________________________________________________
dropout_136 (Dropout)           (None, 256, 256, 32) 0           batch_normalization_88[0][0]     
____________________________________________________________________________________________

In [44]:
os.getcwd()

'/content'