In [1]:
# math
import numpy as np

# ml
import tensorflow as tf
from keras import layers, models, optimizers
from keras import backend as K
from keras.utils import to_categorical
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Lambda, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras import callbacks
from keras.utils.vis_utils import plot_model

# visualization
import matplotlib.pyplot as plt

# aux
import sys
import os
import tqdm

# device check
from tensorflow.python.client import device_lib
print('Devices:', device_lib.list_local_devices())

%matplotlib inline

# GPU check
if not tf.test.gpu_device_name():
    print('No GPU found.')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))

print('Modules imported.')

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


Devices: [name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 6147287220424757222
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 1687760896
locality {
  bus_id: 1
}
incarnation: 17736817575887873542
physical_device_desc: "device: 0, name: GeForce GTX 660, pci bus id: 0000:01:00.0, compute capability: 3.0"
]
Default GPU Device: /device:GPU:0
Modules imported.


# Defining the graph

In [2]:
# inputs dims since we are working with MNIST dataset
img_rows = 28 
img_cols = 28
channels = 1
img_shape = [img_rows, img_cols, channels]

In [3]:
# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)

In [4]:
def squash(vectors, axis=-1):
    """
    The non-linear activation used in Capsule. It drives the length of a large vector to near 1 and small vector to 0
    :param vectors: some vectors to be squashed, N-dim tensor
    :param axis: the axis to squash
    :return: a Tensor with same shape as input vectors
    """
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

In [5]:
def dynamic_routing(inputs, input_num_capsule=32, input_dim_vector=8, num_capsule=10, dim_vector=16):
    x_inp = Input(shape=img_shape)
    x = Dense(input_num_capsule*input_dim_vector, activation=squash)(x_inp)
    x = Dense(num_capsule*dim_vector, activation=squash)(x)
    out_digs = Reshape(num_capsule, dim_vector)
    
    return Model(x_inp, out_digs)

In [5]:
input_num_capsule=32
input_dim_vector=8
num_capsule=10
dim_vector=16

In [106]:
# discriminator structure
def build_discriminator():

        img_shape = (img_rows, img_cols, channels)
        
        #model = Sequential()

        #model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
        #model.add(LeakyReLU(alpha=0.2))
        #model.add(Dropout(0.25))
        #model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        #model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        #model.add(LeakyReLU(alpha=0.2))
        #model.add(Dropout(0.25))
        #model.add(BatchNormalization(momentum=0.8))
        #model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        #model.add(LeakyReLU(alpha=0.2))
        #model.add(Dropout(0.25))
        #model.add(BatchNormalization(momentum=0.8))
        #model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        #model.add(LeakyReLU(alpha=0.2))
        #model.add(Dropout(0.25))

        #model.add(Flatten())
        #model.add(Dense(1, activation='sigmoid'))
        
        #model.summary()

        #img = Input(shape=img_shape)
        #validity = model(img)
        
        
        img = Input(shape=img_shape)
        
        # Layer 1: Just a conventional Conv2D layer
        conv1 = Conv2D(filters=32, kernel_size=3, strides=2, padding='same', activation=squash, name='conv1')(img)
        conv2 = Conv2D(filters=4*64, kernel_size=3, strides=2, padding='same', name='primarycap_conv2d_2')(conv1)
        conv2 = Lambda(squash, name='primarycap_squash_2')(conv2)
        x = Dropout(0.25)(conv2)
        x = BatchNormalization(momentum=0.8)(x)
        x = Reshape((8, 8, -1))(x)
        conv3 = Conv2D(filters=8*128, kernel_size=3, strides=2, padding='same', name='primarycap_conv2d_3')(x)
        conv3 = Lambda(squash, name='primarycap_squash_3')(conv3)
        x = Dropout(0.25)(conv3)
        x = BatchNormalization(momentum=0.8)(x)
        x = Reshape((8, 8, -1))(x)
        conv4 = Conv2D(filters=16*256, kernel_size=3, strides=1, padding='same', name='primarycap_conv2d_4')(x)
        conv4 = Lambda(squash, name='primarycap_squash_4')(conv4)
        x = Dropout(0.25)(conv4)
        x = BatchNormalization(momentum=0.8)(x)
        #digitcaps = CapsuleLayer(num_capsule=10, dim_vector=16, num_routing=3, name='digitcaps')(primarycaps)
        x = Flatten()(x)
        validity = Dense(1, activation='sigmoid')(x)


        return Model(img, validity)

In [107]:
# generator structure
def build_generator():

        noise_shape = (100,)
        
        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation="relu", input_shape=noise_shape))
        model.add(Reshape((7, 7, 128)))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(momentum=0.8)) 
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(1, kernel_size=3, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=noise_shape)
        img = model(noise)

        return Model(noise, img)

In [108]:
# defining an optimizer
optimizer = Adam(0.0002, 0.5)

In [109]:
# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.summary()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_26 (InputLayer)        (None, 28, 28, 1)         0         
_________________________________________________________________
conv1 (Conv2D)               (None, 14, 14, 32)        320       
_________________________________________________________________
primarycap_conv2d_2 (Conv2D) (None, 7, 7, 256)         73984     
_________________________________________________________________
primarycap_squash_2 (Lambda) (None, 7, 7, 256)         0         
_________________________________________________________________
dropout_31 (Dropout)         (None, 7, 7, 256)         0         
_________________________________________________________________
batch_normalization_33 (Batc (None, 7, 7, 256)         1024      
_________________________________________________________________
reshape_19 (Reshape)         (None, 8, 8, 196)         0         
__________

In [110]:
# Build and compile the generator
generator = build_generator()
generator.compile(loss='binary_crossentropy', optimizer=optimizer)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_16 (Dense)             (None, 6272)              633472    
_________________________________________________________________
reshape_21 (Reshape)         (None, 7, 7, 128)         0         
_________________________________________________________________
batch_normalization_36 (Batc (None, 7, 7, 128)         512       
_________________________________________________________________
up_sampling2d_7 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 14, 14, 128)       147584    
_________________________________________________________________
activation_10 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
batch_normalization_37 (Batc (None, 14, 14, 128)       512       
__________

In [111]:
# feeding noise to generator
z = Input(shape=(100,))
img = generator(z)
print(img)

Tensor("model_19/sequential_4/activation_12/Tanh:0", shape=(?, 28, 28, 1), dtype=float32)


In [112]:
# for the combined model we will only train the generator
discriminator.trainable = False

In [113]:
# try to discriminate generated images
valid = discriminator(img)

In [114]:
# The combined model  (stacked generator and discriminator) takes
# noise as input => generates images => determines validity 
combined = Model(z, valid)
print('COMBINED:')
combined.summary()
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

COMBINED:
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_28 (InputLayer)        (None, 100)               0         
_________________________________________________________________
model_19 (Model)             (None, 28, 28, 1)         856705    
_________________________________________________________________
model_18 (Model)             (None, 1)                 11606593  
Total params: 12,463,298
Trainable params: 856,065
Non-trainable params: 11,607,233
_________________________________________________________________


In [115]:
def train(epochs, batch_size=32, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (half_batch, 100))

            # Generate a half batch of new images
            gen_imgs = generator.predict(noise)

            # Train the discriminator
            d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * 32)

            # Train the generator
            g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                save_imgs(epoch)

In [116]:
def save_imgs(epoch):
        directory = "images"
        
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        
        if not os.path.exists(directory):
            os.makedirs(directory)
            
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()

In [None]:
history = train(epochs=30000, batch_size=32, save_interval=500)

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.839026, acc.: 9.38%] [G loss: 1.379021]
1 [D loss: 0.348902, acc.: 53.12%] [G loss: 2.377234]
2 [D loss: 0.476814, acc.: 65.62%] [G loss: 2.562446]
3 [D loss: 6.636912, acc.: 50.00%] [G loss: 1.445375]
4 [D loss: 1.040923, acc.: 81.25%] [G loss: 1.777034]
5 [D loss: 0.498869, acc.: 90.62%] [G loss: 0.640130]
6 [D loss: 0.000503, acc.: 100.00%] [G loss: 0.593138]
7 [D loss: 0.294135, acc.: 96.88%] [G loss: 0.013981]
8 [D loss: 0.000009, acc.: 100.00%] [G loss: 0.127763]
9 [D loss: 0.000000, acc.: 100.00%] [G loss: 0.015772]
10 [D loss: 0.000002, acc.: 100.00%] [G loss: 0.139718]
11 [D loss: 0.000001, acc.: 100.00%] [G loss: 0.173629]
12 [D loss: 0.000103, acc.: 100.00%] [G loss: 0.005854]
13 [D loss: 0.010384, acc.: 100.00%] [G loss: 0.172729]
14 [D loss: 0.000282, acc.: 100.00%] [G loss: 0.000015]
15 [D loss: 0.047752, acc.: 96.88%] [G loss: 0.457836]
16 [D loss: 0.110405, acc.: 96.88%] [G loss: 0.028939]
17 [D loss: 0.064535, acc.: 96.88%] [G loss: 0.001133]
18 [D loss: 1

147 [D loss: 0.000000, acc.: 100.00%] [G loss: 13.638155]
148 [D loss: 0.008160, acc.: 100.00%] [G loss: 14.962907]
149 [D loss: 0.498200, acc.: 96.88%] [G loss: 14.712296]
150 [D loss: 0.996399, acc.: 93.75%] [G loss: 13.504635]
151 [D loss: 0.000000, acc.: 100.00%] [G loss: 14.548648]
152 [D loss: 0.000000, acc.: 100.00%] [G loss: 13.184855]
153 [D loss: 0.440861, acc.: 93.75%] [G loss: 15.104993]
154 [D loss: 0.000000, acc.: 100.00%] [G loss: 15.337946]
155 [D loss: 0.000000, acc.: 100.00%] [G loss: 15.495682]
156 [D loss: 0.498200, acc.: 96.88%] [G loss: 14.738821]
157 [D loss: 0.101910, acc.: 96.88%] [G loss: 15.253975]
158 [D loss: 0.000000, acc.: 100.00%] [G loss: 15.614407]
159 [D loss: 0.000000, acc.: 100.00%] [G loss: 14.678848]
160 [D loss: 0.000000, acc.: 100.00%] [G loss: 14.607024]
161 [D loss: 0.000000, acc.: 100.00%] [G loss: 16.118095]
162 [D loss: 0.006849, acc.: 100.00%] [G loss: 14.103333]
163 [D loss: 0.498200, acc.: 96.88%] [G loss: 16.118095]
164 [D loss: 0.99639

290 [D loss: 1.494599, acc.: 90.62%] [G loss: 13.205391]
291 [D loss: 0.608623, acc.: 93.75%] [G loss: 12.594759]
292 [D loss: 0.164858, acc.: 96.88%] [G loss: 14.661697]
293 [D loss: 1.058335, acc.: 90.62%] [G loss: 16.118095]
294 [D loss: 0.219072, acc.: 96.88%] [G loss: 16.118095]
295 [D loss: 0.116462, acc.: 96.88%] [G loss: 16.099258]
296 [D loss: 0.000081, acc.: 100.00%] [G loss: 16.118095]
297 [D loss: 0.503691, acc.: 96.88%] [G loss: 16.118095]
298 [D loss: 0.503709, acc.: 96.88%] [G loss: 16.118095]
299 [D loss: 1.008345, acc.: 93.75%] [G loss: 16.118095]
300 [D loss: 0.029096, acc.: 96.88%] [G loss: 16.118095]
301 [D loss: 0.498280, acc.: 96.88%] [G loss: 15.614405]
302 [D loss: 0.000000, acc.: 100.00%] [G loss: 16.118095]
303 [D loss: 0.000000, acc.: 100.00%] [G loss: 15.874058]
304 [D loss: 0.000000, acc.: 100.00%] [G loss: 16.118095]
305 [D loss: 0.000001, acc.: 100.00%] [G loss: 15.638876]
306 [D loss: 0.008047, acc.: 100.00%] [G loss: 16.118095]
307 [D loss: 0.000000, ac

434 [D loss: 7.971192, acc.: 50.00%] [G loss: 0.507532]
435 [D loss: 6.974793, acc.: 56.25%] [G loss: 0.503691]
436 [D loss: 7.472993, acc.: 53.12%] [G loss: 0.000000]
437 [D loss: 6.974793, acc.: 56.25%] [G loss: 0.000000]
438 [D loss: 7.472993, acc.: 53.12%] [G loss: 0.503691]
439 [D loss: 7.971192, acc.: 50.00%] [G loss: 0.000000]
440 [D loss: 7.472993, acc.: 53.12%] [G loss: 0.870756]
441 [D loss: 7.971192, acc.: 50.00%] [G loss: 1.011464]
442 [D loss: 7.971192, acc.: 50.00%] [G loss: 0.000000]
443 [D loss: 7.971192, acc.: 50.00%] [G loss: 0.000000]
444 [D loss: 7.971192, acc.: 50.00%] [G loss: 0.503691]
445 [D loss: 7.971192, acc.: 50.00%] [G loss: 0.503691]
446 [D loss: 7.971192, acc.: 50.00%] [G loss: 0.503710]
447 [D loss: 7.971192, acc.: 50.00%] [G loss: 2.014762]
448 [D loss: 6.476594, acc.: 59.38%] [G loss: 1.007381]
449 [D loss: 7.472993, acc.: 53.12%] [G loss: 0.212057]
450 [D loss: 7.971192, acc.: 50.00%] [G loss: 0.503691]
451 [D loss: 7.472993, acc.: 53.12%] [G loss: 0.

580 [D loss: 4.489287, acc.: 71.88%] [G loss: 5.959420]
581 [D loss: 5.520773, acc.: 62.50%] [G loss: 9.559195]
582 [D loss: 5.413846, acc.: 59.38%] [G loss: 11.927180]
583 [D loss: 5.507875, acc.: 65.62%] [G loss: 12.608532]
584 [D loss: 7.076441, acc.: 53.12%] [G loss: 8.059080]
585 [D loss: 5.024969, acc.: 65.62%] [G loss: 7.744893]
586 [D loss: 4.303747, acc.: 71.88%] [G loss: 12.088031]
587 [D loss: 7.943032, acc.: 50.00%] [G loss: 11.283926]
588 [D loss: 6.782081, acc.: 56.25%] [G loss: 10.073810]
589 [D loss: 5.839043, acc.: 62.50%] [G loss: 7.051667]
590 [D loss: 5.485750, acc.: 65.62%] [G loss: 6.777018]
591 [D loss: 5.491205, acc.: 65.62%] [G loss: 5.540595]
592 [D loss: 4.981996, acc.: 68.75%] [G loss: 5.896921]
593 [D loss: 4.982167, acc.: 68.75%] [G loss: 5.540595]
594 [D loss: 5.480195, acc.: 65.62%] [G loss: 6.547988]
595 [D loss: 4.441031, acc.: 71.88%] [G loss: 5.037030]
596 [D loss: 6.904419, acc.: 56.25%] [G loss: 5.036905]
597 [D loss: 6.476762, acc.: 59.38%] [G los

725 [D loss: 9.096541, acc.: 40.62%] [G loss: 15.110714]
726 [D loss: 9.055447, acc.: 43.75%] [G loss: 15.110716]
727 [D loss: 8.557247, acc.: 46.88%] [G loss: 15.111976]
728 [D loss: 8.784017, acc.: 43.75%] [G loss: 15.628025]
729 [D loss: 8.557247, acc.: 46.88%] [G loss: 14.924149]
730 [D loss: 9.553646, acc.: 40.62%] [G loss: 15.614405]
731 [D loss: 8.238914, acc.: 46.88%] [G loss: 16.118095]
732 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.105194]
733 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
734 [D loss: 8.059048, acc.: 50.00%] [G loss: 15.614405]
735 [D loss: 8.059048, acc.: 50.00%] [G loss: 15.614405]
736 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
737 [D loss: 8.059048, acc.: 50.00%] [G loss: 15.657887]
738 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
739 [D loss: 8.059048, acc.: 50.00%] [G loss: 15.113022]
740 [D loss: 8.059048, acc.: 50.00%] [G loss: 15.614405]
741 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.062578]
742 [D loss: 8.059048, acc.: 50

868 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
869 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
870 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
871 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
872 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
873 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
874 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
875 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
876 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
877 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
878 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
879 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
880 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
881 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
882 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
883 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
884 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
885 [D loss: 8.059048, acc.: 50

1011 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1012 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1013 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1014 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1015 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1016 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1017 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1018 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1019 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1020 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1021 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1022 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1023 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1024 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1025 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1026 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1027 [D loss: 8.059048, acc.: 50.00%] [G loss: 16.118095]
1028 [D loss: 