<a href="https://colab.research.google.com/github/matician255/AlexNet-model/blob/main/U-net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPool2D, BatchNormalization, Concatenate, Activation, Conv2DTranspose, Input
from tensorflow.keras.models import Model

In [6]:
# convolutional block
def conv_block(inputs, num_filters):
  x = Conv2D(num_filters, 3, padding='same')(inputs)
  x = BatchNormalization()(x)
  x = Activation("relu")(x)

  x = Conv2D(num_filters, 3, padding='same')(inputs)
  x = BatchNormalization()(x)
  x = Activation("relu")(x)

  return x

# encoder block

def encoder_block(inputs, num_filters):
  x = conv_block(inputs, num_filters)
  p = MaxPool2D((2,2))(x)
  return x, p

def decoder_block(inputs, num_filters, skip_features):
  x = Conv2DTranspose(num_filters, (2,2), strides=2, padding='same')(inputs)
  x = Concatenate()([x, skip_features])
  x = conv_block(x, num_filters)
  return x



In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [7]:
# complete UNET architecture
def build_unet(input_shape):
  inputs = Input(input_shape)

  """Encoder"""
  s1, p1 = encoder_block(inputs, 64)
  s2, p2 = encoder_block(p1, 128)
  s3, p3 = encoder_block(p2, 256)
  s4, p4 = encoder_block(p3, 512)

  """Bridge"""
  b = conv_block(p4, 1024)

  """Decoder"""
  d1 = decoder_block(b, 512, s4)
  d2 = decoder_block(d1, 256, s3)
  d3 = decoder_block(d2, 128, s2)
  d4 = decoder_block(d3, 64, s1)

  """output"""
  outputs = Conv2D(1, (1,1), padding='same', activation='sigmoid')(d4)

  model = Model(inputs, outputs, name='U-NET')
# complete UNET architecture
def build_unet(input_shape):
  inputs = Input(input_shape)

  """Encoder"""
  s1, p1 = encoder_block(inputs, 64)
  s2, p2 = encoder_block(p1, 128)
  s3, p3 = encoder_block(p2, 256)
  s4, p4 = encoder_block(p3, 512)

  """Bridge"""
  b = conv_block(p4, 1024)

  """Decoder"""
  d1 = decoder_block(b, 512, s4)
  d2 = decoder_block(d1, 256, s3)
  d3 = decoder_block(d2, 128, s2)
  d4 = decoder_block(d3, 64, s1)

  """output"""
  outputs = Conv2D(1, (1,1), padding='same', activation='sigmoid')(d4)

  model = Model(inputs, outputs, name='U-NET')
  return model

if __name__ == '__main__':
  input_shape = (512, 512, 3)
  model = build_unet(input_shape)
  model.summary()
  model.summary()


