# Bidirectional GAN

Ref.: DONAHUE, Jeff; KRÄHENBÜHL, Philipp; DARRELL, Trevor.  
      Adversarial feature learning. arXiv preprint arXiv:1605.09782, 2016.  
      https://arxiv.org/abs/1605.09782

In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization
from keras.layers import concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam

Using TensorFlow backend.


# Generator

In [3]:
def build_generator(latent_dim, img_shape):
    model = Sequential()

    model.add(Dense(512, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))

    model.summary()

    z = Input(shape=(latent_dim,))
    gen_img = model(z)

    return Model(z, gen_img)

# Discriminator

In [4]:
def build_discriminator(latent_dim, img_shape):

    z = Input(shape=(latent_dim, ))
    img = Input(shape=img_shape)
    d_in = concatenate([z, Flatten()(img)])

    model = Dense(1024)(d_in)
    model = LeakyReLU(alpha=0.2)(model)
    model = Dropout(0.5)(model)
    model = Dense(1024)(model)
    model = LeakyReLU(alpha=0.2)(model)
    model = Dropout(0.5)(model)
    model = Dense(1024)(model)
    model = LeakyReLU(alpha=0.2)(model)
    model = Dropout(0.5)(model)
    validity = Dense(1, activation="sigmoid")(model)

    return Model([z, img], validity)

In [5]:
# Encoder

In [6]:
def build_encoder(latent_dim, img_shape):
    model = Sequential()

    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(latent_dim))

    model.summary()

    img = Input(shape=img_shape)
    z = model(img)

    return Model(img, z)

In [7]:
# auxiliary function

In [25]:
def do_sample_interval(G, latent_dim, epoch):
    r, c = 5, 5
    z = np.random.normal(size=(25, latent_dim))
    gen_imgs = G.predict(z)

    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist_%d.png" % epoch)
    plt.close()

In [9]:
# train

In [23]:
def train(G, D, encoder, bigan_generator, 
          latent_dim, 
          epochs, batch_size=128, sample_interval=50):

    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()

    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):


        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Sample noise and generate img
        z = np.random.normal(size=(batch_size, latent_dim))
        imgs_ = G.predict(z)

        # Select a random batch of images and encode
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]
        z_ = encoder.predict(imgs)

        # Train the discriminator (img -> z is valid, z -> img is fake)
        d_loss_real = D.train_on_batch([z_, imgs], valid)
        d_loss_fake = D.train_on_batch([z, imgs_], fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------

        # Train the generator (z -> img is valid and img -> z is is invalid)
        g_loss = bigan_generator.train_on_batch([z, imgs], [valid, fake])

        # Plot the progress
        print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0]))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0  or epoch == epochs - 1:
            do_sample_interval(G, latent_dim, epoch)

# main()

In [11]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100

In [12]:
# create optimizer
optimizer = Adam(0.0002, 0.5)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [13]:
# Build and compile the discriminator
D = build_discriminator(latent_dim, img_shape)
D.compile(loss=['binary_crossentropy'],
    optimizer=optimizer,
    metrics=['accuracy'])

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [14]:
# Build the generator
G = build_generator(latent_dim, img_shape)

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_5 (Dense)              (None, 512)               51712     
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
dense_6 (Dense)              (None, 512)               262656    
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048      
_________________________________________________________________
dense_7 (Dense)              (None, 784)              

In [15]:
# Build the encoder
encoder = build_encoder(latent_dim, img_shape)

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_8 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 512)               2048      
_________________________________________________________________
dense_9 (Dense)              (None, 512)               262656    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 512)              

In [16]:
# The part of the bigan that trains the discriminator and encoder
D.trainable = False

In [17]:
# Generate image from sampled noise
z = Input(shape=(latent_dim, ))
img_ = G(z)

In [18]:
# Encode image
img = Input(shape=img_shape)
z_ = encoder(img)

In [20]:
# Latent -> img is fake, and img -> latent is valid
fake = D([z, img_])
valid = D([z_, img])

In [21]:
# Set up and compile the combined model
# Trains generator to fool the discriminator
bigan_generator = Model([z, img], [fake, valid])
bigan_generator.compile(loss=['binary_crossentropy', 'binary_crossentropy'],
    optimizer=optimizer)

## train

In [26]:
epochs=4000
# epochs=40000
train(G, D, encoder, bigan_generator, latent_dim, epochs=epochs, batch_size=32, sample_interval=400)

0 [D loss: 0.266661, acc: 90.62%] [G loss: 5.525970]
1 [D loss: 0.159406, acc: 100.00%] [G loss: 5.670177]
2 [D loss: 0.110983, acc: 98.44%] [G loss: 6.560831]
3 [D loss: 0.070067, acc: 100.00%] [G loss: 7.139620]
4 [D loss: 0.056305, acc: 100.00%] [G loss: 8.750211]
5 [D loss: 0.043628, acc: 100.00%] [G loss: 9.320330]
6 [D loss: 0.033754, acc: 100.00%] [G loss: 10.177339]
7 [D loss: 0.022928, acc: 100.00%] [G loss: 10.352518]
8 [D loss: 0.012991, acc: 100.00%] [G loss: 10.128178]
9 [D loss: 0.015250, acc: 100.00%] [G loss: 11.003234]
10 [D loss: 0.023476, acc: 100.00%] [G loss: 11.932797]
11 [D loss: 0.009466, acc: 100.00%] [G loss: 12.238003]
12 [D loss: 0.008053, acc: 100.00%] [G loss: 12.639989]
13 [D loss: 0.018690, acc: 98.44%] [G loss: 13.858693]
14 [D loss: 0.009253, acc: 100.00%] [G loss: 13.544466]
15 [D loss: 0.011066, acc: 100.00%] [G loss: 13.547765]
16 [D loss: 0.011858, acc: 100.00%] [G loss: 13.203274]
17 [D loss: 0.003389, acc: 100.00%] [G loss: 13.329329]
18 [D loss:

148 [D loss: 0.192493, acc: 90.62%] [G loss: 10.716168]
149 [D loss: 0.335824, acc: 84.38%] [G loss: 8.934477]
150 [D loss: 0.347498, acc: 79.69%] [G loss: 9.117983]
151 [D loss: 0.400675, acc: 84.38%] [G loss: 10.401075]
152 [D loss: 0.547935, acc: 82.81%] [G loss: 10.115701]
153 [D loss: 0.266802, acc: 79.69%] [G loss: 8.113564]
154 [D loss: 0.147362, acc: 96.88%] [G loss: 7.416132]
155 [D loss: 0.517420, acc: 75.00%] [G loss: 11.461792]
156 [D loss: 0.139759, acc: 95.31%] [G loss: 8.939787]
157 [D loss: 0.143692, acc: 96.88%] [G loss: 7.900506]
158 [D loss: 0.407594, acc: 81.25%] [G loss: 9.772662]
159 [D loss: 0.255236, acc: 90.62%] [G loss: 10.083895]
160 [D loss: 0.134348, acc: 93.75%] [G loss: 7.129897]
161 [D loss: 0.795474, acc: 64.06%] [G loss: 13.799948]
162 [D loss: 0.452014, acc: 75.00%] [G loss: 9.877586]
163 [D loss: 0.173852, acc: 92.19%] [G loss: 6.716862]
164 [D loss: 0.893863, acc: 60.94%] [G loss: 13.561461]
165 [D loss: 0.393881, acc: 82.81%] [G loss: 10.650093]
16

296 [D loss: 0.541749, acc: 70.31%] [G loss: 6.204755]
297 [D loss: 0.231997, acc: 87.50%] [G loss: 4.724100]
298 [D loss: 0.442047, acc: 71.88%] [G loss: 6.234987]
299 [D loss: 0.212319, acc: 96.88%] [G loss: 4.316307]
300 [D loss: 0.424617, acc: 79.69%] [G loss: 6.014455]
301 [D loss: 0.306377, acc: 85.94%] [G loss: 5.439369]
302 [D loss: 0.386179, acc: 84.38%] [G loss: 5.444127]
303 [D loss: 0.491338, acc: 73.44%] [G loss: 5.815137]
304 [D loss: 0.218338, acc: 92.19%] [G loss: 4.782976]
305 [D loss: 0.811746, acc: 48.44%] [G loss: 5.988901]
306 [D loss: 0.451243, acc: 73.44%] [G loss: 4.309905]
307 [D loss: 0.279376, acc: 84.38%] [G loss: 4.237712]
308 [D loss: 0.497834, acc: 70.31%] [G loss: 5.571743]
309 [D loss: 0.283938, acc: 89.06%] [G loss: 6.079249]
310 [D loss: 0.190767, acc: 93.75%] [G loss: 4.653200]
311 [D loss: 0.280417, acc: 85.94%] [G loss: 4.972998]
312 [D loss: 0.347642, acc: 87.50%] [G loss: 4.327041]
313 [D loss: 0.227946, acc: 93.75%] [G loss: 4.803055]
314 [D los

445 [D loss: 0.334358, acc: 90.62%] [G loss: 4.970905]
446 [D loss: 0.668104, acc: 60.94%] [G loss: 4.842405]
447 [D loss: 0.284389, acc: 87.50%] [G loss: 4.486209]
448 [D loss: 0.236400, acc: 93.75%] [G loss: 4.531005]
449 [D loss: 0.618416, acc: 71.88%] [G loss: 4.921111]
450 [D loss: 0.292003, acc: 93.75%] [G loss: 4.437847]
451 [D loss: 0.456919, acc: 76.56%] [G loss: 4.689266]
452 [D loss: 0.257107, acc: 92.19%] [G loss: 4.580830]
453 [D loss: 0.304183, acc: 92.19%] [G loss: 4.769463]
454 [D loss: 0.293006, acc: 92.19%] [G loss: 4.421643]
455 [D loss: 0.461784, acc: 78.12%] [G loss: 4.234420]
456 [D loss: 0.142510, acc: 100.00%] [G loss: 4.290384]
457 [D loss: 0.470489, acc: 75.00%] [G loss: 5.108297]
458 [D loss: 0.224818, acc: 96.88%] [G loss: 3.720759]
459 [D loss: 0.751809, acc: 64.06%] [G loss: 4.101781]
460 [D loss: 0.167741, acc: 95.31%] [G loss: 4.347928]
461 [D loss: 0.408954, acc: 84.38%] [G loss: 5.134630]
462 [D loss: 0.352484, acc: 84.38%] [G loss: 4.617553]
463 [D lo

595 [D loss: 0.503360, acc: 75.00%] [G loss: 4.142297]
596 [D loss: 0.415188, acc: 85.94%] [G loss: 3.597692]
597 [D loss: 0.376872, acc: 87.50%] [G loss: 3.771755]
598 [D loss: 0.521124, acc: 73.44%] [G loss: 3.532707]
599 [D loss: 0.412000, acc: 79.69%] [G loss: 4.125612]
600 [D loss: 0.427861, acc: 79.69%] [G loss: 3.886374]
601 [D loss: 0.291729, acc: 90.62%] [G loss: 3.689352]
602 [D loss: 0.681352, acc: 68.75%] [G loss: 4.274494]
603 [D loss: 0.312471, acc: 85.94%] [G loss: 4.506380]
604 [D loss: 0.472190, acc: 76.56%] [G loss: 4.066715]
605 [D loss: 0.356869, acc: 85.94%] [G loss: 3.851837]
606 [D loss: 0.449578, acc: 75.00%] [G loss: 4.336165]
607 [D loss: 0.371600, acc: 84.38%] [G loss: 4.078756]
608 [D loss: 0.516854, acc: 75.00%] [G loss: 3.971596]
609 [D loss: 0.398160, acc: 78.12%] [G loss: 3.719134]
610 [D loss: 0.361815, acc: 82.81%] [G loss: 3.873234]
611 [D loss: 0.595503, acc: 75.00%] [G loss: 4.006230]
612 [D loss: 0.509030, acc: 75.00%] [G loss: 4.299947]
613 [D los

745 [D loss: 0.565713, acc: 73.44%] [G loss: 3.241807]
746 [D loss: 0.466891, acc: 75.00%] [G loss: 3.158059]
747 [D loss: 0.469696, acc: 84.38%] [G loss: 3.230169]
748 [D loss: 0.512032, acc: 73.44%] [G loss: 2.676220]
749 [D loss: 0.587824, acc: 59.38%] [G loss: 3.556659]
750 [D loss: 0.499714, acc: 71.88%] [G loss: 3.265871]
751 [D loss: 0.732827, acc: 59.38%] [G loss: 2.957512]
752 [D loss: 0.462601, acc: 76.56%] [G loss: 3.197873]
753 [D loss: 0.562937, acc: 65.62%] [G loss: 2.950189]
754 [D loss: 0.528406, acc: 71.88%] [G loss: 3.210142]
755 [D loss: 0.436181, acc: 79.69%] [G loss: 3.190323]
756 [D loss: 0.760725, acc: 59.38%] [G loss: 3.425975]
757 [D loss: 0.510606, acc: 68.75%] [G loss: 3.286462]
758 [D loss: 0.518351, acc: 78.12%] [G loss: 3.347046]
759 [D loss: 0.518854, acc: 76.56%] [G loss: 3.316479]
760 [D loss: 0.582839, acc: 73.44%] [G loss: 3.174687]
761 [D loss: 0.523247, acc: 71.88%] [G loss: 2.948381]
762 [D loss: 0.483944, acc: 76.56%] [G loss: 3.152718]
763 [D los

894 [D loss: 0.540031, acc: 65.62%] [G loss: 3.055088]
895 [D loss: 0.580447, acc: 70.31%] [G loss: 3.025200]
896 [D loss: 0.530810, acc: 81.25%] [G loss: 3.122701]
897 [D loss: 0.724169, acc: 60.94%] [G loss: 2.956584]
898 [D loss: 0.459320, acc: 79.69%] [G loss: 2.727070]
899 [D loss: 0.607559, acc: 68.75%] [G loss: 2.831520]
900 [D loss: 0.428687, acc: 81.25%] [G loss: 3.249316]
901 [D loss: 0.445381, acc: 79.69%] [G loss: 2.828544]
902 [D loss: 0.522684, acc: 75.00%] [G loss: 3.000340]
903 [D loss: 0.511234, acc: 73.44%] [G loss: 3.043961]
904 [D loss: 0.562486, acc: 68.75%] [G loss: 3.139113]
905 [D loss: 0.654504, acc: 60.94%] [G loss: 2.750283]
906 [D loss: 0.621668, acc: 62.50%] [G loss: 2.902838]
907 [D loss: 0.487829, acc: 75.00%] [G loss: 2.879560]
908 [D loss: 0.509939, acc: 75.00%] [G loss: 2.874039]
909 [D loss: 0.465119, acc: 81.25%] [G loss: 3.209242]
910 [D loss: 0.440158, acc: 84.38%] [G loss: 3.235978]
911 [D loss: 0.524613, acc: 71.88%] [G loss: 2.839190]
912 [D los

1043 [D loss: 0.545526, acc: 71.88%] [G loss: 2.582049]
1044 [D loss: 0.567107, acc: 71.88%] [G loss: 3.002445]
1045 [D loss: 0.537011, acc: 76.56%] [G loss: 2.691355]
1046 [D loss: 0.609517, acc: 64.06%] [G loss: 2.825548]
1047 [D loss: 0.551482, acc: 73.44%] [G loss: 2.701110]
1048 [D loss: 0.583902, acc: 65.62%] [G loss: 2.966018]
1049 [D loss: 0.693659, acc: 64.06%] [G loss: 2.861208]
1050 [D loss: 0.677051, acc: 60.94%] [G loss: 2.780712]
1051 [D loss: 0.528837, acc: 76.56%] [G loss: 2.534616]
1052 [D loss: 0.596487, acc: 67.19%] [G loss: 2.897629]
1053 [D loss: 0.633004, acc: 64.06%] [G loss: 2.638061]
1054 [D loss: 0.698757, acc: 62.50%] [G loss: 3.003698]
1055 [D loss: 0.534305, acc: 71.88%] [G loss: 3.012428]
1056 [D loss: 0.588122, acc: 71.88%] [G loss: 2.553753]
1057 [D loss: 0.506948, acc: 70.31%] [G loss: 2.799271]
1058 [D loss: 0.505193, acc: 75.00%] [G loss: 2.639468]
1059 [D loss: 0.605672, acc: 65.62%] [G loss: 2.635161]
1060 [D loss: 0.593327, acc: 67.19%] [G loss: 2.

1190 [D loss: 0.590061, acc: 70.31%] [G loss: 2.597369]
1191 [D loss: 0.616155, acc: 67.19%] [G loss: 2.664262]
1192 [D loss: 0.550600, acc: 68.75%] [G loss: 2.657630]
1193 [D loss: 0.590390, acc: 70.31%] [G loss: 3.025695]
1194 [D loss: 0.639629, acc: 62.50%] [G loss: 2.686191]
1195 [D loss: 0.605298, acc: 67.19%] [G loss: 2.896614]
1196 [D loss: 0.561170, acc: 70.31%] [G loss: 2.746134]
1197 [D loss: 0.621627, acc: 59.38%] [G loss: 2.842017]
1198 [D loss: 0.553435, acc: 71.88%] [G loss: 2.754693]
1199 [D loss: 0.730080, acc: 57.81%] [G loss: 2.392956]
1200 [D loss: 0.609166, acc: 68.75%] [G loss: 2.545686]
1201 [D loss: 0.614706, acc: 64.06%] [G loss: 2.739901]
1202 [D loss: 0.574839, acc: 65.62%] [G loss: 2.325250]
1203 [D loss: 0.535075, acc: 75.00%] [G loss: 2.822335]
1204 [D loss: 0.618724, acc: 62.50%] [G loss: 2.250563]
1205 [D loss: 0.473771, acc: 78.12%] [G loss: 2.527232]
1206 [D loss: 0.503397, acc: 71.88%] [G loss: 2.422122]
1207 [D loss: 0.629589, acc: 70.31%] [G loss: 2.

1337 [D loss: 0.525431, acc: 73.44%] [G loss: 2.759533]
1338 [D loss: 0.762407, acc: 48.44%] [G loss: 2.364902]
1339 [D loss: 0.712198, acc: 60.94%] [G loss: 2.553585]
1340 [D loss: 0.637168, acc: 68.75%] [G loss: 2.812963]
1341 [D loss: 0.539008, acc: 70.31%] [G loss: 2.685045]
1342 [D loss: 0.631792, acc: 54.69%] [G loss: 2.502464]
1343 [D loss: 0.678346, acc: 48.44%] [G loss: 2.427691]
1344 [D loss: 0.554602, acc: 68.75%] [G loss: 2.352800]
1345 [D loss: 0.566754, acc: 71.88%] [G loss: 2.665856]
1346 [D loss: 0.598624, acc: 62.50%] [G loss: 2.880019]
1347 [D loss: 0.622700, acc: 60.94%] [G loss: 2.600010]
1348 [D loss: 0.504611, acc: 79.69%] [G loss: 2.525847]
1349 [D loss: 0.642503, acc: 60.94%] [G loss: 2.372295]
1350 [D loss: 0.659768, acc: 67.19%] [G loss: 2.612552]
1351 [D loss: 0.492270, acc: 79.69%] [G loss: 2.618181]
1352 [D loss: 0.480834, acc: 81.25%] [G loss: 2.582860]
1353 [D loss: 0.588942, acc: 68.75%] [G loss: 3.100073]
1354 [D loss: 0.593110, acc: 67.19%] [G loss: 2.

1485 [D loss: 0.590946, acc: 67.19%] [G loss: 2.824907]
1486 [D loss: 0.520438, acc: 71.88%] [G loss: 2.821422]
1487 [D loss: 0.620692, acc: 70.31%] [G loss: 2.617889]
1488 [D loss: 0.647348, acc: 57.81%] [G loss: 2.544348]
1489 [D loss: 0.503395, acc: 73.44%] [G loss: 2.890251]
1490 [D loss: 0.512752, acc: 78.12%] [G loss: 2.832696]
1491 [D loss: 0.524353, acc: 68.75%] [G loss: 3.208224]
1492 [D loss: 0.732687, acc: 64.06%] [G loss: 2.823287]
1493 [D loss: 0.487255, acc: 76.56%] [G loss: 2.552429]
1494 [D loss: 0.694256, acc: 62.50%] [G loss: 2.826921]
1495 [D loss: 0.638325, acc: 67.19%] [G loss: 2.554887]
1496 [D loss: 0.638508, acc: 65.62%] [G loss: 2.826968]
1497 [D loss: 0.699567, acc: 57.81%] [G loss: 2.530822]
1498 [D loss: 0.551792, acc: 71.88%] [G loss: 2.976685]
1499 [D loss: 0.602603, acc: 67.19%] [G loss: 2.334286]
1500 [D loss: 0.589691, acc: 67.19%] [G loss: 2.960347]
1501 [D loss: 0.671550, acc: 62.50%] [G loss: 2.550368]
1502 [D loss: 0.676243, acc: 57.81%] [G loss: 2.

KeyboardInterrupt: 