In [4]:
import numpy as np
import matplotlib.pyplot as plt

import keras
from keras.models import Sequential, Model
from keras.layers import Dense, Flatten, Dropout, Activation, Input, Reshape
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, ZeroPadding2D, MaxPooling2D, UpSampling2D
from keras.layers.embeddings import Embedding
from keras import optimizers, losses

import keras.backend as K

from keras.datasets import mnist

In [21]:
class BGAN:
    def __init__(self):
        img_rows = 28
        img_cols = 28
        img_channel = 1
        self.img_shape = (img_rows, img_cols, img_channel)
        self.latent_dim = 100
        
        optimizer = optimizers.Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999)
        
        # build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss = losses.binary_crossentropy, optimizer = optimizer, metrics = ['accuracy'])
        
        # build the generator
        self.generator = self.build_generator()
        
        z = Input(shape = (self.latent_dim, ))
        img = self.generator(z)
        
        # for the combined model we will only train the generator
        self.discriminator.trainable = False
        
        valid = self.discriminator(img)
        
        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(inputs = z, outputs = valid)
        self.combined.compile(loss = self.boundary_loss, optimizer = optimizer)
        
    def build_generator(self):
        z = Input(shape = (self.latent_dim, ))
        output = Dense(256)(z)
        output = LeakyReLU(alpha = 0.2)(output)
        output = BatchNormalization(momentum = 0.8)(output)
        output = Dense(512)(output)
        output = LeakyReLU(alpha = 0.2)(output)
        output = BatchNormalization(momentum = 0.8)(output)
        output = Dense(1024)(output)
        output = LeakyReLU(alpha = 0.2)(output)
        output = BatchNormalization(momentum = 0.8)(output)
        output = Dense(np.prod(self.img_shape), activation = 'tanh')(output)
        output = Reshape(target_shape = self.img_shape)(output)
        
        return Model(z, output)
    
    def build_discriminator(self):
        img = Input(shape = self.img_shape)
        output = Flatten()(img)
        output = Dense(512)(output)
        output = LeakyReLU(alpha = 0.2)(output)
        output = Dense(256)(output)
        output = LeakyReLU(alpha = 0.2)(output)
        output = Dense(1, activation = 'sigmoid')(output)
        
        return Model(img, output)
    
    def boundary_loss(self, y_true, y_pred):
        """
        Boundary seeking loss.
        Reference: https://wiseodd.github.io/techblog/2017/03/07/boundary-seeking-gan/
        """
        return 0.5 * K.mean((K.log(y_pred) - K.log(1 - y_pred))**2)
    
    def train(self, epochs, batch_size = 128, sample_interval = 50):
        (X_train, _), (_, _) = mnist.load_data()
        
        # Rescale and expand the training data 
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis = -1)
        
        fake = np.zeros((batch_size, 1))
        valid = np.ones((batch_size, 1))
        
        for epoch in range(epochs):
            idx = np.random.randint(0, len(X_train), size = batch_size)
            imgs = X_train[idx]
            
            noise = np.random.normal(size = (batch_size, self.latent_dim))
            fake_imgs = self.generator.predict(noise)
            
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            # train the generator
            g_loss = self.combined.train_on_batch(noise, valid)
            
            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
                
    def sample_images(self, epoch):
        r, c = 5, 5
        
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)
        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5
        
        fig, axes = plt.subplots(nrows = r, ncols = c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axes[i][j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axes[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()

In [22]:
bgan = BGAN()

In [23]:
bgan.train(epochs=30000, batch_size=32, sample_interval=200)

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.551685, acc.: 70.31%] [G loss: 0.227170]
1 [D loss: 0.479946, acc.: 59.38%] [G loss: 0.072888]
2 [D loss: 0.336067, acc.: 79.69%] [G loss: 0.138141]
3 [D loss: 0.258302, acc.: 90.62%] [G loss: 0.397719]
4 [D loss: 0.243799, acc.: 95.31%] [G loss: 0.582942]
5 [D loss: 0.190765, acc.: 98.44%] [G loss: 0.848536]
6 [D loss: 0.158867, acc.: 100.00%] [G loss: 0.913776]
7 [D loss: 0.146574, acc.: 100.00%] [G loss: 1.236303]
8 [D loss: 0.140402, acc.: 100.00%] [G loss: 1.282410]
9 [D loss: 0.119493, acc.: 100.00%] [G loss: 1.613551]
10 [D loss: 0.107253, acc.: 100.00%] [G loss: 1.832155]
11 [D loss: 0.097450, acc.: 100.00%] [G loss: 1.926739]
12 [D loss: 0.082466, acc.: 100.00%] [G loss: 2.344854]
13 [D loss: 0.089764, acc.: 100.00%] [G loss: 2.208390]
14 [D loss: 0.066746, acc.: 100.00%] [G loss: 2.381065]
15 [D loss: 0.064365, acc.: 100.00%] [G loss: 2.511520]
16 [D loss: 0.071220, acc.: 100.00%] [G loss: 2.741194]
17 [D loss: 0.069003, acc.: 100.00%] [G loss: 2.840228]
18 [D lo

KeyboardInterrupt: 