# Variational AutoEncoder

**Author:** [fchollet](https://twitter.com/fchollet)<br>
**Date created:** 2020/05/03<br>
**Last modified:** 2020/05/03<br>
**Description:** Convolutional Variational AutoEncoder (VAE) trained on MNIST digits.

## Setup

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

## Create a sampling layer

In [2]:

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


## Build the encoder

In [3]:
def build_encoder(filters, latent_dim=512, input_shape=(512,512,3)):


    input_image = layers.Input(shape=input_shape, name='encoder_input')

    x = layers.Conv2D(filters[0], 3, (2, 2), padding='same', activation='leaky_relu')(input_image)

    for n_filters in filters[1:]:
        x = layers.Conv2D(n_filters, 3, (2, 2), padding='same', activation='leaky_relu')(x)


    x = layers.Flatten()(x)

    x = layers.Dense(latent_dim)(x)

    z_mean = layers.Dense(latent_dim, name='z_mean')(x)
    z_log_var = layers.Dense(latent_dim, name='z_log_var', kernel_initializer='zeros')(x)


    # use the reparameterization trick and get the output from the sample() function
    z = Sampling()([z_mean, z_log_var])


    return keras.Model(input_image, [z_mean, z_log_var, z], name='encoder')

encoder_filters = [3, 8, 16, 32, 64, 64, 128]
encoder = build_encoder(encoder_filters, latent_dim=512)
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 encoder_input (InputLayer)     [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 256, 256, 3)  84          ['encoder_input[0][0]']          
                                                                                                  
 conv2d_1 (Conv2D)              (None, 128, 128, 8)  224         ['conv2d[0][0]']                 
                                                                                                  
 conv2d_2 (Conv2D)              (None, 64, 64, 16)   1168        ['conv2d_1[0][0]']         

In [4]:
def build_decoder(filters, latent_dim=512, input_shape=(512,512,3)):
    latent_input = layers.Input(shape=(latent_dim,), name='z_sampling')
    #512/2^n where n is number of conv layers
    x_dim = int(input_shape[0] / (2 ** len(filters)))
    #dense_filters = latent_dim * 8
    #x = layers.Dense(dense_filters, activation='relu')(latent_input)
    x = layers.Reshape((x_dim, x_dim, int(latent_dim / x_dim**2)))(latent_input)

    for n_filters in filters:
        x = layers.Conv2D(n_filters, 3, (1, 1), padding='same', activation='tanh')(x)
        x = layers.UpSampling2D()(x)

    ##x = layers.Conv2DTranspose(3, (1, 1), (1, 1), padding='same', activation='sigmoid')(x)

    return keras.Model(latent_input, x, name='decoder')


decoder_filters = [128, 64, 64, 32, 16, 8, 3]
decoder = build_decoder(decoder_filters, latent_dim=512)
decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 z_sampling (InputLayer)     [(None, 512)]             0         
                                                                 
 reshape (Reshape)           (None, 4, 4, 32)          0         
                                                                 
 conv2d_7 (Conv2D)           (None, 4, 4, 128)         36992     
                                                                 
 up_sampling2d (UpSampling2D  (None, 8, 8, 128)        0         
 )                                                               
                                                                 
 conv2d_8 (Conv2D)           (None, 8, 8, 64)          73792     
                                                                 
 up_sampling2d_1 (UpSampling  (None, 16, 16, 64)       0         
 2D)                                                       

## Define the VAE as a `Model` with a custom `train_step`

In [5]:
class VAE(keras.Model):

    def call(self, inputs, training=None, mask=None):
        z_mean, z_log_var, z = encoder(inputs)
        return decoder(z)

    def __init__(self, encoder, decoder, beta=1.0, loss='mse', **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.beta = beta
        self.rec_loss = loss
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

        self.val_total_loss_tracker = keras.metrics.Mean(name="val_total_loss")
        self.val_reconstruction_loss_tracker = keras.metrics.Mean(
            name="val_reconstruction_loss"
        )
        self.val_kl_loss_tracker = keras.metrics.Mean(name="val_kl_loss")


    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = 0
            if self.rec_loss == 'mse':
                reconstruction_loss = tf.reduce_mean(
                    tf.reduce_sum(
                        keras.losses.mse(data, reconstruction), axis=(1,2)
                    )
                )
            elif self.rec_loss == 'bce':
                reconstruction_loss = tf.reduce_mean(
                    tf.reduce_sum(
                        keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                    )
                )
            else:
                print("wrong loss {loss}".format(loss=self.rec_loss))

            kl_loss = -self.beta * 0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def test_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            val_reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            val_kl_loss = -self.beta * 0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            val_kl_loss = tf.reduce_mean(tf.reduce_sum(val_kl_loss, axis=1))
            val_total_loss = val_reconstruction_loss + val_kl_loss

        self.val_total_loss_tracker.update_state(val_total_loss)
        self.val_reconstruction_loss_tracker.update_state(val_reconstruction_loss)
        self.val_kl_loss_tracker.update_state(val_kl_loss)
        return {
            "total_loss": self.val_total_loss_tracker.result(),
            "reconstruction_loss": self.val_reconstruction_loss_tracker.result(),
            "kl_loss": self.val_kl_loss_tracker.result(),
        }


## Train the VAE

In [6]:
%cd "C:\Datasets\"

C:\Datasets


In [7]:
directory = "flickrfaces\splits"

In [8]:
from keras.preprocessing.image import ImageDataGenerator
import os

In [9]:
def preprocess(img):
    img = img.astype(np.float32) / 255.0
    img = (img - 0.5) * 2
    return img

In [10]:
datagen = ImageDataGenerator(
    preprocessing_function=preprocess,
    #rescale=1./255,
    #shear_range=0.2,
    #zoom_range=0.2,
    horizontal_flip=False, #True
    )

In [11]:
X_DIM = 512
Y_DIM = 512

BATCH_SIZE = 128


train_generator = datagen.flow_from_directory(
        os.path.join(directory, 'train'),
        target_size=(X_DIM, Y_DIM),
        class_mode=None,
        batch_size=BATCH_SIZE,
        shuffle=True,
        seed=None,
        )

validation_generator = datagen.flow_from_directory(
        os.path.join(directory, 'val'),
        target_size=(X_DIM, Y_DIM),
        class_mode=None,
        batch_size=BATCH_SIZE,
        shuffle=False,
        seed=None,
        )

test_generator = datagen.flow_from_directory(
        os.path.join(directory, 'test'),
        target_size=(X_DIM, Y_DIM),
        class_mode=None,
        batch_size=1,
        shuffle=False,
        seed=None,
        )

Found 56000 images belonging to 1 classes.
Found 7000 images belonging to 1 classes.
Found 7048 images belonging to 2 classes.


In [12]:
import matplotlib.pyplot as plt
class VisualizeIOCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_dir):
        super().__init__()
        self.original_batch = validation_generator.next()
        self.file_writer = tf.summary.create_file_writer(log_dir)
        original_batch = self.original_batch / 2 + 0.5
        with self.file_writer.as_default():
            images = np.reshape(original_batch[0:8], (-1, 512, 512, 3))
            tf.summary.image("8 training input examples", images, max_outputs=8, step=0)


    def on_epoch_end(self, epoch, logs=None):
        z_mean, z_log_var, z = encoder.predict(self.original_batch)
        reconstructed = decoder.predict(z)
        reconstructed = reconstructed / 2 + 0.5

        # Using the file writer, log the reshaped image.
        with self.file_writer.as_default():
            images = np.reshape(reconstructed[0:8], (-1, 512, 512, 3))
            tf.summary.image("8 training output examples", images, max_outputs=8, step=epoch)


In [13]:
from datetime import datetime
# Checkpoint
checkpoint_path = "D:/Notebooks/Advanced_DL/Checkpoint/"

#if not os.path.exists(checkpoint_path):
#    os.makedirs(checkpoint_path)

callbacks = []

callbacks.append(keras.callbacks.ReduceLROnPlateau(monitor='loss',
                                                   min_delta=50,
                                                   patience=5))

now_str = datetime.now().strftime("%Y%m%d-%H%M%S")

cp_callback = keras.callbacks.ModelCheckpoint(os.path.join(checkpoint_path, now_str + '.hdf5'),
                              monitor='loss',
                              save_best_only=True,
                              save_weights_only=True,
                              mode='auto',
                              verbose=1)
callbacks.append(cp_callback)

log_dir = "D:/Notebooks/Advanced_DL/logs/" + now_str

visualization_callback = VisualizeIOCallback(log_dir +"/images")
callbacks.append(visualization_callback)

callbacks.append(keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0))

# Early Stopping
EARLY_STOP = False
if EARLY_STOP:
    es_callback = keras.callbacks.EarlyStopping(monitor='val_total_loss',
                                                   mode='auto',
                                                   patience=10,
                                                   verbose=1)
    callbacks.append(es_callback)

In [14]:
vae = VAE(encoder, decoder, beta=1)
vae(train_generator.next())
latest_checkpoint = "20221011-200545"
vae.load_weights("D:/Notebooks/Advanced_DL/Checkpoint/" + latest_checkpoint + ".hdf5")
#opt = keras.optimizers.Adam(learning_rate=1e-3)
opt = keras.optimizers.RMSprop(learning_rate=1e-3)

vae.compile(optimizer=opt, loss=None)

results = vae.fit(train_generator,
                  epochs=200,
                  callbacks = callbacks,
                  #validation_data=validation_generator,
                  initial_epoch=41,
                  workers=8,
                  )

Epoch 42/200
Epoch 42: loss improved from inf to 10820.15430, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221011-200545.hdf5
Epoch 43/200
Epoch 43: loss improved from 10820.15430 to 10689.64453, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221011-200545.hdf5
Epoch 44/200
Epoch 44: loss improved from 10689.64453 to 10654.62402, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221011-200545.hdf5
Epoch 45/200
Epoch 45: loss improved from 10654.62402 to 10607.38672, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221011-200545.hdf5
Epoch 46/200
Epoch 46: loss improved from 10607.38672 to 10578.91016, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221011-200545.hdf5
Epoch 47/200
Epoch 47: loss improved from 10578.91016 to 10531.75195, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221011-200545.hdf5
Epoch 48/200
Epoch 48: loss improved from 10531.75195 to 10497.97266, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221011-200545.hdf5
Epoch 4

ResourceExhaustedError: Graph execution error:

Detected at node 'Mean' defined at (most recent call last):
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\traitlets\config\application.py", line 846, in launch_instance
      app.start()
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\ipykernel\kernelapp.py", line 712, in start
      self.io_loop.start()
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\tornado\platform\asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\asyncio\base_events.py", line 601, in run_forever
      self._run_once()
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\asyncio\base_events.py", line 1905, in _run_once
      handle._run()
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\ipykernel\kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\ipykernel\kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\ipykernel\kernelbase.py", line 406, in dispatch_shell
      await result
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\ipykernel\kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\ipykernel\ipkernel.py", line 383, in do_execute
      res = shell.run_cell(
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\ipykernel\zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 2881, in run_cell
      result = self._run_cell(
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 2936, in _run_cell
      return runner(coro)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3135, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3338, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3398, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\loren\AppData\Local\Temp\ipykernel_10396\541913300.py", line 9, in <cell line: 9>
      results = vae.fit(train_generator,
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "C:\Users\loren\AppData\Local\Temp\ipykernel_10396\3340400612.py", line 42, in train_step
      keras.losses.mse(data, reconstruction), axis=(1,2)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\keras\losses.py", line 1486, in mean_squared_error
      return backend.mean(tf.math.squared_difference(y_pred, y_true), axis=-1)
    File "C:\Users\loren\anaconda3\envs\tensorflow-gpu\lib\site-packages\keras\backend.py", line 2915, in mean
      return tf.reduce_mean(x, axis, keepdims)
Node: 'Mean'
OOM when allocating tensor with shape[33554432] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node Mean}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_3382]

In [16]:
vae.evaluate(validation_generator)



[-1086551.875, -1087593.625, 1041.9169921875]