In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import keras
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.datasets import mnist
from tqdm import tqdm
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import adam

Using TensorFlow backend.


In [2]:
def load_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    
    # convert shape of x_train from (60000, 28, 28) to (60000, 784) 
    # 784 columns per row
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train, x_test, y_test)
(X_train, y_train,X_test, y_test)=load_data()
print(X_train.shape)

(60000, 784)


In [3]:
def adam_optimizer():
    return adam(lr=0.0002, beta_1=0.5)

In [4]:
def create_generator():
    generator=Sequential()
    generator.add(Dense(units=256,input_dim=100))
    generator.add(LeakyReLU(0.2))
    
    generator.add(Dense(units=512))
    generator.add(LeakyReLU(0.2))
    
    generator.add(Dense(units=1024))
    generator.add(LeakyReLU(0.2))
    
    generator.add(Dense(units=784, activation='tanh'))
    
    generator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
    return generator
g=create_generator()
g.summary()

Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 256)               25856     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1024)              525312    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
dens

In [5]:
def create_discriminator():
    discriminator=Sequential()
    discriminator.add(Dense(units=1024,input_dim=784))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
       
    
    discriminator.add(Dense(units=512))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
       
    discriminator.add(Dense(units=256))
    discriminator.add(LeakyReLU(0.2))
    
    discriminator.add(Dense(units=1, activation='sigmoid'))
    
    discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
    return discriminator
d =create_discriminator()
d.summary()

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_5 (Dense)              (None, 1024)              803840    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_6 (Dense)              (None, 512)               524800    
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 512)               0         
___________________________

In [9]:
def create_gan(discriminator, generator):
    discriminator.trainable=False
    gan_input = Input(shape=(100,))
    x = generator(gan_input)
    gan_output= discriminator(x)
    gan= Model(inputs=gan_input, outputs=gan_output)
    gan.compile(loss='binary_crossentropy', optimizer='adam')
    return gan
gan = create_gan(d,g)
gan.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 100)               0         
_________________________________________________________________
sequential_1 (Sequential)    (None, 784)               1486352   
_________________________________________________________________
sequential_2 (Sequential)    (None, 1)                 1460225   
Total params: 2,946,577
Trainable params: 1,486,352
Non-trainable params: 1,460,225
_________________________________________________________________


In [10]:
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):
    noise= np.random.normal(loc=0, scale=1, size=[examples, 100])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(100,28,28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('gan_generated_image %d.png' %epoch)

In [None]:
def training(epochs=1, batch_size=128):
    #Loading the data
    (X_train, y_train, X_test, y_test) = load_data()
    batch_count = X_train.shape[0] / batch_size
    
    # create gan 
    generator = create_generator()
    discriminator = create_discriminator()
    gan = create_gan(discriminator, generator)
    
    for epoch in range(1, epochs+1):
        print("epoch {}".format(epoch))
        
        for _ in tqdm(range(batch_size)):
            
            # generate noise input of the generator
            noise = np.random.normal(0,1, [batch_size, 100])
            
            # generate fake images from noise
            generated_images = generator.predict(noise)
            
            # create a batch of data containing real and fake images for MNIST
            image_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]
            
            X = np.concatenate([image_batch, generated_images])
            
            y_dis=np.zeros(2*batch_size)
            y_dis[:batch_size]=0.9
            
            #Pre train discriminator on  fake and real data
            discriminator.trainable=True
            discriminator.train_on_batch(X, y_dis)
            
            #Tricking the noised input of the Generator as real data
            noise= np.random.normal(0,1, [batch_size, 100])
            y_gen = np.ones(batch_size)
            
             # During the training of gan, 
            # the weights of discriminator should be fixed. 
            #We can enforce that by setting the trainable flag
            discriminator.trainable=False
            
            #training  the GAN by alternating the training of the Discriminator 
            #and training the chained GAN model with Discriminator’s weights freezed.
            gan.train_on_batch(noise, y_gen)
            
            if epoch == 1 or epoch % 20 == 0:
                plot_generated_images(epoch, generator)
                
training(200,128)

  0%|          | 0/128 [00:00<?, ?it/s]

epoch 1
Instructions for updating:
Use tf.cast instead.


  """
100%|██████████| 128/128 [03:58<00:00,  1.52s/it]
  2%|▏         | 2/128 [00:00<00:07, 15.93it/s]

epoch 2


100%|██████████| 128/128 [00:08<00:00, 15.95it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.09it/s]

epoch 3


100%|██████████| 128/128 [00:08<00:00, 13.90it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.19it/s]

epoch 4


100%|██████████| 128/128 [00:09<00:00, 14.38it/s]
  2%|▏         | 2/128 [00:00<00:09, 13.62it/s]

epoch 5


100%|██████████| 128/128 [00:08<00:00, 15.12it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.62it/s]

epoch 6


100%|██████████| 128/128 [00:08<00:00, 15.31it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.05it/s]

epoch 7


100%|██████████| 128/128 [00:08<00:00, 14.95it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.09it/s]

epoch 8


100%|██████████| 128/128 [00:08<00:00, 15.22it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.62it/s]

epoch 9


100%|██████████| 128/128 [00:09<00:00, 13.30it/s]
  2%|▏         | 2/128 [00:00<00:09, 13.28it/s]

epoch 10


100%|██████████| 128/128 [00:09<00:00, 13.52it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.60it/s]

epoch 11


100%|██████████| 128/128 [00:08<00:00, 13.42it/s]
  2%|▏         | 2/128 [00:00<00:09, 13.85it/s]

epoch 12


100%|██████████| 128/128 [00:11<00:00, 11.64it/s]
  2%|▏         | 2/128 [00:00<00:09, 12.89it/s]

epoch 13


100%|██████████| 128/128 [00:10<00:00, 12.82it/s]
  2%|▏         | 2/128 [00:00<00:10, 12.30it/s]

epoch 14


100%|██████████| 128/128 [00:10<00:00, 12.25it/s]
  2%|▏         | 2/128 [00:00<00:10, 11.92it/s]

epoch 15


100%|██████████| 128/128 [00:09<00:00, 14.86it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.17it/s]

epoch 16


100%|██████████| 128/128 [00:09<00:00, 13.26it/s]
  2%|▏         | 2/128 [00:00<00:09, 13.17it/s]

epoch 17


100%|██████████| 128/128 [00:09<00:00, 13.94it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.23it/s]

epoch 18


100%|██████████| 128/128 [00:09<00:00, 13.45it/s]
  2%|▏         | 2/128 [00:00<00:09, 13.33it/s]

epoch 19


100%|██████████| 128/128 [00:09<00:00, 14.84it/s]
  0%|          | 0/128 [00:00<?, ?it/s]

epoch 20


100%|██████████| 128/128 [04:25<00:00,  1.69s/it]
  2%|▏         | 2/128 [00:00<00:07, 16.51it/s]

epoch 21


100%|██████████| 128/128 [00:08<00:00, 16.41it/s]
  2%|▏         | 2/128 [00:00<00:07, 16.64it/s]

epoch 22


100%|██████████| 128/128 [00:08<00:00, 14.68it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.53it/s]

epoch 23


100%|██████████| 128/128 [00:09<00:00, 15.15it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.48it/s]

epoch 24


100%|██████████| 128/128 [00:08<00:00, 12.49it/s]
  2%|▏         | 2/128 [00:00<00:12, 10.48it/s]

epoch 25


100%|██████████| 128/128 [00:10<00:00, 12.67it/s]
  2%|▏         | 2/128 [00:00<00:10, 12.22it/s]

epoch 26


100%|██████████| 128/128 [00:09<00:00, 14.73it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.91it/s]

epoch 27


100%|██████████| 128/128 [00:09<00:00, 13.69it/s]
  2%|▏         | 2/128 [00:00<00:09, 13.70it/s]

epoch 28


100%|██████████| 128/128 [00:11<00:00, 13.13it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.92it/s]

epoch 29


100%|██████████| 128/128 [00:09<00:00, 13.98it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.73it/s]

epoch 30


100%|██████████| 128/128 [00:09<00:00, 15.57it/s]
  2%|▏         | 2/128 [00:00<00:07, 16.11it/s]

epoch 31


100%|██████████| 128/128 [00:09<00:00, 11.08it/s]
  2%|▏         | 2/128 [00:00<00:12, 10.16it/s]

epoch 32


100%|██████████| 128/128 [00:10<00:00, 12.42it/s]
  2%|▏         | 2/128 [00:00<00:10, 11.88it/s]

epoch 33


100%|██████████| 128/128 [00:10<00:00, 13.63it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.57it/s]

epoch 34


100%|██████████| 128/128 [00:09<00:00, 13.76it/s]
  2%|▏         | 2/128 [00:00<00:10, 12.56it/s]

epoch 35


100%|██████████| 128/128 [00:10<00:00, 13.26it/s]
  2%|▏         | 2/128 [00:00<00:09, 12.93it/s]

epoch 36


100%|██████████| 128/128 [00:09<00:00, 16.07it/s]
  2%|▏         | 2/128 [00:00<00:07, 16.03it/s]

epoch 37


100%|██████████| 128/128 [00:08<00:00, 13.70it/s]
  2%|▏         | 2/128 [00:00<00:09, 13.51it/s]

epoch 38


100%|██████████| 128/128 [00:08<00:00, 13.25it/s]
  2%|▏         | 2/128 [00:00<00:10, 11.90it/s]

epoch 39


100%|██████████| 128/128 [00:08<00:00, 14.35it/s]
  0%|          | 0/128 [00:00<?, ?it/s]

epoch 40


100%|██████████| 128/128 [04:41<00:00,  1.67s/it]
  2%|▏         | 2/128 [00:00<00:10, 12.17it/s]

epoch 41


100%|██████████| 128/128 [00:09<00:00, 13.68it/s]
  2%|▏         | 2/128 [00:00<00:09, 13.00it/s]

epoch 42


100%|██████████| 128/128 [00:08<00:00, 14.69it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.98it/s]

epoch 43


100%|██████████| 128/128 [00:08<00:00, 15.40it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.32it/s]

epoch 44


100%|██████████| 128/128 [00:08<00:00, 15.63it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.43it/s]

epoch 45


100%|██████████| 128/128 [00:08<00:00, 15.45it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.47it/s]

epoch 46


100%|██████████| 128/128 [00:08<00:00, 15.44it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.39it/s]

epoch 47


100%|██████████| 128/128 [00:08<00:00, 15.15it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.62it/s]

epoch 48


100%|██████████| 128/128 [00:08<00:00, 15.74it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.55it/s]

epoch 49


100%|██████████| 128/128 [00:08<00:00, 15.41it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.12it/s]

epoch 50


100%|██████████| 128/128 [00:08<00:00, 15.54it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.56it/s]

epoch 51


100%|██████████| 128/128 [00:08<00:00, 15.34it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.08it/s]

epoch 52


100%|██████████| 128/128 [00:08<00:00, 15.29it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.44it/s]

epoch 53


100%|██████████| 128/128 [00:08<00:00, 15.22it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.03it/s]

epoch 54


100%|██████████| 128/128 [00:08<00:00, 15.53it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.99it/s]

epoch 55


100%|██████████| 128/128 [00:08<00:00, 15.60it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.96it/s]

epoch 56


100%|██████████| 128/128 [00:08<00:00, 15.45it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.79it/s]

epoch 57


100%|██████████| 128/128 [00:08<00:00, 15.55it/s]
  2%|▏         | 2/128 [00:00<00:08, 15.00it/s]

epoch 58


100%|██████████| 128/128 [00:08<00:00, 15.57it/s]
  2%|▏         | 2/128 [00:00<00:08, 14.72it/s]

epoch 59


100%|██████████| 128/128 [00:08<00:00, 15.47it/s]
  0%|          | 0/128 [00:00<?, ?it/s]

epoch 60


100%|██████████| 128/128 [08:27<00:00, 12.87s/it] 
  0%|          | 0/128 [00:00<?, ?it/s]

epoch 61


 30%|██▉       | 38/128 [01:09<05:35,  3.72s/it]