In [1]:
# imports
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow import keras
from keras.datasets import mnist

In [None]:
# prepration
(xtr ,ytr) ,(xte ,yte) =mnist.load_data()
xtr_p =(xtr /255).reshape(-1,784).astype(np.float32)
xte_p =(xte /255).reshape(-1,784).astype(np.float32)
ytr_p =ytr.reshape(-1,1)
yte_p =yte.reshape(-1,1)
plt.figure(num =1 ,figsize =(3,3) ,dpi =100)
for i in range(len(xtr[:10])) :
  plt.subplot(2 ,5 ,i+1)
  plt.imshow(xtr[i] ,cmap ='gray')
  plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# generator network
def create_generator() :
  gen =keras.Sequential()
  gen.add(keras.layers.Conv2D(filters =3 ,kernel_size =(2,2) ,input_shape =(28,28,1)))
  gen.add(keras.layers.BatchNormalization())
  gen.add(keras.layers.ReLU(max_value =1.0))
  gen.add(keras.layers.AvgPool2D(pool_size=(2,2) ,strides=(1,1)))
  gen.add(keras.layers.Conv2D(filters =6 ,kernel_size =(3,3)))
  gen.add(keras.layers.BatchNormalization())
  gen.add(keras.layers.ReLU(max_value =1.0))
  gen.add(keras.layers.MaxPool2D(pool_size=(2,2) ,strides=(1,1)))
  gen.add(keras.layers.Conv2D(filters =9 ,kernel_size =(4,4)))
  gen.add(keras.layers.BatchNormalization())
  gen.add(keras.layers.ReLU(max_value =1.0))
  gen.add(keras.layers.AvgPool2D(pool_size=(2,2) ,strides=(1,1)))
  gen.add(keras.layers.Conv2D(filters =12 ,kernel_size =(5,5)))
  gen.add(keras.layers.BatchNormalization())
  gen.add(keras.layers.ReLU(max_value =1.0))
  gen.add(keras.layers.MaxPool2D(pool_size=(2,2) ,strides=(1,1)))
  gen.add(keras.layers.Conv2D(filters =14 ,kernel_size =(7,6)))
  gen.add(keras.layers.BatchNormalization())
  gen.add(keras.layers.ReLU(max_value =1.0))
  gen.add(keras.layers.AvgPool2D(pool_size=(2,2) ,strides=(1,1)))
  gen.add(keras.layers.Flatten())
  gen.compile(optimizer =keras.optimizers.Adam(learning_rate=0.0002) ,loss =keras.losses.BinaryCrossentropy)
  return gen
generator =create_generator()
generator.summary()

In [None]:
# discriminator network
def create_discriminator() :
  dis =keras.Sequential()
  dis.add(keras.Input(shape =(784,)))
  dis.add(keras.layers.Dense(units =512))
  dis.add(keras.layers.BatchNormalization())
  dis.add(keras.layers.ReLU(max_value =1.0))
  dis.add(keras.layers.Dense(units =256))
  dis.add(keras.layers.BatchNormalization())
  dis.add(keras.layers.ReLU(max_value =1.0))
  dis.add(keras.layers.Dense(units =128))
  dis.add(keras.layers.BatchNormalization())
  dis.add(keras.layers.ReLU(max_value =1.0))
  dis.add(keras.layers.Dense(units =64))
  dis.add(keras.layers.BatchNormalization())
  dis.add(keras.layers.ReLU(max_value =1.0))
  dis.add(keras.layers.Dense(units =32))
  dis.add(keras.layers.BatchNormalization())
  dis.add(keras.layers.ReLU(max_value =1.0))
  dis.add(keras.layers.Dense(units =16))
  dis.add(keras.layers.BatchNormalization())
  dis.add(keras.layers.ReLU(max_value =1.0))
  dis.add(keras.layers.Dense(units =8))
  dis.add(keras.layers.BatchNormalization())
  dis.add(keras.layers.ReLU(max_value =1.0))
  dis.add(keras.layers.Dense(units =1))
  dis.add(keras.layers.BatchNormalization())
  dis.add(keras.layers.Activation(keras.activations.sigmoid))
  dis.compile(optimizer =keras.optimizers.Adam(learning_rate =0.0002) ,loss =keras.losses.BinaryCrossentropy)
  return dis
discriminator =create_discriminator()
discriminator.summary()

In [None]:
# gan network
def create_gan(generator ,discriminator) :
  gan_input =keras.Input(shape =(28,28,1))
  gen_output =generator(gan_input)
  gan_output =discriminator(gen_output)
  discriminator.trainable =False
  gan =keras.Model(inputs =gan_input ,outputs =gan_output)
  gan.compile(optimizer =keras.optimizers.Adam(learning_rate =0.0005) ,loss =keras.losses.BinaryCrossentropy)
  return gan
gan =create_gan(generator ,discriminator)
gan.summary()

In [None]:
# save_plot function
def save_plot(EPOCH ,EXM ,generator) : 
    noise_tensor =np.random.normal(loc =0.0 ,scale =1.0 ,size =[EXM ,28,28,1])
    generated_img =generator.predict(noise_tensor).reshape(EXM ,28,28,1)
    plt.figure(num =2 ,figsize =(5,5) ,dpi =100)
    for i in range(generated_img.shape[0]) : 
      plt.subplot(10 ,10 ,i+1) 
      plt.imshow(generated_img[i] ,interpolation='nearest' ,cmap ='gray')
      plt.axis('off') 
    plt.tight_layout()
    plt.savefig('gan_generated_img_%d.png' %EPOCH)

In [None]:
# model training
BATCH =60
EPOCH =100
BATCH_COUNT =int(xtr.shape[0] /BATCH)
for e in range(1 ,EPOCH +1) :
  for b in range(BATCH_COUNT) :
    noise_tensor =np.random.normal(loc =0 ,scale =1 ,size =[BATCH ,28,28,1])
    xtr_fake_batch =generator.predict(noise_tensor)
    xtr_real_batch =xtr_p[np.random.randint(low =0 ,high =xtr.shape[0] ,size =BATCH)]
    Xtr =np.concatenate((xtr_real_batch ,xtr_fake_batch))
    y_dis =np.zeros((2*BATCH ,1))
    y_dis[:BATCH,0] =1.0
    discriminator.trainable =True
    discriminator.train_on_batch(Xtr ,y_dis)
    discriminator.trainable =False
    y_gan =np.ones((BATCH ,1))
    gan.train_on_batch(noise_tensor ,y_gan)
    # model evaluation
    EXM =100
    save_plot(e ,EXM ,generator)