In [None]:
import tensorflow.keras as keras

def conv_block(x, filters, kernel_size=(3, 3), padding='same', strides=1):
    x = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation='relu')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation='relu')(x)
    x = keras.layers.BatchNormalization()(x)
    return x

def unet(input_size, n_classes=1):
    inputs = keras.layers.Input(input_size)

    # Encoder
    c1 = conv_block(inputs, 64)
    p1 = keras.layers.MaxPooling2D((2, 2))(c1)

    c2 = conv_block(p1, 128)
    p2 = keras.layers.MaxPooling2D((2, 2))(c2)

    c3 = conv_block(p2, 256)
    p3 = keras.layers.MaxPooling2D((2, 2))(c3)

    c4 = conv_block(p3, 512)
    p4 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c4)

    c5 = conv_block(p4, 1024)

    # Decoder
    u6 = keras.layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = keras.layers.concatenate([u6, c4], axis=3)
    c6 = conv_block(u6, 512)

    u7 = keras.layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = keras.layers.concatenate([u7, c3], axis=3)
    c7 = conv_block(u7, 256)

    u8 = keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = keras.layers.concatenate([u8, c2], axis=3)
    c8 = conv_block(u8, 128)

    u9 = keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = keras.layers.concatenate([u9, c1], axis=3)
    c9 = conv_block(u9, 64)

    outputs = keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)

    model = keras.models.Model(inputs=[inputs], outputs=[outputs])

    return model