In [1]:
## encoder : kernel size = 4, stride = 2, leakyReLU(0.2)
from keras import layers, Input, Model

class EncodeBlock(layers.Layer):
    def __init__(self, n_filters, use_bn=True):
        super(EncodeBlock, self).__init__()
        self.use_bn = use_bn       
        self.conv = layers.Conv2D(n_filters, 4, 2, "same", use_bias=False)
        self.batchnorm = layers.BatchNormalization()
        self.lrelu= layers.LeakyReLU(0.2)

    def call(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.batchnorm(x)
        return self.lrelu(x)

class DecodeBlock(layers.Layer):
    def __init__(self, f, dropout=True):
        super(DecodeBlock, self).__init__()
        self.dropout = dropout
        self.Transconv = layers.Conv2DTranspose(f, 4, 2, "same", use_bias=False)
        self.batchnorm = layers.BatchNormalization()
        self.relu = layers.ReLU()
        
    def call(self, x):
        x = self.Transconv(x)
        x = self.batchnorm(x)
        if self.dropout:
            x = layers.Dropout(.5)(x)
        return self.relu(x)

In [7]:
inp = Input((512,512,1))
## down
x00 = EncodeBlock(64)(inp)
x10 = EncodeBlock(128)(x00)
x20 = EncodeBlock(256)(x10)
x30 = EncodeBlock(512)(x20)

## bottle
x40 = EncodeBlock(512)(x30)
print(x40.shape)

## middle-1
x10up = DecodeBlock(64)(x10)
x01 = layers.Concatenate()([x00, x10up])
x20up = DecodeBlock(128)(x20)
x11 = layers.Concatenate()([x10, x20up])
x30up = DecodeBlock(256)(x30)
x21 = layers.Concatenate()([x20, x30up])

## middle-2
x11up = DecodeBlock(64)(x11)
x02 = layers.Concatenate()([x00, x01, x11up])
x21up = DecodeBlock(128)(x21)
x12 = layers.Concatenate()([x10, x11, x21up])

## middle-3
x12up = DecodeBlock(64)(x12)
x03 = layers.Concatenate()([x00, x01, x02, x12up])

## up
x40up = DecodeBlock(512)(x40)
x31 = layers.Concatenate()([x30, x40up])
x31up = DecodeBlock(256)(x31)
x22 = layers.Concatenate()([x20, x21, x31up])
x22up = DecodeBlock(128)(x22)
x13 = layers.Concatenate()([x10, x11, x12, x22up])
x13up = DecodeBlock(64)(x13)
print(x13up.shape)
x04 = layers.Concatenate()([x00, x01, x02, x03, x13up])
x = layers.Conv2DTranspose(1, 4, 2, "same", use_bias=False)(x04)
print(x.shape)

(None, 16, 16, 512)
(None, 256, 256, 64)
(None, 512, 512, 1)
