# U-Net Model for Segmentation

We will convert this notebook into script since this is just the model itself, no data is attached. We need to compile and train our model and then record loss and metrics with our predictions

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Model

In [4]:
def double_conv (x, c_out): #x is the input tensor, c_out is the number of output channels
    x = layers.Conv2D(c, 3, padding="same", use_bias =False)(x) # 3x3 2D convolution with equal output height/width
    x = layers.BatchNormalization()(x) # Normalizes the feature maps so that each channel has a stable mean and variance
    x = layers.ReLU()(x) # ReLU activation
    x = layers.Conv2D(c_out, 3, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

In [None]:
def UNetSmallTF(input_shape=(256,256,1), num_classes=3, base=32): # takes 256x256 greyscale images with 3 classes and a baseline of 32 filters
    inputs = layers.Input(shape=input_shape) # input tensor


    #Encoder - each stage downsamples via MaxPooling, reducing spacial resolution and extracting features
    
    c1 = double_conv(inputs, base)
    p1 = layers.MaxPool2D(2)(c1) # 128 x 128 x 32, MaxPooling reduces resolution by 2x

    c2 = double_conv(p1, base*2) # 64 filters
    p2 = layers.MaxPool2D(2)(c2) # 64 x 64 x 64

    c3 = double_conv(p2, base*4) # 128 filters
    p3 = layers.MaxPool2D(2)(c3) # 32 x 32 x 128

    # Bottleneck
    bn = double_conv(p3, base*8) # stops at 256 filters since it's at the bottom of the U

    # Decoder - upsampling the features back into the input resolution
    
    u3 = layers.Conv2DTranspose(base*4, 2, strides=2, padding="same")(bn) # learnable upsampling, double spacial size 32x32 -> 64x64
    u3 = layers.Concatenate()([u3, c3]) # concatenate with c3 to reintroduce spatial features using, c3 is a "skip connection"
    u3 = double_conv(u3, base*4) 

    u2 = layers.Conv2DTranspose(base*2, 2, strides=2, padding="same")(u3)
    u2 = layers.Concatenate()([u2, c2])
    u2 = double_conv(u2, base*2)

    u1 = layers.Conv2DTranspose(base, 2, strides=2, padding="same")(u2)
    u1 = layers.Concatenate()([u1, c1])
    u1 = double_conv(u1, base)

    logits = layers.Conv2D(num_classes, 1, padding="same")(u1)  # a 1x1 convolution acts as a linear classifier at each pixel
    return Model(inputs, logits, name="UNetSmallTF") 