In [1]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [2]:
%cd /content/gdrive/My Drive/Colab Notebooks
!ls

/content/gdrive/My Drive/Colab Notebooks
clr_callback.py			  DCGAN_MNIST.ipynb  mnist.ipynb
CyclicalLearningRate_MNIST.ipynb  gan		     __pycache__


In [3]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.utils import to_categorical, plot_model
from keras.layers import Activation,Input, Dense, Conv2D, MaxPooling2D, Flatten, Dropout, Conv2DTranspose, Reshape, BatchNormalization, LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
from keras import backend as K
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import os
import argparse

Using TensorFlow backend.


In [0]:
def plot_images(generator,
                noise_input,
                show=False,
                step=0,
                model_name="gan"):
    """Generate fake images and plot them
    For visualization purposes, generate fake images
    then plot them in a square grid
    # Arguments
        generator (Model): The Generator Model for fake images generation
        noise_input (ndarray): Array of z-vectors
        show (bool): Whether to show plot or not
        step (int): Appended to filename of the save images
        model_name (string): Model name
    """
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, "%05d.png" % step)
    images = generator.predict(noise_input)
    plt.figure(figsize=(28, 28))
    num_images = images.shape[0]
    image_size = images.shape[1]
    rows = int(math.sqrt(noise_input.shape[0]))
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')

In [0]:
def build_generator(inputs, image_size):
    kernel_size = 5
    kernel_filters = [128, 64, 32, 1]
    image_resize = image_size//4
    
    x = Dense(image_resize * image_resize * kernel_filters[0])(inputs)
    x = Reshape((image_resize, image_resize, kernel_filters[0]))(x)
    
    for filters in kernel_filters:
        if filters > kernel_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters, kernel_size, strides=strides, padding='same')(x)
    x = Activation('sigmoid')(x)
    generator = Model(inputs=inputs, outputs=x)
    return generator

In [0]:
def build_discriminator(inputs):
    x = inputs
    kernel_size = 5
    kernel_filters = [32, 64, 128, 256]
    for filters in kernel_filters:
        if filters == kernel_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = LeakyReLU(alpha = 0.2)(x)
        x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = Flatten()(x)
    x = Dense(1)(x)
    x = Activation('sigmoid')(x)
    
    discriminator = Model(inputs=inputs, outputs=x)
    return discriminator

In [0]:
def build_and_train_model(): 
    (x_train, _), (x_test, _) = mnist.load_data()
    image_size = x_train.shape[1]
    x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1)
    x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[2], 1)

    x_train = x_train/255.
    x_test = x_test/255.
    
    latent_size = 100
    batch_size = 100
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    
    
    # build discriminator
    input_shape = (image_size, image_size, 1)
    inputs = Input(shape=input_shape)
    discriminator = build_discriminator(inputs)
    optimizer = RMSprop(lr = lr, decay = decay)
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    #discriminator.summary()
    
    # build generator
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape)
    generator = build_generator(inputs, image_size)
    #generator.summary()
    
    # build adverserial model
    optimizer = RMSprop(lr = lr*0.5, decay = decay*0.5)
    discriminator.trainable = False
    adversarial = Model(inputs=inputs, outputs=discriminator(generator(inputs)))
    adversarial.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    #adversarial.summary()
    
    # train discriminator and adversarial networks
    models = (generator, discriminator, adversarial)
    params = (batch_size, latent_size, train_steps)
    train(models, x_train, params)

In [0]:
def train(models, x_train, params):
    generator, discriminator, adversarial = models
    batch_size, latent_size, train_steps = params
    save_interval = 500
    noise_input = np.random.uniform(-1.0, 1.0, size=[16, latent_size])
    train_size = x_train.shape[0]
    
    for i in range(train_steps):
        rand_indexes = np.random.randint(0, train_size, size=batch_size)
        real_images = x_train[rand_indexes]
        noises = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
        fake_images = generator.predict(noises)
        x = np.concatenate((real_images, fake_images))
        y = np.ones([2*batch_size, 1])
        y[batch_size:,:] = 0.0
        
        loss, acc = discriminator.train_on_batch(x,y)
        log = "Epochs %d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)
        
        noises = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
        y = np.ones([batch_size, 1])
        
        loss, acc = adversarial.train_on_batch(noises,y)
        log = "%s [adversarial loss: %f, acc: %f]" % (log, loss, acc)
        print(log)
        
        if (i + 1) % save_interval == 0:
            if (i + 1) == train_steps:
                show = True
            else:
                show = False
            #rand_indexes_draw = np.random.randint(0, 64, size=16)
            # plot generator images on a periodic basis
            plot_images(generator, noise_input=noise_input,  show=show, step=(i + 1))
    generator.save("generator.h5")

In [0]:
build_and_train_model()