# WGAN - Wasserstein Generative Adversarial Network

# Import TensorFlow 2.x.

In [0]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
tf.random.set_seed(7)

import tensorflow.keras.layers as layers
import tensorflow.keras.models as models

import numpy as np
np.random.seed(7)

import matplotlib.pyplot as plot

print(tf.__version__)

# Clip model weights to a given hypercube.

In [0]:
from tf.keras.constraints import Constraint

In [0]:
class ClipConstraint(Constraint):	
	def __init__(self, clip_value):
		self._clip_value = clip_value

	def __call__(self, weights):
		return tf.keras.backend.clip(weights, -self._clip_value, self._clip_value)

	def get_config(self):
		return {'clip_value': self._clip_value}

# Compute Wasserstein loss.

In [0]:
def wasserstein_loss(y_true, y_pred):
	return tf.keras.backend.mean(y_true * y_pred)

# Create the critic model.

In [0]:
def create_critic_model(input_shape=(28,28,1)):

	initializer = tf.keras.initializers.RandomNormal(stddev=0.02)
	constraint = ClipConstraint(0.01)
 
	model = tf.keras.Models.Sequential()
 
	model.add(layers.Conv2D(64, (4,4), strides=(2,2), 
                         padding='same', 
                         kernel_initializer=initializer, kernel_constraint=constraint, 
                         input_shape=input_shape))
	model.add(layers.BatchNormalization())
	model.add(layers.LeakyReLU(alpha=0.2))
 
	model.add(layers.Conv2D(64, (4,4), strides=(2,2), 
                         padding='same', 
                         kernel_initializer=initializer, kernel_constraint=constraint))
	model.add(layers.BatchNormalization())
	model.add(layers.LeakyReLU(alpha=0.2))
 
	model.add(layers.Flatten())
	model.add(layers.Dense(1))
 
	optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=optimizer)
 
	return(model)

In [0]:
input_shape = (28,28,1)
critic_model = create_critic_model(input_shape)

# Create the generator model.

In [0]:
def create_generator_model(latent_dimension=50):
	initializer = tf.keras.initializers.RandomNormal(stddev=0.02)

	model = tf.keras.models.Sequential()

	number_of_nodes = 128 * 7 * 7

	model.add(layers.Dense(number_of_nodes, kernel_initializer=initializer, input_dim=latent_dimension))
	model.add(layers.LeakyReLU(alpha=0.2))
	model.add(layers.Reshape((7, 7, 128)))
 
	model.add(layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=initializer))
	model.add(layers.BatchNormalization())
	model.add(layers.LeakyReLU(alpha=0.2))
 
	model.add(layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=initializer))
	model.add(layers.BatchNormalization())
	model.add(layers.LeakyReLU(alpha=0.2))
 
	model.add(layers.Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=initializer))
	return(model)

### Define the size of the latent space.

In [0]:
latent_dimension = 50
generator_model = create_generator_model(latent_dimension)

# Create the GAN model.
* Combine the generator and the critic models.
* Use for updating the generator model.

In [0]:
def create_gan_model(generator, critic):	
	critic.trainable = False

	model = tf.keras.models.Sequential()
 
	model.add(generator)
	model.add(critic)

	optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.00005) 
	model.compile(loss=wasserstein_loss, optimizer=optimizer) 
	return(model)

In [0]:
gan_model = define_gan(generator_model, critic_model)

# Load dataset.

### Load real samples.

In [0]:
from keras.datasets.mnist import load_data

In [0]:
def load_real_samples():

	(trainX, trainy), (_, _) = load_data()
	selected_ix = trainy == 7
	X = trainX[selected_ix]

	X = tf.expand_dims(X, axis=-1) 
  X = tf.constant(X, dtype=tf.float32)
	X = (X - 127.5) / 127.5

	return( X )

### Generate real samples.

In [0]:
def generate_real_samples(dataset, number_of_samples):
	# choose random instances
	ix = randint(0, dataset.shape[0], number_of_samples)
	# select images
	X = dataset[ix]
	# generate class labels, -1 for 'real'
	y = -ones((number_of_samples, 1))
	return X, y

### Generate points in latent space as input for the generator.

In [0]:
def generate_latent_points(latent_dimension, number_of_samples):
	x_input = np.random.randn(latent_dimension * number_of_samples)
	x_input = x_input.reshape(number_of_samples, latent_dimension)
	return(x_input)

### Generate fake samples with class labels using the generator.

In [0]:
def generate_fake_samples(generator, latent_dimension, number_of_samples):
	x_input = generate_latent_points(latent_dimension, number_of_samples)
	X = generator.predict(x_input)
	y = ones((number_of_samples, 1))
	return(X, y)

# Create the graph.

In [0]:
def plot_graph(critic_real_history, critic_fake_history, generator_history):
	pyplot.plot(critic_real_history, label='critic real')
	pyplot.plot(critic_fake_history, label='critic fake')
	pyplot.plot(generator_history, label='generator')
	pyplot.legend()	
	pyplot.show()

In [0]:
from numpy import mean
from numpy import ones
from numpy.random import randint

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, latent_dimension, number_of_samples=100):
	# prepare fake examples
	X, _ = generate_fake_samples(g_model, latent_dimension, number_of_samples)
	# scale from [-1,1] to [0,1]
	X = (X + 1) / 2.0
	# plot images
	for i in range(10 * 10):
		# define subplot
		pyplot.subplot(10, 10, 1 + i)
		# turn off axis
		pyplot.axis('off')
		# plot raw pixel data
		pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
	# save plot to file
	filename1 = 'generated_plot_%04d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	# save the generator model
	filename2 = 'model_%04d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

# train the generator and critic
def train(g_model, c_model, gan_model, dataset, latent_dimension, n_epochs=10, n_batch=64, n_critic=5):
	# calculate the number of batches per training epoch
	bat_per_epo = int(dataset.shape[0] / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# calculate the size of half a batch of samples
	half_batch = int(n_batch / 2)
	# lists for keeping track of loss
	c1_hist, c2_hist, g_hist = list(), list(), list()
	# manually enumerate epochs
	for i in range(n_steps):
		# update the critic more than the generator
		c1_tmp, c2_tmp = list(), list()
		for _ in range(n_critic):
			# get randomly selected 'real' samples
			X_real, y_real = generate_real_samples(dataset, half_batch)
			# update critic model weights
			c_loss1 = c_model.train_on_batch(X_real, y_real)
			c1_tmp.append(c_loss1)
			# generate 'fake' examples
			X_fake, y_fake = generate_fake_samples(g_model, latent_dimension, half_batch)
			# update critic model weights
			c_loss2 = c_model.train_on_batch(X_fake, y_fake)
			c2_tmp.append(c_loss2)
		# store critic loss
		c1_hist.append(mean(c1_tmp))
		c2_hist.append(mean(c2_tmp))
		# prepare points in latent space as input for the generator
		X_gan = generate_latent_points(latent_dimension, n_batch)
		# create inverted labels for the fake samples
		y_gan = -ones((n_batch, 1))
		# update the generator via the critic's error
		g_loss = gan_model.train_on_batch(X_gan, y_gan)
		g_hist.append(g_loss)
		# summarize loss on this batch
		print('>%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))
		# evaluate the model performance every 'epoch'
		if (i+1) % bat_per_epo == 0:
			summarize_performance(i, g_model, latent_dimension)
	# line plots of loss
	plot_history(c1_hist, c2_hist, g_hist)


# load image data
dataset = load_real_samples()
print(dataset.shape)
# train model
train(generator, critic, gan_model, dataset, latent_dimension)