In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models

In [2]:
def z_sampling(argv):
    # by default, random_normal has mean=0 and std=1.0
    epsilon = tf.random_normal(shape=tf.shape(argv[0]))
    return argv[0] + tf.exp(0.5 * argv[1]) * epsilon

def VAE(latent_dim = 16, use_bias=True):
    # h x w
    _x = layers.Input(shape=(None, None, 1), name="Input")
    
    x = layers.Conv2D(64, 3, padding='same', activation='relu', use_bias=use_bias, name="Encoder_1")(_x)
    x = layers.MaxPool2D()(x)
    # h/2 x w/2
    
    x = layers.Conv2D(128, 3, padding='same', activation='relu', use_bias=use_bias, name="Encoder_2")(x)
    x = layers.MaxPool2D()(x)
    # h/4 x w/4
    
    x = layers.Conv2D(256, 3, padding='same', activation='relu', use_bias=use_bias, name="Encoder_3")(x)
    x = layers.MaxPool2D()(x)
    # h/8 x w/8
    
    x = layers.Conv2D(512, 3, padding='same', activation='relu', use_bias=use_bias, name="Encoder_4")(x)
    x = layers.MaxPool2D()(x)
    # h/16 x w/16
    
    x = layers.Conv2D(1024, 3, padding='same', activation='relu', use_bias=use_bias, name="Encoder_5")(x)
    
    # Spatial z_mean, z_log_var
    mean = layers.Conv2D(latent_dim, 1, name="z_mean")(x)
    log_var = layers.Conv2D(latent_dim, 1, name="z_log_var")(x)
    
    z = layers.Lambda(z_sampling)([mean, log_var])
    
    # h/8 x w/8
    x = layers.UpSampling2D(interpolation='bilinear')(z)
    x = layers.Conv2D(512, 3, padding='same', activation='relu', use_bias=use_bias, name="Decoder_4")(x)
    
    # h/4 x w/4
    x = layers.UpSampling2D(interpolation='bilinear')(x)
    x = layers.Conv2D(256, 3, padding='same', activation='relu', use_bias=use_bias, name="Decoder_3")(x)
    
    # h/2 x w/2
    x = layers.UpSampling2D(interpolation='bilinear')(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu', use_bias=use_bias, name="Decoder_2")(x)
    
    # h x w
    x = layers.UpSampling2D(interpolation='bilinear')(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu', use_bias=use_bias, name="Decoder_1")(x)
    
    # Prediction
    x = layers.Conv2D(1, 1, padding='same', activation='relu', use_bias=use_bias, name="Prediction")(x)
    
    return models.Model(inputs=_x, outputs = [mean, log_var, x])


In [3]:
model = VAE(latent_dim=128)
model.summary()

Instructions for updating:
Colocations handled automatically by placer.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Input (InputLayer)              (None, None, None, 1 0                                            
__________________________________________________________________________________________________
Encoder_1 (Conv2D)              (None, None, None, 6 640         Input[0][0]                      
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, None, None, 6 0           Encoder_1[0][0]                  
__________________________________________________________________________________________________
Encoder_2 (Conv2D)              (None, None, None, 1 73856       max_pooling2d[0][0]              
_____________________________________