# Conditional Autoencoder

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [4]:
np.zeros(shape=(10, ))

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [55]:
dropout_rate = 0.3
learn_rate_init = 0.0001

def apply_bn_and_dropout(x):
    return layers.Dropout(dropout_rate)(layers.BatchNormalization()(x))

def make_encoder(input_dim: int, latent_dim: int, num_classes: int):
    input_img = keras.Input(shape=(input_dim, input_dim, 1))
    flatt_img = layers.Flatten()(input_img)
    input_lbl = keras.Input(shape=(num_classes,), dtype='float32')

    x = layers.concatenate([flatt_img, input_lbl])
    x = layers.Dense(256, activation='relu')(x)
    x = apply_bn_and_dropout(x)
    x = layers.Dense(128, activation='relu')(x)
    x = apply_bn_and_dropout(x)
    z_mean = layers.Dense(latent_dim, name="z_mean")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
    z = Sampling()([z_mean, z_log_var])
    encoder = keras.Model([input_img, input_lbl], [z_mean, z_log_var, z], name="encoder")
    encoder.summary()
    return encoder

def make_decoder(output_dim: int, latent_dim: int, num_classes: int):
    latent_inputs = keras.Input(shape=(latent_dim,))
    input_lbl = keras.Input(shape=(num_classes,), dtype='float32')
    x = layers.concatenate([latent_inputs, input_lbl])
    x = layers.Dense(128, activation='relu')(x)
    x = apply_bn_and_dropout(x)
    x = layers.Dense(256, activation='relu')(x)
    x = apply_bn_and_dropout(x)
    x = layers.Dense(28 * 28, activation='sigmoid')(x)
    decoder_outputs = layers.Reshape((output_dim, output_dim, 1))(x)
    decoder = keras.Model([latent_inputs, input_lbl], decoder_outputs, name="decoder")
    decoder.summary()
    return decoder

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

class CVAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        print(data, data[0], len(data[0]))
        img, lbl = data[0]

        print(img.shape)

        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(img, lbl)
            reconstruction = self.decoder(z, lbl)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(img, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }


In [56]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

y_train_cat = keras.utils.to_categorical(y_train).astype(np.float32)
y_test_cat = keras.utils.to_categorical(y_test).astype(np.float32)
num_calsses = y_train_cat.shape[1]

mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
class_lbls_digits = np.concatenate([y_train_cat, y_test_cat], axis=0)

encoder = make_encoder(28, 2, num_calsses)
decoder = make_decoder(28, 2, num_calsses)
vae = CVAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam(learn_rate_init))


Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_57 (InputLayer)          [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 flatten_15 (Flatten)           (None, 784)          0           ['input_57[0][0]']               
                                                                                                  
 input_58 (InputLayer)          [(None, 10)]         0           []                               
                                                                                                  
 concatenate_27 (Concatenate)   (None, 794)          0           ['flatten_15[0][0]',             
                                                                  'input_58[0][0]']         

In [57]:
def plot_label_clusters(vae, data, labels, epoch: int = 1):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.title(f'Epoch: {epoch}')
    plt.xlim(-1000, 1000)
    plt.ylim(-1000, 1000)
    plt.savefig(f'figs/cvae/{epoch}.png')

    plt.close()
    #plt.show()

def on_epoch_end(epoch, logs):
    plot_label_clusters(vae, x_train, y_train, epoch)
    clear_output()


save_fig = keras.callbacks.LambdaCallback(on_epoch_end=on_epoch_end)
tb       = keras.callbacks.TensorBoard(log_dir='logs_cvae')

In [58]:
vae.fit([mnist_digits, class_lbls_digits], epochs=50, batch_size=500, callbacks=[save_fig, tb])


Epoch 1/50
((<tf.Tensor 'IteratorGetNext:0' shape=(500, 28, 28, 1) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(500, 10) dtype=float32>),) (<tf.Tensor 'IteratorGetNext:0' shape=(500, 28, 28, 1) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(500, 10) dtype=float32>) 2
(500, 28, 28, 1)


ValueError: in user code:

    File "/home/denys/anaconda3/envs/py3D/lib/python3.7/site-packages/keras/engine/training.py", line 1021, in train_function  *
        return step_function(self, iterator)
    File "/home/denys/anaconda3/envs/py3D/lib/python3.7/site-packages/keras/engine/training.py", line 1010, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/denys/anaconda3/envs/py3D/lib/python3.7/site-packages/keras/engine/training.py", line 1000, in run_step  **
        outputs = model.train_step(data)
    File "<ipython-input-55-74eb05596939>", line 74, in train_step
        z_mean, z_log_var, z = self.encoder(img, lbl)
    File "/home/denys/anaconda3/envs/py3D/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/home/denys/anaconda3/envs/py3D/lib/python3.7/site-packages/keras/engine/input_spec.py", line 200, in assert_input_compatibility
        raise ValueError(f'Layer "{layer_name}" expects {len(input_spec)} input(s),'

    ValueError: Layer "encoder" expects 2 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(500, 28, 28, 1) dtype=float32>]
