<a href="https://colab.research.google.com/github/burn0/Deepnoid-Education/blob/main/WGAN_keras_bykim.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Source code :**

J. Brownlee, "How to Develop a Wasserstein Generative Adversarial Network (WGAN) From Scratch," Machine Learning Mastery, 2019

https://machinelearningmastery.com/how-to-code-a-wasserstein-generative-adversarial-network-wgan-from-scratch/

# **Purpose**

WGAN을 이용하여 MNIST 데이터 (숫자 손글씨)를 생성해봅시다.

# **1. Module import**
코드 실행에 필요한 module을 import 합니다.

만약, 본인이 코드를 수정할 때 필요로 하는 모듈이 있다면 추가해주시면 됩니다.

In [None]:
from numpy import expand_dims
from numpy import mean
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras import backend
from keras.optimizers import RMSprop
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.initializers import RandomNormal
from keras.constraints import Constraint
from matplotlib import pyplot

# **2. Weight clipping**

WGAN의 특징 중 하나인 weight clipping을 define 해줍니다.

In [None]:
# clip model weights to a given hypercube
class ClipConstraint(Constraint):
	# set clip value when initialized
	def __init__(self, clip_value):
		self.clip_value = clip_value
 
	# clip model weights to hypercube
	def __call__(self, weights):
		return backend.clip(weights, -self.clip_value, self.clip_value)
 
	# get the config
	def get_config(self):
		return {'clip_value': self.clip_value}

# **3. Wasserstein loss**

원 논문의 저자는 Pytorch로 해당 loss를 구현했습니다.

Keras에서는 아래 코드와 같이 Wasserstein loss를 구현할 수 있습니다.

Keras에 내장되어 있는 loss가 아닌 customized loss를 사용하고 싶을 때는 아래와 같이 loss를 define 해준 후 model compile 할때 loss 부분에 기입하면 됩니다.

In [None]:
# calculate wasserstein loss
def wasserstein_loss(y_true, y_pred):
	return backend.mean(y_true * y_pred)

# **4. Critic model**

Critic으로 사용할 network를 구성해봅시다.

2개의 convolution layer 에서는 weight clipping을 적용했습니다.

Optimizer는 강의에서 말씀드렸다시피 RMSProp를 사용하였습니다.

Learning rate는 비교적 작은 값으로 설정합니다.

model을 compile 할때, loss는 미리 define한 wasserstein loss를 사용합니다.

In [None]:
# define the standalone critic model
def define_critic(in_shape=(28,28,1)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# weight constraint
	const = ClipConstraint(0.01)
	# define model
	model = Sequential()
	# downsample to 14x14
	model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	# downsample to 7x7
	model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	# scoring, linear activation
	model.add(Flatten())
	model.add(Dense(1))
	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

# **5. Generator model**

1개의 fully-connected layer와 3개의 convolution layer로 구성된 간단한 network를 통해 generator를 구성해봅시다.

Output layer를 제외한 나머지 layer에서는 leakyReLU activation function을 사용하였으며, 2개의 convolution layer에서는 batch normalization을 적용하였습니다.

In [None]:
# define the standalone generator model
def define_generator(latent_dim):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# define model
	model = Sequential()
	# foundation for 7x7 image
	n_nodes = 128 * 7 * 7
	model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Reshape((7, 7, 128)))
	# upsample to 14x14
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	# upsample to 28x28
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	# output 28x28x1
	model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
	return model

# **6. GAN model**

GAN을 이용한 학습이후 최종적으로 사용하게 되는 model은 generator 부분입니다.

Generator는 critic의 영향을 받아 학습을 수행하기 때문에 아래와 같이 define 합니다.

In [None]:
# define the combined generator and critic model, for updating the generator
def define_gan(generator, critic):
	# make weights in the critic not trainable
	for layer in critic.layers:
		if not isinstance(layer, BatchNormalization):
			layer.trainable = False
	# connect them
	model = Sequential()
	# add generator
	model.add(generator)
	# add the critic
	model.add(critic)
	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

# **7. Data prepare**

대표적인 public dataset MNIST dataset을 불러옵니다.

0에서 9까지 손글씨 데이터 중 맘에드는 숫자를 골라 불러와봅시다.

각 영상의 Min, Max value를 generator의 output과 동일하게 -1 ~ +1 로 rescaling 해줍니다.

Random하게 real, fake sample을 만들고 각각에 대해 -1, +1 을 label로 설정해줍니다.

In [None]:
# load images
def load_real_samples():
	# load dataset
	(trainX, trainy), (_, _) = load_data()
	# select all of the examples for a given class
	selected_ix = trainy == 7
	X = trainX[selected_ix]
	# expand to 3d, e.g. add channels
	X = expand_dims(X, axis=-1)
	# convert from ints to floats
	X = X.astype('float32')
	# scale from [0,255] to [-1,1]
	X = (X - 127.5) / 127.5
	return X

In [None]:
# select real samples
def generate_real_samples(dataset, n_samples):
	# choose random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# select images
	X = dataset[ix]
	# generate class labels, -1 for 'real'
	y = -ones((n_samples, 1))
	return X, y

In [None]:
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
	# generate points in the latent space
	x_input = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input

In [None]:
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
	# generate points in latent space
	x_input = generate_latent_points(latent_dim, n_samples)
	# predict outputs
	X = generator.predict(x_input)
	# create class labels with 1.0 for 'fake'
	y = ones((n_samples, 1))
	return X, y

In [None]:
# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, latent_dim, n_samples=100):
	# prepare fake examples
	X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
	# scale from [-1,1] to [0,1]
	X = (X + 1) / 2.0
	# plot images
	for i in range(10 * 10):
		# define subplot
		pyplot.subplot(10, 10, 1 + i)
		# turn off axis
		pyplot.axis('off')
		# plot raw pixel data
		pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
	# save plot to file
	filename1 = 'generated_plot_%04d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	# save the generator model
	filename2 = 'model_%04d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

In [None]:
# create a line plot of loss for the gan and save to file
def plot_history(d1_hist, d2_hist, g_hist):
	# plot history
	pyplot.plot(d1_hist, label='crit_real')
	pyplot.plot(d2_hist, label='crit_fake')
	pyplot.plot(g_hist, label='gen')
	pyplot.legend()
	pyplot.savefig('plot_line_plot_loss.png')
	pyplot.close()

In [None]:
# train the generator and critic
def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=64, n_critic=5):
	# calculate the number of batches per training epoch
	bat_per_epo = int(dataset.shape[0] / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# calculate the size of half a batch of samples
	half_batch = int(n_batch / 2)
	# lists for keeping track of loss
	c1_hist, c2_hist, g_hist = list(), list(), list()
	# manually enumerate epochs
	for i in range(n_steps):
		# update the critic more than the generator
		c1_tmp, c2_tmp = list(), list()
		for _ in range(n_critic):
			# get randomly selected 'real' samples
			X_real, y_real = generate_real_samples(dataset, half_batch)
			# update critic model weights
			c_loss1 = c_model.train_on_batch(X_real, y_real)
			c1_tmp.append(c_loss1)
			# generate 'fake' examples
			X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
			# update critic model weights
			c_loss2 = c_model.train_on_batch(X_fake, y_fake)
			c2_tmp.append(c_loss2)
		# store critic loss
		c1_hist.append(mean(c1_tmp))
		c2_hist.append(mean(c2_tmp))
		# prepare points in latent space as input for the generator
		X_gan = generate_latent_points(latent_dim, n_batch)
		# create inverted labels for the fake samples
		y_gan = -ones((n_batch, 1))
		# update the generator via the critic's error
		g_loss = gan_model.train_on_batch(X_gan, y_gan)
		g_hist.append(g_loss)
		# summarize loss on this batch
		print('>%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))
		# evaluate the model performance every 'epoch'
		if (i+1) % bat_per_epo == 0:
			summarize_performance(i, g_model, latent_dim)
	# line plots of loss
	plot_history(c1_hist, c2_hist, g_hist)

In [None]:
# size of the latent space
latent_dim = 50

In [None]:
# create the critic
critic = define_critic()

In [None]:
# create the generator
generator = define_generator(latent_dim)

In [None]:
# create the gan
gan_model = define_gan(generator, critic)

In [None]:
# load image data
dataset = load_real_samples()
print(dataset.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
(6265, 28, 28, 1)


In [None]:
# train model
train(generator, critic, gan_model, dataset, latent_dim)

>1, c1=-2.093, c2=0.039 g=-0.451
>2, c1=-6.552, c2=0.126 g=-1.705
>3, c1=-9.801, c2=0.195 g=-2.771
>4, c1=-12.063, c2=0.262 g=-3.654
>5, c1=-14.278, c2=0.325 g=-4.973
>6, c1=-16.481, c2=0.387 g=-5.933
>7, c1=-18.076, c2=0.483 g=-6.891
>8, c1=-18.622, c2=0.538 g=-7.695
>9, c1=-20.449, c2=0.624 g=-8.787
>10, c1=-21.319, c2=0.695 g=-10.021
>11, c1=-21.959, c2=0.795 g=-10.813
>12, c1=-24.063, c2=0.876 g=-11.631
>13, c1=-24.211, c2=0.960 g=-12.271
>14, c1=-25.077, c2=1.055 g=-13.310
>15, c1=-25.981, c2=1.128 g=-14.318
>16, c1=-27.447, c2=1.212 g=-14.925
>17, c1=-27.227, c2=1.265 g=-15.344
>18, c1=-28.518, c2=1.335 g=-16.321
>19, c1=-28.274, c2=1.398 g=-16.753
>20, c1=-29.125, c2=1.451 g=-16.674
>21, c1=-29.791, c2=1.470 g=-18.304
>22, c1=-30.993, c2=1.457 g=-17.895
>23, c1=-31.110, c2=1.466 g=-18.449
>24, c1=-31.919, c2=1.445 g=-19.369
>25, c1=-31.630, c2=1.370 g=-20.192
>26, c1=-32.683, c2=1.414 g=-20.781
>27, c1=-32.873, c2=1.227 g=-21.320
>28, c1=-34.091, c2=1.160 g=-22.332
>29, c1=-33.8