In [3]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from typing import Tuple

In [2]:
tf.__version__

'2.4.0'

# Part I: Model

In [14]:
class Sampler(layers.Layer):
    """
    Implement the sampling layer in the Variational Auto-Encoder model
    
    Use (z_mean, z_log_var) to sample z
    
    z = mean + sigma * epsilon
    """
    def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
        """
        params:
            inputs: a tuple of z_mean, z_log_var
              z_mean: tensor, shape: (batch, latent_dim)
              z_log_var: tensor, shape: (batch, latent_dim)
        retursn:
            sampled hidden variable z: (batch, latent_dim)
        """
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
                

In [31]:
class Encoder(layers.Layer):
    """
    Implement the Encoder layer in the Variational Auto-Encoder model    
    """
    def __init__(self, intermediate_dim=64, latent_dim=32, name='encoder', **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation='relu')
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)
        self.sampler = Sampler()
        
    def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
        """
        params:
            inputs: input data, shape: (batch, original_dim)
        returns:            
            z_mean, z_log_var, z
              z_mean: shape: (batch, latent_dim)
              z_log_var: shape: (batch, latent_dim)
              z: shape: (batch, latent_dim)
        """
        x = self.dense_proj(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampler((z_mean, z_log_var))
        return z_mean, z_log_var, z
    

In [32]:
class Decoder(layers.Layer):
    """
    Implement the Decoder layer in the Variational Auto-Encoder model
    """
    def __init__(self, original_dim, intermediate_dim=64, name='decoder', **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation='relu')
        self.dense_output = layers.Dense(original_dim, activation='sigmoid')
        
    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        params:
            inputs: encoded input, shape: (batch, latent_dim)
        returns:
            reconstructed input, shape: (batch, original_dim)
        """
        x = self.dense_proj(inputs)
        return self.dense_output(x)
    

In [33]:
class VariationalAutoEncoder(keras.Model):
    """
    Implement a Variational Auto-Encoder model
    """
    
    def __init__(self, original_dim, intermediate_dim=64, latent_dim=32, name='autoencoder', **kwargs):
        super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
        self.original_dim = original_dim
        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)
    
    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        params:
            inputs: original data, shape: (batch, original_dim)
        returns:
            reconstructed data, shape: (batch, original_dim)
        """
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        
        # add KL divergence regularization loss
        kl_loss = -0.5 * tf.reduce_mean(
            1 + z_log_var - tf.exp(z_log_var) - tf.square(z_mean)
        )
        self.add_loss(kl_loss)
        
        return reconstructed
        

# Part II: Training

## Training data

In [21]:
(x_train, _y_train), _ = tf.keras.datasets.mnist.load_data()

In [25]:
type(x_train), x_train.shape

(numpy.ndarray, (60000, 28, 28))

In [26]:
x_train = x_train.reshape(-1, 784).astype('float32') / 255

In [27]:
x_train.shape

(60000, 784)

In [28]:
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

In [35]:
for batch_idx, batch in enumerate(train_dataset):
    if batch_idx == 0:
        break

In [37]:
batch.shape

TensorShape([64, 784])

## Specifying modules

In [34]:
original_dim = 28 * 28
print(f'original_dim: {original_dim}')

vae = VariationalAutoEncoder(original_dim, 64, 32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()

loss_metric = tf.keras.metrics.Mean()

original_dim: 784


## Training loop

In [45]:
epochs = 2
for epoch_idx in range(epochs):
    for batch_idx, x_train_batch in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            reconstructed = vae(x_train_batch)
            
            # compute reconstruction loss
            loss = mse_loss_fn(x_train_batch, reconstructed)
            # add KL divergence loss
            loss += sum(vae.losses)
        
        grads = tape.gradient(loss, vae.trainable_weights)
        optimizer.apply_gradients(zip(grads, vae.trainable_weights))
        
        loss_metric(loss)
        
        if batch_idx % 100 == 0:
            print(f'step: {batch_idx}, mean loss {loss_metric.result()}')
            

step: 0, mean loss 0.3160361349582672
step: 100, mean loss 0.12490179389715195
step: 200, mean loss 0.09885703772306442
step: 300, mean loss 0.08890584856271744
step: 400, mean loss 0.08404117077589035
step: 500, mean loss 0.08076141029596329
step: 600, mean loss 0.0786130279302597
step: 700, mean loss 0.0770300105214119
step: 800, mean loss 0.07590256631374359
step: 900, mean loss 0.07487347722053528
step: 0, mean loss 0.07458935678005219
step: 100, mean loss 0.07393629103899002
step: 200, mean loss 0.07344485074281693
step: 300, mean loss 0.07296932488679886
step: 400, mean loss 0.072649747133255
step: 500, mean loss 0.07224930822849274
step: 600, mean loss 0.07195466756820679
step: 700, mean loss 0.07167019695043564
step: 800, mean loss 0.07143555581569672
step: 900, mean loss 0.07117217779159546
