<a href="https://colab.research.google.com/github/mahdiislam79/Image_segmentation_practice/blob/main/Making_a_U_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [31]:
# u-net model

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.optimizers import Adam

In [32]:
def conv_block(prev, filters, dropout):
  conv = Conv2D(filters, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(prev)
  conv = Dropout(dropout)(conv)
  conv = Conv2D(filters, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(conv)
  pool = MaxPooling2D(2, 2)(conv)
  
  return pool


In [33]:
def conv_transpose_block(prev, filters, dropout, spat_info):

  conv_t = Conv2DTranspose(filters, (2,2), strides = (2, 2), padding = 'same')(prev)
  conv_t = concatenate([conv_t,spat_info])
  conv_t = Conv2D(filters, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(conv_t)
  conv_t = Dropout(dropout)(conv_t)
  conv_t = Conv2D(filters, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(conv_t)

  return conv_t


In [34]:
inputs = Input((256, 256, 1))
s = inputs

In [15]:
def unet_model_1(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS, n_classes):

  # Build the model 
  inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
  s = inputs

  conv1 = conv_block(filters = 16, prev = s, dropout = 0.1)

  conv2 = conv_block(filters = 32, prev = conv1, dropout = 0.1)

  conv3 = conv_block(filters = 64, prev = conv2, dropout = 0.2)

  conv4 = conv_block(filters = 128, prev = conv3, dropout = 0.2)

  conv5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(conv4)
  conv5 = Dropout(0.3)(conv5)
  conv5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(conv5)

  u6 = Conv2DTranspose(128, (2,2), strides = (2, 2), padding = 'same')(conv5)
  u6 = concatenate([u6,conv4])
  c6 = Conv2D(128, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(u6)
  c6 = Dropout(0.2)(c6)
  c6 = Conv2D(128, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(c6)

  u7 = Conv2DTranspose(64, (2,2), strides = (2, 2), padding = 'same')(c6)
  u7 = concatenate([u7,conv3])
  c7 = Conv2D(64, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(u7)
  c7 = Dropout(0.2)(c7)
  c7 = Conv2D(64, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(c7)

  u8 = Conv2DTranspose(32, (2,2), strides = (2, 2), padding = 'same')(c7)
  u8 = concatenate([u8,conv2])
  c8 = Conv2D(32, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(u8)
  c8 = Dropout(0.1)(c8)
  c8 = Conv2D(32, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(c8)

  u9 = Conv2DTranspose(16, (2,2), strides = (2, 2), padding = 'same')(c8)
  u9 = concatenate([u9,conv1])
  c9 = Conv2D(16, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(u9)
  c9 = Dropout(0.1)(c9)
  c9 = Conv2D(16, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(c9)

  # conv_t6 = conv_transpose_block(filters = 128, prev = conv5, dropout = 0.2, spat_info = conv4)

  # conv_t7 = conv_transpose_block(filters = 64, prev = conv_t6, dropout = 0.2, spat_info = conv3)

  # conv_t8 = conv_transpose_block(filters = 32, prev = conv_t7, dropout=0.1, spat_info = conv2)

  # conv_t9 = conv_transpose_block(filters = 16, prev = conv_t8, dropout = 0.1, spat_info = conv1)

  outputs = Conv2D(n_classes, (1,1), activation = 'sigmoid')(c9)

  model = Model(inputs = [inputs], outputs = [outputs])

  model.compile(optimizer = Adam(learning_rate= 1e-3), loss = 'binary_crossentropy', metrics = ['accuracy']) # Change the loss based on the n_classes. If n_classes = binary then sigmoid will be used otherwise 'softmax'

  model.summary()

  return model

In [21]:
def unet_model(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS, n_classes):
  # Build the model
  inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
  # s = Lambda(lambda x: x / 255)(inputs) # No need for this if we normalize our inputs beforehand
  s = inputs 

  # Contraction path 
  c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(s)
  c1 = Dropout(0.1)(c1)
  c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(c1)
  p1 = MaxPooling2D(2, 2)(c1)

  c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(p1)
  c2 = Dropout(0.1)(c2)
  c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(c2)
  p2 = MaxPooling2D(2, 2)(c2)

  c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(p2)
  c3 = Dropout(0.2)(c3)
  c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(c3)
  p3 = MaxPooling2D(2, 2)(c3)

  c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(p3)
  c4 = Dropout(0.2)(c4)
  c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(c4)
  p4 = MaxPooling2D(2, 2)(c4)

  c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(p4)
  c5 = Dropout(0.3)(c5)
  c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same')(c5)

  # Expansive path 
  # Can use Conv2DTranspose or UpSampling 

  u6 = Conv2DTranspose(128, (2,2), strides = (2, 2), padding = 'same')(c5)
  u6 = concatenate([u6,c4])
  c6 = Conv2D(128, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(u6)
  c6 = Dropout(0.2)(c6)
  c6 = Conv2D(128, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(c6)

  u7 = Conv2DTranspose(64, (2,2), strides = (2, 2), padding = 'same')(c6)
  u7 = concatenate([u7,c3])
  c7 = Conv2D(64, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(u7)
  c7 = Dropout(0.2)(c7)
  c7 = Conv2D(64, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(c7)

  u8 = Conv2DTranspose(32, (2,2), strides = (2, 2), padding = 'same')(c7)
  u8 = concatenate([u8,c2])
  c8 = Conv2D(32, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(u8)
  c8 = Dropout(0.1)(c8)
  c8 = Conv2D(32, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(c8)

  u9 = Conv2DTranspose(16, (2,2), strides = (2, 2), padding = 'same')(c8)
  u9 = concatenate([u9,c1])
  c9 = Conv2D(16, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(u9)
  c9 = Dropout(0.1)(c9)
  c9 = Conv2D(16, (3, 3), activation = 'relu', kernel_initializer='he_uniform', padding = 'same')(c9)

  outputs = Conv2D(n_classes, (1,1), activation = 'sigmoid')(c9)

  model = Model(inputs = [inputs], outputs = [outputs])

  model.compile(optimizer = Adam(learning_rate= 1e-3), loss = 'binary_crossentropy', metrics = ['accuracy']) # Change the loss based on the n_classes. If n_classes = binary then sigmoid will be used otherwise 'softmax'

  model.summary()

  return model



In [23]:
my_unet_model = unet_model(IMG_HEIGHT = 256, IMG_WIDTH = 256, IMG_CHANNELS = 1, n_classes=1)

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_6 (InputLayer)           [(None, 256, 256, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_48 (Conv2D)             (None, 256, 256, 16  160         ['input_6[0][0]']                
                                )                                                                 
                                                                                                  
 dropout_23 (Dropout)           (None, 256, 256, 16  0           ['conv2d_48[0][0]']              
                                )                                                           