# Boundary-Seeking Generative Adversarial Networks
### BGAN

Ref. HJELM, R. Devon et al. Boundary-seeking generative adversarial networks. 
     arXiv preprint arXiv:1702.08431, 2017.
     https://arxiv.org/abs/1702.08431

In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import keras.backend as K

Using TensorFlow backend.


# Generator

In [3]:
def build_generator(latent_dim, img_shape):

    model = Sequential()

    model.add(Dense(256, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))

    model.summary()

    noise = Input(shape=(latent_dim,))
    img = model(noise)

    return Model(noise, img)

# Discriminator

In [4]:
def build_discriminator(img_shape):

    model = Sequential()

    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)

# helper functions

In [5]:
def boundary_loss(y_true, y_pred):
    """
    Boundary seeking loss.
    Reference: https://wiseodd.github.io/techblog/2017/03/07/boundary-seeking-gan/
    """
    return 0.5 * K.mean((K.log(y_pred) - K.log(1 - y_pred))**2)

In [6]:
def sample_images(epoch, latent_dim, G):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    gen_imgs = G.predict(noise)
    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist_%d.png" % epoch)
    plt.close()

# training the model

In [7]:
def train(G, D, combined, latent_dim, epochs, batch_size=128, sample_interval=50):

    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()

    # Rescale -1 to 1
    X_train = X_train / 127.5 - 1.
    X_train = np.expand_dims(X_train, axis=3)

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Select a random batch of images
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]

        noise = np.random.normal(0, 1, (batch_size, latent_dim))

        # Generate a batch of new images
        gen_imgs = G.predict(noise)

        # Train the discriminator
        d_loss_real = D.train_on_batch(imgs, valid)
        d_loss_fake = D.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = combined.train_on_batch(noise, valid)

        # Plot the progress
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(epoch, latent_dim, G)

# main()

In [8]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100

In [9]:
# create optimizer
optimizer = Adam(0.0002, 0.5)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [10]:
# Build and compile the discriminator
D = build_discriminator(img_shape)
D.compile(loss='binary_crossentropy',
    optimizer=optimizer,
    metrics=['accuracy'])

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
________________________________________________

In [11]:
# Build the generator
G = build_generator(latent_dim, img_shape)

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_4 (Dense)              (None, 256)               25856     
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 256)               1024      
_________________________________________________________________
dense_5 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048      
_________________________________________________________________
dense_6 (Dense)              (None, 1024)             

In [12]:
# The generator takes noise as input and generated imgs
z = Input(shape=(latent_dim,))
img = G(z)

In [13]:
# For the combined model we will only train the generator
D.trainable = False

In [14]:
# The valid takes generated images as input and determines validity
valid = D(img)

In [15]:
# The combined model  (stacked generator and discriminator)
# Trains the generator to fool the discriminator
combined = Model(z, valid)
combined.compile(loss=boundary_loss, optimizer=optimizer)

## run

there is some bug in the code, because G loss goes to inf in iteration 143

In [16]:
epochs=3000
# epochs=30000
train(G, D, combined, latent_dim, epochs=epochs, batch_size=32, sample_interval=200)




  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.495093, acc.: 64.06%] [G loss: 0.118298]


  'Discrepancy between trainable weights and collected trainable'


1 [D loss: 0.308667, acc.: 98.44%] [G loss: 0.129122]
2 [D loss: 0.281581, acc.: 96.88%] [G loss: 0.207400]
3 [D loss: 0.273255, acc.: 96.88%] [G loss: 0.433630]
4 [D loss: 0.273254, acc.: 87.50%] [G loss: 0.627098]
5 [D loss: 0.230863, acc.: 96.88%] [G loss: 0.885139]
6 [D loss: 0.163481, acc.: 100.00%] [G loss: 1.453845]
7 [D loss: 0.137503, acc.: 100.00%] [G loss: 1.909683]
8 [D loss: 0.103650, acc.: 100.00%] [G loss: 2.438941]
9 [D loss: 0.097468, acc.: 100.00%] [G loss: 2.603986]
10 [D loss: 0.074304, acc.: 100.00%] [G loss: 3.110146]
11 [D loss: 0.062242, acc.: 100.00%] [G loss: 3.676673]
12 [D loss: 0.053192, acc.: 100.00%] [G loss: 4.028837]
13 [D loss: 0.052831, acc.: 100.00%] [G loss: 3.959708]
14 [D loss: 0.043390, acc.: 100.00%] [G loss: 4.470638]
15 [D loss: 0.042539, acc.: 100.00%] [G loss: 4.747778]
16 [D loss: 0.038933, acc.: 100.00%] [G loss: 5.010965]
17 [D loss: 0.029929, acc.: 100.00%] [G loss: 5.738226]
18 [D loss: 0.033720, acc.: 100.00%] [G loss: 5.467397]
19 [D 

146 [D loss: nan, acc.: 0.00%] [G loss: nan]
147 [D loss: nan, acc.: 0.00%] [G loss: nan]
148 [D loss: nan, acc.: 0.00%] [G loss: nan]
149 [D loss: nan, acc.: 0.00%] [G loss: nan]
150 [D loss: nan, acc.: 0.00%] [G loss: nan]
151 [D loss: nan, acc.: 0.00%] [G loss: nan]
152 [D loss: nan, acc.: 0.00%] [G loss: nan]
153 [D loss: nan, acc.: 0.00%] [G loss: nan]
154 [D loss: nan, acc.: 0.00%] [G loss: nan]
155 [D loss: nan, acc.: 0.00%] [G loss: nan]
156 [D loss: nan, acc.: 0.00%] [G loss: nan]
157 [D loss: nan, acc.: 0.00%] [G loss: nan]
158 [D loss: nan, acc.: 0.00%] [G loss: nan]
159 [D loss: nan, acc.: 0.00%] [G loss: nan]
160 [D loss: nan, acc.: 0.00%] [G loss: nan]
161 [D loss: nan, acc.: 0.00%] [G loss: nan]
162 [D loss: nan, acc.: 0.00%] [G loss: nan]
163 [D loss: nan, acc.: 0.00%] [G loss: nan]
164 [D loss: nan, acc.: 0.00%] [G loss: nan]
165 [D loss: nan, acc.: 0.00%] [G loss: nan]
166 [D loss: nan, acc.: 0.00%] [G loss: nan]
167 [D loss: nan, acc.: 0.00%] [G loss: nan]
168 [D los

  dv = (np.float64(self.norm.vmax) -
  np.float64(self.norm.vmin))
  a_min = np.float64(newmin)
  a_max = np.float64(newmax)
  data = np.array(a, copy=False, subok=subok)


201 [D loss: nan, acc.: 0.00%] [G loss: nan]
202 [D loss: nan, acc.: 0.00%] [G loss: nan]
203 [D loss: nan, acc.: 0.00%] [G loss: nan]
204 [D loss: nan, acc.: 0.00%] [G loss: nan]
205 [D loss: nan, acc.: 0.00%] [G loss: nan]
206 [D loss: nan, acc.: 0.00%] [G loss: nan]
207 [D loss: nan, acc.: 0.00%] [G loss: nan]
208 [D loss: nan, acc.: 0.00%] [G loss: nan]
209 [D loss: nan, acc.: 0.00%] [G loss: nan]
210 [D loss: nan, acc.: 0.00%] [G loss: nan]
211 [D loss: nan, acc.: 0.00%] [G loss: nan]
212 [D loss: nan, acc.: 0.00%] [G loss: nan]
213 [D loss: nan, acc.: 0.00%] [G loss: nan]
214 [D loss: nan, acc.: 0.00%] [G loss: nan]
215 [D loss: nan, acc.: 0.00%] [G loss: nan]
216 [D loss: nan, acc.: 0.00%] [G loss: nan]
217 [D loss: nan, acc.: 0.00%] [G loss: nan]
218 [D loss: nan, acc.: 0.00%] [G loss: nan]
219 [D loss: nan, acc.: 0.00%] [G loss: nan]
220 [D loss: nan, acc.: 0.00%] [G loss: nan]
221 [D loss: nan, acc.: 0.00%] [G loss: nan]
222 [D loss: nan, acc.: 0.00%] [G loss: nan]
223 [D los

KeyboardInterrupt: 