## Setup

In [20]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

## Create a sampling layer

In [21]:

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


## Build the encoder

In [22]:
def build_encoder(filters, latent_dim=512, input_shape=(512,512,3)):


    input_image = layers.Input(shape=input_shape, name='encoder_input')

    x = layers.Conv2D(filters[0], 3, (2, 2), padding='same', activation='leaky_relu')(input_image)

    for n_filters in filters[1:]:
        x = layers.Conv2D(n_filters, 3, (2, 2), padding='same', activation='leaky_relu')(x)

    x = layers.Flatten()(x)

    x = layers.Dense(latent_dim)(x)

    z_mean = layers.Dense(latent_dim, name='z_mean')(x)
    z_log_var = layers.Dense(latent_dim, name='z_log_var', kernel_initializer='zeros')(x)


    # use the reparameterization trick and get the output from the sample() function
    z = Sampling()([z_mean, z_log_var])


    return keras.Model(input_image, [z_mean, z_log_var, z], name='encoder')

encoder_filters = [3, 8, 16, 32, 64, 64, 128, 128]
encoder = build_encoder(encoder_filters, latent_dim=512)
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 encoder_input (InputLayer)     [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_16 (Conv2D)             (None, 256, 256, 3)  84          ['encoder_input[0][0]']          
                                                                                                  
 conv2d_17 (Conv2D)             (None, 128, 128, 8)  224         ['conv2d_16[0][0]']              
                                                                                                  
 conv2d_18 (Conv2D)             (None, 64, 64, 16)   1168        ['conv2d_17[0][0]']        

In [23]:
def build_decoder(filters, latent_dim=512, input_shape=(512,512,3)):
    latent_input = layers.Input(shape=(latent_dim,), name='z_sampling')
    #512/2^n where n is number of conv layers
    x_dim = int(input_shape[0] / (2 ** len(filters)))
    #dense_filters = latent_dim * 8
    #x = layers.Dense(dense_filters, activation='relu')(latent_input)
    x = layers.Reshape((x_dim, x_dim, int(latent_dim / x_dim**2)))(latent_input)

    for n_filters in filters:
        x = layers.Conv2D(n_filters, 3, (1, 1), padding='same', activation='tanh')(x)
        x = layers.UpSampling2D()(x)

    ##x = layers.Conv2DTranspose(3, (1, 1), (1, 1), padding='same', activation='sigmoid')(x)

    return keras.Model(latent_input, x, name='decoder')


decoder_filters = [128, 128, 64, 64, 32, 16, 8, 3]
decoder = build_decoder(decoder_filters, latent_dim=512)
decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 z_sampling (InputLayer)     [(None, 512)]             0         
                                                                 
 reshape_1 (Reshape)         (None, 2, 2, 128)         0         
                                                                 
 conv2d_24 (Conv2D)          (None, 2, 2, 128)         147584    
                                                                 
 up_sampling2d_8 (UpSampling  (None, 4, 4, 128)        0         
 2D)                                                             
                                                                 
 conv2d_25 (Conv2D)          (None, 4, 4, 128)         147584    
                                                                 
 up_sampling2d_9 (UpSampling  (None, 8, 8, 128)        0         
 2D)                                                       

## Define the VAE as a `Model` with a custom `train_step`

In [24]:
class VAE(keras.Model):

    def call(self, inputs, training=None, mask=None):
        z_mean, z_log_var, z = encoder(inputs)
        return decoder(z)

    def __init__(self, encoder, decoder, beta=1.0, loss='mse', **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.beta = beta
        self.rec_loss = loss
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

        self.val_total_loss_tracker = keras.metrics.Mean(name="val_total_loss")
        self.val_reconstruction_loss_tracker = keras.metrics.Mean(
            name="val_reconstruction_loss"
        )
        self.val_kl_loss_tracker = keras.metrics.Mean(name="val_kl_loss")


    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = 0
            if self.rec_loss == 'mse':
                reconstruction_loss = tf.reduce_mean(
                    tf.reduce_sum(
                        keras.losses.mse(data, reconstruction), axis=(1,2)
                    )
                )
            elif self.rec_loss == 'bce':
                reconstruction_loss = tf.reduce_mean(
                    tf.reduce_sum(
                        keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                    )
                )
            else:
                print("wrong loss {loss}".format(loss=self.rec_loss))

            kl_loss = -self.beta * 0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def test_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            val_reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            val_kl_loss = -self.beta * 0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            val_kl_loss = tf.reduce_mean(tf.reduce_sum(val_kl_loss, axis=1))
            val_total_loss = val_reconstruction_loss + val_kl_loss

        self.val_total_loss_tracker.update_state(val_total_loss)
        self.val_reconstruction_loss_tracker.update_state(val_reconstruction_loss)
        self.val_kl_loss_tracker.update_state(val_kl_loss)
        return {
            "total_loss": self.val_total_loss_tracker.result(),
            "reconstruction_loss": self.val_reconstruction_loss_tracker.result(),
            "kl_loss": self.val_kl_loss_tracker.result(),
        }


## Train the VAE

In [25]:
%cd "C:\Datasets\"

C:\Datasets


In [26]:
directory = "flickrfaces\splits"

In [27]:
from keras.preprocessing.image import ImageDataGenerator
import os

In [28]:
def preprocess(img):
    img = img.astype(np.float32) / 255.0
    img = (img - 0.5) * 2
    return img

In [29]:
from glob import glob
import random

def make_dataset(path, batch_size, split="train"):

    def parse_image(filename):
        image = tf.io.read_file(filename)
        image = tf.image.decode_png(image, channels=3)
        #image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.resize(image, [X_DIM, Y_DIM])
        image = image / 255.0
        image = (image - 0.5) * 2
        return image

    def configure_for_performance(ds, cache=True):
        if cache:
            if isinstance(cache, str):
                ds = ds.cache(cache)
            else:
                ds = ds.cache()


        ds = ds.shuffle(buffer_size=1000)
        ds = ds.repeat()
        ds = ds.batch(batch_size)
        ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        return ds

    cache = split + ".tfcache"
    filenames = glob(path + '/*/*')
    random.shuffle(filenames)

    filenames_ds = tf.data.Dataset.from_tensor_slices(filenames)
    ds = filenames_ds.map(parse_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = configure_for_performance(ds, cache=False)

    return ds

In [30]:
BATCH_SIZE = 32
X_DIM, Y_DIM = 512, 512
train_ds = make_dataset(os.path.join(directory, "train"), BATCH_SIZE, split="train")
val_ds = make_dataset(os.path.join(directory, "val"), BATCH_SIZE, split="val")

In [31]:
class VisualizeIOCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_dir):
        super().__init__()
        self.original_batch = tf.data.Dataset.as_numpy_iterator(train_ds).next()
        self.file_writer = tf.summary.create_file_writer(log_dir)
        original_batch = self.original_batch / 2 + 0.5
        with self.file_writer.as_default():
            images = np.reshape(original_batch[0:10], (-1, 512, 512, 3))
            tf.summary.image("10 training input examples", images, max_outputs=10, step=0)


    def on_epoch_end(self, epoch, logs=None):
        z_mean, z_log_var, z = encoder.predict(self.original_batch)
        reconstructed = decoder.predict(z)
        reconstructed = reconstructed / 2 + 0.5

        # Using the file writer, log the reshaped image.
        with self.file_writer.as_default():
            images = np.reshape(reconstructed[0:10], (-1, 512, 512, 3))
            tf.summary.image("10 training output examples", images, max_outputs=10, step=epoch)


In [32]:
from datetime import datetime
# Checkpoint
checkpoint_path = "D:/Notebooks/Advanced_DL/Checkpoint/"

#if not os.path.exists(checkpoint_path):
#    os.makedirs(checkpoint_path)

callbacks = []

callbacks.append(keras.callbacks.ReduceLROnPlateau(monitor='loss',
                                                   min_delta=50,
                                                   patience=5))

now_str = datetime.now().strftime("%Y%m%d-%H%M%S")

cp_callback = keras.callbacks.ModelCheckpoint(os.path.join(checkpoint_path, now_str + '.hdf5'),
                              monitor='loss',
                              save_best_only=True,
                              save_weights_only=True,
                              mode='auto',
                              verbose=1)
callbacks.append(cp_callback)

log_dir = "D:/Notebooks/Advanced_DL/logs/" + now_str

visualization_callback = VisualizeIOCallback(log_dir +"/images")
callbacks.append(visualization_callback)

callbacks.append(keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0))

# Early Stopping
EARLY_STOP = False
if EARLY_STOP:
    es_callback = keras.callbacks.EarlyStopping(monitor='val_total_loss',
                                                   mode='auto',
                                                   patience=10,
                                                   verbose=1)
    callbacks.append(es_callback)

In [33]:
vae = VAE(encoder, decoder, beta=1)
#vae(train_generator.next())
latest_checkpoint = "20221011-200545"
#vae.load_weights("D:/Notebooks/Advanced_DL/Checkpoint/" + latest_checkpoint + ".hdf5")
#opt = keras.optimizers.Adam(learning_rate=1e-3)
opt = keras.optimizers.RMSprop(learning_rate=1e-3)

vae.compile(optimizer=opt, loss=None)

EPOCHS = 200
num_images_train = len(glob(os.path.join(directory, "train") + "/*/*"))
num_images_val = len(glob(os.path.join(directory, "val") + "/*/*"))
print(num_images_train, num_images_val)

56000 7000


In [34]:
import math

results = vae.fit(train_ds,
                  steps_per_epoch=math.ceil(num_images_train/BATCH_SIZE),
                  validation_steps=math.ceil(num_images_val/BATCH_SIZE),
                  epochs=EPOCHS,
                  callbacks = callbacks,
                  validation_data=val_ds,
                  validation_freq=10,
                  workers=8,
                  use_multiprocessing=True,
                  )

Epoch 1/200
Epoch 1: loss improved from inf to 34674.54297, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221012-133849.hdf5
Epoch 2/200
Epoch 2: loss improved from 34674.54297 to 25970.13281, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221012-133849.hdf5
Epoch 3/200
Epoch 3: loss improved from 25970.13281 to 23853.27539, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221012-133849.hdf5
Epoch 4/200
Epoch 4: loss improved from 23853.27539 to 21879.22266, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221012-133849.hdf5
Epoch 5/200
Epoch 5: loss improved from 21879.22266 to 20616.00000, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221012-133849.hdf5
Epoch 6/200
Epoch 6: loss improved from 20616.00000 to 19825.58398, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221012-133849.hdf5
Epoch 7/200
Epoch 7: loss improved from 19825.58398 to 19216.92969, saving model to D:/Notebooks/Advanced_DL/Checkpoint\20221012-133849.hdf5
Epoch 8/200
Epoch 8: 

In [35]:
vae.evaluate(val_ds)

ValueError: When providing an infinite dataset, you must specify the number of steps to run (if you did not intend to create an infinite dataset, make sure to not call `repeat()` on the dataset).