# Adversarial Autoencoder

ref. 

In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten
# from keras.layers import merge
from keras.layers import Lambda, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import keras.backend as K

Using TensorFlow backend.


# Encoder

In [3]:
def build_encoder(img_shape, latent_dim):
    # Encoder

    img = Input(shape=img_shape)

    h = Flatten()(img)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    mu = Dense(latent_dim)(h)
    log_var = Dense(latent_dim)(h)
    
    # NOTE: original code uses keras.layer.merge() function which is deprecated
    #     latent_repr = merge([mu, log_var],
    #                          mode=lambda p: p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2),
    #                          output_shape=lambda p: p[0])
    
    # NOW: we use Lambda layer to do the same thing
    latent_repr = Lambda(lambda p: p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2), 
                         output_shape=lambda p: p[0])([mu, log_var])   
    
    return Model(img, latent_repr)

# Generator

In [4]:
def build_decoder(img_shape, latent_dim):

    model = Sequential()

    model.add(Dense(512, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))

    model.summary()

    z = Input(shape=(latent_dim,))
    img = model(z)

    return Model(z, img)

# Discriminator

In [5]:
def build_discriminator(latent_dim):

    model = Sequential()

    model.add(Dense(512, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation="sigmoid"))
    model.summary()

    encoded_repr = Input(shape=(latent_dim, ))
    validity = model(encoded_repr)

    return Model(encoded_repr, validity)

# Save images

In [6]:
def sample_images(epoch, decoder, latent_dim):
    r, c = 5, 5

    z = np.random.normal(size=(r*c, latent_dim))
    gen_imgs = decoder.predict(z)

    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist_%d.png" % epoch)
    plt.close()

# Save model

In [7]:
def save(model, model_name):
    model_path = "saved_model/%s.json" % model_name
    weights_path = "saved_model/%s_weights.hdf5" % model_name
    options = {"file_arch": model_path,
                "file_weight": weights_path}
    json_string = model.to_json()
    open(options['file_arch'], 'w').write(json_string)
    model.save_weights(options['file_weight'])

def save_model(G, D):
    save(G, "aae_generator")
    save(D, "aae_discriminator")

# Train the GAN

In [8]:
def train(encoder, decoder, adversarial_autoencoder, D, latent_dim, epochs, batch_size=128, sample_interval=50):

    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()

    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Select a random batch of images
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]

        latent_fake = encoder.predict(imgs)
        latent_real = np.random.normal(size=(batch_size, latent_dim))

        # Train the discriminator
        d_loss_real = D.train_on_batch(latent_real, valid)
        d_loss_fake = D.train_on_batch(latent_fake, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Adversarial autoencoder
        # ---------------------

        # Train the combined stack
        g_loss = adversarial_autoencoder.train_on_batch(imgs, [imgs, valid])

        # Plot the progress
        print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(epoch, decoder, latent_dim)

# main()

In [9]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 10

In [10]:
optimizer = Adam(0.0002, 0.5)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [11]:
# Build and compile the discriminator
D = build_discriminator(latent_dim)
D.compile(loss='binary_crossentropy',
    optimizer=optimizer,
    metrics=['accuracy'])

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 512)               5632      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 257       
Total params: 137,217
Trainable params: 137,217
Non-trainable params: 0
_________________________________________________________________
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [12]:
# Build the encoder / decoder
encoder = build_encoder(img_shape, latent_dim)
decoder = build_decoder(img_shape, latent_dim)

img = Input(shape=img_shape)

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_8 (Dense)              (None, 512)               5632      
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 512)               262656    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_10 (Dense)             (None, 784)               402192    
_________________________________________________________________
reshape_1 (Reshape)          (None, 28, 28, 1)         0         
Total params: 670,480
Trainable params: 670,480
Non-trainable params: 0
________________________________________________

In [13]:
# The generator takes the image, encodes it and reconstructs it
# from the encoding
encoded_repr = encoder(img)
reconstructed_img = decoder(encoded_repr)

In [14]:
# For the adversarial_autoencoder model we will only train the generator
D.trainable = False

In [15]:
# The discriminator determines validity of the encoding
validity = D(encoded_repr)

In [16]:
# The adversarial_autoencoder model  (stacked generator and discriminator)
adversarial_autoencoder = Model(img, [reconstructed_img, validity])
adversarial_autoencoder.compile(loss=['mse', 'binary_crossentropy'],
    loss_weights=[0.999, 0.001],
    optimizer=optimizer)

In [18]:
epochs=1000
# epochs=20000
train(encoder, decoder, adversarial_autoencoder, D, 
      latent_dim, epochs=epochs, batch_size=32, sample_interval=200)

0 [D loss: 0.351715, acc: 81.25%] [G loss: 0.176907, mse: 0.171583]
1 [D loss: 0.263372, acc: 87.50%] [G loss: 0.166061, mse: 0.154519]
2 [D loss: 0.272404, acc: 89.06%] [G loss: 0.157300, mse: 0.151216]
3 [D loss: 0.435980, acc: 76.56%] [G loss: 0.166682, mse: 0.160853]
4 [D loss: 0.322865, acc: 82.81%] [G loss: 0.176703, mse: 0.169069]
5 [D loss: 0.430299, acc: 79.69%] [G loss: 0.170225, mse: 0.165230]
6 [D loss: 0.192809, acc: 89.06%] [G loss: 0.178563, mse: 0.157949]
7 [D loss: 0.246403, acc: 89.06%] [G loss: 0.191749, mse: 0.176326]
8 [D loss: 0.316771, acc: 82.81%] [G loss: 0.156234, mse: 0.148636]
9 [D loss: 0.504943, acc: 78.12%] [G loss: 0.175928, mse: 0.171551]
10 [D loss: 0.270260, acc: 87.50%] [G loss: 0.157584, mse: 0.146232]
11 [D loss: 0.338696, acc: 76.56%] [G loss: 0.164910, mse: 0.155430]
12 [D loss: 0.149349, acc: 95.31%] [G loss: 0.161164, mse: 0.155375]
13 [D loss: 0.247623, acc: 87.50%] [G loss: 0.177295, mse: 0.173077]
14 [D loss: 0.294870, acc: 82.81%] [G loss: 

125 [D loss: 0.078893, acc: 100.00%] [G loss: 0.124468, mse: 0.119391]
126 [D loss: 0.067672, acc: 100.00%] [G loss: 0.113078, mse: 0.107785]
127 [D loss: 0.090189, acc: 98.44%] [G loss: 0.127000, mse: 0.122012]
128 [D loss: 0.086529, acc: 100.00%] [G loss: 0.146677, mse: 0.142333]
129 [D loss: 0.064016, acc: 100.00%] [G loss: 0.120246, mse: 0.114912]
130 [D loss: 0.074298, acc: 100.00%] [G loss: 0.130564, mse: 0.125363]
131 [D loss: 0.080659, acc: 100.00%] [G loss: 0.126577, mse: 0.121532]
132 [D loss: 0.079828, acc: 100.00%] [G loss: 0.126292, mse: 0.121870]
133 [D loss: 0.079487, acc: 100.00%] [G loss: 0.123485, mse: 0.118601]
134 [D loss: 0.070987, acc: 100.00%] [G loss: 0.123697, mse: 0.118457]
135 [D loss: 0.075759, acc: 100.00%] [G loss: 0.129174, mse: 0.124573]
136 [D loss: 0.061345, acc: 100.00%] [G loss: 0.120010, mse: 0.114424]
137 [D loss: 0.087516, acc: 98.44%] [G loss: 0.119403, mse: 0.114338]
138 [D loss: 0.081143, acc: 100.00%] [G loss: 0.125573, mse: 0.120844]
139 [D l

240 [D loss: 0.018518, acc: 100.00%] [G loss: 0.115086, mse: 0.109027]
241 [D loss: 0.025122, acc: 100.00%] [G loss: 0.104766, mse: 0.099108]
242 [D loss: 0.034832, acc: 100.00%] [G loss: 0.128552, mse: 0.123130]
243 [D loss: 0.030869, acc: 100.00%] [G loss: 0.117769, mse: 0.112382]
244 [D loss: 0.022173, acc: 100.00%] [G loss: 0.114354, mse: 0.108657]
245 [D loss: 0.032696, acc: 100.00%] [G loss: 0.116319, mse: 0.109877]
246 [D loss: 0.026905, acc: 100.00%] [G loss: 0.119171, mse: 0.113402]
247 [D loss: 0.032313, acc: 100.00%] [G loss: 0.115802, mse: 0.110127]
248 [D loss: 0.025061, acc: 100.00%] [G loss: 0.120485, mse: 0.114594]
249 [D loss: 0.026181, acc: 100.00%] [G loss: 0.120425, mse: 0.114294]
250 [D loss: 0.020481, acc: 100.00%] [G loss: 0.105706, mse: 0.099417]
251 [D loss: 0.028986, acc: 100.00%] [G loss: 0.121073, mse: 0.115080]
252 [D loss: 0.020552, acc: 100.00%] [G loss: 0.121307, mse: 0.115152]
253 [D loss: 0.035622, acc: 100.00%] [G loss: 0.142085, mse: 0.136524]
254 [D

358 [D loss: 0.013547, acc: 100.00%] [G loss: 0.104706, mse: 0.097720]
359 [D loss: 0.012419, acc: 100.00%] [G loss: 0.112736, mse: 0.106112]
360 [D loss: 0.010971, acc: 100.00%] [G loss: 0.095877, mse: 0.089163]
361 [D loss: 0.016836, acc: 100.00%] [G loss: 0.094811, mse: 0.088040]
362 [D loss: 0.017313, acc: 100.00%] [G loss: 0.112725, mse: 0.106657]
363 [D loss: 0.025845, acc: 98.44%] [G loss: 0.105281, mse: 0.098952]
364 [D loss: 0.013006, acc: 100.00%] [G loss: 0.111063, mse: 0.104764]
365 [D loss: 0.022767, acc: 100.00%] [G loss: 0.096979, mse: 0.090652]
366 [D loss: 0.018398, acc: 100.00%] [G loss: 0.121998, mse: 0.115709]
367 [D loss: 0.012137, acc: 100.00%] [G loss: 0.121859, mse: 0.115275]
368 [D loss: 0.018903, acc: 100.00%] [G loss: 0.120657, mse: 0.114266]
369 [D loss: 0.017797, acc: 100.00%] [G loss: 0.112377, mse: 0.106114]
370 [D loss: 0.011167, acc: 100.00%] [G loss: 0.094846, mse: 0.087808]
371 [D loss: 0.012801, acc: 100.00%] [G loss: 0.105578, mse: 0.098810]
372 [D 

482 [D loss: 0.014453, acc: 100.00%] [G loss: 0.106115, mse: 0.099078]
483 [D loss: 0.017408, acc: 100.00%] [G loss: 0.123042, mse: 0.115921]
484 [D loss: 0.024188, acc: 100.00%] [G loss: 0.105952, mse: 0.098349]
485 [D loss: 0.021645, acc: 100.00%] [G loss: 0.120169, mse: 0.113380]
486 [D loss: 0.010673, acc: 100.00%] [G loss: 0.117437, mse: 0.110914]
487 [D loss: 0.021769, acc: 100.00%] [G loss: 0.102592, mse: 0.096176]
488 [D loss: 0.019578, acc: 100.00%] [G loss: 0.114288, mse: 0.107272]
489 [D loss: 0.014874, acc: 100.00%] [G loss: 0.115966, mse: 0.109071]
490 [D loss: 0.029983, acc: 100.00%] [G loss: 0.099159, mse: 0.092082]
491 [D loss: 0.011759, acc: 100.00%] [G loss: 0.108632, mse: 0.102065]
492 [D loss: 0.023598, acc: 100.00%] [G loss: 0.111347, mse: 0.104702]
493 [D loss: 0.011421, acc: 100.00%] [G loss: 0.099322, mse: 0.091462]
494 [D loss: 0.045625, acc: 98.44%] [G loss: 0.116977, mse: 0.111262]
495 [D loss: 0.029454, acc: 100.00%] [G loss: 0.114118, mse: 0.106620]
496 [D 

601 [D loss: 0.026721, acc: 100.00%] [G loss: 0.118015, mse: 0.111801]
602 [D loss: 0.045063, acc: 98.44%] [G loss: 0.104354, mse: 0.098065]
603 [D loss: 0.068724, acc: 96.88%] [G loss: 0.103370, mse: 0.096646]
604 [D loss: 0.036800, acc: 100.00%] [G loss: 0.103441, mse: 0.096869]
605 [D loss: 0.047354, acc: 98.44%] [G loss: 0.108461, mse: 0.101604]
606 [D loss: 0.060362, acc: 96.88%] [G loss: 0.098541, mse: 0.092734]
607 [D loss: 0.024737, acc: 100.00%] [G loss: 0.110494, mse: 0.104628]
608 [D loss: 0.038110, acc: 98.44%] [G loss: 0.100483, mse: 0.092559]
609 [D loss: 0.068513, acc: 96.88%] [G loss: 0.102059, mse: 0.095283]
610 [D loss: 0.052487, acc: 98.44%] [G loss: 0.122259, mse: 0.115723]
611 [D loss: 0.032153, acc: 100.00%] [G loss: 0.088070, mse: 0.081539]
612 [D loss: 0.044761, acc: 98.44%] [G loss: 0.119534, mse: 0.112782]
613 [D loss: 0.051422, acc: 96.88%] [G loss: 0.104644, mse: 0.096974]
614 [D loss: 0.070089, acc: 96.88%] [G loss: 0.111162, mse: 0.103965]
615 [D loss: 0.0

719 [D loss: 0.128089, acc: 96.88%] [G loss: 0.098214, mse: 0.091940]
720 [D loss: 0.072628, acc: 95.31%] [G loss: 0.106792, mse: 0.100327]
721 [D loss: 0.112002, acc: 95.31%] [G loss: 0.116255, mse: 0.110443]
722 [D loss: 0.110819, acc: 95.31%] [G loss: 0.108948, mse: 0.103447]
723 [D loss: 0.073229, acc: 96.88%] [G loss: 0.102462, mse: 0.097477]
724 [D loss: 0.074840, acc: 98.44%] [G loss: 0.112307, mse: 0.106448]
725 [D loss: 0.074387, acc: 96.88%] [G loss: 0.094471, mse: 0.088433]
726 [D loss: 0.043125, acc: 98.44%] [G loss: 0.108081, mse: 0.102212]
727 [D loss: 0.126157, acc: 95.31%] [G loss: 0.089015, mse: 0.082935]
728 [D loss: 0.152542, acc: 93.75%] [G loss: 0.113341, mse: 0.107769]
729 [D loss: 0.047233, acc: 96.88%] [G loss: 0.092518, mse: 0.085587]
730 [D loss: 0.094723, acc: 98.44%] [G loss: 0.098709, mse: 0.092194]
731 [D loss: 0.152384, acc: 95.31%] [G loss: 0.111962, mse: 0.105137]
732 [D loss: 0.167645, acc: 93.75%] [G loss: 0.110764, mse: 0.104151]
733 [D loss: 0.08655

840 [D loss: 0.144183, acc: 95.31%] [G loss: 0.100755, mse: 0.095000]
841 [D loss: 0.048777, acc: 100.00%] [G loss: 0.112902, mse: 0.107919]
842 [D loss: 0.289642, acc: 90.62%] [G loss: 0.119040, mse: 0.113810]
843 [D loss: 0.224495, acc: 90.62%] [G loss: 0.101544, mse: 0.096336]
844 [D loss: 0.168311, acc: 90.62%] [G loss: 0.099713, mse: 0.094521]
845 [D loss: 0.158252, acc: 93.75%] [G loss: 0.110824, mse: 0.104816]
846 [D loss: 0.228684, acc: 90.62%] [G loss: 0.101344, mse: 0.096345]
847 [D loss: 0.085127, acc: 96.88%] [G loss: 0.089214, mse: 0.084049]
848 [D loss: 0.153904, acc: 95.31%] [G loss: 0.106541, mse: 0.101485]
849 [D loss: 0.165429, acc: 93.75%] [G loss: 0.104652, mse: 0.099284]
850 [D loss: 0.159009, acc: 95.31%] [G loss: 0.095175, mse: 0.089749]
851 [D loss: 0.346986, acc: 85.94%] [G loss: 0.104720, mse: 0.099970]
852 [D loss: 0.124363, acc: 95.31%] [G loss: 0.119714, mse: 0.114628]
853 [D loss: 0.162540, acc: 92.19%] [G loss: 0.097108, mse: 0.091529]
854 [D loss: 0.2632

965 [D loss: 0.287981, acc: 84.38%] [G loss: 0.104563, mse: 0.100536]
966 [D loss: 0.114932, acc: 96.88%] [G loss: 0.094930, mse: 0.090177]
967 [D loss: 0.232710, acc: 87.50%] [G loss: 0.097663, mse: 0.094404]
968 [D loss: 0.214479, acc: 87.50%] [G loss: 0.086445, mse: 0.082429]
969 [D loss: 0.388241, acc: 84.38%] [G loss: 0.111653, mse: 0.107958]
970 [D loss: 0.139281, acc: 93.75%] [G loss: 0.083510, mse: 0.079179]
971 [D loss: 0.245045, acc: 89.06%] [G loss: 0.091596, mse: 0.088020]
972 [D loss: 0.277751, acc: 87.50%] [G loss: 0.103205, mse: 0.098628]
973 [D loss: 0.133986, acc: 96.88%] [G loss: 0.118131, mse: 0.113874]
974 [D loss: 0.358937, acc: 87.50%] [G loss: 0.100741, mse: 0.096808]
975 [D loss: 0.145501, acc: 95.31%] [G loss: 0.109642, mse: 0.104319]
976 [D loss: 0.197487, acc: 89.06%] [G loss: 0.097715, mse: 0.093530]
977 [D loss: 0.224348, acc: 89.06%] [G loss: 0.107030, mse: 0.102822]
978 [D loss: 0.331534, acc: 84.38%] [G loss: 0.104776, mse: 0.101221]
979 [D loss: 0.14344