In [2]:
import numpy as np
import matplotlib.pyplot as plt

import keras
from keras.models import Model, Sequential
from keras.layers import Dense, Flatten, Activation, Input, Dropout, Reshape, merge, GaussianNoise, LeakyReLU
from keras.layers import Conv2D, MaxPool2D, ZeroPadding2D, UpSampling2D
from keras.layers.embeddings import Embedding
from keras import optimizers, losses
from keras.utils.np_utils import to_categorical

import keras.backend as K

from keras.datasets import mnist

In [None]:
x = K.tf.

In [33]:
class AdversialAutoEncoder():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 10         ## The dim of encoded representaion
        
        optimizer = optimizers.Adam(lr = 0.0002, beta_1 = 0.5,  beta_2 = 0.999)
        
        # Build and complile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss = losses.binary_crossentropy, optimizer = optimizer, metrics = ['accuracy'])
        self.discriminator.summary()
        
        # Build the encoder/decoder model
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()
        
        img = Input(self.img_shape)
        
        # the generator takes the img, encodes it and reconstucts it from the encoded representation
        encoded_repr = self.encoder(img)
        reconstructed_img = self.decoder(encoded_repr)
        
        # For the adversarial autoencoder model, we will train only the generator        
        self.discriminator.trainable = False                 #??
        
        # the discriminator determines the validity of the encoding
        validity = self.discriminator(encoded_repr)          
        
        self.adverserial_autoencoder = Model(inputs = img, outputs = [reconstructed_img, validity])
        self.adverserial_autoencoder.compile(loss = [losses.MSE, losses.binary_crossentropy], loss_weights = [0.999, 0.001], optimizer = optimizer)
        
        self.adverserial_autoencoder.summary()
        
    def build_encoder(self):
        """
        Encoder Model
        """
        img = Input(shape = self.img_shape)
        output = Flatten()(img)
        output = Dense(units = 512)(output)
        output = LeakyReLU(alpha = 0.2)(output)
        output = Dense(units = 512)(output)
        output = LeakyReLU(alpha = 0.2)(output)
        mu = Dense(self.latent_dim)(output)
        log_var = Dense(self.latent_dim)(output)
        
        latent_repr = merge([mu, log_var], mode = lambda p : p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2), output_shape = lambda p : p[0])    
        #???
        
        return Model(inputs = img, outputs = latent_repr)
        
        
    def build_decoder(self):
        """
        Decoder Model
        """
        encoded_repr = Input(shape = (self.latent_dim, ))
        output = Dense(512)(encoded_repr)
        output = LeakyReLU(alpha = 0.2)(output)
        output = Dense(512)(output)
        output = LeakyReLU(alpha = 0.2)(output)
        output = Dense(np.prod(self.img_shape), activation = 'tanh')(output)
        output = Reshape(self.img_shape)(output)
        
        return Model(encoded_repr, output)
    
    def build_discriminator(self):
        """
        Discriminator to check validity of encoded representation
        """
        encoded_repr = Input(shape = (self.latent_dim, ))
        output = Dense(512)(encoded_repr)
        output = LeakyReLU(alpha = 0.2)(output)
        output = Dense(256)(output)
        output = LeakyReLU(alpha = 0.2)(output)
        validity = Dense(1, activation = 'sigmoid')(output)
        
        return Model(encoded_repr, validity)
    
    def train(self, epochs, batch_size = 128, sample_interval = 50):
        ##?? sample_interval?
        
        # Load the dataset
        (X_train, y_train), (X_test, y_test)  = mnist.load_data()
        
        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis = -1)
        
        # Adverserial ground truth
        fake = np.zeros(shape = (batch_size, 1))
        valid = np.ones(shape = (batch_size, 1))
        
        for epoch in range(epochs):
            # ---------------------
            #  Train Discriminator
            # ---------------------
            
            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            
            latent_fake = self.encoder.predict(imgs)
            latent_real = np.random.normal(size = (batch_size, self.latent_dim))         #??
            
            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(latent_real, valid)
            d_loss_fake = self.discriminator.train_on_batch(latent_fake, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            # ---------------------
            #  Train Generator
            # ---------------------
            g_loss = self.adverserial_autoencoder.train_on_batch(imgs, [imgs, valid])
            
            print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
        
    def sample_images(self, epoch):
        r, c = 5, 5

        z = np.random.normal(size = (r * c, self.latent_dim))
        gen_imgs = self.decoder.predict(z)

        gen_imgs = 0.5 * gen_imgs + 0.5         ##?? why this op?

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap = 'gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()

    def save_model(self):
        def save(model, model_name):
            model_path = "saved_model/%s.json" % model_name
            weights_path = "saved_model/%s_weights.hdf5" % model_name
            options = {"file_arch": model_path,
                        "file_weight": weights_path}
            json_string = model.to_json()
            open(options['file_arch'], 'w').write(json_string)
            model.save_weights(options['file_weight'])

        save(self.generator, "aae_generator")
        save(self.discriminator, "aae_discriminator")

In [36]:
aae = AdversialAutoEncoder()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_29 (InputLayer)        (None, 10)                0         
_________________________________________________________________
dense_72 (Dense)             (None, 512)               5632      
_________________________________________________________________
leaky_re_lu_43 (LeakyReLU)   (None, 512)               0         
_________________________________________________________________
dense_73 (Dense)             (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_44 (LeakyReLU)   (None, 256)               0         
_________________________________________________________________
dense_74 (Dense)             (None, 1)                 257       
Total params: 137,217
Trainable params: 137,217
Non-trainable params: 0
_________________________________________________________________


  name=name)


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_32 (InputLayer)           (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
model_27 (Model)                (None, 10)           674836      input_32[0][0]                   
__________________________________________________________________________________________________
model_28 (Model)                (None, 28, 28, 1)    670480      model_27[1][0]                   
__________________________________________________________________________________________________
model_26 (Model)                (None, 1)            137217      model_27[1][0]                   
Total params: 1,482,533
Trainable params: 1,345,316
Non-trainable params: 137,217
___________________________

In [37]:
aae.train(epochs=20000, batch_size=32, sample_interval=200)

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.913092, acc: 45.31%] [G loss: 0.946523, mse: 0.947043]
1 [D loss: 0.782248, acc: 50.00%] [G loss: 0.920929, mse: 0.921292]
2 [D loss: 0.710221, acc: 48.44%] [G loss: 0.893771, mse: 0.894013]
3 [D loss: 0.615654, acc: 78.12%] [G loss: 0.847014, mse: 0.846966]
4 [D loss: 0.459754, acc: 92.19%] [G loss: 0.763195, mse: 0.762303]
5 [D loss: 0.328656, acc: 95.31%] [G loss: 0.646367, mse: 0.643964]
6 [D loss: 0.300163, acc: 87.50%] [G loss: 0.518102, mse: 0.513914]
7 [D loss: 0.270516, acc: 90.62%] [G loss: 0.419596, mse: 0.414062]
8 [D loss: 0.237106, acc: 96.88%] [G loss: 0.351373, mse: 0.345035]
9 [D loss: 0.247129, acc: 93.75%] [G loss: 0.342412, mse: 0.335817]
10 [D loss: 0.231882, acc: 96.88%] [G loss: 0.315184, mse: 0.308712]
11 [D loss: 0.232225, acc: 95.31%] [G loss: 0.297824, mse: 0.291458]
12 [D loss: 0.221944, acc: 96.88%] [G loss: 0.274119, mse: 0.267772]
13 [D loss: 0.202385, acc: 96.88%] [G loss: 0.266603, mse: 0.260596]
14 [D loss: 0.192060, acc: 100.00%] [G loss:

KeyboardInterrupt: 