### Setup

In [1]:
import numpy as np
import tensorflow as tf
import keras

import os
import datetime
import time as timer
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

import models as M


### Constants

In [2]:
EPOCH = 30
BATCH = 128

DATADIR = "/Users/mghifary/Work/Code/AI/data"
MODELDIR = "/Users/mghifary/Work/Code/AI/models"

### Load dataset

In [3]:
data_path = os.path.join(DATADIR, "mnist.npz")
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data(data_path)
x_train = x_train.astype("float32") / 255.
x_test = x_test.astype("float32") / 255.

mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1)


ds_train = tf.data.Dataset.from_tensor_slices(mnist_digits)
ds_train = ds_train.shuffle(buffer_size=1024).batch(BATCH)


### Train VAE

In [4]:
# Setup summary writers for tensorboard
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/vae-mnist/' + current_time + '/train'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)

In [5]:
latent_dim = 2
encoder = M.create_conv_encoder(input_shape=(28, 28, 1), latent_dim=latent_dim)
decoder = M.create_conv_decoder(latent_dim=latent_dim)
model = M.VAE(encoder, decoder)

optimizer = keras.optimizers.legacy.Adam(3e-4)
modelpath = os.path.join(MODELDIR, "vae-conv_mnist.h5")

model.build(input_shape=(None, 28, 28, 1))

In [6]:
model.encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 28, 28, 1)]          0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 14, 14, 32)           320       ['input_1[0][0]']             
                                                                                                  
 conv2d_1 (Conv2D)           (None, 7, 7, 64)             18496     ['conv2d[0][0]']              
                                                                                                  
 flatten (Flatten)           (None, 3136)                 0         ['conv2d_1[0][0]']            
                                                                                            

In [7]:
model.decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 2)]               0         
                                                                 
 dense_1 (Dense)             (None, 3136)              9408      
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 conv2d_transpose (Conv2DTr  (None, 14, 14, 64)        36928     
 anspose)                                                        
                                                                 
 conv2d_transpose_1 (Conv2D  (None, 28, 28, 32)        18464     
 Transpose)                                                      
                                                                 
 conv2d_transpose_2 (Conv2D  (None, 28, 28, 1)         289 

In [8]:
model.summary()

Model: "vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder (Functional)        [(None, 2),               69076     
                              (None, 2),                         
                              (None, 2)]                         
                                                                 
 decoder (Functional)        (None, 28, 28, 1)         65089     
                                                                 
Total params: 134171 (524.11 KB)
Trainable params: 134165 (524.08 KB)
Non-trainable params: 6 (24.00 Byte)
_________________________________________________________________


In [9]:
# Training on batch function
@tf.function
def train_vae_on_batch(model, optimizer, inputs):
    with tf.GradientTape() as tape:
        z_mean, z_log_var, z = model.encoder(inputs, training=True)
        reconstruction = model.decoder(z, training=True)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(inputs, reconstruction), axis=(1 , 2)
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss
        
    grads = tape.gradient(total_loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    model.total_loss_tracker.update_state(total_loss)
    model.reconstruction_loss_tracker.update_state(reconstruction_loss)
    model.kl_loss_tracker.update_state(kl_loss)

In [10]:
# Do training
for epoch in range(EPOCH):
    # Mini-batch training
    train_duration = 0.0
    for step, x in enumerate(ds_train):
        start_t = timer.time()
        train_vae_on_batch(model, optimizer, x)
        elapsed_t = timer.time() - start_t

        train_duration += elapsed_t

    tloss = model.total_loss_tracker.result()
    rloss = model.reconstruction_loss_tracker.result()
    klloss = model.kl_loss_tracker.result()

    
    # Store log
    with train_summary_writer.as_default():
        tf.summary.scalar('total_loss', tloss, step=epoch)
        tf.summary.scalar('reconstruction_loss', rloss, step=epoch)
        tf.summary.scalar('kl_loss', klloss, step=epoch)

    print(f"Epoch {epoch+1} - Training [total_loss: {tloss:.5f}, reconstruction_loss: {rloss:.5f}, kl_loss: {klloss:.5f}] ({train_duration:.3f} secs)")

    # Save model
    model.save_weights(modelpath, overwrite=True, save_format=None, options=None)

Epoch 1 - Training [total_loss: 234.45265, reconstruction_loss: 231.09221, kl_loss: 3.36059] (16.185 secs)
Epoch 2 - Training [total_loss: 213.26534, reconstruction_loss: 209.73053, kl_loss: 3.53488] (15.843 secs)
Epoch 3 - Training [total_loss: 201.00110, reconstruction_loss: 196.85825, kl_loss: 4.14256] (17.604 secs)
Epoch 4 - Training [total_loss: 193.28387, reconstruction_loss: 188.80579, kl_loss: 4.47765] (17.851 secs)
Epoch 5 - Training [total_loss: 187.90463, reconstruction_loss: 183.20607, kl_loss: 4.69811] (17.625 secs)
Epoch 6 - Training [total_loss: 183.85678, reconstruction_loss: 178.99518, kl_loss: 4.86134] (17.759 secs)
Epoch 7 - Training [total_loss: 180.66057, reconstruction_loss: 175.67039, kl_loss: 4.98987] (17.846 secs)
Epoch 8 - Training [total_loss: 178.08731, reconstruction_loss: 172.99196, kl_loss: 5.09488] (18.143 secs)
Epoch 9 - Training [total_loss: 175.95769, reconstruction_loss: 170.77371, kl_loss: 5.18366] (18.614 secs)
Epoch 10 - Training [total_loss: 174.