In [3]:
# Let's make this notebook compatible for Python 2 and 3
from __future__ import division, print_function

# Import libraries
import pandas as pd
import numpy as np
import os
import math
import itertools
import progressbar

# for visualization
import matplotlib.pyplot as plt

# to import module from parent directory
import sys
sys.path.append('..')

# Dataset API from sklearn
from sklearn import datasets

In [4]:
from deep_learning.optimizers import Adam
from deep_learning.loss_functions import CrossEntropy, SquareLoss
from deep_learning.layers import Dense, Dropout, Flatten, Activation, Reshape, BatchNormalization 
from deep_learning.neural_network import NeuralNetwork

In [None]:
class Autoencoder():
    """An Autoencoder with deep fully-connected neural nets.

    Training Data: MNIST Handwritten Digits (28x28 images)
    """
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.img_dim = self.img_rows * self.img_cols
        self.latent_dim = 128 # The dimension of the data embedding

        optimizer = Adam(learning_rate=0.0002, b1=0.5)
        loss_function = SquareLoss

        self.encoder = self.build_encoder(optimizer, loss_function)
        self.decoder = self.build_decoder(optimizer, loss_function)

        self.autoencoder = NeuralNetwork(optimizer=optimizer, loss=loss_function)
        self.autoencoder.layers.extend(self.encoder.layers)
        self.autoencoder.layers.extend(self.decoder.layers)

        print ()
        self.autoencoder.summary(name="Variational Autoencoder")

    def build_encoder(self, optimizer, loss_function):

        encoder = NeuralNetwork(optimizer=optimizer, loss=loss_function)
        encoder.add(Dense(512, input_shape=(self.img_dim,)))
        encoder.add(Activation('leaky_relu'))
        encoder.add(BatchNormalization(momentum=0.8))
        encoder.add(Dense(256))
        encoder.add(Activation('leaky_relu'))
        encoder.add(BatchNormalization(momentum=0.8))
        encoder.add(Dense(self.latent_dim))

        return encoder

    def build_decoder(self, optimizer, loss_function):

        decoder = NeuralNetwork(optimizer=optimizer, loss=loss_function)
        decoder.add(Dense(256, input_shape=(self.latent_dim,)))
        decoder.add(Activation('leaky_relu'))
        decoder.add(BatchNormalization(momentum=0.8))
        decoder.add(Dense(512))
        decoder.add(Activation('leaky_relu'))
        decoder.add(BatchNormalization(momentum=0.8))
        decoder.add(Dense(self.img_dim))
        decoder.add(Activation('tanh'))

        return decoder

    def train(self, n_epochs, batch_size=128, save_interval=50):

        mnist = datasets.fetch_mldata('MNIST original')

        X = mnist.data
        y = mnist.target

        # Rescale [-1, 1]
        X = (X.astype(np.float32) - 127.5) / 127.5

        for epoch in range(n_epochs):

            # Select a random half batch of images
            idx = np.random.randint(0, X.shape[0], batch_size)
            imgs = X[idx]

            # Train the Autoencoder
            loss, _ = self.autoencoder.train_on_batch(imgs, imgs)

            # Display the progress
            print ("%d [D loss: %f]" % (epoch, loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch, X)

    def save_imgs(self, epoch, X):
        r, c = 5, 5 # Grid size
        # Select a random half batch of images
        idx = np.random.randint(0, X.shape[0], r*c)
        imgs = X[idx]
        # Generate images and reshape to image shape
        gen_imgs = self.autoencoder.predict(imgs).reshape((-1, self.img_rows, self.img_cols))

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        plt.suptitle("Autoencoder")
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("ae_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    ae = Autoencoder()
    ae.train(n_epochs=200000, batch_size=64, save_interval=400)



+-------------------------+
| Variational Autoencoder |
+-------------------------+
Input Shape: (784,)
+------------------------+------------+--------------+
| Layer Type             | Parameters | Output Shape |
+------------------------+------------+--------------+
| Dense                  | 401920     | (512,)       |
| Activation (LeakyReLU) | 0          | (512,)       |
| BatchNormalization     | 1024       | (512,)       |
| Dense                  | 131328     | (256,)       |
| Activation (LeakyReLU) | 0          | (256,)       |
| BatchNormalization     | 512        | (256,)       |
| Dense                  | 32896      | (128,)       |
| Dense                  | 33024      | (256,)       |
| Activation (LeakyReLU) | 0          | (256,)       |
| BatchNormalization     | 512        | (256,)       |
| Dense                  | 131584     | (512,)       |
| Activation (LeakyReLU) | 0          | (512,)       |
| BatchNormalization     | 1024       | (512,)       |
| Dense        

310 [D loss: 0.330981]
311 [D loss: 0.314899]
312 [D loss: 0.330760]
313 [D loss: 0.324923]
314 [D loss: 0.332062]
315 [D loss: 0.314893]
316 [D loss: 0.321148]
317 [D loss: 0.318560]
318 [D loss: 0.322350]
319 [D loss: 0.314662]
320 [D loss: 0.308637]
321 [D loss: 0.323957]
322 [D loss: 0.317947]
323 [D loss: 0.317813]
324 [D loss: 0.309265]
325 [D loss: 0.320596]
326 [D loss: 0.319668]
327 [D loss: 0.310394]
328 [D loss: 0.317836]
329 [D loss: 0.314485]
330 [D loss: 0.312452]
331 [D loss: 0.313244]
332 [D loss: 0.326952]
333 [D loss: 0.308411]
334 [D loss: 0.312797]
335 [D loss: 0.310896]
336 [D loss: 0.316527]
337 [D loss: 0.315982]
338 [D loss: 0.318197]
339 [D loss: 0.317702]
340 [D loss: 0.304638]
341 [D loss: 0.313145]
342 [D loss: 0.319059]
343 [D loss: 0.315048]
344 [D loss: 0.307954]
345 [D loss: 0.305742]
346 [D loss: 0.319870]
347 [D loss: 0.322255]
348 [D loss: 0.310462]
349 [D loss: 0.320925]
350 [D loss: 0.307745]
351 [D loss: 0.307101]
352 [D loss: 0.325724]
353 [D loss

668 [D loss: 0.266248]
669 [D loss: 0.270703]
670 [D loss: 0.265600]
671 [D loss: 0.272191]
672 [D loss: 0.318701]
673 [D loss: 0.262886]
674 [D loss: 0.261353]
675 [D loss: 0.265153]
676 [D loss: 0.264942]
677 [D loss: 0.266125]
678 [D loss: 0.266397]
679 [D loss: 0.262858]
680 [D loss: 0.259541]
681 [D loss: 0.270350]
682 [D loss: 0.267079]
683 [D loss: 0.267645]
684 [D loss: 0.284718]
685 [D loss: 0.252535]
686 [D loss: 0.266327]
687 [D loss: 0.266572]
688 [D loss: 0.257200]
689 [D loss: 0.275976]
690 [D loss: 0.266376]
691 [D loss: 0.269848]
692 [D loss: 0.259641]
693 [D loss: 0.265617]
694 [D loss: 0.263667]
695 [D loss: 0.287400]
696 [D loss: 0.266186]
697 [D loss: 0.272100]
698 [D loss: 0.265068]
699 [D loss: 0.268523]
700 [D loss: 0.258919]
701 [D loss: 0.253631]
702 [D loss: 0.267419]
703 [D loss: 0.270785]
704 [D loss: 0.262081]
705 [D loss: 0.262285]
706 [D loss: 0.278156]
707 [D loss: 0.250528]
708 [D loss: 0.259005]
709 [D loss: 0.265866]
710 [D loss: 0.263519]
711 [D loss

1026 [D loss: 0.225626]
1027 [D loss: 0.221823]
1028 [D loss: 0.218433]
1029 [D loss: 0.220307]
1030 [D loss: 0.211908]
1031 [D loss: 0.209995]
1032 [D loss: 0.217651]
1033 [D loss: 0.210062]
1034 [D loss: 0.215570]
1035 [D loss: 0.213503]
1036 [D loss: 0.224599]
1037 [D loss: 0.215067]
1038 [D loss: 0.214599]
1039 [D loss: 0.226404]
1040 [D loss: 0.215161]
1041 [D loss: 0.212181]
1042 [D loss: 0.212074]
1043 [D loss: 0.212778]
1044 [D loss: 0.215298]
1045 [D loss: 0.214949]
1046 [D loss: 0.219698]
1047 [D loss: 0.221258]
1048 [D loss: 0.239405]
1049 [D loss: 0.211642]
1050 [D loss: 0.211435]
1051 [D loss: 0.201970]
1052 [D loss: 0.209836]
1053 [D loss: 0.219132]
1054 [D loss: 0.208751]
1055 [D loss: 0.201959]
1056 [D loss: 0.202528]
1057 [D loss: 0.214745]
1058 [D loss: 0.225659]
1059 [D loss: 0.204885]
1060 [D loss: 0.212761]
1061 [D loss: 0.210271]
1062 [D loss: 0.213623]
1063 [D loss: 0.201632]
1064 [D loss: 0.206849]
1065 [D loss: 0.222771]
1066 [D loss: 0.217429]
1067 [D loss: 0.

1370 [D loss: 0.173201]
1371 [D loss: 0.183789]
1372 [D loss: 0.191526]
1373 [D loss: 0.170926]
1374 [D loss: 0.170532]
1375 [D loss: 0.176957]
1376 [D loss: 0.168597]
1377 [D loss: 0.179179]
1378 [D loss: 0.171248]
1379 [D loss: 0.178018]
1380 [D loss: 0.169137]
1381 [D loss: 0.174729]
1382 [D loss: 0.187421]
1383 [D loss: 0.172448]
1384 [D loss: 0.176062]
1385 [D loss: 0.177595]
1386 [D loss: 0.179332]
1387 [D loss: 0.184663]
1388 [D loss: 0.184386]
1389 [D loss: 0.184330]
1390 [D loss: 0.177455]
1391 [D loss: 0.170531]
1392 [D loss: 0.167825]
1393 [D loss: 0.182198]
1394 [D loss: 0.165150]
1395 [D loss: 0.166549]
1396 [D loss: 0.168568]
1397 [D loss: 0.182131]
1398 [D loss: 0.168667]
1399 [D loss: 0.167456]
1400 [D loss: 0.174079]
1401 [D loss: 0.179087]
1402 [D loss: 0.176983]
1403 [D loss: 0.170827]
1404 [D loss: 0.167206]
1405 [D loss: 0.182979]
1406 [D loss: 0.165452]
1407 [D loss: 0.171900]
1408 [D loss: 0.177457]
1409 [D loss: 0.171263]
1410 [D loss: 0.200337]
1411 [D loss: 0.

1712 [D loss: 0.138219]
1713 [D loss: 0.142521]
1714 [D loss: 0.140823]
1715 [D loss: 0.151038]
1716 [D loss: 0.137495]
1717 [D loss: 0.136055]
1718 [D loss: 0.139768]
1719 [D loss: 0.142040]
1720 [D loss: 0.162698]
1721 [D loss: 0.140292]
1722 [D loss: 0.165212]
1723 [D loss: 0.138037]
1724 [D loss: 0.135628]
1725 [D loss: 0.142089]
1726 [D loss: 0.135447]
1727 [D loss: 0.148231]
1728 [D loss: 0.147539]
1729 [D loss: 0.143690]
1730 [D loss: 0.138327]
1731 [D loss: 0.134403]
1732 [D loss: 0.137175]
1733 [D loss: 0.143948]
1734 [D loss: 0.130857]
1735 [D loss: 0.139430]
1736 [D loss: 0.159192]
1737 [D loss: 0.146272]
1738 [D loss: 0.137325]
1739 [D loss: 0.142312]
1740 [D loss: 0.153664]
1741 [D loss: 0.150267]
1742 [D loss: 0.182753]
1743 [D loss: 0.144195]
1744 [D loss: 0.132311]
1745 [D loss: 0.135225]
1746 [D loss: 0.146960]
1747 [D loss: 0.144899]
1748 [D loss: 0.135320]
1749 [D loss: 0.147367]
1750 [D loss: 0.141773]
1751 [D loss: 0.154700]
1752 [D loss: 0.145195]
1753 [D loss: 0.

2055 [D loss: 0.113704]
2056 [D loss: 0.131793]
2057 [D loss: 0.121421]
2058 [D loss: 0.117584]
2059 [D loss: 0.130651]
2060 [D loss: 0.130534]
2061 [D loss: 0.126701]
2062 [D loss: 0.118520]
2063 [D loss: 0.119789]
2064 [D loss: 0.120922]
2065 [D loss: 0.111520]
2066 [D loss: 0.113853]
2067 [D loss: 0.129886]
2068 [D loss: 0.111952]
2069 [D loss: 0.122714]
2070 [D loss: 0.112466]
2071 [D loss: 0.112289]
2072 [D loss: 0.126785]
2073 [D loss: 0.117215]
2074 [D loss: 0.126566]
2075 [D loss: 0.115612]
2076 [D loss: 0.119097]
2077 [D loss: 0.114297]
2078 [D loss: 0.116598]
2079 [D loss: 0.116808]
2080 [D loss: 0.115695]
2081 [D loss: 0.132155]
2082 [D loss: 0.125988]
2083 [D loss: 0.124806]
2084 [D loss: 0.116805]
2085 [D loss: 0.122999]
2086 [D loss: 0.113583]
2087 [D loss: 0.130351]
2088 [D loss: 0.119271]
2089 [D loss: 0.112186]
2090 [D loss: 0.121044]
2091 [D loss: 0.116867]
2092 [D loss: 0.135394]
2093 [D loss: 0.113575]
2094 [D loss: 0.123381]
2095 [D loss: 0.121722]
2096 [D loss: 0.

2397 [D loss: 0.130790]
2398 [D loss: 0.101592]
2399 [D loss: 0.107966]
2400 [D loss: 0.098641]
2401 [D loss: 0.099400]
2402 [D loss: 0.097981]
2403 [D loss: 0.107094]
2404 [D loss: 0.112917]
2405 [D loss: 0.091395]
2406 [D loss: 0.099571]
2407 [D loss: 0.099022]
2408 [D loss: 0.114232]
2409 [D loss: 0.106024]
2410 [D loss: 0.101333]
2411 [D loss: 0.097362]
2412 [D loss: 0.099492]
2413 [D loss: 0.112058]
2414 [D loss: 0.095442]
2415 [D loss: 0.099899]
2416 [D loss: 0.099182]
2417 [D loss: 0.093898]
2418 [D loss: 0.100472]
2419 [D loss: 0.099622]
2420 [D loss: 0.092038]
2421 [D loss: 0.100227]
2422 [D loss: 0.095083]
2423 [D loss: 0.109372]
2424 [D loss: 0.089432]
2425 [D loss: 0.097257]
2426 [D loss: 0.100823]
2427 [D loss: 0.106844]
2428 [D loss: 0.097385]
2429 [D loss: 0.099787]
2430 [D loss: 0.102010]
2431 [D loss: 0.107122]
2432 [D loss: 0.108288]
2433 [D loss: 0.094870]
2434 [D loss: 0.100347]
2435 [D loss: 0.156498]
2436 [D loss: 0.097250]
2437 [D loss: 0.091604]
2438 [D loss: 0.

2739 [D loss: 0.087376]
2740 [D loss: 0.087921]
2741 [D loss: 0.085179]
2742 [D loss: 0.093956]
2743 [D loss: 0.095700]
2744 [D loss: 0.084828]
2745 [D loss: 0.092068]
2746 [D loss: 0.108109]
2747 [D loss: 0.091764]
2748 [D loss: 0.089959]
2749 [D loss: 0.094359]
2750 [D loss: 0.089328]
2751 [D loss: 0.094382]
2752 [D loss: 0.093653]
2753 [D loss: 0.078456]
2754 [D loss: 0.086518]
2755 [D loss: 0.085249]
2756 [D loss: 0.083401]
2757 [D loss: 0.100475]
2758 [D loss: 0.096003]
2759 [D loss: 0.082706]
2760 [D loss: 0.129122]
2761 [D loss: 0.083108]
2762 [D loss: 0.079806]
2763 [D loss: 0.085056]
2764 [D loss: 0.096065]
2765 [D loss: 0.092121]
2766 [D loss: 0.095020]
2767 [D loss: 0.082395]
2768 [D loss: 0.082772]
2769 [D loss: 0.089564]
2770 [D loss: 0.089297]
2771 [D loss: 0.085790]
2772 [D loss: 0.102006]
2773 [D loss: 0.085879]
2774 [D loss: 0.090611]
2775 [D loss: 0.138562]
2776 [D loss: 0.098809]
2777 [D loss: 0.085164]
2778 [D loss: 0.106259]
2779 [D loss: 0.098825]
2780 [D loss: 0.

3082 [D loss: 0.075813]
3083 [D loss: 0.076121]
3084 [D loss: 0.081430]
3085 [D loss: 0.086103]
3086 [D loss: 0.083156]
3087 [D loss: 0.073696]
3088 [D loss: 0.077320]
3089 [D loss: 0.071791]
3090 [D loss: 0.087245]
3091 [D loss: 0.094174]
3092 [D loss: 0.085636]
3093 [D loss: 0.078124]
3094 [D loss: 0.071504]
3095 [D loss: 0.087078]
3096 [D loss: 0.089222]
3097 [D loss: 0.074176]
3098 [D loss: 0.075070]
3099 [D loss: 0.077217]
3100 [D loss: 0.072647]
3101 [D loss: 0.083948]
3102 [D loss: 0.075948]
3103 [D loss: 0.074642]
3104 [D loss: 0.087164]
3105 [D loss: 0.071938]
3106 [D loss: 0.073475]
3107 [D loss: 0.077331]
3108 [D loss: 0.078782]
3109 [D loss: 0.074335]
3110 [D loss: 0.083174]
3111 [D loss: 0.089078]
3112 [D loss: 0.080669]
3113 [D loss: 0.086186]
3114 [D loss: 0.073823]
3115 [D loss: 0.074039]
3116 [D loss: 0.071380]
3117 [D loss: 0.074647]
3118 [D loss: 0.077565]
3119 [D loss: 0.076694]
3120 [D loss: 0.076804]
3121 [D loss: 0.070010]
3122 [D loss: 0.076723]
3123 [D loss: 0.

3425 [D loss: 0.067857]
3426 [D loss: 0.068776]
3427 [D loss: 0.089280]
3428 [D loss: 0.073010]
3429 [D loss: 0.070535]
3430 [D loss: 0.072737]
3431 [D loss: 0.076834]
3432 [D loss: 0.082865]
3433 [D loss: 0.076740]
3434 [D loss: 0.066410]
3435 [D loss: 0.074508]
3436 [D loss: 0.075676]
3437 [D loss: 0.076536]
3438 [D loss: 0.064797]
3439 [D loss: 0.073579]
3440 [D loss: 0.077106]
3441 [D loss: 0.067659]
3442 [D loss: 0.088067]
3443 [D loss: 0.068826]
3444 [D loss: 0.075469]
3445 [D loss: 0.071770]
3446 [D loss: 0.065723]
3447 [D loss: 0.067162]
3448 [D loss: 0.070173]
3449 [D loss: 0.069123]
3450 [D loss: 0.066201]
3451 [D loss: 0.069046]
3452 [D loss: 0.068539]
3453 [D loss: 0.067551]
3454 [D loss: 0.072046]
3455 [D loss: 0.075789]
3456 [D loss: 0.081814]
3457 [D loss: 0.068574]
3458 [D loss: 0.074137]
3459 [D loss: 0.063670]
3460 [D loss: 0.066202]
3461 [D loss: 0.071024]
3462 [D loss: 0.070990]
3463 [D loss: 0.070266]
3464 [D loss: 0.066096]
3465 [D loss: 0.073608]
3466 [D loss: 0.

3769 [D loss: 0.061372]
3770 [D loss: 0.069626]
3771 [D loss: 0.070463]
3772 [D loss: 0.075045]
3773 [D loss: 0.068924]
3774 [D loss: 0.059648]
3775 [D loss: 0.059809]
3776 [D loss: 0.067094]
3777 [D loss: 0.069309]
3778 [D loss: 0.071714]
3779 [D loss: 0.061105]
3780 [D loss: 0.069299]
3781 [D loss: 0.069466]
3782 [D loss: 0.063277]
3783 [D loss: 0.069674]
3784 [D loss: 0.066097]
3785 [D loss: 0.061351]
3786 [D loss: 0.055794]
3787 [D loss: 0.067854]
3788 [D loss: 0.080144]
3789 [D loss: 0.065080]
3790 [D loss: 0.064857]
3791 [D loss: 0.075425]
3792 [D loss: 0.063513]
3793 [D loss: 0.068430]
3794 [D loss: 0.063267]
3795 [D loss: 0.080427]
3796 [D loss: 0.058870]
3797 [D loss: 0.070705]
3798 [D loss: 0.060380]
3799 [D loss: 0.061332]
3800 [D loss: 0.061034]
3801 [D loss: 0.073192]
3802 [D loss: 0.058497]
3803 [D loss: 0.067975]
3804 [D loss: 0.070438]
3805 [D loss: 0.065465]
3806 [D loss: 0.063098]
3807 [D loss: 0.060914]
3808 [D loss: 0.074345]
3809 [D loss: 0.064475]
3810 [D loss: 0.

4112 [D loss: 0.059605]
4113 [D loss: 0.056897]
4114 [D loss: 0.068356]
4115 [D loss: 0.062765]
4116 [D loss: 0.064354]
4117 [D loss: 0.058903]
4118 [D loss: 0.066395]
4119 [D loss: 0.059398]
4120 [D loss: 0.066744]
4121 [D loss: 0.075409]
4122 [D loss: 0.058267]
4123 [D loss: 0.057947]
4124 [D loss: 0.056233]
4125 [D loss: 0.170255]
4126 [D loss: 0.061189]
4127 [D loss: 0.059532]
4128 [D loss: 0.064129]
4129 [D loss: 0.060612]
4130 [D loss: 0.073809]
4131 [D loss: 0.062019]
4132 [D loss: 0.053660]
4133 [D loss: 0.055453]
4134 [D loss: 0.058052]
4135 [D loss: 0.058014]
4136 [D loss: 0.052880]
4137 [D loss: 0.083998]
4138 [D loss: 0.054033]
4139 [D loss: 0.052508]
4140 [D loss: 0.069026]
4141 [D loss: 0.066959]
4142 [D loss: 0.072583]
4143 [D loss: 0.057379]
4144 [D loss: 0.052328]
4145 [D loss: 0.055534]
4146 [D loss: 0.063395]
4147 [D loss: 0.057041]
4148 [D loss: 0.059654]
4149 [D loss: 0.055940]
4150 [D loss: 0.054483]
4151 [D loss: 0.053250]
4152 [D loss: 0.060648]
4153 [D loss: 0.

4455 [D loss: 0.050221]
4456 [D loss: 0.057728]
4457 [D loss: 0.051914]
4458 [D loss: 0.070288]
4459 [D loss: 0.054628]
4460 [D loss: 0.049156]
4461 [D loss: 0.050154]
4462 [D loss: 0.054322]
4463 [D loss: 0.056412]
4464 [D loss: 0.054834]
4465 [D loss: 0.064015]
4466 [D loss: 0.055726]
4467 [D loss: 0.049288]
4468 [D loss: 0.051707]
4469 [D loss: 0.050139]
4470 [D loss: 0.053871]
4471 [D loss: 0.049876]
4472 [D loss: 0.050911]
4473 [D loss: 0.049165]
4474 [D loss: 0.056839]
4475 [D loss: 0.065285]
4476 [D loss: 0.055012]
4477 [D loss: 0.054598]
4478 [D loss: 0.051834]
4479 [D loss: 0.052420]
4480 [D loss: 0.055180]
4481 [D loss: 0.054118]
4482 [D loss: 0.055770]
4483 [D loss: 0.059617]
4484 [D loss: 0.056394]
4485 [D loss: 0.058260]
4486 [D loss: 0.055937]
4487 [D loss: 0.060226]
4488 [D loss: 0.064578]
4489 [D loss: 0.063196]
4490 [D loss: 0.052043]
4491 [D loss: 0.051104]
4492 [D loss: 0.051533]
4493 [D loss: 0.049383]
4494 [D loss: 0.067831]
4495 [D loss: 0.058461]
4496 [D loss: 0.