<a href="https://colab.research.google.com/github/harshyadav1508/GAN/blob/main/WGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#
%cd "/content/drive/MyDrive/Colab Notebooks/GAN/WGAN"
!pwd

/content/drive/MyDrive/Colab Notebooks/GAN/WGAN
/content/drive/MyDrive/Colab Notebooks/GAN/WGAN


In [3]:
import pandas as pd
from keras import backend
from keras.constraints import Constraint
from keras.datasets.mnist import load_data
from tensorflow.keras.optimizers import RMSprop
from keras.models import Sequential
from keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU, BatchNormalization
from keras.initializers import RandomNormal
import numpy as np
from numpy import ones, mean, expand_dims, reshape
from numpy.random import randn, randint
from matplotlib import pyplot

# Clip model weights to a given hypercube

In [4]:
class ClipConstraint(Constraint):

  def __init__(self, clip_value):
    self.clip_value=clip_value

  def __call__(self,weights):
    return backend.clip(weights,-self.clip_value,self.clip_value)

  def get_config(self):
    return {'clip_value':self.clip_value}

# Calculate wasserstein loss

In [5]:
def wasserstein_loss(y_true, y_pred):
  return backend.mean(y_true*y_pred)

# Constant

In [6]:
init = RandomNormal(stddev=0.02)
const= ClipConstraint(0.01)
opt = RMSprop(lr=0.00005)
latent_dim=50

  super(RMSprop, self).__init__(name, **kwargs)


# Define the standalone critic model

In [7]:
def define_critic(in_shape=(28,28,1)):

  model=Sequential([
      Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape, activation=LeakyReLU(alpha=0.2)),
      BatchNormalization(momentum=0.8),
      Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, activation=LeakyReLU(alpha=0.2)),
      BatchNormalization(momentum=0.8),
      Flatten(),
      Dense(1)
  ])

  model.compile(loss=wasserstein_loss, optimizer=opt)

  return model


# Define the standalone generator model

In [8]:
from keras.backend import tanh
def define_generator(latent_dim):

  n_nodes=7*7*128
  model=Sequential(
      [
          Dense(n_nodes, kernel_initializer=init,input_dim=latent_dim,activation=LeakyReLU(alpha=0.2)),
          Reshape((7,7,128)),
          Conv2DTranspose(128,(4,4),padding='same', strides=(2,2), kernel_initializer=init, activation=LeakyReLU(alpha=0.2)),
          BatchNormalization(momentum=0.8),
          Conv2DTranspose(128,(4,4),padding='same', strides=(2,2), kernel_initializer=init, activation=LeakyReLU(alpha=0.2)),
          BatchNormalization(momentum=0.8),
          Conv2D(1,(7,7),padding='same', kernel_initializer=init, activation='tanh')
      ]
  )

  model.summary()
  return model

In [9]:
critic = define_generator(100)

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 6272)              633472    
                                                                 
 reshape (Reshape)           (None, 7, 7, 128)         0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 128)      262272    
 nspose)                                                         
                                                                 
 batch_normalization (BatchN  (None, 14, 14, 128)      512       
 ormalization)                                                   
                                                                 
 conv2d_transpose_1 (Conv2DT  (None, 28, 28, 128)      262272    
 ranspose)                                                       
                                                        

# Combined generator and critic model,

In [10]:
def define_gan(generator, critic):
	# make weights in the critic not trainable
	for layer in critic.layers:
		if not isinstance(layer, BatchNormalization):
			layer.trainable = False
	# connect them
	model = Sequential()
	# add generator
	model.add(generator)
	# add the critic
	model.add(critic)
	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

In [11]:
def load_real_samples():
  (trainX,trainy),(_,_)=load_data()
  select_ix=trainy==6
  X=trainX[select_ix]
  X=expand_dims(X,-1)
  X=X.astype('float')
  X=(X-127.5)/127.5

  return X


# Select real samples

In [12]:
def generate_real_samples(dataset,n_samples):
  ix=randint(0,dataset.shape[0],n_samples)
  X=dataset[ix]
  y=-np.ones((n_samples,1))

  return X,y


# generate points in latent space as input for the generator

In [13]:
def generate_latent_points(latent_dim,n_samples):
  x_input=randn(latent_dim*n_samples)
  x_input=x_input.reshape(n_samples,latent_dim)
  return x_input

# use the generator to generate n fake examples, with class labels

In [14]:
def generate_fake_samples(generator, latent_dim, n_samples):
  x_input=generate_latent_points(latent_dim,n_samples)
  X=generator.predict(x_input)
  y=np.ones((n_samples,1))

  return X,y


In [15]:
def summarize_performance(step, g_model, latent_dim, n_samples=100):

	X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
	X = (X + 1) / 2.0
	# plot images
	for i in range(10 * 10):
		pyplot.axis('off')
		pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
	# save plot to file
	filename1 = 'generated_plot_%04d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	filename2 = 'model_%04d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

In [16]:
def plot_history(d1_hist, d2_hist, g_hist):
	# plot history
	pyplot.plot(d1_hist, label='crit_real')
	pyplot.plot(d2_hist, label='crit_fake')
	pyplot.plot(g_hist, label='gen')
	pyplot.legend()
	pyplot.savefig('plot_line_plot_loss.png')
	pyplot.close()

In [17]:
def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=64, n_critic=5):
  bat_per_epo = int(dataset.shape[0] / n_batch)
  n_steps = bat_per_epo * n_epochs
  half_batch = int(n_batch / 2)
  c1_hist, c2_hist, g_hist = list(), list(), list()

  for i in range(n_steps):
    c1_tmp, c2_tmp = list(), list()
    
    for _ in range(n_critic):
      X_real, y_real = generate_real_samples(dataset, half_batch)
      c_loss1 = c_model.train_on_batch(X_real, y_real)
      c1_tmp.append(c_loss1)

      X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
      c_loss2 = c_model.train_on_batch(X_fake, y_fake)
      c2_tmp.append(c_loss2)

    c1_hist.append(mean(c1_tmp))
    c2_hist.append(mean(c2_tmp))

    X_gan = generate_latent_points(latent_dim, n_batch)
    y_gan = -ones((n_batch, 1))

    g_loss = gan_model.train_on_batch(X_gan, y_gan)
    g_hist.append(g_loss)

		
    print('>%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))

    if (i+1) % bat_per_epo == 0:
      summarize_performance(i, g_model, latent_dim)

    plot_history(c1_hist, c2_hist, g_hist)
		




In [18]:
latent_dim = 50
# create the critic
critic = define_critic()
# create the generator
generator = define_generator(latent_dim)
# create the gan
gan_model = define_gan(generator, critic)
# load image data
dataset = load_real_samples()
print(dataset.shape)
# train model
train(generator, critic, gan_model, dataset, latent_dim)

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_2 (Dense)             (None, 6272)              319872    
                                                                 
 reshape_1 (Reshape)         (None, 7, 7, 128)         0         
                                                                 
 conv2d_transpose_2 (Conv2DT  (None, 14, 14, 128)      262272    
 ranspose)                                                       
                                                                 
 batch_normalization_4 (Batc  (None, 14, 14, 128)      512       
 hNormalization)                                                 
                                                                 
 conv2d_transpose_3 (Conv2DT  (None, 28, 28, 128)      262272    
 ranspose)                                                       
                                                      



>Saved: generated_plot_0092.png and model_0092.h5
>93, c1=-83.007, c2=12.747 g=-10.492
>94, c1=-84.919, c2=12.333 g=-9.842
>95, c1=-84.097, c2=12.147 g=-11.870
>96, c1=-86.215, c2=11.983 g=-11.099
>97, c1=-85.831, c2=12.633 g=-10.590
>98, c1=-86.804, c2=13.630 g=-11.401
>99, c1=-87.170, c2=13.072 g=-12.428
>100, c1=-87.588, c2=14.414 g=-12.834
>101, c1=-88.012, c2=13.422 g=-12.558
>102, c1=-87.557, c2=12.743 g=-11.384
>103, c1=-88.689, c2=13.180 g=-13.483
>104, c1=-89.437, c2=14.549 g=-13.601
>105, c1=-89.739, c2=16.474 g=-12.104
>106, c1=-91.666, c2=17.829 g=-12.148
>107, c1=-91.083, c2=16.226 g=-12.947
>108, c1=-91.671, c2=15.088 g=-11.954
>109, c1=-92.056, c2=15.501 g=-13.765
>110, c1=-92.152, c2=16.405 g=-12.916
>111, c1=-92.228, c2=18.584 g=-14.479
>112, c1=-93.245, c2=16.405 g=-13.306
>113, c1=-94.215, c2=14.559 g=-14.902
>114, c1=-95.975, c2=17.832 g=-15.432
>115, c1=-93.959, c2=16.114 g=-14.095
>116, c1=-95.735, c2=16.324 g=-14.116
>117, c1=-96.147, c2=18.900 g=-13.663
>118, c1



>Saved: generated_plot_0184.png and model_0184.h5
>185, c1=-96.418, c2=67.932 g=-60.775
>186, c1=-98.259, c2=69.395 g=-60.626
>187, c1=-96.832, c2=67.866 g=-57.227
>188, c1=-93.858, c2=69.724 g=-57.396
>189, c1=-94.265, c2=68.496 g=-57.852
>190, c1=-96.486, c2=68.552 g=-58.496
>191, c1=-95.518, c2=68.545 g=-60.486
>192, c1=-94.184, c2=68.445 g=-60.112
>193, c1=-91.979, c2=69.465 g=-58.996
>194, c1=-93.039, c2=69.277 g=-62.410
>195, c1=-92.847, c2=68.365 g=-61.668
>196, c1=-94.096, c2=68.866 g=-61.030
>197, c1=-93.548, c2=71.001 g=-60.825
>198, c1=-93.004, c2=70.735 g=-59.618
>199, c1=-91.773, c2=69.812 g=-62.668
>200, c1=-93.025, c2=70.493 g=-61.400
>201, c1=-91.904, c2=71.702 g=-63.200
>202, c1=-94.151, c2=71.996 g=-62.078
>203, c1=-90.713, c2=72.704 g=-63.255
>204, c1=-92.887, c2=71.694 g=-63.515
>205, c1=-92.620, c2=72.467 g=-62.440
>206, c1=-90.043, c2=72.238 g=-62.056
>207, c1=-91.264, c2=71.223 g=-60.375
>208, c1=-91.839, c2=72.707 g=-60.349
>209, c1=-89.411, c2=72.502 g=-61.829




>Saved: generated_plot_0276.png and model_0276.h5
>277, c1=-82.319, c2=75.750 g=-68.582
>278, c1=-83.550, c2=75.892 g=-68.089
>279, c1=-83.713, c2=76.529 g=-68.989
>280, c1=-83.063, c2=77.237 g=-68.997
>281, c1=-82.938, c2=75.447 g=-68.329
>282, c1=-83.513, c2=75.804 g=-68.625
>283, c1=-81.499, c2=74.241 g=-66.627
>284, c1=-82.436, c2=75.650 g=-68.347
>285, c1=-83.025, c2=75.161 g=-66.005
>286, c1=-83.616, c2=75.232 g=-67.897
>287, c1=-82.400, c2=76.157 g=-68.126
>288, c1=-84.306, c2=76.629 g=-69.386
>289, c1=-85.239, c2=77.068 g=-69.558
>290, c1=-85.062, c2=77.525 g=-69.155
>291, c1=-85.688, c2=77.983 g=-70.694
>292, c1=-85.481, c2=78.250 g=-70.452
>293, c1=-86.583, c2=77.795 g=-69.423
>294, c1=-86.498, c2=77.936 g=-70.797
>295, c1=-86.916, c2=78.120 g=-71.550
>296, c1=-84.487, c2=77.961 g=-70.130
>297, c1=-83.525, c2=77.932 g=-70.480
>298, c1=-84.992, c2=78.778 g=-71.246
>299, c1=-84.637, c2=78.228 g=-71.331
>300, c1=-86.097, c2=77.590 g=-71.000
>301, c1=-85.525, c2=77.475 g=-70.811




>Saved: generated_plot_0368.png and model_0368.h5
>369, c1=-83.847, c2=75.323 g=-69.434
>370, c1=-82.929, c2=75.812 g=-68.271
>371, c1=-83.691, c2=74.717 g=-69.557
>372, c1=-84.154, c2=75.623 g=-69.606
>373, c1=-84.377, c2=76.114 g=-69.649
>374, c1=-83.439, c2=75.970 g=-69.416
>375, c1=-85.014, c2=75.284 g=-69.428
>376, c1=-84.206, c2=76.416 g=-68.384
>377, c1=-83.365, c2=76.206 g=-68.140
>378, c1=-84.059, c2=76.577 g=-70.317
>379, c1=-83.318, c2=75.730 g=-69.946
>380, c1=-84.621, c2=75.059 g=-68.794
>381, c1=-84.180, c2=75.997 g=-68.379
>382, c1=-84.731, c2=75.594 g=-68.942
>383, c1=-83.670, c2=74.435 g=-68.404
>384, c1=-84.120, c2=75.649 g=-70.497
>385, c1=-83.229, c2=74.990 g=-69.002
>386, c1=-84.526, c2=75.171 g=-67.570
>387, c1=-83.971, c2=75.860 g=-67.944
>388, c1=-84.364, c2=74.967 g=-67.818
>389, c1=-83.510, c2=75.337 g=-67.950
>390, c1=-84.387, c2=74.756 g=-67.053
>391, c1=-82.413, c2=74.231 g=-68.711
>392, c1=-83.943, c2=75.406 g=-66.457
>393, c1=-84.549, c2=74.469 g=-67.555




>Saved: generated_plot_0460.png and model_0460.h5
>461, c1=-84.709, c2=74.834 g=-68.913
>462, c1=-86.341, c2=74.493 g=-69.695
>463, c1=-86.461, c2=76.062 g=-67.446
>464, c1=-88.207, c2=75.612 g=-67.700
>465, c1=-85.935, c2=75.409 g=-68.624
>466, c1=-85.225, c2=74.566 g=-67.457
>467, c1=-86.182, c2=75.020 g=-68.970
>468, c1=-85.904, c2=75.442 g=-68.257
>469, c1=-86.840, c2=74.584 g=-68.613
>470, c1=-86.976, c2=76.709 g=-69.092
>471, c1=-85.811, c2=74.583 g=-66.504
>472, c1=-85.824, c2=74.297 g=-68.207
>473, c1=-84.097, c2=74.632 g=-68.245
>474, c1=-85.665, c2=74.953 g=-69.336
>475, c1=-85.645, c2=75.541 g=-69.556
>476, c1=-85.243, c2=76.094 g=-69.736
>477, c1=-87.353, c2=75.637 g=-69.429
>478, c1=-86.813, c2=76.143 g=-68.549
>479, c1=-86.072, c2=76.003 g=-70.292
>480, c1=-87.227, c2=75.674 g=-70.240
>481, c1=-86.864, c2=75.642 g=-67.262
>482, c1=-85.795, c2=74.812 g=-69.014
>483, c1=-84.580, c2=73.584 g=-67.888
>484, c1=-84.955, c2=72.283 g=-66.569
>485, c1=-86.272, c2=72.758 g=-64.990




>Saved: generated_plot_0552.png and model_0552.h5
>553, c1=-86.132, c2=73.182 g=-67.948
>554, c1=-87.319, c2=73.512 g=-67.473
>555, c1=-86.492, c2=74.101 g=-67.288
>556, c1=-86.775, c2=73.896 g=-65.689
>557, c1=-85.549, c2=72.810 g=-66.583
>558, c1=-86.607, c2=73.140 g=-66.827
>559, c1=-85.342, c2=73.678 g=-66.092
>560, c1=-85.917, c2=73.733 g=-67.410
>561, c1=-86.170, c2=73.259 g=-64.690
>562, c1=-86.692, c2=73.257 g=-63.663
>563, c1=-85.047, c2=72.053 g=-67.135
>564, c1=-86.494, c2=73.314 g=-65.286
>565, c1=-86.695, c2=73.420 g=-67.921
>566, c1=-87.262, c2=74.242 g=-65.845
>567, c1=-87.802, c2=73.646 g=-66.842
>568, c1=-87.043, c2=73.648 g=-67.727
>569, c1=-87.300, c2=74.100 g=-68.530
>570, c1=-88.499, c2=74.709 g=-67.447
>571, c1=-87.853, c2=74.978 g=-68.802
>572, c1=-89.216, c2=73.365 g=-66.261
>573, c1=-87.568, c2=73.943 g=-67.883
>574, c1=-89.193, c2=73.967 g=-66.930
>575, c1=-86.850, c2=74.888 g=-66.944
>576, c1=-88.298, c2=73.011 g=-65.178
>577, c1=-87.219, c2=74.559 g=-65.771




>Saved: generated_plot_0644.png and model_0644.h5
>645, c1=-86.200, c2=69.152 g=-60.262
>646, c1=-86.117, c2=70.292 g=-61.530
>647, c1=-85.950, c2=72.001 g=-64.914
>648, c1=-87.047, c2=71.534 g=-64.084
>649, c1=-86.287, c2=71.136 g=-64.937
>650, c1=-86.622, c2=70.994 g=-63.963
>651, c1=-85.483, c2=70.592 g=-62.685
>652, c1=-85.032, c2=71.422 g=-65.122
>653, c1=-86.407, c2=70.208 g=-64.708
>654, c1=-85.673, c2=71.087 g=-65.206
>655, c1=-87.031, c2=71.458 g=-63.871
>656, c1=-87.259, c2=71.749 g=-65.334
>657, c1=-87.640, c2=71.382 g=-64.147
>658, c1=-87.520, c2=71.907 g=-63.945
>659, c1=-86.573, c2=70.814 g=-65.020
>660, c1=-88.597, c2=72.504 g=-66.828
>661, c1=-87.657, c2=70.821 g=-62.879
>662, c1=-88.428, c2=72.556 g=-64.447
>663, c1=-87.436, c2=70.706 g=-64.413
>664, c1=-88.176, c2=71.545 g=-64.359
>665, c1=-87.582, c2=72.034 g=-64.377
>666, c1=-87.233, c2=72.409 g=-67.313
>667, c1=-86.832, c2=72.292 g=-66.750
>668, c1=-88.049, c2=73.034 g=-67.099
>669, c1=-87.594, c2=71.684 g=-64.623




>Saved: generated_plot_0736.png and model_0736.h5
>737, c1=-86.296, c2=69.922 g=-59.252
>738, c1=-88.453, c2=69.062 g=-60.532
>739, c1=-86.423, c2=69.618 g=-61.593
>740, c1=-89.184, c2=69.972 g=-58.554
>741, c1=-87.711, c2=68.966 g=-63.894
>742, c1=-87.762, c2=66.296 g=-61.317
>743, c1=-87.279, c2=68.467 g=-59.609
>744, c1=-86.276, c2=67.725 g=-62.211
>745, c1=-85.778, c2=68.292 g=-61.834
>746, c1=-87.570, c2=66.719 g=-59.526
>747, c1=-85.571, c2=66.728 g=-62.454
>748, c1=-86.804, c2=66.090 g=-61.032
>749, c1=-84.539, c2=69.116 g=-61.834
>750, c1=-84.802, c2=66.666 g=-59.634
>751, c1=-85.695, c2=66.779 g=-58.413
>752, c1=-88.428, c2=66.062 g=-58.858
>753, c1=-85.540, c2=66.414 g=-57.022
>754, c1=-87.458, c2=67.846 g=-59.839
>755, c1=-85.115, c2=66.525 g=-55.929
>756, c1=-84.445, c2=66.918 g=-57.610
>757, c1=-85.263, c2=66.264 g=-58.153
>758, c1=-85.017, c2=65.398 g=-56.063
>759, c1=-83.883, c2=66.637 g=-57.899
>760, c1=-87.036, c2=66.536 g=-61.654
>761, c1=-84.943, c2=67.404 g=-59.789




>Saved: generated_plot_0828.png and model_0828.h5
>829, c1=-88.683, c2=67.764 g=-59.457
>830, c1=-86.959, c2=69.013 g=-59.712
>831, c1=-87.044, c2=67.512 g=-59.296
>832, c1=-87.205, c2=66.363 g=-61.678
>833, c1=-88.914, c2=67.759 g=-60.424
>834, c1=-86.003, c2=66.130 g=-60.699
>835, c1=-90.351, c2=66.529 g=-58.817
>836, c1=-89.606, c2=66.763 g=-58.946
>837, c1=-88.647, c2=67.565 g=-58.235
>838, c1=-88.472, c2=69.063 g=-59.327
>839, c1=-89.038, c2=65.885 g=-58.319
>840, c1=-88.775, c2=66.359 g=-59.513
>841, c1=-90.507, c2=65.640 g=-63.141
>842, c1=-90.163, c2=65.936 g=-58.851
>843, c1=-90.456, c2=66.704 g=-58.829
>844, c1=-87.673, c2=67.777 g=-61.302
>845, c1=-88.887, c2=67.002 g=-55.796
>846, c1=-89.579, c2=65.933 g=-57.276
>847, c1=-88.105, c2=66.290 g=-56.628
>848, c1=-86.093, c2=66.833 g=-56.540
>849, c1=-88.828, c2=67.032 g=-58.182
>850, c1=-86.732, c2=65.163 g=-56.183
>851, c1=-86.079, c2=67.157 g=-59.899
>852, c1=-88.281, c2=65.388 g=-60.531
>853, c1=-87.618, c2=67.248 g=-58.686




>Saved: generated_plot_0920.png and model_0920.h5
