In [None]:
import tensorflow as tf
from numpy import expand_dims
from numpy import mean
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras import backend
from keras.optimizers import RMSprop
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.initializers import RandomNormal
from keras.constraints import Constraint
from matplotlib import pyplot

In [None]:
def wasserstein_loss(y_true, y_pred):
	return tf.reduce_mean(y_true * y_pred)

In [None]:
# load images from the dataset
def load_real_samples():
	(trainX, trainy), (_, _) = load_data()
	selected_ix = trainy == 5
	X = trainX[selected_ix]
	X = expand_dims(X, axis=-1)
	X = X.astype('float32')
	X = (X - 127.5) / 127.5
	return X

# select real samples
def generate_real_samples(dataset, n_samples):
	ix = randint(0, dataset.shape[0], n_samples)
	X = dataset[ix]
	y = -ones((n_samples, 1))
	return X, y

In [None]:
# select real samples
def generate_real_samples(dataset, n_samples):
	ix = randint(0, dataset.shape[0], n_samples)
	X = dataset[ix]
	y = -ones((n_samples, 1))
	return X, y

In [None]:
# clip model
class ClipConstraint(Constraint):
	def __init__(self, clip_value):
		self.clip_value = clip_value
	def __call__(self, weights):
		return backend.clip(weights, -self.clip_value, self.clip_value)

In [None]:
# critic model
def define_critic(in_shape=(28,28,1)):
	init = RandomNormal(stddev=0.02)
	const = ClipConstraint(0.01)
	model = Sequential()
	model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	model.add(Flatten())
	model.add(Dense(1))
	opt = RMSprop(learning_rate=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

In [None]:
def define_generator(latent_dim):

	init = RandomNormal(stddev=0.03)
	# define model
	model = Sequential()
	n_nodes = 128 * 7 * 7
	model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Reshape((7, 7, 128)))
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
	return model

In [None]:
def define_gan(generator, critic):
	# make weights in the critic not trainable
	for layer in critic.layers:
		if not isinstance(layer, BatchNormalization):
			layer.trainable = False
	model = Sequential()
	model.add(generator)
	model.add(critic)
	opt = RMSprop(learning_rate=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

In [None]:
def generate_latent_points(latent_dim, n_samples):
	x_input = randn(latent_dim * n_samples)
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input

# fake examples
def generate_fake_samples(generator, latent_dim, n_samples):
	x_input = generate_latent_points(latent_dim, n_samples)
	X = generator.predict(x_input)
	y = ones((n_samples, 1))
	return X, y

In [None]:
# train the generator and critic
def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=64, n_critic=5):
	# number of batches per training epoch
	bat_per_epo = int(dataset.shape[0] / n_batch)
	# number of training iterations
	n_steps = bat_per_epo * n_epochs
	# size of half a batch of samples
	half_batch = int(n_batch / 2)

	c1_hist, c2_hist, g_hist = list(), list(), list()

	for i in range(n_steps):
		# update the critic
		c1_tmp, c2_tmp = list(), list()
		for _ in range(n_critic):
			X_real, y_real = generate_real_samples(dataset, half_batch)
			c_loss1 = c_model.train_on_batch(X_real, y_real)
			c1_tmp.append(c_loss1)
			X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
			c_loss2 = c_model.train_on_batch(X_fake, y_fake)
			c2_tmp.append(c_loss2)
		c1_hist.append(mean(c1_tmp))
		c2_hist.append(mean(c2_tmp))
		X_gan = generate_latent_points(latent_dim, n_batch)
		y_gan = -ones((n_batch, 1))
		g_loss = gan_model.train_on_batch(X_gan, y_gan)
		# Check if g_loss is a list and extract the first element if it is
		if isinstance(g_loss, list):
			g_loss = g_loss[0]

		g_hist.append(g_loss)
		print('>%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))
		# evaluate the model performance every 'epoch'
		if (i+1) % bat_per_epo == 0:
			summarize_performance(i, g_model, latent_dim)
	# line plots of loss
	plot_history(c1_hist, c2_hist, g_hist)

In [None]:
def summarize_performance(step, g_model, latent_dim, n_samples=100):
	X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
	X = (X + 1) / 2.0
	for i in range(10 * 10):
		pyplot.subplot(10, 10, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
	filename1 = 'plot_%04d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()

def plot_history(d1_hist, d2_hist, g_hist):
	# plot history
	pyplot.plot(d1_hist, label='crit_real')
	pyplot.plot(d2_hist, label='crit_fake')
	pyplot.plot(g_hist, label='gen')
	pyplot.legend()
	pyplot.savefig('line_plot_loss.png')
	pyplot.close()

In [None]:
latent_dim = 50
critic = define_critic()
generator = define_generator(latent_dim)
gan_model = define_gan(generator, critic)
dataset = load_real_samples()
print(dataset.shape)
train(generator, critic, gan_model, dataset, latent_dim)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


(5421, 28, 28, 1)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 341ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
>1, c1=-0.492, c2=-0.359 g=-0.350
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
>2, c1=-0.294, c2=-0.242 g=-0.190
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0

KeyboardInterrupt: 

In [None]:
import os
import imageio

imgdir = '/content/'

gif_files = [file for file in os.listdir(imgdir) if file.startswith('plot_')]

gif_files.sort()

images = []
for image_file in gif_files:
	image_path = os.path.join(imgdir, image_file)
	images.append(imageio.imread(image_path))

imageio.mimsave('/content/output.gif', images,format="GIF", fps=2)