# Code explanation

At [this blog](http://nnormandin.com/science/2017/07/01/cvae.html) quite a few details of typical Keras models are explained. Note that older Keras versions had different ways to handle merging of layers as for a variational autoencoder (see e.g. [here](https://github.com/keras-team/keras/issues/3921)).

However, I don't get why there is a `sample_z` function. The purpose of an adversarial autoencoder is that it would not need differentiable probability densities in the latent layer, or that's what I thought. The latent representation should be compared to samples from a normal distribution by the discriminator.

Ah, that is actually in the original paper! The authors distinguish three different autoencoders. (1) The deterministic autoencoder (that's if you skip the layer containing random variables altogether). (2) An autoencoder that uses a Gaussian posterior. In this case we can indeed use the same renormalization trick as in Kingma and Welling. (3) A general autoencoder with a "univeral approximate posterior" where we add noise to the input of the encoder.

The network has to match q(z) to p(z) by only exploiting the stochasticity in the data distribution in the deterministic case. However, the authors found that for all different types an extensive sweep over hyperparameters did obtain similiar test-likelihoods. All their reported results where subsequently with a deterministic autoencoder.

In [1]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers import MaxPooling2D
from keras.layers import Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras import losses
from keras.utils import to_categorical
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np

Using TensorFlow backend.


In the case of a non-deterministic autoencoder we have a layer with random variables where we sample from using the renormalization trick. See my website on [inference](https://www.annevanrossum.com/blog/2018/01/30/inference-in-deep-learning/) and other [variance reduction methods](https://www.annevanrossum.com/blog/2018/05/26/random-gradients/).

In [22]:
def sample_z(args):
    mu, log_var = args
    batch = K.shape(mu)[0]
    eps = K.random_normal(shape=(batch, latent_dim), mean=0., stddev=1.)
    return mu + K.exp(log_var / 2) * eps

The encoder, discriminator, and decoder have layers of size 512 or 256 and are densely connected. I have not experimented much with the number of nodes. Regarding the activation function leaky rectifiers are used. A rectifier is a function of the form: $f(x) = \max(0,x)$, in other words, making sure the values don't go below zero, but not bounding it from above. The leaky rectifiers are defined through $f(x) = x$ for $x > 0$ and $f(x) = x \alpha$ otherwise. This makes it less likely to have them "stuck" when all there inputs become negative.

In [23]:
def build_encoder(latent_dim, img_shape):
    deterministic = 1
    img = Input(shape=img_shape)
    h = Flatten()(img)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    if deterministic:
        latent_repr = Dense(latent_dim)(h)
    else:
        mu = Dense(latent_dim)(h)
        log_var = Dense(latent_dim)(h)
        latent_repr = Lambda(sample_z)([mu, log_var])
    return Model(img, latent_repr)

In [24]:
def build_discriminator(latent_dim):
    model = Sequential()
    model.add(Dense(512, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation="sigmoid"))
    encoded_repr = Input(shape=(latent_dim, ))
    validity = model(encoded_repr)
    return Model(encoded_repr, validity)

In [25]:
def build_decoder(latent_dim, img_shape):
    model = Sequential()
    model.add(Dense(512, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))
    z = Input(shape=(latent_dim,))
    img = model(z)
    return Model(z, img)

The input are 28x28 images. The optimization used is Adam. The loss is binary cross-entropy.

In [26]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
# Results can be found in just_2_rv
#latent_dim = 2
latent_dim = 8

optimizer = Adam(0.0002, 0.5)

# Build and compile the discriminator
discriminator = build_discriminator(latent_dim)
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])        

In [27]:
# Build the encoder / decoder
encoder = build_encoder(latent_dim, img_shape)
decoder = build_decoder(latent_dim, img_shape)

In [28]:
# The generator takes the image, encodes it and reconstructs it
# from the encoding
img = Input(shape=img_shape)
encoded_repr = encoder(img)
reconstructed_img = decoder(encoded_repr)

# For the adversarial_autoencoder model we will only train the generator
# It will say something like: 
#   UserWarning: Discrepancy between trainable weights and collected trainable weights, 
#   did you set `model.trainable` without calling `model.compile` after ?
# We only set trainable to false for the discriminator when it is part of the autoencoder...
discriminator.trainable = False

# The discriminator determines validity of the encoding
validity = discriminator(encoded_repr)

# The adversarial_autoencoder model  (stacked generator and discriminator)
adversarial_autoencoder = Model(img, [reconstructed_img, validity])
adversarial_autoencoder.compile(loss=['mse', 'binary_crossentropy'], loss_weights=[0.999, 0.001], optimizer=optimizer)

In [29]:
discriminator.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_12 (InputLayer)        (None, 8)                 0         
_________________________________________________________________
sequential_7 (Sequential)    (None, 1)                 136193    
Total params: 272,386
Trainable params: 136,193
Non-trainable params: 136,193
_________________________________________________________________


  'Discrepancy between trainable weights and collected trainable'


In [30]:
epochs=5000
batch_size=128
sample_interval=100 
    
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()

# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

In [None]:
def sample_prior(batch_size, latent_dim):
    return np.random.normal(size=(batch_size, latent_dim))

In [31]:
def sample_images(latent_dim, decoder, epoch):
    r, c = 5, 5

    z = sample_prior(r*c, latent_dim)
    gen_imgs = decoder.predict(z)

    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist_%d.png" % epoch)
    plt.close()

# Training

Each epoch a batch is chosen from the images at random. The typical batch size is 128 items out of 60.000 images. The change to pick the same image is minimal (but not zero).

The "real" latent variables for the encoder will be Normal distributed. They all have the same N(0,1) distribution, mu=0, sigma=1. The variables are returned in a 128x10 matrix if we use 10 latent variables.

The discriminator doesn't know that there is order to the "real" and "fake" samples. We can just first train it on all the real ones and then all the fake ones. I don't know if it matters for the training, but we might try to actually build one data structure where this is randomized...



In [32]:
for epoch in range(epochs):

    # ---------------------
    #  Train Discriminator
    # ---------------------

    # Select a random batch of images
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs = X_train[idx]

    latent_fake = encoder.predict(imgs)
    
    # Here we generate the "TRUE" samples
    latent_real = sample_prior(batch_size, latent_dim)
                      
    # Train the discriminator
    d_loss_real = discriminator.train_on_batch(latent_real, valid)
    d_loss_fake = discriminator.train_on_batch(latent_fake, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # ---------------------
    #  Train Generator
    # ---------------------

    # Train the generator
    g_loss = adversarial_autoencoder.train_on_batch(imgs, [imgs, valid])

    # Plot the progress (every 10th epoch)
    if epoch % 10 == 0:
        print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))
    
    # Save generated images (every sample interval, e.g. every 100th epoch)
    if epoch % sample_interval == 0:
        sample_images(latent_dim, decoder, epoch)

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.733090, acc: 29.30%] [G loss: 0.937058, mse: 0.937302]
10 [D loss: 0.239192, acc: 95.70%] [G loss: 0.298136, mse: 0.291328]
20 [D loss: 0.144058, acc: 100.00%] [G loss: 0.264847, mse: 0.259248]
30 [D loss: 0.092582, acc: 100.00%] [G loss: 0.242276, mse: 0.236328]
40 [D loss: 0.074854, acc: 99.61%] [G loss: 0.222630, mse: 0.215970]
50 [D loss: 0.062436, acc: 99.22%] [G loss: 0.190846, mse: 0.183588]
60 [D loss: 0.095282, acc: 96.09%] [G loss: 0.191600, mse: 0.183074]
70 [D loss: 0.148439, acc: 93.36%] [G loss: 0.187342, mse: 0.179397]
80 [D loss: 0.282353, acc: 87.11%] [G loss: 0.162460, mse: 0.153044]
90 [D loss: 0.272774, acc: 86.33%] [G loss: 0.158903, mse: 0.151126]
100 [D loss: 0.414059, acc: 76.95%] [G loss: 0.155098, mse: 0.149622]
110 [D loss: 0.311753, acc: 81.64%] [G loss: 0.147203, mse: 0.141764]
120 [D loss: 0.284345, acc: 86.33%] [G loss: 0.146561, mse: 0.141146]
130 [D loss: 0.231917, acc: 93.36%] [G loss: 0.137570, mse: 0.132817]
140 [D loss: 0.170791, acc: 9

1170 [D loss: 0.381582, acc: 82.81%] [G loss: 0.095675, mse: 0.093037]
1180 [D loss: 0.448240, acc: 81.64%] [G loss: 0.098654, mse: 0.096297]
1190 [D loss: 0.414639, acc: 82.42%] [G loss: 0.106431, mse: 0.104133]
1200 [D loss: 0.366467, acc: 82.42%] [G loss: 0.100071, mse: 0.097311]
1210 [D loss: 0.464068, acc: 82.42%] [G loss: 0.095387, mse: 0.092945]
1220 [D loss: 0.390966, acc: 83.98%] [G loss: 0.096915, mse: 0.094313]
1230 [D loss: 0.323465, acc: 84.38%] [G loss: 0.099023, mse: 0.096262]
1240 [D loss: 0.415529, acc: 81.64%] [G loss: 0.096245, mse: 0.093837]
1250 [D loss: 0.441138, acc: 80.47%] [G loss: 0.103170, mse: 0.101083]
1260 [D loss: 0.394912, acc: 84.77%] [G loss: 0.086904, mse: 0.084378]
1270 [D loss: 0.306187, acc: 88.28%] [G loss: 0.090949, mse: 0.088351]
1280 [D loss: 0.460602, acc: 81.64%] [G loss: 0.098537, mse: 0.096169]
1290 [D loss: 0.426463, acc: 81.64%] [G loss: 0.100142, mse: 0.097736]
1300 [D loss: 0.401576, acc: 82.81%] [G loss: 0.090307, mse: 0.087849]
1310 [

2330 [D loss: 0.429680, acc: 85.16%] [G loss: 0.088079, mse: 0.086561]
2340 [D loss: 0.548557, acc: 75.39%] [G loss: 0.087695, mse: 0.086288]
2350 [D loss: 0.386990, acc: 85.16%] [G loss: 0.078010, mse: 0.076290]
2360 [D loss: 0.509312, acc: 74.22%] [G loss: 0.083256, mse: 0.081769]
2370 [D loss: 0.515790, acc: 75.39%] [G loss: 0.090319, mse: 0.088793]
2380 [D loss: 0.422168, acc: 82.42%] [G loss: 0.081264, mse: 0.079667]
2390 [D loss: 0.506212, acc: 76.95%] [G loss: 0.089530, mse: 0.088181]
2400 [D loss: 0.510352, acc: 75.39%] [G loss: 0.086726, mse: 0.085118]
2410 [D loss: 0.472916, acc: 79.69%] [G loss: 0.092049, mse: 0.090543]
2420 [D loss: 0.511076, acc: 75.39%] [G loss: 0.092242, mse: 0.090883]
2430 [D loss: 0.535683, acc: 76.17%] [G loss: 0.084046, mse: 0.082635]
2440 [D loss: 0.507388, acc: 77.34%] [G loss: 0.083620, mse: 0.082116]
2450 [D loss: 0.438597, acc: 79.69%] [G loss: 0.082667, mse: 0.080906]
2460 [D loss: 0.458207, acc: 82.42%] [G loss: 0.082713, mse: 0.081152]
2470 [

3490 [D loss: 0.520496, acc: 77.73%] [G loss: 0.078522, mse: 0.077175]
3500 [D loss: 0.472371, acc: 78.52%] [G loss: 0.081553, mse: 0.080117]
3510 [D loss: 0.424223, acc: 79.69%] [G loss: 0.076715, mse: 0.075104]
3520 [D loss: 0.500713, acc: 76.17%] [G loss: 0.090787, mse: 0.089459]
3530 [D loss: 0.465267, acc: 78.91%] [G loss: 0.077123, mse: 0.075528]
3540 [D loss: 0.490207, acc: 77.34%] [G loss: 0.080588, mse: 0.079202]
3550 [D loss: 0.485929, acc: 78.12%] [G loss: 0.077616, mse: 0.076295]
3560 [D loss: 0.495565, acc: 77.34%] [G loss: 0.079649, mse: 0.078177]
3570 [D loss: 0.501225, acc: 76.95%] [G loss: 0.082814, mse: 0.081527]
3580 [D loss: 0.467843, acc: 78.52%] [G loss: 0.082501, mse: 0.081169]
3590 [D loss: 0.483830, acc: 76.56%] [G loss: 0.082574, mse: 0.081120]
3600 [D loss: 0.518891, acc: 75.00%] [G loss: 0.083353, mse: 0.082085]
3610 [D loss: 0.518961, acc: 73.83%] [G loss: 0.079603, mse: 0.078159]
3620 [D loss: 0.539621, acc: 73.83%] [G loss: 0.084420, mse: 0.083170]
3630 [

4650 [D loss: 0.456974, acc: 77.73%] [G loss: 0.076247, mse: 0.074730]
4660 [D loss: 0.447029, acc: 78.12%] [G loss: 0.078506, mse: 0.077068]
4670 [D loss: 0.473796, acc: 79.30%] [G loss: 0.080637, mse: 0.079220]
4680 [D loss: 0.537035, acc: 74.22%] [G loss: 0.077635, mse: 0.076269]
4690 [D loss: 0.528634, acc: 76.17%] [G loss: 0.077473, mse: 0.076141]
4700 [D loss: 0.401224, acc: 83.98%] [G loss: 0.078919, mse: 0.077340]
4710 [D loss: 0.462022, acc: 77.34%] [G loss: 0.072833, mse: 0.071464]
4720 [D loss: 0.477221, acc: 80.08%] [G loss: 0.075664, mse: 0.074093]
4730 [D loss: 0.485937, acc: 75.78%] [G loss: 0.080506, mse: 0.079072]
4740 [D loss: 0.511461, acc: 76.56%] [G loss: 0.076304, mse: 0.074912]
4750 [D loss: 0.454260, acc: 77.73%] [G loss: 0.073992, mse: 0.072575]
4760 [D loss: 0.504159, acc: 75.39%] [G loss: 0.076472, mse: 0.075035]
4770 [D loss: 0.502164, acc: 76.17%] [G loss: 0.076245, mse: 0.074880]
4780 [D loss: 0.475251, acc: 77.34%] [G loss: 0.071712, mse: 0.070346]
4790 [