In [1]:
import numpy as np

import tensorflow as tf

from keras import layers, models, optimizers
from keras import backend as K
from keras.utils import to_categorical
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras import callbacks
from keras.utils.vis_utils import plot_model

import matplotlib.pyplot as plt
import sys
import os
import tqdm

from tensorflow.python.client import device_lib
print('Devices:', device_lib.list_local_devices())

%matplotlib inline

# check for a GPU
if not tf.test.gpu_device_name():
    print('No GPU found.')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))

print('Modules imported.')

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


Devices: [name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 15880288978476585256
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 1693306880
locality {
  bus_id: 1
}
incarnation: 3910100126469382704
physical_device_desc: "device: 0, name: GeForce GTX 660, pci bus id: 0000:01:00.0, compute capability: 3.0"
]
Default GPU Device: /device:GPU:0
Modules imported.


## GAN

In [2]:
width = 28
height = 28
channels = 1

In [17]:
# discriminator structure
def build_discriminator():

        img_shape = (width, height, channels)
        
        model = Sequential()

        model.add(Flatten(input_shape=img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=img_shape)
        validity = model(img)

        return Model(img, validity)

In [18]:
# generator structure
def build_generator():

        noise_shape = (100,)
        
        model = Sequential()

        model.add(Dense(256, input_shape=noise_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))
        model.add(Reshape((28, 28, 1)))

        model.summary()

        noise = Input(shape=noise_shape)
        img = model(noise)

        return Model(noise, img)

In [19]:
# defining an optimizer
optimizer = Adam(0.0002, 0.5)

In [20]:
# build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_8 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_10 (Dense)             (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________


In [21]:
# build and compile the generator
generator = build_generator()
generator.compile(loss='binary_crossentropy', optimizer=optimizer)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_11 (Dense)             (None, 256)               25856     
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 256)               1024      
_________________________________________________________________
dense_12 (Dense)             (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_5 (Batch (None, 512)               2048      
_________________________________________________________________
dense_13 (Dense)             (None, 1024)              525312    
__________

In [8]:
# feeding noise to generator
z = Input(shape=(100,))
img = generator(z)

In [9]:
# for the combined model we will only train the generator
discriminator.trainable = False

In [10]:
# try to discriminate generated images
valid = discriminator(img)

In [11]:
# The combined model  (stacked generator and discriminator) takes
# noise as input => generates images => determines validity 
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

In [12]:
def train(epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (half_batch, 100))

            # Generate a half batch of new images
            gen_imgs = generator.predict(noise)

            # Train the discriminator
            d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = combined.train_on_batch(noise, valid_y)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                save_imgs(epoch)

In [13]:
def save_imgs(epoch):
        directory = "images"
        
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        
        if not os.path.exists(directory):
            os.makedirs(directory)
            
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()

In [15]:
history = train(epochs=30000, batch_size=128, save_interval=1000)

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.618875, acc.: 67.97%] [G loss: 0.802113]
1 [D loss: 0.639175, acc.: 64.06%] [G loss: 0.823674]
2 [D loss: 0.629066, acc.: 66.41%] [G loss: 0.825397]
3 [D loss: 0.651026, acc.: 56.25%] [G loss: 0.807527]
4 [D loss: 0.641097, acc.: 60.16%] [G loss: 0.827198]
5 [D loss: 0.631091, acc.: 63.28%] [G loss: 0.852882]
6 [D loss: 0.609378, acc.: 71.88%] [G loss: 0.877632]
7 [D loss: 0.625812, acc.: 69.53%] [G loss: 0.887825]
8 [D loss: 0.619630, acc.: 71.88%] [G loss: 0.877363]
9 [D loss: 0.624530, acc.: 67.97%] [G loss: 0.859427]
10 [D loss: 0.623899, acc.: 68.75%] [G loss: 0.849812]
11 [D loss: 0.626899, acc.: 66.41%] [G loss: 0.872907]
12 [D loss: 0.621005, acc.: 69.53%] [G loss: 0.842928]
13 [D loss: 0.613719, acc.: 69.53%] [G loss: 0.834394]
14 [D loss: 0.606841, acc.: 72.66%] [G loss: 0.849864]
15 [D loss: 0.603702, acc.: 68.75%] [G loss: 0.856017]
16 [D loss: 0.593130, acc.: 74.22%] [G loss: 0.850864]
17 [D loss: 0.607741, acc.: 66.41%] [G loss: 0.834288]
18 [D loss: 0.611515

294 [D loss: 0.565179, acc.: 76.56%] [G loss: 0.959823]
295 [D loss: 0.593797, acc.: 74.22%] [G loss: 0.935532]
296 [D loss: 0.574642, acc.: 72.66%] [G loss: 0.961534]
297 [D loss: 0.559794, acc.: 77.34%] [G loss: 0.986102]
298 [D loss: 0.597754, acc.: 75.00%] [G loss: 0.932431]
299 [D loss: 0.580030, acc.: 70.31%] [G loss: 0.924178]
300 [D loss: 0.582159, acc.: 71.88%] [G loss: 0.910555]
301 [D loss: 0.585752, acc.: 73.44%] [G loss: 0.928123]
302 [D loss: 0.610191, acc.: 67.19%] [G loss: 0.917465]
303 [D loss: 0.585693, acc.: 69.53%] [G loss: 0.908503]
304 [D loss: 0.595971, acc.: 65.62%] [G loss: 0.893332]
305 [D loss: 0.574341, acc.: 71.88%] [G loss: 0.901649]
306 [D loss: 0.620162, acc.: 67.19%] [G loss: 0.949387]
307 [D loss: 0.586688, acc.: 78.12%] [G loss: 0.902039]
308 [D loss: 0.557538, acc.: 73.44%] [G loss: 0.911080]
309 [D loss: 0.572605, acc.: 75.00%] [G loss: 0.961598]
310 [D loss: 0.598426, acc.: 65.62%] [G loss: 0.911406]
311 [D loss: 0.579988, acc.: 70.31%] [G loss: 0.

586 [D loss: 0.573569, acc.: 78.12%] [G loss: 1.059465]
587 [D loss: 0.617965, acc.: 70.31%] [G loss: 1.004442]
588 [D loss: 0.596242, acc.: 71.88%] [G loss: 1.017766]
589 [D loss: 0.623021, acc.: 64.06%] [G loss: 1.003483]
590 [D loss: 0.581931, acc.: 68.75%] [G loss: 1.090993]
591 [D loss: 0.582219, acc.: 71.88%] [G loss: 1.070310]
592 [D loss: 0.558309, acc.: 72.66%] [G loss: 1.074763]
593 [D loss: 0.582613, acc.: 75.00%] [G loss: 1.051298]
594 [D loss: 0.566128, acc.: 77.34%] [G loss: 1.057800]
595 [D loss: 0.574451, acc.: 72.66%] [G loss: 1.010246]
596 [D loss: 0.578024, acc.: 70.31%] [G loss: 0.977121]
597 [D loss: 0.584743, acc.: 74.22%] [G loss: 0.968903]
598 [D loss: 0.617654, acc.: 63.28%] [G loss: 1.000791]
599 [D loss: 0.569237, acc.: 75.00%] [G loss: 0.973988]
600 [D loss: 0.586063, acc.: 72.66%] [G loss: 0.974901]
601 [D loss: 0.578633, acc.: 72.66%] [G loss: 0.945917]
602 [D loss: 0.608126, acc.: 63.28%] [G loss: 0.957294]
603 [D loss: 0.582138, acc.: 75.00%] [G loss: 0.

KeyboardInterrupt: 

## Capsule Net

In [13]:
def CapsNet(input_shape, n_class, num_routing):
    """
    A Capsule Network on MNIST.
    :param input_shape: data shape, 3d, [width, height, channels]
    :param n_class: number of classes
    :param num_routing: number of routing iterations
    :return: A Keras Model with 2 inputs and 2 outputs
    """
    img = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=8, kernel_size=5, strides=1, padding='valid', activation='relu', name='conv1')(img)

    # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_vector]
    primarycaps = PrimaryCap(conv1, dim_vector=8, n_channels=8, kernel_size=9, strides=2, padding='valid')

    # Layer 3: Capsule layer. Routing algorithm works here.
    digitcaps = CapsuleLayer(num_capsule=n_class, dim_vector=16, num_routing=num_routing, name='digitcaps')(primarycaps)
    
    x = layers.Flatten()(digitcaps)
    prediction = Dense(1, activation='sigmoid')(x)

    # two-input-two-output keras Model
    return models.Model(img, prediction)

discriminator = CapsNet([28, 28, 1], 10, 3)
discriminator.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         (None, 28, 28, 1)         0         
_________________________________________________________________
conv1 (Conv2D)               (None, 24, 24, 8)         208       
_________________________________________________________________
primarycap_conv2d (Conv2D)   (None, 8, 8, 64)          41536     
_________________________________________________________________
primarycap_reshape (Reshape) (None, 512, 8)            0         
_________________________________________________________________
primarycap_squash (Lambda)   (None, 512, 8)            0         
_________________________________________________________________
digitcaps (CapsuleLayer)     (None, 10, 16)            660480    
_________________________________________________________________
flatten_3 (Flatten)          (None, 160)               0         
__________

In [14]:
def margin_loss(y_true, y_pred):
    """
    Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
    :param y_true: [None, n_classes]
    :param y_pred: [None, num_capsule]
    :return: a scalar loss value.
    """
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))

    return K.mean(K.sum(L, 1))

In [15]:
def train(model, data):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=save_dir + '/tensorboard-logs',
                               batch_size=batch_size, histogram_freq=debug)
    checkpoint = callbacks.ModelCheckpoint(save_dir + '/weights-{epoch:02d}.h5',
                                           save_best_only=True, save_weights_only=True, verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: lr * (0.9 ** epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=lr),
                  loss=[margin_loss],
                  metrics=['accuracy'])

    """
    # Training without data augmentation:
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    """

    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(generator=train_generator(x_train, y_train, batch_size, shift_fraction),
                        steps_per_epoch=int(y_train.shape[0] / batch_size),
                        epochs=epochs,
                        validation_data=[[x_test, y_test], [y_test, x_test]],
                        callbacks=[log, tb, checkpoint, lr_decay])
    # End: Training with data augmentation -----------------------------------------------------------------------#

    model.save_weights(save_dir + '/trained_model_capsnet.h5')
    print('Trained model saved to \'%s/trained_model_capsnet.h5\'' % save_dir)

    from utils import plot_log
    plot_log(save_dir + '/log.csv', show=True)

    return model

In [16]:
def test(model, data):
    x_test, y_test = data
    y_pred, x_recon = model.predict([x_test, y_test], batch_size=100)
    print('-'*50)
    print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0])

    import matplotlib.pyplot as plt
    from utils import combine_images
    from PIL import Image

    img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))
    image = img * 255
    Image.fromarray(image.astype(np.uint8)).save("real_and_recon.png")
    print()
    print('Reconstructed images are saved to ./real_and_recon.png')
    print('-'*50)
    plt.imshow(plt.imread("real_and_recon.png", ))
    plt.show()

In [17]:
def load_mnist():
    # the data, shuffled and split between train and test sets
    from keras.datasets import mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
    x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.
    y_train = to_categorical(y_train.astype('float32'))
    y_test = to_categorical(y_test.astype('float32'))
    return (x_train, y_train), (x_test, y_test)

In [18]:
batch_size = 32
epochs = 30
lam_recon = 0.392  # 784 * 0.0005, paper uses sum of SE, here uses MSE
num_routing = 3
shift_fraction = 0.1
debug = 1  # debug>0 will save weights by TensorBoard
save_dir ='./result'
is_training = 1
weights = None
lr = 0.001

In [19]:
# load data
(x_train, y_train), (x_test, y_test) = load_mnist()

# define model
model = CapsNet(input_shape=[28, 28, 1],
                n_class=len(np.unique(np.argmax(y_train, 1))),
                num_routing=num_routing)
model.summary()
#plot_model(model, to_file=save_dir+'/model.png', show_shapes=True)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 28, 28, 1)         0         
_________________________________________________________________
conv1 (Conv2D)               (None, 24, 24, 8)         208       
_________________________________________________________________
primarycap_conv2d (Conv2D)   (None, 8, 8, 64)          41536     
_________________________________________________________________
primarycap_reshape (Reshape) (None, 512, 8)            0         
_________________________________________________________________
primarycap_squash (Lambda)   (None, 512, 8)            0         
_________________________________________________________________
digitcaps (CapsuleLayer)     (None, 10, 16)            660480    
_________________________________________________________________
flatten_4 (Flatten)          (None, 160)               0         
__________

In [20]:
# compile the model
model.compile(optimizer=optimizers.Adam(lr=lr),
loss=[margin_loss],
metrics=['accuracy'])

In [21]:
# train or test
if weights is not None:  # init the model weights with provided one
    model.load_weights(args.weights)
if is_training:
    history = model.train_on_batch(x_train, y_train)
else:  # as long as weights are given, will run testing
    if args.weights is None:
        print('No weights are provided. Will test using random initialized weights.')
    model.train_on_batch(x_test, y_test)

ResourceExhaustedError: OOM when allocating tensor with shape[9,9,8,64]
	 [[Node: training_1/Adam/mul_11 = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](Adam_1/beta_1/read, training_1/Adam/Variable_2/read)]]
	 [[Node: digitcaps_3/Reshape_2/_267 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_369_digitcaps_3/Reshape_2", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'training_1/Adam/mul_11', defined at:
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\runpy.py", line 170, in _run_module_as_main
    "__main__", mod_spec)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
    app.start()
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\ipykernel\kernelapp.py", line 478, in start
    self.io_loop.start()
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\zmq\eventloop\ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\tornado\ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\zmq\eventloop\zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\zmq\eventloop\zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\zmq\eventloop\zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\ipykernel\kernelbase.py", line 281, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\ipykernel\kernelbase.py", line 232, in dispatch_shell
    handler(stream, idents, msg)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\ipykernel\kernelbase.py", line 397, in execute_request
    user_expressions, allow_stdin)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\ipykernel\ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\ipykernel\zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\IPython\core\interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\IPython\core\interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\IPython\core\interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-21-0c77fd4d45a4>", line 5, in <module>
    history = model.train_on_batch(x_train, y_train)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\keras\engine\training.py", line 1838, in train_on_batch
    self._make_train_function()
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\keras\engine\training.py", line 990, in _make_train_function
    loss=self.total_loss)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\keras\legacy\interfaces.py", line 87, in wrapper
    return func(*args, **kwargs)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\keras\optimizers.py", line 432, in get_updates
    m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\tensorflow\python\ops\variables.py", line 754, in _run_op
    return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\tensorflow\python\ops\math_ops.py", line 894, in binary_op_wrapper
    return func(x, y, name=name)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1117, in _mul_dispatch
    return gen_math_ops._mul(x, y, name=name)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 2725, in _mul
    "Mul", x=x, y=y, name=name)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\tensorflow\python\framework\ops.py", line 2956, in create_op
    op_def=op_def)
  File "C:\Users\husey_000\Miniconda3\envs\capsule-gans\lib\site-packages\tensorflow\python\framework\ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[9,9,8,64]
	 [[Node: training_1/Adam/mul_11 = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](Adam_1/beta_1/read, training_1/Adam/Variable_2/read)]]
	 [[Node: digitcaps_3/Reshape_2/_267 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_369_digitcaps_3/Reshape_2", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
