In [2]:
#########################
# First, configure GPU  #
#########################

import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
print(gpus)
if gpus:
  # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
  try:
    tf.config.experimental.set_virtual_device_configuration(
        gpus[0],
        [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Virtual devices must be set before GPUs have been initialized
    print(e)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
1 Physical GPUs, 1 Logical GPUs


In [3]:
# tensorflow
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# check version
print('Using TensorFlow v%s' % tf.__version__)
acc_str = 'accuracy' if tf.__version__[:2] == '2.' else 'acc'

# helpers
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')


# need certainty to explain some of the results
import random as python_random
python_random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)

Using TensorFlow v2.4.1


In [4]:
# tensorboard
# %load_ext tensorboard
import datetime
import os
logs_base_dir = "./constrained_tensorboard_logs"
os.makedirs(logs_base_dir, exist_ok=True)
log_dir = os.path.join(logs_base_dir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch=0)

In [5]:
# load dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

# normalise images
train_images = train_images / 255.0
test_images = test_images / 255.0

# create a dataset (iterable) from the data using a specified batch size
batch_size = 128
dataset = tf.data.Dataset.from_tensor_slices(train_images)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

# print info
print("Number of training data: %d" % len(train_labels))
print("Number of test data: %d" % len(test_labels))
print("Image pixels: %s" % str(train_images[0].shape))
print("Number of classes: %d" % (np.max(train_labels) + 1))

Number of training data: 60000
Number of test data: 10000
Image pixels: (28, 28)
Number of classes: 10


In [6]:
# sampling z with (z_mean, z_log_var)
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.shape(z_mean))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


# latent dimension
latent_dim = 2

# build the encoder (convolutional layers)
image_input = keras.Input(shape=(28, 28, 1))
# x = layers.Flatten()(image_input)
# x = layers.Dense(128, activation='relu')(x)
# x = layers.Dense(16, activation="relu")(x)
x = layers.Conv2D(8, kernel_size=(4, 4), activation="relu", padding="same")(image_input)
x = layers.MaxPool2D(pool_size=(2, 2))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(8, kernel_size=(3, 3), activation="relu", padding="same")(x)
x = layers.MaxPool2D(pool_size=(2, 2))(x)
x = layers.BatchNormalization()(x)
x = layers.Flatten()(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z_output = Sampling()([z_mean, z_log_var])
encoder_VAE = keras.Model(image_input, [z_mean, z_log_var, z_output], name="encoder")
encoder_VAE.summary()


# build the decoder (dense layers)
z_input = keras.Input(shape=(latent_dim,))
x = layers.Dense(16, activation="relu")(z_input)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dense(28 * 28, activation="sigmoid")(x)
image_output = layers.Reshape((28, 28))(x)
decoder_VAE = keras.Model(z_input, image_output, name="decoder")
decoder_VAE.summary()

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 28, 28, 8)    136         input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 14, 14, 8)    0           conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 14, 14, 8)    32          max_pooling2d[0][0]              
____________________________________________________________________________________________

In [7]:
# BVAE class
class BVAE(keras.Model):
    # constructor
    # remove beta, add in warmup, d, gamma, lr_lambda, lr_w
    def __init__(self, encoder, decoder, KLD_aim, **kwargs): 
        super(BVAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        # self.warmup = warmup
        # self.d = d
        self.KLD_aim = KLD_aim
        # self.gamma = gamma
        # self.lr_init_lambda = lr_init_lambda
        # self.lr_w = lr_w


    # customise train_step() to implement constrained optimisation
    # 
    # The goal is to optimise the dual of the problem: min reconstr_loss
    # subject to KLD <= some value.
    # 
    # The dual is: (max over lambda)(min over w) Lagrangian.
    # Lagrangian = reconstr_loss + lambda*h(w),
    # h(w) = [sum over training examples of ReLU(KLD - KLD_aim)]
    # def train_step(self, x):
    #     if isinstance(x, tuple):
    #         x = x[0]
        
    #     Lambda = 0

    #     for i in range(warmup): # will this actually iterate? Perhaps need to look into writing custom training loop.
    #         ####################################
    #         # SGD step for reconstruction loss #
    #         ####################################
    #         with tf.GradientTape() as tape:
    #             # encoding
    #             z_mean, z_log_var, z = self.encoder(x)
    #             # decoding
    #             x_prime = self.decoder(z)
    #             # reconstruction error by binary crossentropy loss
    #             reconstruction_loss = (
    #                 tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * 28 * 28
    #             )
    #             # KL divergence
    #             kld = -0.5 * tf.reduce_mean(
    #                 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
    #             )
    #             # constraint h
    #             h = tf.nn.relu(kld - KLD_aim)
    #             # Lagrangian
    #             lagrangian = reconstruction_loss + Lambda * h
    #         # apply gradient
    #         grads = tape.gradient(lagrangian, self.trainable_weights)
    #         self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
    #         # return loss for metrics log
    #         return {
    #             "loss": loss,
    #             "reconstruction_loss": reconstruction_loss,
    #             "kld": kld,
    #         }
        
    #     l = 1
    #     t = 1
    #     t1 = 1
    #     lr_lambda = self.lr_init_lambda

In [8]:
# Hyperparameters
KLD_aim = 1.0
# Change these to adjust learning rates
# lr_lambda = 0.01
# lr_w = 0.001

# N.B. Look this up for learning rate decay:
# tf.keras.optimizers.schedules.InverseTimeDecay(
#     initial_learning_rate, decay_steps, decay_rate, staircase=False, name=None
# )

# build the BVAE
vae_model = BVAE(encoder=encoder_VAE, decoder=decoder_VAE, KLD_aim=KLD_aim)

# compile the VAE
vae_model.compile(optimizer=keras.optimizers.Adam())

In [9]:
# Lagrangian loss functions & derivative wrt lambda
def lagrangian(reconstr_loss, kld, Lambda):
    # constraint h
    h = tf.nn.relu(kld - KLD_aim)
    # Lagrangian
    l = reconstr_loss + Lambda * h
    return tf.reduce_mean(l)

def dL_dlambda(kld):
    # constraint h
    h = tf.nn.relu(kld - KLD_aim)
    # derivative = h
    return h

In [10]:
###############################
# Train steps for custom loop #
###############################

# Warmup training step (this is just train_w_step with lambda = 0)
@tf.function
def warmup_step(x):
    if isinstance(x, tuple):
        x = x[0]
    with tf.GradientTape() as tape:
        # encoding
        z_mean, z_log_var, z = vae_model.encoder(x)
        # decoding
        x_prime = vae_model.decoder(z)
        # reconstruction error by binary crossentropy loss
        reconstruction_loss = (
            tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * 28 * 28
        )
        loss = reconstruction_loss # optimise for reconstruction loss only
    # apply gradient
    grads = tape.gradient(loss, vae_model.trainable_weights)
    vae_model.optimizer.apply_gradients(zip(grads, vae_model.trainable_weights))

    # metrics log
    logits = {
        "loss": loss,
    }
    return logits

# Reconstruction training step (updates model params)
@tf.function
def train_w_step(x, Lambda):
    if isinstance(x, tuple):
        x = x[0]
    with tf.GradientTape() as tape:
        # encoding
        z_mean, z_log_var, z = vae_model.encoder(x)
        # decoding
        x_prime = vae_model.decoder(z)
        # reconstruction error by binary crossentropy loss
        reconstruction_loss = (
            tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * 28 * 28
        )
        # KL divergence
        kld = -0.5 * tf.reduce_mean(
            1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
        )
        # loss = lagrangian
        loss = lagrangian(reconstruction_loss, kld, Lambda)
    # apply gradient
    grads = tape.gradient(loss, vae_model.trainable_weights)
    vae_model.optimizer.apply_gradients(zip(grads, vae_model.trainable_weights))

    # metrics log
    logits = {
        "loss": loss,
        "reconstruction_loss": reconstruction_loss,
        "kl_loss": kld,
        "lambda": Lambda,
    }
    return logits

# Constraint training step (updates lambda)
# @tf.function
def train_lambda_step(x, lr, Lambda):
    if isinstance(x, tuple):
        x = x[0]
    
    # with tf.GradientTape() as tape:
    # encoding
    z_mean, z_log_var, z = vae_model.encoder(x)
    # decoding
    # x_prime = vae_model.decoder(z)
    # reconstruction error by binary crossentropy loss
    # reconstruction_loss = (
    #     tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * 28 * 28
    # )
    # KL divergence
    kld = -0.5 * tf.reduce_mean(
        1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
    )
    # loss = - lagrangian (SGA)
    # loss = - lagrangian(reconstruction_loss, kld, Lambda)
    # calculate and apply gradient
    # Note lambda is the trainable param here, not model params
    # TODO: See if it's possible to include lambda in the VAE model
    # this might make the function compilable
    # the other option is to hard-code the SGA rather than using an optimizer

    # grad = tape.gradient(target=loss, sources=[Lambda])
    # opt = tf.keras.optimizers.Adam()
    # opt.apply_gradients(zip(grad, [Lambda]))

    Lambda = Lambda + lr * dL_dlambda(kld)

    # metrics log
    logits = {
        # "loss": loss,
        # "reconstruction_loss": reconstruction_loss,
        "kl_loss": kld,
        "lambda": Lambda,
    }
    return logits

In [62]:
# train the VAE
# vae_model.fit(train_images, train_images, epochs=50, batch_size=128, callbacks=[tensorboard_callback])

# training loop
warmup_iters = 100
SGD_steps = 1
increment_SGD_steps = 1
epochs = 10
Lambda = tf.Variable(0.0)

for step, train_image_batch in enumerate(dataset):
    # perform warmup
    logits = warmup_step(train_image_batch)
    # check if it's time to end warmup
    # TODO: this is clumsy, try to think of another way
    # could pre-process dataset to contain only (warmup_iters) batches
    if step >= warmup_iters:
        break

# log after warmup
print("\nTraining logs at end of warmup:")
for metric, value in logits.items():
    print(metric, value.numpy())

# alternating steps
for epoch in range(epochs):
    print(f"\nStart of epoch {epoch + 1}")

    # iterate over batches
    for step, train_image_batch in enumerate(dataset):
        # update lambda
        logits = train_lambda_step(train_image_batch, Lambda)

        for i in range(SGD_steps):
            logits = train_w_step(train_image_batch, Lambda)
        
        # increment no. of steps
        SGD_steps += increment_SGD_steps

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                f"\nTraining logs at step {step}:"
            )
            for metric, value in logits.items():
                print(metric, value.numpy())
            print("Seen: %d samples" % ((step + 1) * batch_size))
    


Training logs at end of warmup:
loss 179.24568

Start of epoch 1


ValueError: in user code:

    <ipython-input-61-60e4cba731df>:90 train_lambda_step  *
        opt.apply_gradients(zip(grad, [Lambda]))
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\keras\optimizer_v2\optimizer_v2.py:604 apply_gradients  **
        self._create_all_weights(var_list)
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\keras\optimizer_v2\optimizer_v2.py:781 _create_all_weights
        _ = self.iterations
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\keras\optimizer_v2\optimizer_v2.py:788 __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\keras\optimizer_v2\optimizer_v2.py:921 iterations
        self._iterations = self.add_weight(
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\keras\optimizer_v2\optimizer_v2.py:1122 add_weight
        variable = self._add_variable_with_custom_getter(
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\training\tracking\base.py:805 _add_variable_with_custom_getter
        new_variable = getter(
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\keras\engine\base_layer_utils.py:130 make_variable
        return tf_variables.VariableV1(
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\ops\variables.py:260 __call__
        return cls._variable_v1_call(*args, **kwargs)
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\ops\variables.py:206 _variable_v1_call
        return previous_getter(
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\ops\variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    C:\Users\James\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\def_function.py:730 invalid_creator_scope
        raise ValueError(

    ValueError: tf.function-decorated function tried to create variables on non-first call.


In [None]:
# scatter plot of encodings in the latent space
def scatter_plot_encodings_latent(encodings, labels):
    plt.figure(dpi=100)
    scat = plt.scatter(encodings[:, 0], encodings[:, 1], c=labels, s=.5, cmap='Paired')
    plt.gca().add_artist(plt.legend(*scat.legend_elements(), 
                         title='Image labels', bbox_to_anchor=(1.5, 1.)))
    plt.xlabel('Feature X')
    plt.ylabel('Feature Y')
    plt.gca().set_aspect(1)
    plt.show()
    
# histogram plot of encodings in the latent space
def hist_plot_encodings_latent(encodings, labels, digit, dim, ax):
    # extract
    encodings_digit = encodings[labels == digit, dim]
    # histogram
    ax.hist(encodings_digit, bins=60, density=True, color=['g', 'b'][dim], alpha=.5)
    # mean and std dev
    mean = np.mean(encodings_digit)
    std = np.std(encodings_digit)
    ax.axvline(mean, c='r')
    ax.set_xlabel('Digit %d, Feature %s\n~${\cal N}(\mu=%.1f, \sigma=%.1f)$' % 
                  (digit, ['X', 'Y'][dim], mean, std), c='k')
    
# generate images from the latent space
def generate_images_latent(decoder, x0, x1, dx, y0, y1, dy):
    # uniformly sample the latent space
    nx = round((x1 - x0) / dx) + 1
    ny = round((y1 - y0) / dy) + 1
    grid_x = np.linspace(x0, x1, nx)
    grid_y = np.linspace(y1, y0, ny)
    latent = np.array(np.meshgrid(grid_x, grid_y)).reshape(2, nx * ny).T

    # decode images
    decodings = decoder.predict(latent)
    
    # display a (nx, ny) 2D manifold of digits
    figure = np.zeros((28 * ny, 28 * nx))
    for iy in np.arange(ny):
        for ix in np.arange(nx):
            figure[iy * 28 : (iy + 1) * 28, ix * 28 : (ix + 1) * 28] = decodings[iy * nx + ix]
            
    # plot figure
    plt.figure(dpi=100, figsize=(nx / 3, ny / 3))
    plt.xticks(np.arange(28 // 2, nx * 28 + 28 // 2, 28), np.round(grid_x, 1), rotation=90)
    plt.yticks(np.arange(28 // 2, ny * 28 + 28 // 2, 28), np.round(grid_y, 1))
    plt.xlabel('Feature X')
    plt.ylabel('Feature Y')
    plt.imshow(figure, cmap="Greys_r")
    plt.grid(False)
    plt.show()
    
# encode images by BVAE
train_encodings_BVAE = encoder_VAE.predict(train_images)

# scatter plot of encodings by BVAE
scatter_plot_encodings_latent(train_encodings_BVAE[2], train_labels)

# histogram plot of encodings by BVAE
fig, axes = plt.subplots(5, 4, dpi=100, figsize=(15, 12), sharex=True)
plt.subplots_adjust(hspace=.4)
for digit in range(10):
    hist_plot_encodings_latent(train_encodings_BVAE[2], train_labels, digit, 0, 
                               axes[digit // 2, digit % 2 * 2 + 0])
    hist_plot_encodings_latent(train_encodings_BVAE[2], train_labels, digit, 1, 
                               axes[digit // 2, digit % 2 * 2 + 1])
plt.show()

# generate images by BVAE
generate_images_latent(decoder_VAE, x0=-2, x1=2, dx=.1, y0=-2, y1=2, dy=.1)