In [1]:
import keras
import numpy as np
import matplotlib.pyplot as plt

Using TensorFlow backend.


In [2]:
def define_gan(g_model, dis_model):
    model = keras.models.Sequential()
    dis_model.trainable = False
    model.add(g_model)
    model.add(dis_model)
    opt = keras.optimizers.adam(learning_rate= 0.0002,
                                beta_1= 0.5)
    
    model.compile(loss= 'binary_crossentropy',
                  optimizer= opt)
    return model

In [3]:
def define_generator(latent_dim):
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(units= 128 * 7 * 7,
                                 input_dim= latent_dim))
    model.add(keras.layers.Reshape((7, 7, 128)))
    
    model.add(keras.layers.Conv2DTranspose(filters= 128,
                                           kernel_size= (4,4),
                                           strides= (2,2),
                                           padding= 'same'))
    model.add(keras.layers.LeakyReLU(0.2))
    model.add(keras.layers.Conv2DTranspose(filters= 128,
                                           kernel_size= (4,4),
                                           strides= (2,2),
                                           padding= 'same'))
    model.add(keras.layers.LeakyReLU(0.2))
    model.add(keras.layers.Conv2D(filters= 1,
                                  kernel_size= (7,7),
                                  activation= 'sigmoid',
                                  padding= 'same'))
    return model

In [4]:
def define_discriminator(input_shape= (28,28,1)):
    model = keras.models.Sequential()
    model.add(keras.layers.Conv2D(filters= 64,
                                  strides= (2,2),
                                  kernel_size= (3, 3),
                                  padding= 'same',
                                  input_shape= input_shape))
    model.add(keras.layers.Dropout(0.4))
    model.add(keras.layers.LeakyReLU(0.2))
    model.add(keras.layers.Conv2D(filters= 64,
                                  strides= (2,2),
                                  kernel_size= (3, 3),
                                  padding= 'same',
                                  input_shape= input_shape))
    model.add(keras.layers.Dropout(0.4))
    model.add(keras.layers.LeakyReLU(0.2))
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(units= 1,
                                 activation= 'sigmoid'))
    opt = keras.optimizers.adam(learning_rate= 0.0002, beta_1= 0.5)
    model.compile(loss= 'binary_crossentropy', optimizer= opt)
    return model

In [5]:
def load_mnist_data():
    (X_train, y_train), (_, _) = keras.datasets.mnist.load_data()
    indices = y_train == 8
    X_train = X_train[indices]    
    X_train = np.expand_dims(X_train, axis= -1).astype('float32') / 255.0
    return X_train

In [6]:
def generate_real_samples(dataset, n_samples):
    ix = np.random.randint(0, dataset.shape[0], n_samples)
    X = dataset[ix]
    y = np.ones((n_samples, 1))
    return X, y

In [7]:
def generate_latent_points(latent_dim, n_samples):
    x_input = np.random.randn(latent_dim * n_samples)
    x_input = x_input.reshape((n_samples, latent_dim))
    return x_input

In [8]:
def generate_fake_samples(g_model, latent_dim, n_samples):
    X_input = generate_latent_points(latent_dim= latent_dim,
                               n_samples= n_samples)
    X = g_model.predict(X_input)
    y = np.zeros((n_samples, 1))
    return X, y

In [9]:
def summarize_model(epoch, g_model, d_model, latent_dim, dataset, n_samples= 100):
    X_real, y_real = generate_real_samples(dataset= dataset, n_samples= n_samples)
    X_fake, y_fake = generate_fake_samples(g_model= g_model,
                                           latent_dim= latent_dim,
                                           n_samples= n_samples)
    
    acc_real = d_model.evaluate(X_real, y_real, verbose= 0)
    acc_fake = d_model.evaluate(X_fake, y_fake, verbose= 0)
    print(f'Epoch: {epoch + 1}, Accuracy on real data: {acc_real}, Accuracy on generated data: {acc_fake}')
    save_plot(X_fake, epoch= epoch, n=10)
    model_name = f'./New/generator_model_{epoch + 1}.h5'
    g_model.save(model_name)

In [10]:
def save_plot(examples, epoch, n=5):
    for i in range(n * n):
        plt.subplot(n, n, 1+i)
        plt.axis('off')
        plt.imshow(examples[i, :, :, 0], cmap= 'gray')
    filename = f'./New/generated_plot_epoch{epoch + 1}.png'
    plt.savefig(filename)
    plt.close()

In [11]:
def train_gan(gan_model, g_model, d_model, dataset, latent_dim, epochs= 100, batch_size= 128):
    half_batch = int(batch_size / 2)
    batch_per_epoch = int(dataset.shape[0]/batch_size)
    for i in range(epochs):
        for j in range(batch_per_epoch):
            # Generating real and fake examples
            X_real, y_real = generate_real_samples(dataset= dataset, n_samples= half_batch)
            X_fake, y_fake = generate_fake_samples(g_model= g_model,
                                                   latent_dim= latent_dim,
                                                   n_samples= half_batch)
            # Stacking the training datas
            X, y = np.vstack((X_real, X_fake)), np.vstack((y_real, y_fake))
            # Training the discriminator mode
            d_loss = d_model.train_on_batch(X, y)
            
            # Generating image from latent space
            x_input = generate_latent_points(latent_dim= latent_dim,
                                             n_samples= batch_size)
            
            X_gan = generate_latent_points(latent_dim= latent_dim,
                                           n_samples= batch_size)
            
            y_gan = np.ones((batch_size, 1))
            
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            print(f'Epoch: {i + 1}, batch: {j}/{batch_per_epoch},dloss: {d_loss}, gloss: {g_loss}')
            
        # Saving the model every once in a while
        summarize_model(epoch= i,
                        g_model= g_model,
                        d_model= d_model,
                        dataset= dataset,
                        latent_dim= latent_dim)

In [12]:
latent_dim = 50
dataset = load_mnist_data()
g_model = define_generator(latent_dim= latent_dim)
d_model = define_discriminator()
gan_model = define_gan(g_model= g_model, dis_model= d_model)

# Training the GAN for MNIST!!
train_gan(gan_model= gan_model,
          g_model= g_model,
          d_model= d_model,
          dataset= dataset,
          latent_dim= latent_dim)

  'Discrepancy between trainable weights and collected trainable'


Epoch: 1, batch: 0/45,dloss: 0.6926295757293701, gloss: 0.6922855377197266


  'Discrepancy between trainable weights and collected trainable'


Epoch: 1, batch: 1/45,dloss: 0.6828494071960449, gloss: 0.7103743553161621
Epoch: 1, batch: 2/45,dloss: 0.673771858215332, gloss: 0.7182149887084961
Epoch: 1, batch: 3/45,dloss: 0.6695533394813538, gloss: 0.7322429418563843
Epoch: 1, batch: 4/45,dloss: 0.6620096564292908, gloss: 0.7466903924942017
Epoch: 1, batch: 5/45,dloss: 0.6532983183860779, gloss: 0.7609587907791138
Epoch: 1, batch: 6/45,dloss: 0.6450750827789307, gloss: 0.7778518199920654
Epoch: 1, batch: 7/45,dloss: 0.6403759717941284, gloss: 0.7905789613723755
Epoch: 1, batch: 8/45,dloss: 0.6308712363243103, gloss: 0.803397536277771
Epoch: 1, batch: 9/45,dloss: 0.6240161657333374, gloss: 0.8127363920211792
Epoch: 1, batch: 10/45,dloss: 0.6138575673103333, gloss: 0.8215730786323547
Epoch: 1, batch: 11/45,dloss: 0.6109062433242798, gloss: 0.8262985944747925
Epoch: 1, batch: 12/45,dloss: 0.6023463010787964, gloss: 0.8327080607414246
Epoch: 1, batch: 13/45,dloss: 0.6024649739265442, gloss: 0.8365345597267151
Epoch: 1, batch: 14/45,

Epoch: 3, batch: 17/45,dloss: 0.588718056678772, gloss: 0.785611629486084
Epoch: 3, batch: 18/45,dloss: 0.5911518335342407, gloss: 0.7781992554664612
Epoch: 3, batch: 19/45,dloss: 0.5948381423950195, gloss: 0.82975172996521
Epoch: 3, batch: 20/45,dloss: 0.567048192024231, gloss: 0.7949709892272949
Epoch: 3, batch: 21/45,dloss: 0.5813242197036743, gloss: 0.7879652976989746
Epoch: 3, batch: 22/45,dloss: 0.6071549654006958, gloss: 0.7955690622329712
Epoch: 3, batch: 23/45,dloss: 0.5795159339904785, gloss: 0.8136553764343262
Epoch: 3, batch: 24/45,dloss: 0.5957659482955933, gloss: 0.8213846683502197
Epoch: 3, batch: 25/45,dloss: 0.5884613394737244, gloss: 0.7865116000175476
Epoch: 3, batch: 26/45,dloss: 0.5807105302810669, gloss: 0.7700284123420715
Epoch: 3, batch: 27/45,dloss: 0.5755963325500488, gloss: 0.8022135496139526
Epoch: 3, batch: 28/45,dloss: 0.5682511329650879, gloss: 0.8049841523170471
Epoch: 3, batch: 29/45,dloss: 0.5879316329956055, gloss: 0.8196221590042114
Epoch: 3, batch: 

Epoch: 5, batch: 33/45,dloss: 0.6579248309135437, gloss: 0.795072078704834
Epoch: 5, batch: 34/45,dloss: 0.6475034952163696, gloss: 0.7872962951660156
Epoch: 5, batch: 35/45,dloss: 0.6412064433097839, gloss: 0.7727903127670288
Epoch: 5, batch: 36/45,dloss: 0.6629592180252075, gloss: 0.7595198154449463
Epoch: 5, batch: 37/45,dloss: 0.6523784399032593, gloss: 0.7736433744430542
Epoch: 5, batch: 38/45,dloss: 0.6491248607635498, gloss: 0.7646116018295288
Epoch: 5, batch: 39/45,dloss: 0.66050785779953, gloss: 0.7613397836685181
Epoch: 5, batch: 40/45,dloss: 0.6399863958358765, gloss: 0.7874056696891785
Epoch: 5, batch: 41/45,dloss: 0.6383528113365173, gloss: 0.7782950401306152
Epoch: 5, batch: 42/45,dloss: 0.6399199366569519, gloss: 0.7853345274925232
Epoch: 5, batch: 43/45,dloss: 0.6568453311920166, gloss: 0.7773042917251587
Epoch: 5, batch: 44/45,dloss: 0.6437255144119263, gloss: 0.7494907379150391
Epoch: 5, Accuracy on real data: 0.6495729470252991, Accuracy on generated data: 0.63159245

KeyboardInterrupt: 