In [1]:
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Input, Reshape
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras import utils

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tensorflow.keras.datasets import mnist

In [2]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)

In [3]:
# build generator

def build_generator():
    
    noise_shape = (100,)   #1D array of size 100 (latent vector/ noise)
    alpha = 0.2
    momentum = 0.8
    
    model = Sequential()
    
    model.add(Dense(256, input_shape=noise_shape))
    model.add(LeakyReLU(alpha = alpha))
    model.add(BatchNormalization(momentum = momentum))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha = alpha))
    model.add(BatchNormalization(momentum = momentum))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha = alpha))
    model.add(BatchNormalization(momentum = momentum))
    
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))
    
    noise = Input(shape = noise_shape)
    img = model(noise)
    
    return Model(noise, img)
    

In [4]:
# Build discriminator

def build_discriminator():
    
    alpha = 0.2
    model = Sequential()
    
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha = alpha))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha = alpha))
    model.add(Dense(1, activation = 'sigmoid'))
    model.summary()
    
    img = Input(shape=img_shape)
    validity = model(img)
    
    return Model(img,validity)       # validity is whether image is real or fake
    


In [5]:
# training part

def train(epochs, batch_size = 128, save_interval = 500):
    
    (X_train, _), (_, _) = mnist.load_data()
    
    X_train = (X_train.astype(np.float32)) / 255
    
    X_train = np.expand_dims(X_train, axis = 3)   # adding 3rd axis so its 28 x 28 x 1
    
    half_batch = int(batch_size/2)                # mixing half real and half fake
    
    for epoch in range(epochs):
        
        # Select sample of real images
        idx = np.random.randint(0, X_train.shape[0], half_batch)
        imgs = X_train[idx]
        
        noise = np.random.normal(0,1,(half_batch, 100))      # 64 rows and 100 columns
        
        # Generate images
        gen_imgs = generator.predict(noise)                 # generator produces fake images
        
        
        # Train the discriminator on real and fake images, seperately
        # Research shows seperate training is more effective
        d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
        
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # Train generator within epoch
        
        noise = np.random.normal(0,1,(half_batch, 100))
        
        
        valid_y = np.array([1] * half_batch)              # Creates an array of 1s in column format, size = batch_size
        
        # We fool the discriminator by labelling fakes as real
        g_loss = combined.train_on_batch(noise, valid_y)
        
        
        print('%d [D loss: %f, acc.:%.2f%%] [G loss:%f]' %(epoch, d_loss[0], 100*d_loss[1], g_loss))
        
        if epoch % save_interval == 0:
            save_imgs(epoch)

def save_imgs(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r*c, 100))
    gen_imgs = generator.predict(noise)
    
    # rescale images 0-1
    gen_imgs = 0.5 * gen_imgs + 0.5
    
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap = 'gray')
            axs[i,j].axis('off')
            cnt+=1
    fig.savefig("./mnistimages/mnist_%d.png" % epoch)
    plt.close()



In [7]:
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer = 'adam', metrics=['accuracy'])


generator = build_generator()
generator.compile(loss='binary_crossentropy', optimizer = 'adam')

z = Input(shape=(100,))
img = generator(z)

# Freeze discrimnator while we train generator
discriminator.trainable = False

valid = discriminator(img)

combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer = 'adam')

train(epochs =1000, batch_size =32, save_interval=10)

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_7 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_8 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
________________________________________________

131 [D loss: 0.064826, acc.:96.88%] [G loss:23.188292]
132 [D loss: 0.071553, acc.:96.88%] [G loss:22.458523]
133 [D loss: 0.069980, acc.:96.88%] [G loss:21.644001]
134 [D loss: 0.335399, acc.:96.88%] [G loss:19.932758]
135 [D loss: 0.013467, acc.:100.00%] [G loss:18.088310]
136 [D loss: 0.027143, acc.:100.00%] [G loss:15.794668]
137 [D loss: 0.054897, acc.:96.88%] [G loss:14.218191]
138 [D loss: 0.006074, acc.:100.00%] [G loss:16.210281]
139 [D loss: 0.004360, acc.:100.00%] [G loss:15.932207]
140 [D loss: 0.009859, acc.:100.00%] [G loss:15.223002]
141 [D loss: 0.298764, acc.:96.88%] [G loss:15.594488]
142 [D loss: 0.004166, acc.:100.00%] [G loss:15.759548]
143 [D loss: 0.194297, acc.:96.88%] [G loss:18.913671]
144 [D loss: 0.051100, acc.:100.00%] [G loss:19.781002]
145 [D loss: 0.094354, acc.:96.88%] [G loss:20.810982]
146 [D loss: 0.107867, acc.:93.75%] [G loss:19.709187]
147 [D loss: 0.024908, acc.:100.00%] [G loss:21.935736]
148 [D loss: 0.037826, acc.:100.00%] [G loss:18.790073]
1

281 [D loss: 0.031254, acc.:100.00%] [G loss:30.280186]
282 [D loss: 0.205404, acc.:90.62%] [G loss:28.080942]
283 [D loss: 0.168057, acc.:93.75%] [G loss:26.076992]
284 [D loss: 0.037027, acc.:100.00%] [G loss:25.783264]
285 [D loss: 0.053829, acc.:96.88%] [G loss:26.027897]
286 [D loss: 0.028124, acc.:100.00%] [G loss:25.917963]
287 [D loss: 0.172474, acc.:93.75%] [G loss:25.766357]
288 [D loss: 0.039629, acc.:100.00%] [G loss:30.464237]
289 [D loss: 0.099055, acc.:96.88%] [G loss:34.235023]
290 [D loss: 0.058133, acc.:96.88%] [G loss:28.138783]
291 [D loss: 0.046382, acc.:100.00%] [G loss:26.955765]
292 [D loss: 0.033044, acc.:100.00%] [G loss:21.475203]
293 [D loss: 0.119962, acc.:96.88%] [G loss:20.085720]
294 [D loss: 0.012586, acc.:100.00%] [G loss:19.257965]
295 [D loss: 0.028164, acc.:100.00%] [G loss:21.242985]
296 [D loss: 0.018720, acc.:100.00%] [G loss:20.036598]
297 [D loss: 0.132663, acc.:96.88%] [G loss:24.886215]
298 [D loss: 0.069787, acc.:96.88%] [G loss:23.642035]
2

429 [D loss: 0.008436, acc.:100.00%] [G loss:31.081802]
430 [D loss: 0.005768, acc.:100.00%] [G loss:29.469057]
431 [D loss: 0.006571, acc.:100.00%] [G loss:23.225374]
432 [D loss: 0.012885, acc.:100.00%] [G loss:22.569136]
433 [D loss: 0.011073, acc.:100.00%] [G loss:24.133787]
434 [D loss: 0.041592, acc.:96.88%] [G loss:26.657228]
435 [D loss: 0.010650, acc.:100.00%] [G loss:24.277203]
436 [D loss: 0.009292, acc.:100.00%] [G loss:24.436632]
437 [D loss: 0.010393, acc.:100.00%] [G loss:23.005756]
438 [D loss: 0.014595, acc.:100.00%] [G loss:24.726898]
439 [D loss: 0.191681, acc.:90.62%] [G loss:19.977100]
440 [D loss: 0.069081, acc.:96.88%] [G loss:25.001085]
441 [D loss: 0.063752, acc.:96.88%] [G loss:31.255970]
442 [D loss: 0.066773, acc.:96.88%] [G loss:34.374916]
443 [D loss: 0.066937, acc.:96.88%] [G loss:29.800293]
444 [D loss: 0.134677, acc.:93.75%] [G loss:31.049273]
445 [D loss: 0.007963, acc.:100.00%] [G loss:32.697762]
446 [D loss: 0.009094, acc.:100.00%] [G loss:29.470936]

577 [D loss: 0.047424, acc.:100.00%] [G loss:44.542892]
578 [D loss: 0.150201, acc.:93.75%] [G loss:33.907349]
579 [D loss: 0.005238, acc.:100.00%] [G loss:30.854643]
580 [D loss: 0.005931, acc.:100.00%] [G loss:28.174984]
581 [D loss: 0.084802, acc.:93.75%] [G loss:23.330893]
582 [D loss: 0.002381, acc.:100.00%] [G loss:29.644543]
583 [D loss: 0.047973, acc.:96.88%] [G loss:34.824352]
584 [D loss: 0.009855, acc.:100.00%] [G loss:44.967491]
585 [D loss: 0.062949, acc.:96.88%] [G loss:41.918243]
586 [D loss: 0.054194, acc.:100.00%] [G loss:40.060101]
587 [D loss: 0.025499, acc.:100.00%] [G loss:40.245472]
588 [D loss: 0.006000, acc.:100.00%] [G loss:31.951349]
589 [D loss: 0.005434, acc.:100.00%] [G loss:27.515993]
590 [D loss: 0.002439, acc.:100.00%] [G loss:25.945278]
591 [D loss: 0.001489, acc.:100.00%] [G loss:22.729153]
592 [D loss: 0.002695, acc.:100.00%] [G loss:18.370872]
593 [D loss: 0.047662, acc.:96.88%] [G loss:22.847160]
594 [D loss: 0.002934, acc.:100.00%] [G loss:27.43660

725 [D loss: 0.001677, acc.:100.00%] [G loss:40.816788]
726 [D loss: 0.002770, acc.:100.00%] [G loss:45.471779]
727 [D loss: 0.056313, acc.:96.88%] [G loss:49.748123]
728 [D loss: 0.022254, acc.:100.00%] [G loss:47.025314]
729 [D loss: 0.012913, acc.:100.00%] [G loss:47.669914]
730 [D loss: 0.017491, acc.:100.00%] [G loss:46.197639]
731 [D loss: 0.004955, acc.:100.00%] [G loss:46.743187]
732 [D loss: 0.004093, acc.:100.00%] [G loss:48.048183]
733 [D loss: 0.006486, acc.:100.00%] [G loss:44.191292]
734 [D loss: 0.004343, acc.:100.00%] [G loss:40.070103]
735 [D loss: 0.002052, acc.:100.00%] [G loss:37.029579]
736 [D loss: 0.004790, acc.:100.00%] [G loss:34.994434]
737 [D loss: 0.002493, acc.:100.00%] [G loss:27.951366]
738 [D loss: 0.028516, acc.:96.88%] [G loss:34.897148]
739 [D loss: 0.002350, acc.:100.00%] [G loss:37.089630]
740 [D loss: 0.002533, acc.:100.00%] [G loss:32.974701]
741 [D loss: 0.003981, acc.:100.00%] [G loss:37.026169]
742 [D loss: 0.007888, acc.:100.00%] [G loss:38.40

875 [D loss: 0.015406, acc.:100.00%] [G loss:37.065601]
876 [D loss: 0.010111, acc.:100.00%] [G loss:38.295582]
877 [D loss: 0.118205, acc.:96.88%] [G loss:42.822800]
878 [D loss: 0.009485, acc.:100.00%] [G loss:47.810284]
879 [D loss: 0.033239, acc.:100.00%] [G loss:47.751495]
880 [D loss: 0.038728, acc.:100.00%] [G loss:46.278282]
881 [D loss: 0.010499, acc.:100.00%] [G loss:42.643433]
882 [D loss: 0.005152, acc.:100.00%] [G loss:35.951401]
883 [D loss: 0.002703, acc.:100.00%] [G loss:33.332527]
884 [D loss: 0.004001, acc.:100.00%] [G loss:34.345425]
885 [D loss: 0.131378, acc.:96.88%] [G loss:32.398575]
886 [D loss: 0.001354, acc.:100.00%] [G loss:35.516247]
887 [D loss: 0.005246, acc.:100.00%] [G loss:43.820618]
888 [D loss: 0.054378, acc.:96.88%] [G loss:39.383774]
889 [D loss: 0.009274, acc.:100.00%] [G loss:38.394859]
890 [D loss: 0.015631, acc.:100.00%] [G loss:32.083595]
891 [D loss: 0.016231, acc.:100.00%] [G loss:34.650970]
892 [D loss: 0.022136, acc.:100.00%] [G loss:29.220

In [None]:
generator.save('generator_model_test.1')