# Generating anime faces with Variational Autoencoders (VAE)

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, LeakyReLU, Dropout
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint 
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
from glob import glob

# Data Preparation

In [None]:
WEIGHTS_FOLDER = '/kaggle/working/weights/'
DATA_FOLDER = '/kaggle/input/animefacedataset/images//'

if not os.path.exists(WEIGHTS_FOLDER):
  os.makedirs(os.path.join(WEIGHTS_FOLDER,"AE"))
  os.makedirs(os.path.join(WEIGHTS_FOLDER,"VAE"))

filenames = np.array(glob(os.path.join(DATA_FOLDER, '*.jpg')))
NUM_IMAGES = len(filenames)
print("Total number of images : " + str(NUM_IMAGES))

In [None]:
def build_decoder(test=False, out_size=(64, 64)):
    def decoder(path):
        img = file_bytes = tf.io.read_file(path)
        img = tf.image.decode_jpeg(file_bytes, channels=3)  
        img = tf.image.resize(img, (64, 64))
        img = tf.cast(img, tf.float32) / 255.0
        return img
    def decoder_train(path):
        return decoder(path), decoder(path)

    return decoder if test else decoder_train

def build_dataset(paths, test=False, shuffle=1, batch_size=1):
    AUTO = tf.data.experimental.AUTOTUNE
    decoder = build_decoder(test)

    dset = tf.data.Dataset.from_tensor_slices(paths)
    dset = dset.map(decoder, num_parallel_calls=AUTO)
    
    dset = dset.shuffle(shuffle)
    dset = dset.batch(batch_size)
    return dset

In [None]:
INPUT_DIM = (64,64,3) # Image dimension
BATCH_SIZE = 128
Z_DIM = 100 # Dimension of the latent vector (z)

In [None]:
train_paths, valid_paths, _, _ = train_test_split(filenames, filenames, test_size=0.2, shuffle=True)

train_dataset = build_dataset(train_paths, batch_size=128)
valid_dataset = build_dataset(valid_paths, batch_size=128)

# Autoencoder Model Definition

In [None]:
class ConvAutoencoder:
    @staticmethod
    def build(input_dim, latentDim=Z_DIM):
        inputs = Input(shape = input_dim)
        x = inputs
        
        for f in [32, 64, 64, 64]:
            x = Conv2D(f, (3,3), strides=2, padding="same")(x)
            x = LeakyReLU()(x)
        
        volumeSize = K.int_shape(x)
        x = Flatten()(x)
        latent = Dense(latentDim)(x)
        encoder = Model(inputs, latent, name = "encoder")
        
        print(encoder.summary())
        
        latentInputs = Input(shape=(latentDim,))
        x = Dense(np.prod(volumeSize[1:]))(latentInputs)
        x = Reshape((volumeSize[1], volumeSize[2], volumeSize[3]))(x)
        
        for f in [64, 64, 32]:
            x = Conv2DTranspose(f, (3, 3), strides=2, padding="same")(x)
            x = LeakyReLU()(x)

        x = Conv2DTranspose(3, (3, 3), strides=2, padding="same")(x)
        outputs = Activation("sigmoid")(x)

        decoder = Model(latentInputs, outputs, name="decoder")
        
        print(decoder.summary())
        
        autoencoder = Model(inputs, decoder(encoder(inputs)),name="autoencoder")
        
        print(autoencoder.summary())
        return (encoder, decoder, autoencoder)

In [None]:
encoder, decoder, autoencoder = ConvAutoencoder.build(INPUT_DIM)

In [None]:
LEARNING_RATE = 0.0005
N_EPOCHS = 10

optimizer = Adam(lr = LEARNING_RATE)

def r_loss(y_true, y_pred):
    return K.mean(K.square(y_true - y_pred), axis = [1,2,3])

autoencoder.compile(optimizer=optimizer, loss = r_loss)

checkpoint_ae_best = ModelCheckpoint(os.path.join(WEIGHTS_FOLDER, 'AE/ae_best_weights.h5'),
                                     monitor='val_loss',
                                     mode='min',
                                     save_best_only=True,
                                     save_weights_only = False, 
                                     verbose=1)

checkpoint_ae_last = ModelCheckpoint(os.path.join(WEIGHTS_FOLDER, 'AE/ae_last_weights.h5'),
                                     monitor='val_loss',
                                     mode='min',
                                     save_best_only=False,
                                     save_weights_only = False, 
                                     verbose=1)

In [None]:
autoencoder.fit(train_dataset,
                epochs=10,
                callbacks=[checkpoint_ae_best, checkpoint_ae_last],
                validation_data=valid_dataset)

# Inference using trained Autoencoder

In [None]:
test_dataset = build_dataset(valid_paths, test=True)
autoencoder.load_weights('/kaggle/working/weights/AE/ae_last_weights.h5')

In [None]:
data = list(test_dataset.take(20))

fig = plt.figure(figsize=(30, 10))
for n in range(0, 20, 2):
    image = autoencoder.predict(data[n])
    
    plt.subplot(2, 10, n + 1)
    plt.imshow(np.squeeze(data[n]))
    plt.title('original image')
    
    plt.subplot(2, 10, n + 2)
    plt.imshow(np.squeeze(image))
    plt.title('reconstruct')
    
plt.show()

# Variable Autoencoder Model Definition

In [None]:
class VariableAutoencoder:
    @staticmethod    
    def build(input_shape=INPUT_DIM):

        #Encoder
        input_encoder = Input(shape=(input_shape))
        x = Conv2D(32, kernel_size=(3, 3), strides = 2, padding='same', name='encoder_cov2d_1')(input_encoder)
        x = LeakyReLU()(x)
        x = Conv2D(64, kernel_size=(3, 3), strides = 2, padding='same', name='encoder_cov2d_2')(x)
        x = LeakyReLU()(x)
        x = Conv2D(64, kernel_size=(3, 3), strides = 2, padding='same', name='encoder_conv2d_3')(x)
        x = LeakyReLU()(x)
        x = Conv2D(64, kernel_size=(3, 3), strides = 2, padding='same', name='encoder_conv2d_4')(x)
        volumeSize = K.int_shape(x)
        x = Flatten()(x)

        latent_mu = Dense(Z_DIM, name='latent_mean')(x)
        latent_log_var = Dense(Z_DIM, name='latent_log_var')(x)
        
        def sampling(args=None):
            z_mean, z_log_var = args
            batch = K.shape(z_mean)[0]

            epsilon = K.random_normal(shape=(batch, Z_DIM))
            return z_mean + K.exp(0.5 * z_log_var) * epsilon
        
        latent_sample = Lambda(sampling)([latent_mu, latent_log_var])
        encoder = Model(input_encoder, latent_sample, name='encoder')

        latent_input = Input(shape=(Z_DIM,), name='decoder_input')
        x = Dense(np.prod(volumeSize[1:]))(latent_input)
        x = Reshape((volumeSize[1], volumeSize[2], volumeSize[3]))(x)
        x = Conv2DTranspose(64, kernel_size=(3, 3), strides=2, padding='same', name='conv2d_1')(x)
        x = LeakyReLU()(x)
        x = Conv2DTranspose(64, kernel_size=(3, 3), strides=2, padding='same', name='conv2d_2')(x)
        x = LeakyReLU()(x)
        x = Conv2DTranspose(32, kernel_size=(3, 3), strides=2, padding='same', name='conv2d_3')(x)
        x = LeakyReLU()(x)
        x = Conv2DTranspose(3, kernel_size=(3, 3), strides=2, padding='same', name='conv2d_4')(x)
        output_decoder = Activation('sigmoid')(x)

        decoder = Model(latent_input, output_decoder, name='decoder')

        output_vae = decoder(encoder(input_encoder))
        variable_autoencoder = Model(input_encoder, output_vae, name ='variable_autoencoder')

        reconstruction_loss = binary_crossentropy(input_encoder, output_vae) * (64 * 64)
        reconstruction_loss = K.mean(reconstruction_loss)

        kl_loss = 1 + latent_log_var - K.square(latent_mu) - K.exp(latent_log_var)
        kl_loss = K.sum(kl_loss, axis=-1)
        kl_loss *= -0.5

        vae_loss = K.mean(reconstruction_loss + kl_loss)

        variable_autoencoder.add_loss(vae_loss)  
        variable_autoencoder.add_metric(reconstruction_loss, name='reconstruction_loss')
        variable_autoencoder.add_metric(kl_loss, name='kl_divergence_loss')

        return variable_autoencoder, encoder, decoder

In [None]:
var_autoencoder, var_encoder, var_decoder = VariableAutoencoder.build()
var_autoencoder.compile(optimizer='adam')

# encoder.summary()
# decoder.summary()
var_autoencoder.summary()

# Train Variable Autoencoder

In [None]:
VAE_N_EPOCHS = 10
checkpoint_vae_best = ModelCheckpoint(os.path.join(WEIGHTS_FOLDER, 'VAE/vae_best_model.h5'), 
                                      monitor='val_loss',
                                      mode='min',
                                      save_best_only=True,
                                      save_weights_only = False, 
                                      verbose=1)
    
checkpoint_vae_last = ModelCheckpoint(os.path.join(WEIGHTS_FOLDER, 'VAE/vae_last_model.h5'),
                                      monitor='val_loss',
                                      mode='min',
                                      verbose=1,
                                      save_best_only=False,
                                      save_weights_only=False)

var_autoencoder.fit(train_dataset,
                    epochs=VAE_N_EPOCHS,
                    callbacks=[checkpoint_vae_best, checkpoint_vae_last],
                    validation_data=valid_dataset)

# Encoding-Decoding using trained Variable Autoencoder

In [None]:
test_dataset = build_dataset(valid_paths, test=True)
var_autoencoder.load_weights(os.path.join(WEIGHTS_FOLDER, 'VAE/vae_last_model.h5'))

In [None]:
data = list(test_dataset.take(20))

fig = plt.figure(figsize=(30, 10))
for n in range(0, 20, 2):
    image = var_autoencoder.predict(data[n])
    
    plt.subplot(2, 10, n + 1)
    plt.imshow(np.squeeze(data[n]))
    plt.title('original image')
    
    plt.subplot(2, 10, n + 2)
    plt.imshow(np.squeeze(image))
    plt.title('reconstruct')
    
plt.show()

# Generate new anime faces using Variable Autoencoder

In [None]:
def vae_generate_images(n_to_show=20):
    random_codes = np.random.normal(size=(n_to_show, Z_DIM))
    new_faces = var_decoder.predict(np.array(random_codes))

    fig = plt.figure(figsize=(30, 15))

    for i in range(n_to_show):
        ax = fig.add_subplot(6, 10, i+1)
        ax.imshow(new_faces[i])
        ax.axis('off')
    plt.show()

In [None]:
vae_generate_images(30)