# VAEGAN Implementation

In [1]:
from keras.models import Sequential, Model
from keras.layers import *
from tensorflow.keras.optimizers import *
import os
from keras import metrics, backend as K
from PIL import Image
import numpy as np
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from tensorflow.keras.optimizers import Adam
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Dropout
from keras.layers import Lambda
from keras.layers import Activation
from matplotlib import pyplot
from keras import backend

def generate_latent_points(latent_dim, n_samples):
	z_input = randn(latent_dim * n_samples)
	z_input = z_input.reshape(n_samples, latent_dim)
	return z_input

def generate_real_samples(dataset, n_samples):
	# split into images and labels
	images, labels = dataset
	# choose random instances
	ix = randint(0, images.shape[0], n_samples)
	# select images and labels
	X, labels = images[ix], labels[ix]
	# generate class labels
	y = ones((n_samples, 1))
	return [X, labels], y

def generate_fake_samples(generator, latent_dim, n_samples):
	z_input = generate_latent_points(latent_dim, n_samples)
	images = generator.predict(z_input)
	y = zeros((n_samples, 1))
	return images, y

def load_real_samples():
	(trainX, trainy), (_, _) = load_data()
	X = expand_dims(trainX, axis=-1)
	X = X.astype('float32')
	X = (X - 127.5) / 127.5
	print(X.shape, trainy.shape)
	return [X, trainy]

def encoder(kernel, filter, rows, columns, channel):
    X = Input(shape=(rows, columns, channel))
    model = Conv2D(filters=filter, kernel_size=kernel, strides=2, padding='same')(X)
    model = BatchNormalization(epsilon=1e-5)(model)
    model = LeakyReLU(alpha=0.2)(model)

    model = Conv2D(filters=filter*2, kernel_size=kernel, strides=2, padding='same')(model)
    model = BatchNormalization(epsilon=1e-5)(model)
    model = LeakyReLU(alpha=0.2)(model)

    model = Conv2D(filters=filter*4, kernel_size=kernel, strides=2, padding='same')(model)
    model = BatchNormalization(epsilon=1e-5)(model)
    model = LeakyReLU(alpha=0.2)(model)

    model = Conv2D(filters=filter*8, kernel_size=kernel, strides=2, padding='same')(model)
    model = BatchNormalization(epsilon=1e-5)(model)
    model = LeakyReLU(alpha=0.2)(model)

    model = Flatten()(model)

    mean = Dense(512)(model)
    logsigma = Dense(512, activation='tanh')(model)
    latent = Lambda(sampling, output_shape=(512,))([mean, logsigma])
    meansigma = Model([X], [mean, logsigma, latent])
    meansigma.compile(optimizer = SGDop, loss='mse')
    return meansigma


def decgen(kernel, filter, rows, columns, channel):
    X = Input(shape=(512,))

    model = Dense(filter*8*rows*columns)(X)
    model = Reshape((rows, columns, filter * 8))(model)
    model = BatchNormalization(epsilon=1e-5)(model)
    model = Activation('relu')(model)

    model = Conv2DTranspose(filters=filter*4, kernel_size=kernel, strides=2, padding='same')(model)
    model = BatchNormalization(epsilon=1e-5)(model)
    model = Activation('relu')(model)

    model = Conv2DTranspose(filters=filter*2, kernel_size=kernel, strides=2, padding='same')(model)
    model = BatchNormalization(epsilon=1e-5)(model)
    model = Activation('relu')(model)

    model = Conv2DTranspose(filters=filter, kernel_size=kernel, strides=2, padding='same')(model)
    model = BatchNormalization(epsilon=1e-5)(model)
    model = Activation('relu')(model)

    model = Conv2DTranspose(filters=channel, kernel_size=kernel, strides=2, padding='same')(model)
    model = Activation('tanh')(model)

    model = Flatten()(model)
    model = Dense(rows*columns)(model)
    model = Reshape((rows, columns, 1))(model)
    model = Activation('tanh')(model)

    model = Model(X, model)
    model.compile(optimizer = SGDop, loss='mse')
    return model


def sampling(args):
    mean, logsigma = args
    epsilon = K.random_normal(shape=(K.shape(mean)[0], 512), mean=0., stddev=1.0)
    return mean + K.exp(logsigma / 2) * epsilon

def custom_activation(output):
	logexpsum = backend.sum(backend.exp(output), axis=-1, keepdims=True)
	result = logexpsum / (logexpsum + 1.0)
	return result

def define_gan(dec_model, d_model):
	d_model.trainable = False
	gan_output = d_model(dec_model.output)
	model = Model(dec_model.input, gan_output)
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt)
	return model

def select_supervised_samples(dataset, n_samples=100, n_classes=10):
	X, y = dataset
	X_list, y_list = list(), list()
	n_per_class = int(n_samples / n_classes)
	for i in range(n_classes):
		# get all images for this class
		X_with_class = X[y == i]
		# choose random instances
		ix = randint(0, len(X_with_class), n_per_class)
		# add to list
		[X_list.append(X_with_class[j]) for j in ix]
		[y_list.append(i) for j in ix]
	return asarray(X_list), asarray(y_list)
 
def summarize_performance(step, vae_model, c_model, latent_dim, dataset, n_samples = 100):
  X, y = dataset
  ix = randint(0, len(X), n_samples)
  print(ix)
  X = vae_model.predict(X[ix, :, :, :])
  for i in range(100):
    pyplot.subplot(10, 10, 1 + i)
    pyplot.axis('off')
    pyplot.imshow(X[i, :, :, 0], cmap = 'gray_r')
  filename1 = 'generated_plot_%04d_vae.png' % (step+1)
  pyplot.savefig(filename1)
  pyplot.close()
  X, y = dataset
  _, acc = c_model.evaluate(X, y, verbose=0)
  print('Classifier Accuracy: %.3f%%' % (acc * 100))
  filename2 = 'g_model_%04d_vae.h5' % (step+1)
  vae_model.save(filename2)
  filename3 = 'c_model_%04d_vae.h5' % (step+1)
  c_model.save(filename3)
  print('>Saved: %s, %s, and %s' % (filename1, filename2, filename3))


def train(dec_model, enc_model, vae_model, d_model, c_model, gan_model, dataset, latent_dim, n_epochs = 20, n_batch = 100):
  X_sup, y_sup = select_supervised_samples(dataset)
  print(X_sup.shape, y_sup.shape)
  bat_per_epo = int(dataset[0].shape[0] / n_batch)
  n_steps = bat_per_epo * n_epochs
  half_batch = int(n_batch / 2)
  print('n_epochs=%d, n_batch=%d, 1/2=%d, b/e=%d, steps=%d' % (n_epochs, n_batch, half_batch, bat_per_epo, n_steps))
  for i in range(n_steps):
    [Xsup_real, ysup_real], _ = generate_real_samples([X_sup, y_sup], half_batch)
    c_loss, c_acc = c_model.train_on_batch(Xsup_real, ysup_real)
    [X_real, _], y_real = generate_real_samples(dataset, half_batch)
    d_loss1 = d_model.train_on_batch(X_real, y_real)
    X_fake, y_fake = generate_fake_samples(dec_model, latent_dim, half_batch)
    d_loss2 = d_model.train_on_batch(X_fake, y_fake)
    X_gan, y_gan = enc_model.predict(X_sup)[0], ones((n_batch, 1))
    v_loss = vae_model.train_on_batch(X_sup, y_sup)
    g_loss = gan_model.train_on_batch(X_gan, y_gan)
    print('>%d, c[%.3f,%.0f], d[%.3f,%.3f], g[%.3f], v[%.3f]' % (i+1, c_loss, c_acc*100, d_loss1, d_loss2, g_loss, v_loss))
    if (i+1) % (bat_per_epo * 1) == 0:
      summarize_performance(i, vae_model, c_model, latent_dim, dataset)


def define_discriminator(in_shape=(28,28,1), n_classes=10):
	in_image = Input(shape=in_shape)
	fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(in_image)
	fe = LeakyReLU(alpha=0.2)(fe)
	fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
	fe = LeakyReLU(alpha=0.2)(fe)
	fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
	fe = LeakyReLU(alpha=0.2)(fe)
	fe = Flatten()(fe)
	fe = Dropout(0.4)(fe)
	fe = Dense(n_classes)(fe)
	c_out_layer = Activation('softmax')(fe)
	c_model = Model(in_image, c_out_layer)
	c_model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
	d_out_layer = Lambda(custom_activation)(fe)
	d_model = Model(in_image, d_out_layer)
	d_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
	return d_model, c_model

def define_VAE(enc_model, dec_model, kernel, filter, rows, columns, channel):
  X = Input(shape=(rows, columns, channel))
  mean, logsigma, Z = enc_model(X)
  output = dec_model(Z)
  dec = dec_model(mean + logsigma)
  VAE = Model(X, output)
  kl = - 0.5 * backend.sum(1 + logsigma - backend.square(mean) - backend.exp(logsigma), axis=-1)
  crossent = 64 * metrics.mse(backend.flatten(X), backend.flatten(output))
  VAEloss = backend.mean(crossent + kl)
  VAE.add_loss(VAEloss)
  VAE.compile(optimizer=SGDop)
  return VAE

In [None]:
rows = 28
columns = 28
channel = 1
kernel = 5
filter = 32
SGDop = SGD(learning_rate = 0.0003)
ADAMop = Adam(learning_rate = 0.0002)

latent_dim = 512
d_model, c_model = define_discriminator()
enc_model = encoder(kernel, filter, rows, columns, channel)
dec_model = decgen(kernel, filter, rows, columns, channel)
vae_model = define_VAE(enc_model, dec_model, kernel, filter, rows, columns, channel)
gan_model = define_gan(dec_model, d_model)
dataset = load_real_samples()
train(dec_model, enc_model, vae_model, d_model, c_model, gan_model, dataset, latent_dim)

  super(Adam, self).__init__(name, **kwargs)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
(60000, 28, 28, 1) (60000,)
(100, 28, 28, 1) (100,)
n_epochs=20, n_batch=100, 1/2=50, b/e=600, steps=12000
>1, c[2.308,6], d[0.093,2.399], g[0.092], v[288.767]
>2, c[2.285,14], d[0.087,2.406], g[0.077], v[277.528]
>3, c[2.301,6], d[0.084,2.410], g[0.080], v[253.656]
>4, c[2.271,18], d[0.080,2.411], g[0.085], v[234.552]
>5, c[2.243,22], d[0.077,2.406], g[0.091], v[220.356]
>6, c[2.247,20], d[0.074,2.394], g[0.101], v[208.755]
>7, c[2.307,8], d[0.072,2.372], g[0.113], v[199.011]
>8, c[2.230,16], d[0.070,2.340], g[0.131], v[190.396]
>9, c[2.208,8], d[0.068,2.293], g[0.154], v[183.133]
>10, c[2.204,16], d[0.064,2.232], g[0.189], v[176.285]
>11, c[2.263,10], d[0.062,2.155], g[0.248], v[169.866]
>12, c[2.196,16], d[0.061,2.053], g[0.343], v[163.858]
>13, c[2.180,24], d[0.060,1.956], g[0.484], v[158.636]
>14, c[2.084,30], d[0.062,1.952], g[0.090], v[137.933]
>15, c[1.957,32], d[0.058,2.337], g[0.118], 