# UNET

![UNET](https://doimages.nyc3.cdn.digitaloceanspaces.com/010AI-ML/content/images/2021/05/image-26.png)


In [3]:
# !which python


In [4]:
from tensorflow.keras.layers import (
    Conv2D, MaxPool2D, UpSampling2D, Concatenate, Conv2DTranspose,
    BatchNormalization, Dropout, Lambda, Activation
)
from tensorflow.keras.models import Model
from tensorflow.keras import Input

In [5]:
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

In [6]:
x = Input((256, 256, 3))
y = conv_block(x, 32)
print(y.shape)

(None, 256, 256, 32)


### **Encoder block**

In [7]:
def encoder_block(inputs, num_filters):
    x = conv_block(inputs, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p

In [8]:
x = Input((256, 256, 3))
s, p = encoder_block(x, 32)
print(s.shape, p.shape)

(None, 256, 256, 32) (None, 128, 128, 32)


### **Decoder block**

In [9]:
def decoder_block(inputs, skip, num_filters):
  x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
  x = Concatenate()([x, skip])
  x = conv_block(x, num_filters)
  return x

In [10]:
x = Input((256, 256, 3))
s = Input((512, 512, 3))
y = decoder_block(x, s, 32)
print(y.shape)

(None, 512, 512, 32)


### **UNET**

In [11]:
def build_unet(input_shape):
  inputs = Input(input_shape)

  # Encoder
  s1, p1 = encoder_block(inputs, 64)
  s2, p2 = encoder_block(p1, 128)
  s3, p3 = encoder_block(p2, 256)
  s4, p4 = encoder_block(p3, 512)

  # Bridge
  b1 = conv_block(p4, 1024)
  # print(s1.shape, s2.shape, s3.shape, s4.shape)

  #Decoder
  d1 = decoder_block(b1, s4, 512)
  d2 = decoder_block(d1, s3, 256)
  d3 = decoder_block(d2, s2, 128)
  d4 = decoder_block(d3, s1, 64)

  outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
  # print(outputs.shape)
  model = Model(inputs, outputs, name="UNET")
  return model

### **Run the Model**

In [12]:
input_shape = (256, 256, 3)
model = build_unet(input_shape)

In [13]:
model.summary()