In [1]:
import numpy as np
import time
from itertools import product
from keras.datasets import mnist

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Reshape, Input, Lambda
from keras.layers import Conv2D, Conv2DTranspose, Cropping2D
from keras.layers import LeakyReLU, Dropout, UpSampling2D, AveragePooling2D
from keras.layers import BatchNormalization
from keras.optimizers import Adam, RMSprop
from keras.models import Model
from keras.backend import int_shape, tf
import matplotlib.pyplot as plt

Using TensorFlow backend.


In [18]:
class PG_GAN(object):
    def __init__(self, img_rows=28, img_cols=28, channel=1):

        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channel = channel
        self.D = None   # discriminator
        self.G = None   # generator
        self.AM = None  # adversarial model
        self.DM = None  # discriminator model
    
    def discriminator(self):
        
        if self.D:
            return self.D
        
        depth = 64
        input_shape = (28, 28, 1)
        
        x = Input(shape=input_shape)
        y = AveragePooling2D(pool_size = 2)(x)
       # y = Conv2D(depth, 4, strides=1, padding = 'same')(y)
       # y = LeakyReLU(alpha=0.2)(y)
       # y = Dropout(.2)(y)
        y = AveragePooling2D(pool_size = 2)(y)
        #y = Conv2D(depth*2, 4, strides=1, padding = 'same')(y)
        #y = LeakyReLU(alpha=0.2)(y)
        #y = Dropout(.2)(y)
        y = AveragePooling2D(pool_size = 2)(y)
        y = Conv2D(depth*4, 4, strides=1, padding = 'same')(y)
        y = LeakyReLU(alpha=0.2)(y)
        y = Dropout(.2)(y)
        
        y = Flatten()(y)
        y = Dense(1,activation = 'sigmoid')(y)
        self.D = Model(x,y)
        self.D.summary()
        return self.D
    
    def generator(self):
        
        if self.G:
            return self.G
        depth = 64
        dim = 4#self.D.layers[-2].input_shape[1]
        x = Input(shape=(100,))
        y = Dense(depth*4*dim*dim,)(x)
        y = Reshape((dim, dim, depth*4))(y)
        y = UpSampling2D(2,interpolation='bilinear')(y)
        y = Conv2D(depth*2, 4, strides=1, padding = 'same')(y)
        y = LeakyReLU(alpha=0.2)(y)
        y = UpSampling2D(2,interpolation='bilinear')(y)
        #y = Conv2D(depth, 4, strides=1, padding = 'same')(y)
        #y = LeakyReLU(alpha=0.2)(y)
        y = UpSampling2D(2,interpolation='bilinear')(y)
        #y = Conv2D(1, 4, strides=1, padding = 'same')(y)
        #y = LeakyReLU(alpha=0.2)(y)

        y = Activation('tanh')(y)
        crop = int((int_shape(y)[1]-self.D.input_shape[1])/2)
        y = Cropping2D(cropping=((crop,crop),(crop,crop)))(y)
        self.G = Model(x,y)
        self.G.summary()
        return self.G

    def discriminator_model(self):
        if self.DM:
            return self.DM
        optimizer = Adam(lr=0.0002,beta_1=0.5, decay=0)
        self.DM = Sequential()
        self.DM.add(self.discriminator())
        self.DM.compile(loss='binary_crossentropy', optimizer=optimizer,\
                    metrics=['binary_accuracy'])
        return self.DM

    def adversarial_model(self):
        if self.AM:
            return self.AM
        optimizer = Adam(lr=0.0002,beta_1=0.5, decay=0)
        self.AM = Sequential()
        self.AM.add(self.generator())
        discriminator = self.discriminator()
        discriminator.trainable=False
        self.AM.add(discriminator)
        self.AM.compile(loss='binary_crossentropy', optimizer=optimizer,\
                metrics=['binary_accuracy'])
        discriminator.trainable=True
        self.AM.summary()
        return self.AM
    
#     def replace_layer(self, model, layer_id, new_layer):

#         layers = [l for l in model.layers]

#         x = layers[0].output
#         for i in range(1, len(layers)):
#             if i == layer_id:
#                 x = new_layer(x)
#             else:
#                 x = layers[i](x)

#         new_model = Model(inputs=layers[0].input, outputs=x)
#         return new_model
    
    def insert_layer(self, model, layer_id, new_layer):

        layers = [l for l in model.layers]

        x = layers[0].output
        for i in range(1, len(layers)):
            if i == layer_id:
                x = new_layer(x)
            x = layers[i](x)

        new_model = Model(inputs=layers[0].input, outputs=x)
        return new_model
    
    def increase_res(self):
        generator = self.generator()
        discriminator = self.discriminator()
        discriminator.layers.pop(0)
        newDInput = Input(shape=(7,7,16))
        newD = Conv2D(49, 4, strides=2, padding = 'same')(newDInput)
        newD = LeakyReLU(alpha=0.2)(newD)
        newDOutput = discriminator(newD)
        discriminator = Model(inputs=newDInput,outputs=newDOutput)
        self.D = discriminator
        self.D.summary()
        
        generator =  self.insert_layer(generator,len(generator.layers)-2,Conv2D(16, 4, strides=1, padding='same'))
        generator =  self.insert_layer(generator,len(generator.layers)-2,LeakyReLU(alpha=0.2))
        generator =  self.insert_layer(generator,len(generator.layers)-2,BatchNormalization(momentum=0.9))
        #crop = int((int_shape(generator.layers[-2].output)[1]-discriminator.input_shape[1])/2)
        #print(discriminator.input_shape,crop)
        generator =  self.replace_layer(generator,len(generator.layers)-1,Cropping2D(cropping=((1,0),(1,0))))
        self.G = generator
        self.G.summary()
        
        optimizer = Adam(lr=0.001,beta_1=0,beta_2=0.99, decay=0)
        self.DM = Sequential()
        self.DM.add(self.D)
        self.DM.compile(loss='binary_crossentropy', optimizer=optimizer,\
            metrics=['accuracy'])
        
        optimizer = Adam(lr=0.001,beta_1=0,beta_2=0.99, decay=0)
        self.AM = Sequential()
        self.AM.add(self.G)
        discriminator.trainable=False
        self.AM.add(discriminator)
        self.AM.compile(loss='binary_crossentropy', optimizer=optimizer,\
            metrics=['accuracy'])
        self.AM.summary()
        discriminator.trainable=True
        


In [19]:
class MNIST_DCGAN(object):
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channel = 1

        (self.x_train, _), (_, _) = mnist.load_data()
        self.x_train = self.x_train.reshape(-1, self.img_rows,\
        	self.img_cols, 1).astype(np.float32)/255*2-1

        self.PGGAN = PG_GAN()
        self.discriminator =  self.PGGAN.discriminator_model()
        self.adversarial = self.PGGAN.adversarial_model()
        self.generator = self.PGGAN.generator()
        self.res_scale=7
    def increase_res(self):
        self.PGGAN.increase_res()
        self.discriminator =  self.PGGAN.discriminator_model()
        self.adversarial = self.PGGAN.adversarial_model()
        self.generator = self.PGGAN.generator()
        self.res_scale=4
    
    def train(self, train_steps=2000, batch_size=256, save_interval=0):
        noise_input = None
        if save_interval>0:
            noise_input = np.random.normal(loc=0., scale=1., size=[16, 100])
        for i in range(train_steps):
            images_train = self.x_train[np.random.randint(0,
                self.x_train.shape[0], size=batch_size), :, :, :]
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            images_fake = self.generator.predict(noise)
            x = np.concatenate((images_train, images_fake))
            y = np.ones([2*batch_size, 1])
            y[batch_size:, :] = 0
            d_loss = self.discriminator.train_on_batch(x, y)

            y = np.ones([batch_size, 1])
            noise = np.random.normal(loc=0., scale=1., size=[batch_size, 100])
            a_loss = self.adversarial.train_on_batch(noise, y)
            log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
            log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
            if i%50==0:
                print(log_mesg)
            if save_interval>0:
                if (i+1)%save_interval==0:
                    self.plot_images(save2file=True, samples=noise_input.shape[0],\
                        noise=noise_input, step=(i+1))

    def plot_images(self, save2file=False, fake=True, samples=16, noise=None, step=0):
        filename = 'mnist.png'
        if fake:
            if noise is None:
                noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])
            else:
                filename = "mnist_%d.png" % step
            images = self.generator.predict(noise)
        else:
            i = np.random.randint(0, self.x_train.shape[0], samples)
            images = self.x_train[i, :, :, :]
        plt.figure(figsize=(10,10))
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :,0]
            #image = np.reshape(image, [self.img_rows, self.img_cols])
            plt.imshow(image)
            plt.axis('off')
        plt.tight_layout()
        if save2file:
            plt.savefig(filename)
            plt.close('all')
        else:
            plt.show()
            
    def add_output_layer(self, model, new_layer):
        x = new_layer(model.layers[-1].output)
        return Model(inputs = model.inputs, outputs = x)
        
    

In [20]:
mnist_dcgan = MNIST_DCGAN()
mnist_dcgan.train(train_steps=2000, batch_size=256, save_interval=500)
mnist_dcgan.plot_images(fake=True)
mnist_dcgan.plot_images(fake=False, save2file=True)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_12 (InputLayer)        (None, 28, 28, 1)         0         
_________________________________________________________________
average_pooling2d_13 (Averag (None, 14, 14, 1)         0         
_________________________________________________________________
average_pooling2d_14 (Averag (None, 7, 7, 1)           0         
_________________________________________________________________
average_pooling2d_15 (Averag (None, 3, 3, 1)           0         
_________________________________________________________________
conv2d_32 (Conv2D)           (None, 3, 3, 256)         4352      
_________________________________________________________________
leaky_re_lu_31 (LeakyReLU)   (None, 3, 3, 256)         0         
_________________________________________________________________
dropout_16 (Dropout)         (None, 3, 3, 256)         0         
__________

ValueError: number of input channels does not match corresponding dimension of filter, 128 != 1