In [1]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
class DCGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(100,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
        model.add(Reshape((7, 7, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1, activation='sigmoid'))

        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train.astype(np.float32) / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            # Sample noise and generate a half batch of new images
            noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator (real classified as ones and generated as zeros)
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Sample generator input
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (wants discriminator to mistake images as real)
            g_loss = self.combined.train_on_batch(noise, np.ones((batch_size, 1)))

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        #fig.suptitle("DCGAN: Generated digits", fontsize=12)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()

In [None]:
if __name__ == '__main__':
    dcgan = DCGAN()
    dcgan.train(epochs=4000, batch_size=32, save_interval=50)

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_8 (Conv2D)            (None, 14, 14, 32)        320       
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 14, 14, 32)        0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 7, 64)          18496     
_________________________________________________________________
zero_padding2d_2 (ZeroPaddin (None, 8, 8, 64)          0         
_________________________________________________________________
batch_normalization_6 (Batch (None, 8, 8, 64)          256       
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 8, 8, 64)         

  'Discrepancy between trainable weights and collected trainable'
  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 1.148359, acc.: 21.88%] [G loss: 0.990779]


  'Discrepancy between trainable weights and collected trainable'


1 [D loss: 0.844449, acc.: 46.88%] [G loss: 0.914905]
2 [D loss: 0.660298, acc.: 56.25%] [G loss: 0.823153]
3 [D loss: 0.697182, acc.: 62.50%] [G loss: 1.441353]
4 [D loss: 0.641947, acc.: 59.38%] [G loss: 1.397679]
5 [D loss: 0.565258, acc.: 71.88%] [G loss: 1.211776]
6 [D loss: 0.513829, acc.: 71.88%] [G loss: 0.933647]
7 [D loss: 0.353127, acc.: 87.50%] [G loss: 1.159211]
8 [D loss: 0.540959, acc.: 84.38%] [G loss: 1.163827]
9 [D loss: 0.627752, acc.: 65.62%] [G loss: 0.989400]
10 [D loss: 0.575173, acc.: 68.75%] [G loss: 0.916221]
11 [D loss: 0.407147, acc.: 84.38%] [G loss: 0.952731]
12 [D loss: 0.690662, acc.: 59.38%] [G loss: 0.994638]
13 [D loss: 0.710530, acc.: 65.62%] [G loss: 0.804999]
14 [D loss: 1.020631, acc.: 53.12%] [G loss: 0.687528]
15 [D loss: 0.581943, acc.: 75.00%] [G loss: 0.755136]
16 [D loss: 1.022227, acc.: 28.12%] [G loss: 0.925201]
17 [D loss: 0.957538, acc.: 40.62%] [G loss: 1.116328]
18 [D loss: 0.606571, acc.: 62.50%] [G loss: 1.440040]
19 [D loss: 0.65917

151 [D loss: 0.752723, acc.: 56.25%] [G loss: 1.076088]
152 [D loss: 0.833663, acc.: 56.25%] [G loss: 1.379928]
153 [D loss: 0.763785, acc.: 59.38%] [G loss: 1.216852]
154 [D loss: 0.769362, acc.: 59.38%] [G loss: 0.949765]
155 [D loss: 0.969296, acc.: 40.62%] [G loss: 0.724516]
156 [D loss: 0.929952, acc.: 43.75%] [G loss: 0.915444]
157 [D loss: 0.791298, acc.: 53.12%] [G loss: 1.240051]
158 [D loss: 0.731949, acc.: 53.12%] [G loss: 1.220877]
159 [D loss: 0.978541, acc.: 34.38%] [G loss: 0.942580]
160 [D loss: 0.883829, acc.: 40.62%] [G loss: 0.945182]
161 [D loss: 0.927619, acc.: 50.00%] [G loss: 0.718518]
162 [D loss: 0.997613, acc.: 50.00%] [G loss: 1.040340]
163 [D loss: 0.898997, acc.: 50.00%] [G loss: 1.068262]
164 [D loss: 0.880921, acc.: 43.75%] [G loss: 1.013907]
165 [D loss: 0.828575, acc.: 56.25%] [G loss: 0.847318]
166 [D loss: 0.975275, acc.: 43.75%] [G loss: 0.864780]
167 [D loss: 0.826729, acc.: 46.88%] [G loss: 0.866853]
168 [D loss: 0.862551, acc.: 50.00%] [G loss: 0.

301 [D loss: 0.920223, acc.: 37.50%] [G loss: 1.046093]
302 [D loss: 0.909049, acc.: 34.38%] [G loss: 1.202247]
303 [D loss: 1.012792, acc.: 34.38%] [G loss: 0.932703]
304 [D loss: 0.791550, acc.: 46.88%] [G loss: 1.066933]
305 [D loss: 0.914955, acc.: 50.00%] [G loss: 0.908118]
306 [D loss: 0.786333, acc.: 53.12%] [G loss: 1.022485]
307 [D loss: 0.780288, acc.: 50.00%] [G loss: 1.056467]
308 [D loss: 0.838594, acc.: 50.00%] [G loss: 0.803846]
309 [D loss: 0.798588, acc.: 50.00%] [G loss: 0.782498]
310 [D loss: 0.880372, acc.: 46.88%] [G loss: 0.950231]
311 [D loss: 0.763916, acc.: 56.25%] [G loss: 0.696178]
312 [D loss: 0.728215, acc.: 56.25%] [G loss: 1.044043]
313 [D loss: 0.742991, acc.: 46.88%] [G loss: 1.124492]
314 [D loss: 0.784671, acc.: 53.12%] [G loss: 0.992454]
315 [D loss: 0.819511, acc.: 50.00%] [G loss: 0.905306]
316 [D loss: 0.828331, acc.: 50.00%] [G loss: 0.704213]
317 [D loss: 0.861271, acc.: 43.75%] [G loss: 0.998016]
318 [D loss: 1.033657, acc.: 25.00%] [G loss: 1.

451 [D loss: 0.564687, acc.: 68.75%] [G loss: 0.879614]
452 [D loss: 0.698155, acc.: 56.25%] [G loss: 0.833788]
453 [D loss: 0.926760, acc.: 28.12%] [G loss: 0.812065]
454 [D loss: 0.738143, acc.: 62.50%] [G loss: 0.838694]
455 [D loss: 0.865236, acc.: 40.62%] [G loss: 0.906785]
456 [D loss: 0.911875, acc.: 37.50%] [G loss: 1.026815]
457 [D loss: 0.771083, acc.: 53.12%] [G loss: 1.147197]
458 [D loss: 0.872380, acc.: 31.25%] [G loss: 1.189011]
459 [D loss: 0.602719, acc.: 62.50%] [G loss: 1.280451]
460 [D loss: 0.819551, acc.: 43.75%] [G loss: 1.043511]
461 [D loss: 0.928092, acc.: 37.50%] [G loss: 0.829206]
462 [D loss: 0.594969, acc.: 65.62%] [G loss: 0.761659]
463 [D loss: 0.815826, acc.: 37.50%] [G loss: 0.915486]
464 [D loss: 0.812736, acc.: 46.88%] [G loss: 1.037505]
465 [D loss: 0.811303, acc.: 43.75%] [G loss: 1.040223]
466 [D loss: 0.827406, acc.: 43.75%] [G loss: 0.973311]
467 [D loss: 0.809948, acc.: 50.00%] [G loss: 1.063476]
468 [D loss: 0.854889, acc.: 56.25%] [G loss: 0.

601 [D loss: 0.844074, acc.: 40.62%] [G loss: 1.040664]
602 [D loss: 0.678228, acc.: 59.38%] [G loss: 0.993440]
603 [D loss: 0.843323, acc.: 46.88%] [G loss: 0.883795]
604 [D loss: 0.835433, acc.: 46.88%] [G loss: 0.983314]
605 [D loss: 0.810059, acc.: 43.75%] [G loss: 0.883666]
606 [D loss: 0.684153, acc.: 56.25%] [G loss: 0.798817]
607 [D loss: 0.712030, acc.: 50.00%] [G loss: 0.941178]
608 [D loss: 0.709936, acc.: 56.25%] [G loss: 0.998260]
609 [D loss: 0.817102, acc.: 46.88%] [G loss: 0.982737]
610 [D loss: 0.890321, acc.: 43.75%] [G loss: 1.041978]
611 [D loss: 0.738179, acc.: 46.88%] [G loss: 0.989702]
612 [D loss: 0.821666, acc.: 53.12%] [G loss: 0.902773]
613 [D loss: 0.808482, acc.: 43.75%] [G loss: 0.923298]
614 [D loss: 0.814465, acc.: 50.00%] [G loss: 1.038082]
615 [D loss: 0.757267, acc.: 53.12%] [G loss: 0.879972]
616 [D loss: 0.774989, acc.: 37.50%] [G loss: 0.918246]
617 [D loss: 0.741929, acc.: 46.88%] [G loss: 0.892204]
618 [D loss: 0.766255, acc.: 50.00%] [G loss: 0.

751 [D loss: 0.669143, acc.: 59.38%] [G loss: 1.077587]
752 [D loss: 0.748564, acc.: 50.00%] [G loss: 1.021429]
753 [D loss: 0.794448, acc.: 50.00%] [G loss: 1.001646]
754 [D loss: 0.829202, acc.: 40.62%] [G loss: 0.952267]
755 [D loss: 0.734504, acc.: 59.38%] [G loss: 1.014950]
756 [D loss: 0.750173, acc.: 46.88%] [G loss: 0.896929]
757 [D loss: 0.757075, acc.: 56.25%] [G loss: 0.820738]
758 [D loss: 0.766146, acc.: 40.62%] [G loss: 0.789742]
759 [D loss: 0.736167, acc.: 62.50%] [G loss: 0.893753]
760 [D loss: 0.779810, acc.: 43.75%] [G loss: 0.880179]
761 [D loss: 1.009141, acc.: 37.50%] [G loss: 0.858683]
762 [D loss: 0.789930, acc.: 50.00%] [G loss: 1.020809]
763 [D loss: 0.762673, acc.: 53.12%] [G loss: 0.990550]
764 [D loss: 0.833675, acc.: 37.50%] [G loss: 0.973320]
765 [D loss: 0.810632, acc.: 40.62%] [G loss: 0.783085]
766 [D loss: 0.726945, acc.: 53.12%] [G loss: 0.856160]
767 [D loss: 0.700914, acc.: 53.12%] [G loss: 0.901650]
768 [D loss: 0.615228, acc.: 65.62%] [G loss: 0.

901 [D loss: 0.805773, acc.: 46.88%] [G loss: 0.880099]
902 [D loss: 0.841963, acc.: 40.62%] [G loss: 0.915545]
903 [D loss: 0.834787, acc.: 40.62%] [G loss: 1.063383]
904 [D loss: 0.648931, acc.: 68.75%] [G loss: 0.939006]
905 [D loss: 0.708981, acc.: 62.50%] [G loss: 1.058537]
906 [D loss: 0.793892, acc.: 40.62%] [G loss: 0.937756]
907 [D loss: 0.739876, acc.: 56.25%] [G loss: 1.054374]
908 [D loss: 0.749382, acc.: 53.12%] [G loss: 1.004531]
909 [D loss: 0.798939, acc.: 40.62%] [G loss: 0.912733]
910 [D loss: 0.772642, acc.: 50.00%] [G loss: 0.879463]
911 [D loss: 0.814748, acc.: 43.75%] [G loss: 1.050684]
912 [D loss: 0.792562, acc.: 43.75%] [G loss: 0.984359]
913 [D loss: 0.613628, acc.: 78.12%] [G loss: 1.018261]
914 [D loss: 0.747097, acc.: 53.12%] [G loss: 0.831286]
915 [D loss: 0.728940, acc.: 46.88%] [G loss: 1.044839]
916 [D loss: 0.840440, acc.: 31.25%] [G loss: 1.073790]
917 [D loss: 0.718337, acc.: 50.00%] [G loss: 1.023030]
918 [D loss: 0.800343, acc.: 53.12%] [G loss: 0.

1051 [D loss: 0.758702, acc.: 59.38%] [G loss: 0.960357]
1052 [D loss: 0.626934, acc.: 71.88%] [G loss: 0.838841]
1053 [D loss: 0.698399, acc.: 56.25%] [G loss: 0.975618]
1054 [D loss: 0.718616, acc.: 50.00%] [G loss: 0.756481]
1055 [D loss: 0.749007, acc.: 56.25%] [G loss: 0.902100]
1056 [D loss: 0.740789, acc.: 68.75%] [G loss: 1.092706]
1057 [D loss: 0.732739, acc.: 65.62%] [G loss: 1.011984]
1058 [D loss: 0.800597, acc.: 43.75%] [G loss: 0.839484]
1059 [D loss: 0.571110, acc.: 68.75%] [G loss: 0.844804]
1060 [D loss: 0.624616, acc.: 68.75%] [G loss: 1.025043]
1061 [D loss: 0.826424, acc.: 40.62%] [G loss: 0.840274]
1062 [D loss: 0.755385, acc.: 46.88%] [G loss: 0.853069]
1063 [D loss: 0.662757, acc.: 68.75%] [G loss: 0.995262]
1064 [D loss: 0.688546, acc.: 62.50%] [G loss: 0.959428]
1065 [D loss: 0.781752, acc.: 43.75%] [G loss: 1.110161]
1066 [D loss: 0.561707, acc.: 81.25%] [G loss: 0.837591]
1067 [D loss: 0.815264, acc.: 37.50%] [G loss: 0.847651]
1068 [D loss: 0.760349, acc.: 4

1196 [D loss: 0.663200, acc.: 62.50%] [G loss: 1.015478]
1197 [D loss: 0.693101, acc.: 59.38%] [G loss: 0.900995]
1198 [D loss: 0.823031, acc.: 37.50%] [G loss: 0.831017]
1199 [D loss: 0.728226, acc.: 59.38%] [G loss: 0.896684]
1200 [D loss: 0.827910, acc.: 46.88%] [G loss: 0.888012]
1201 [D loss: 0.802978, acc.: 50.00%] [G loss: 0.855495]
1202 [D loss: 0.609157, acc.: 59.38%] [G loss: 0.963837]
1203 [D loss: 0.742217, acc.: 59.38%] [G loss: 0.958601]
1204 [D loss: 0.695911, acc.: 65.62%] [G loss: 0.878500]
1205 [D loss: 0.644594, acc.: 62.50%] [G loss: 0.906100]
1206 [D loss: 0.771676, acc.: 46.88%] [G loss: 0.752431]
1207 [D loss: 0.733099, acc.: 53.12%] [G loss: 1.042437]
1208 [D loss: 0.651114, acc.: 65.62%] [G loss: 0.966604]
1209 [D loss: 0.742991, acc.: 50.00%] [G loss: 0.985792]
1210 [D loss: 0.728574, acc.: 53.12%] [G loss: 0.891964]
1211 [D loss: 0.761366, acc.: 46.88%] [G loss: 0.956938]
1212 [D loss: 0.806498, acc.: 43.75%] [G loss: 0.918888]
1213 [D loss: 0.812444, acc.: 4

In [15]:
from IPython.display import display, Image, HTML
def Display_Images(images, header=None, width="100%"): # to match Image syntax
    if type(width)==type(1): width = "{}px".format(width)
    html = ["<table style='width:{}'><tr>".format(width)]
    if header is not None:
        html += ["<th>{}</th>".format(h) for h in header] + ["</tr><tr>"]

    cols=1
    for image in images:
        print(image)
        html.append("<td><img src='{}' /></td>".format(image))
        cols+=1
        if (cols>3):
            html.append("</tr><tr>")
            cols=1
    html.append("</tr></table>")
    display(HTML(''.join(html)))

In [16]:
import glob
file_map=[]
for image_path in glob.glob('./' + "/images/*.png"):
    file_map.append(image_path)

In [17]:
len(file_map)

80

In [20]:
indices=[0,10,20,30,40,50,60,70,79]
Display_Images([file_map[i] for i in indices],width="100%")

.//images/mnist_0.png
.//images/mnist_1400.png
.//images/mnist_1850.png
.//images/mnist_2300.png
.//images/mnist_2750.png
.//images/mnist_3200.png
.//images/mnist_3650.png
.//images/mnist_500.png
.//images/mnist_950.png


In [19]:
display()