In [3]:
import numpy as np
import time
from keras.datasets import mnist

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Reshape, Input, Lambda
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D, Cropping2D
from keras.layers import LeakyReLU, Dropout, UpSampling2D
from keras.layers import BatchNormalization
from keras.optimizers import Adam, RMSprop
from keras.models import Model
from keras.backend import int_shape, tf
import matplotlib.pyplot as plt

In [4]:
class PG_GAN(object):
    def __init__(self, img_rows=28, img_cols=28, channel=1):

        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channel = channel
        self.D = None   # discriminator
        self.G = None   # generator
        self.AM = None  # adversarial model
        self.DM = None  # discriminator model

    def discriminator(self):
        def d_block(x,n_features,kernel,stride):
            x = Conv2D(n_features, kernel, strides=stride, padding = 'same')(x)
            x = LeakyReLU(alpha=0.2)(x)
            x = Dropout(.1)(x)
            return x
        if self.D:
            return self.D
        
        depth = 64
        input_shape = (self.img_rows, self.img_cols, self.channel)
        
        x = Input(shape=input_shape)
        y = d_block(x,depth*1,5,2)
        y = d_block(y,depth*2,5,2)
        y = d_block(y,depth*4,5,2)
        y = d_block(y,depth*8,5,1)
        
        y = Flatten(name='Flatten')(y)
        y = Dense(1,activation = 'sigmoid')(y)
        self.D = Model(x,y)
        self.D.summary()
        return self.D
    
    def generator(self):
        def g_block(x,n_features,kernel,stride):
            x = UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
            x = Conv2DTranspose(n_features, kernel, strides=stride, padding='same')(x)
            x = LeakyReLU(alpha=0.2)(x)
            x = BatchNormalization(momentum=0.9)(x)
            return x
        if self.G:
            return self.G
        depth = 64
        dim = self.discriminator().get_layer('Flatten').input_shape[1]
        x = Input(shape=(100,))
        y = Dense(dim*dim*depth*8)(x)
        y = Activation('relu')(y)
        y = Reshape((dim, dim, depth*8))(y)
        y = BatchNormalization(momentum=0.9)(y)

        y = g_block(y,depth*4,5,1)
        y = g_block(y,depth*2,5,1)
        y = g_block(y,depth,5,1)
        

        y = Conv2DTranspose(1, 5, padding='same')(y)
        y = Activation('tanh', name='Tanh')(y)
        crop = int((int_shape(y)[1]-self.D.input_shape[1])/2)
        y = Cropping2D(cropping=((crop,crop),(crop,crop)), name = 'Crop2D')(y)
        self.G = Model(x,y)
        self.G.summary()
        return self.G

    def discriminator_model(self):
        if self.DM:
            return self.DM
        optimizer = RMSprop(lr=0.0002, decay=6e-8)
        self.DM = Sequential()
        self.DM.add(self.discriminator())
        self.DM.compile(loss='binary_crossentropy', optimizer=optimizer,\
            metrics=['accuracy'])
        return self.DM

    def adversarial_model(self):
        if self.AM:
            return self.AM
        optimizer = RMSprop(lr=0.0001, decay=3e-8)
        self.AM = Sequential()
        self.AM.add(self.generator())
        self.AM.add(self.discriminator())
        self.AM.compile(loss='binary_crossentropy', optimizer=optimizer,\
            metrics=['accuracy'])
        return self.AM


In [7]:
class MNIST_DCGAN(object):
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channel = 1

        (self.x_train, _), (_, _) = mnist.load_data()
        self.x_train = self.x_train.reshape(-1, self.img_rows,\
        	self.img_cols, 1).astype(np.float32)

        self.PGGAN = PG_GAN()
        self.discriminator =  self.PGGAN.discriminator_model()
        self.adversarial = self.PGGAN.adversarial_model()
        self.generator = self.PGGAN.generator()

    def train(self, train_steps=2000, batch_size=256, save_interval=0):
        noise_input = None
        if save_interval>0:
            noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])
        for i in range(train_steps):
            images_train = self.x_train[np.random.randint(0,
                self.x_train.shape[0], size=batch_size), :, :, :]
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            images_fake = self.generator.predict(noise)
            x = np.concatenate((images_train, images_fake))
            y = np.ones([2*batch_size, 1])
            y[batch_size:, :] = 0
            d_loss = self.discriminator.train_on_batch(x, y)

            y = np.ones([batch_size, 1])
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            a_loss = self.adversarial.train_on_batch(noise, y)
            log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
            log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
            print(log_mesg)
            if save_interval>0:
                if (i+1)%save_interval==0:
                    self.plot_images(save2file=True, samples=noise_input.shape[0],\
                        noise=noise_input, step=(i+1))

    def plot_images(self, save2file=False, fake=True, samples=16, noise=None, step=0):
        filename = 'mnist.png'
        if fake:
            if noise is None:
                noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])
            else:
                filename = "mnist_%d.png" % step
            images = self.generator.predict(noise)
        else:
            i = np.random.randint(0, self.x_train.shape[0], samples)
            images = self.x_train[i, :, :, :]

        plt.figure(figsize=(10,10))
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image, [self.img_rows, self.img_cols])
            plt.imshow(image)
            plt.axis('off')
        plt.tight_layout()
        if save2file:
            plt.savefig(filename)
            plt.close('all')
        else:
            plt.show()

In [8]:
mnist_dcgan = MNIST_DCGAN()
mnist_dcgan.train(train_steps=2000, batch_size=256, save_interval=500)
mnist_dcgan.plot_images(fake=True)
mnist_dcgan.plot_images(fake=False, save2file=True)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 64)        1664      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 128)         204928    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 7, 7, 128)         0         
__________

KeyboardInterrupt: 