# Generative Adversarial Network using Keras

In [1]:
from keras.datasets import mnist
from keras.layers import Conv2D,Dense,Dropout,Flatten,Input,Reshape,UpSampling2D
import matplotlib.pyplot as plt
import numpy as np
from keras.models import Model, Sequential
from keras.layers.advanced_activations import LeakyReLU
from tqdm import tqdm_notebook as tqdm,trange
from keras.optimizers import Adam
import time
import sys
import os

Using TensorFlow backend.


In [2]:
np.random.seed(1000)

In [3]:
(x_train,y_train),(x_test,y_test)=mnist.load_data()

## Loading input

In [4]:
input_dim_generator=100

In [5]:
x_train=(x_train.astype(np.float32)-127.5)/127.5
x_train=x_train.reshape(-1,28,28,1)

## Creating model

![title](images/GAN.png)

In [6]:
def create_model():
    adam = Adam(lr=0.0002, beta_1=0.5)

    #Generator network
    generator=Sequential()
    generator.add(Dense(128*7*7,input_shape=[input_dim_generator]))
    generator.add(LeakyReLU(0.2))
    generator.add(Reshape((7,7,128)))
    generator.add(UpSampling2D(size=(2,2)))
    generator.add(Conv2D(64, kernel_size=(5, 5), padding='same'))
    generator.add(LeakyReLU(0.2))
    generator.add(UpSampling2D(size=(2, 2)))
    generator.add(Conv2D(1, kernel_size=(5, 5), padding='same', activation='tanh'))
    generator.compile(loss="binary_crossentropy",optimizer='adam')
    generator.summary()
    
    #Descriminator Network
    discriminator=Sequential()
    discriminator.add(Conv2D(64,(5,5),input_shape=(28,28,1)))
    discriminator.add(LeakyReLU())
    discriminator.add(Conv2D(128,(5,5)))
    discriminator.add(LeakyReLU())
    discriminator.add(Dropout(0.3))
    discriminator.add(Flatten())
    discriminator.add(Dense(1,activation='sigmoid'))
    discriminator.compile(loss='binary_crossentropy',optimizer='adam')
    discriminator.summary()

    #Combining both networks
    gan_input=Input(shape=[input_dim_generator])
    middle_output=generator(gan_input)
    gan_output=discriminator(middle_output)
    gan_model=Model(inputs=gan_input,outputs=gan_output)
    gan_model.compile(loss="binary_crossentropy",optimizer=adam)
    
    gan_model.summary()
    return generator,discriminator,gan_model

In [7]:
def plot_loss(d_losses,g_losses):
    plt.figure(figsize=(10,8))
    plt.plot(d_losses,label="Discriminitve Loss")
    plt.plot(g_losses,label="Generative Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("generated_images/loss_epoch")

In [8]:
def plot_generated_images(generator,epoch_num):
    num_examples=100
    x_gen=np.random.normal(0,1,size=[num_examples,input_dim_generator])
    generated_images=generator.predict(x_gen)
    generated_images=generated_images.reshape(num_examples,28,28)
    
    plt.figure(figsize=(10,10))
    for i in range(generated_images.shape[0]):
        plt.subplot(10,10,i+1)
        plt.imshow(generated_images[i],interpolation='nearest',cmap='gray_r')
        plt.axis('off')

    plt.tight_layout()
    plt.savefig("generated_images/epoch_%d.png"%epoch_num)

In [9]:
def train(epochs=10,batch_size=128):
    d_losses=[]
    g_losses=[]
    batch_count = x_train.shape[0] // batch_size
    
    generator,discriminator,gan_model=create_model()
    
    for epoch_num in range(0,epochs+1):
#         print("-"*20,"Epoch ",epoch_num,"-"*20)
        for i in tqdm(range(batch_count),desc="Epoch %d"%epoch_num):
            noise_batch=np.random.normal(0,1,size=[batch_size,input_dim_generator])
            image_batch=x_train[np.random.randint(0,x_train.shape[0],size=batch_size)]
            generated_image_batch=generator.predict(noise_batch)
            
            x_dis=np.concatenate([image_batch,generated_image_batch])
            y_dis=np.zeros(2*batch_size)
            y_dis[:batch_size]=1
            
            discriminator.trainable=True
            d_loss=discriminator.train_on_batch(x_dis,y_dis)
            
            discriminator.trainable=False
            y_gen=np.ones(batch_size)
            g_loss=gan_model.train_on_batch(noise_batch,y_gen)
            
        d_losses.append(d_loss)
        g_losses.append(g_loss)
        if(epoch_num%20==0):
            plot_generated_images(generator,epoch_num)
            
    plot_loss(d_losses,g_losses)
            

In [10]:
train(100)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 6272)              633472    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 6272)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 64)        204864    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 28, 28, 64)        0         
__________

# Loss
![title](images/loss_epoch.png)