# Overview

The new model CapsuleNet proposed by Sara Sabour (and Geoffry Hinton) claims to deliver state of the art results on [MNIST](https://arxiv.org/abs/1710.09829). The kernel aims to create and train the model using the Kaggle Dataset and then make a submission to see where it actually ends up. Given the constraint of using a Kaggle Kernel means it can't be trained as long as we would like or with GPU's but IMHO if a model can't be reasonably well trained in an hour on a 28x28 dataset, that model probably won't be too useful in the immediate future.

## Implementation Details

* Keras implementation of CapsNet in Hinton's paper Dynamic Routing Between Capsules.
* Code adapted from https://github.com/XifengGuo/CapsNet-Keras/blob/master/capsulenet.py, and modified by K Scott Mander: https://www.kaggle.com/kmader/capsulenet-on-mnist
*  Author: Xifeng Guo, E-mail: `guoxifeng1990@163.com`, Github: `https://github.com/XifengGuo/CapsNet-Keras`
*     The code has been completely ported to TF2
*     The entire CapsNet model is wrapped as a Scikit learn model, and hyperparameter tuning has been demonstrated using GridSearchCV. 
*     This also enables the model to be used in sklearn pipelines and other workflows.

Result:
    Validation accuracy > 99.5% after 20 epochs. Still under-fitting.
    About 110 seconds per epoch on a single GTX1070 GPU card
    


In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
# !python -m pip install tensorflow==2.1.0
# !python -m pip install keras==2.3.1
!python -m pip install scikeras

from tensorflow.python.framework.ops import disable_eager_execution

disable_eager_execution()

Collecting scikeras
  Downloading https://files.pythonhosted.org/packages/3f/bb/f423fcc01cc51c6bc071175aadfae4bde246d33482375fe00fb0fc6e5caf/scikeras-0.2.1-py3-none-any.whl
Installing collected packages: scikeras
Successfully installed scikeras-0.2.1


In [4]:
import numpy as np
import os
import pandas as pd

import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, initializers
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from PIL import Image


from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import callbacks
from keras.utils.vis_utils import plot_model
from tensorflow.keras.utils import plot_model as model_plotter

from scikeras.wrappers import KerasClassifier, KerasRegressor, BaseWrapper

from sklearn.model_selection import GridSearchCV

K.set_image_data_format('channels_last')

In [5]:
tf.__version__

'2.4.0'

In [6]:
tf.keras.layers.Input

<function tensorflow.python.keras.engine.input_layer.Input>

# Capsule Layers 
Here is the implementation of the necessary layers for the CapsuleNet. These are not optimized yet and can be made significantly more performant. 

In [7]:
class Length(layers.Layer):
    """
    Compute the length of vectors. This is used to compute a Tensor that has the same shape with y_true in margin_loss.
    Using this layer as model's output can directly predict labels by using `y_pred = np.argmax(model.predict(x), 1)`
    inputs: shape=[None, num_vectors, dim_vector]
    output: shape=[None, num_vectors]
    """
    def call(self, inputs, **kwargs):
        return tf.sqrt(tf.reduce_sum(tf.square(inputs), -1) + K.epsilon())

    def compute_output_shape(self, input_shape):
        return input_shape[:-1]

    def get_config(self):
        config = super(Length, self).get_config()
        return config

In [8]:
class Mask(layers.Layer):
    """
    Mask a Tensor with shape=[None, num_capsule, dim_vector] either by the capsule with max length or by an additional 
    input mask. Except the max-length capsule (or specified capsule), all vectors are masked to zeros. Then flatten the
    masked Tensor.
    For example:
        ```
        x = keras.layers.Input(shape=[8, 3, 2])  # batch_size=8, each sample contains 3 capsules with dim_vector=2
        y = keras.layers.Input(shape=[8, 3])  # True labels. 8 samples, 3 classes, one-hot coding.
        out = Mask()(x)  # out.shape=[8, 6]
        # or
        out2 = Mask()([x, y])  # out2.shape=[8,6]. Masked with true labels y. Of course y can also be manipulated.
        ```
    """
    def call(self, inputs, **kwargs):
        if type(inputs) is list:  # true label is provided with shape = [None, n_classes], i.e. one-hot code.
            assert len(inputs) == 2
            inputs, mask = inputs
        else:  # if no true label, mask by the max length of capsules. Mainly used for prediction
            # compute lengths of capsules
            x = tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))
            # generate the mask which is a one-hot code.
            # mask.shape=[None, n_classes]=[None, num_capsule]
            mask = tf.one_hot(indices=tf.argmax(x, 1), depth=x.shape[1])

        # inputs.shape=[None, num_capsule, dim_capsule]
        # mask.shape=[None, num_capsule]
        # masked.shape=[None, num_capsule * dim_capsule]
        masked = K.batch_flatten(inputs * tf.expand_dims(mask, -1))
        return masked

    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple:  # true label provided
            return tuple([None, input_shape[0][1] * input_shape[0][2]])
        else:  # no true label provided
            return tuple([None, input_shape[1] * input_shape[2]])

    def get_config(self):
        config = super(Mask, self).get_config()
        return config

In [9]:
def squash(vectors, axis=-1):
    """
    The non-linear activation used in Capsule. It drives the length of a large vector to near 1 and small vector to 0
    :param vectors: some vectors to be squashed, N-dim tensor
    :param axis: the axis to squash
    :return: a Tensor with same shape as input vectors
    """
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

In [10]:
class CapsuleLayer(layers.Layer):
    """
    The capsule layer. It is similar to Dense layer. Dense layer has `in_num` inputs, each is a scalar, the output of the
    neuron from the former layer, and it has `out_num` output neurons. CapsuleLayer just expand the output of the neuron
    from scalar to vector. So its input shape = [None, input_num_capsule, input_dim_capsule] and output shape = \
    [None, num_capsule, dim_capsule]. For Dense Layer, input_dim_capsule = dim_capsule = 1.
    :param num_capsule: number of capsules in this layer
    :param dim_capsule: dimension of the output vectors of the capsules in this layer
    :param routings: number of iterations for the routing algorithm
    """
    def __init__(self, num_capsule, dim_capsule, routings=3,
                 kernel_initializer='glorot_uniform',
                 **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.kernel_initializer = initializers.get(kernel_initializer)

    def build(self, input_shape):
        assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]"
        self.input_num_capsule = input_shape[1]
        self.input_dim_capsule = input_shape[2]

        # Transform matrix, from each input capsule to each output capsule, there's a unique weight as in Dense layer.
        self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
                                        self.dim_capsule, self.input_dim_capsule],
                                 initializer=self.kernel_initializer,
                                 name='W')

        self.built = True

    def call(self, inputs, training=None):
        # inputs.shape=[None, input_num_capsule, input_dim_capsule]
        # inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule, 1]
        inputs_expand = tf.expand_dims(tf.expand_dims(inputs, 1), -1)

        # Replicate num_capsule dimension to prepare being multiplied by W
        # inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule, 1]
        inputs_tiled = tf.tile(inputs_expand, [1, self.num_capsule, 1, 1, 1])

        # Compute `inputs * W` by scanning inputs_tiled on dimension 0.
        # W.shape=[num_capsule, input_num_capsule, dim_capsule, input_dim_capsule]
        # x.shape=[num_capsule, input_num_capsule, input_dim_capsule, 1]
        # Regard the first two dimensions as `batch` dimension, then
        # matmul(W, x): [..., dim_capsule, input_dim_capsule] x [..., input_dim_capsule, 1] -> [..., dim_capsule, 1].
        # inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
        inputs_hat = tf.squeeze(tf.map_fn(lambda x: tf.matmul(self.W, x), elems=inputs_tiled))

        # Begin: Routing algorithm ---------------------------------------------------------------------#
        # The prior for coupling coefficient, initialized as zeros.
        # b.shape = [None, self.num_capsule, 1, self.input_num_capsule].
        b = tf.zeros(shape=[inputs.shape[0], self.num_capsule, 1, self.input_num_capsule])

        assert self.routings > 0, 'The routings should be > 0.'
        for i in range(self.routings):
            # c.shape=[batch_size, num_capsule, 1, input_num_capsule]
            c = tf.nn.softmax(b, axis=1)

            # c.shape = [batch_size, num_capsule, 1, input_num_capsule]
            # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
            # The first two dimensions as `batch` dimension,
            # then matmal: [..., 1, input_num_capsule] x [..., input_num_capsule, dim_capsule] -> [..., 1, dim_capsule].
            # outputs.shape=[None, num_capsule, 1, dim_capsule]
            outputs = squash(tf.matmul(c, inputs_hat))  # [None, 10, 1, 16]

            if i < self.routings - 1:
                # outputs.shape =  [None, num_capsule, 1, dim_capsule]
                # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
                # The first two dimensions as `batch` dimension, then
                # matmal:[..., 1, dim_capsule] x [..., input_num_capsule, dim_capsule]^T -> [..., 1, input_num_capsule].
                # b.shape=[batch_size, num_capsule, 1, input_num_capsule]
                b += tf.matmul(outputs, inputs_hat, transpose_b=True)
        # End: Routing algorithm -----------------------------------------------------------------------#

        return tf.squeeze(outputs)

    def compute_output_shape(self, input_shape):
        return tuple([None, self.num_capsule, self.dim_capsule])

    def get_config(self):
        config = {
            'num_capsule': self.num_capsule,
            'dim_capsule': self.dim_capsule,
            'routings': self.routings
        }
        base_config = super(CapsuleLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))



In [11]:
def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
    """
    Apply Conv2D `n_channels` times and concatenate all capsules
    :param inputs: 4D tensor, shape=[None, width, height, channels]
    :param dim_capsule: the dim of the output vector of capsule
    :param n_channels: the number of types of capsules
    :return: output tensor, shape=[None, num_capsule, dim_capsule]
    """
    output = layers.Conv2D(filters=dim_capsule*n_channels, kernel_size=kernel_size, strides=strides, padding=padding,
                           name='primarycap_conv2d')(inputs)
    outputs = layers.Reshape(target_shape=[-1, dim_capsule], name='primarycap_reshape')(output)
    return layers.Lambda(squash, name='primarycap_squash')(outputs)

In [12]:
from keras.backend import *
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops

def own_batch_dot(x, y, axes=None):
  """Batchwise dot product.
  `batch_dot` is used to compute dot product of `x` and `y` when
  `x` and `y` are data in batch, i.e. in a shape of
  `(batch_size, :)`.
  `batch_dot` results in a tensor or variable with less dimensions
  than the input. If the number of dimensions is reduced to 1,
  we use `expand_dims` to make sure that ndim is at least 2.
  Arguments:
      x: Keras tensor or variable with `ndim >= 2`.
      y: Keras tensor or variable with `ndim >= 2`.
      axes: list of (or single) int with target dimensions.
          The lengths of `axes[0]` and `axes[1]` should be the same.
  Returns:
      A tensor with shape equal to the concatenation of `x`'s shape
      (less the dimension that was summed over) and `y`'s shape
      (less the batch dimension and the dimension that was summed over).
      If the final rank is 1, we reshape it to `(batch_size, 1)`.
  Examples:
      Assume `x = [[1, 2], [3, 4]]` and `y = [[5, 6], [7, 8]]`
      `batch_dot(x, y, axes=1) = [[17, 53]]` which is the main diagonal
      of `x.dot(y.T)`, although we never have to calculate the off-diagonal
      elements.
      Shape inference:
      Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
      If `axes` is (1, 2), to find the output shape of resultant tensor,
          loop through each dimension in `x`'s shape and `y`'s shape:
      * `x.shape[0]` : 100 : append to output shape
      * `x.shape[1]` : 20 : do not append to output shape,
          dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
      * `y.shape[0]` : 100 : do not append to output shape,
          always ignore first dimension of `y`
      * `y.shape[1]` : 30 : append to output shape
      * `y.shape[2]` : 20 : do not append to output shape,
          dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
      `output_shape` = `(100, 30)`
  ```python
      >>> x_batch = K.ones(shape=(32, 20, 1))
      >>> y_batch = K.ones(shape=(32, 30, 20))
      >>> xy_batch_dot = K.batch_dot(x_batch, y_batch, axes=[1, 2])
      >>> K.int_shape(xy_batch_dot)
      (32, 1, 30)
  ```
  """
  if isinstance(axes, int):
    axes = (axes, axes)
  x_ndim = ndim(x)
  y_ndim = ndim(y)
  if axes is None:
    # behaves like tf.batch_matmul as default
    axes = [x_ndim - 1, y_ndim - 2]
  if x_ndim > y_ndim:
    diff = x_ndim - y_ndim
    y = array_ops.reshape(y,
                          array_ops.concat(
                              [array_ops.shape(y), [1] * (diff)], axis=0))
  elif y_ndim > x_ndim:
    diff = y_ndim - x_ndim
    x = array_ops.reshape(x,
                          array_ops.concat(
                              [array_ops.shape(x), [1] * (diff)], axis=0))
  else:
    diff = 0
  if ndim(x) == 2 and ndim(y) == 2:
    if axes[0] == axes[1]:
      out = math_ops.reduce_sum(math_ops.multiply(x, y), axes[0])
    else:
      out = math_ops.reduce_sum(
          math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
  else:
    adj_x = None if axes[0] == ndim(x) - 1 else True
    adj_y = True if axes[1] == ndim(y) - 1 else None
    out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
  if diff:
    if x_ndim > y_ndim:
      idx = x_ndim + y_ndim - 3
    else:
      idx = x_ndim - 1
    out = array_ops.squeeze(out, list(range(idx, idx + diff)))
  if ndim(out) == 1:
    out = expand_dims(out, 1)
  return out

# Build the Model
Here we use the layers to build up the model. The model is a bit different from a standard $X\rightarrow y$  model, it is $(X,y)\rightarrow (y,X)$ meaning it attempts to predict the class from the image, and then at the same time, using the same capsule reconstruct the image from the class. The approach appears very cGAN-like where the task of reconstructing better helps the model 'understand' the image data better.

In [13]:
def CapsNet(input_shape, n_class, routings, batch_size,
            n_filters_c1=256, 
            dim_caps_prim=8, 
            n_channels_prim=32, 
            dim_caps_sec=16):
    """
    A Capsule Network on MNIST.
    :param input_shape: data shape, 3d, [width, height, channels]
    :param n_class: number of classes
    :param routings: number of routing iterations
    :param batch_size: size of batch
    :return: Two Keras Models, the first one used for training, and the second one for evaluation.
            `eval_model` can also be used for training.
    """
    x = layers.Input(shape=input_shape, batch_size=batch_size)

    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=n_filters_c1, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)

    # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
    primarycaps = PrimaryCap(conv1, dim_capsule=dim_caps_prim, n_channels=n_channels_prim, kernel_size=9, strides=2, padding='valid')

    # Layer 3: Capsule layer. Routing algorithm works here.
    digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=dim_caps_sec, routings=routings, name='digitcaps')(primarycaps)

    # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
    # If using tensorflow, this will not be necessary. :)
    out_caps = Length(name='capsnet')(digitcaps)
    print

    # Decoder network.
    y = layers.Input(shape=(n_class,))
    masked_by_y = Mask()([digitcaps, y])  # The true label is used to mask the output of capsule layer. For training
    masked = Mask()(digitcaps)  # Mask using the capsule with maximal length. For prediction

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    # decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))

    # Models for training and evaluation (prediction)
    train_model = models.Model(inputs=[x, y], outputs=[out_caps, decoder(masked_by_y)])
    eval_model = models.Model(inputs=x, outputs=[out_caps, decoder(masked)])
    # train_model = models.Model(inputs=x, outputs=[out_caps, decoder(masked_by_y)])
    # eval_model = models.Model(inputs=x, outputs=[out_caps, decoder(masked)])

    # manipulate model
    noise = layers.Input(shape=(n_class, 16))
    noised_digitcaps = layers.Add()([digitcaps, noise])
    masked_noised_y = Mask()([noised_digitcaps, y])
    manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y))
    return train_model, eval_model, manipulate_model

In [14]:
def margin_loss(y_true, y_pred):
    """
    Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
    :param y_true: [None, n_classes]
    :param y_pred: [None, num_capsule]
    :return: a scalar loss value.
    """
    # return tf.reduce_mean(tf.square(y_pred))
    L = y_true * tf.square(tf.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * tf.square(tf.maximum(0., y_pred - 0.1))

    return tf.reduce_mean(tf.reduce_sum(L, 1))

In [15]:
import numpy as np
from matplotlib import pyplot as plt
import csv
import math
import pandas

def plot_log(filename, show=True):

    data = pandas.read_csv(filename)

    fig = plt.figure(figsize=(4,6))
    fig.subplots_adjust(top=0.95, bottom=0.05, right=0.95)
    fig.add_subplot(211)
    for key in data.keys():
        if key.find('loss') >= 0 and not key.find('val') >= 0:  # training loss
            plt.plot(data['epoch'].values, data[key].values, label=key)
    plt.legend()
    plt.title('Training loss')

    fig.add_subplot(212)
    for key in data.keys():
        if key.find('acc') >= 0:  # acc
            plt.plot(data['epoch'].values, data[key].values, label=key)
    plt.legend()
    plt.title('Training and validation accuracy')

    # fig.savefig('result/log.png')
    if show:
        plt.show()

#### Define a Multi-Output Transformer


In [22]:
from typing import List

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import LabelEncoder, FunctionTransformer, OneHotEncoder


class MultiOutputTransformer(BaseEstimator, TransformerMixin):

    def fit(self, y):
        # y_bin, y_cat = y[:, 0], y[:, 1]
        y_caps, y_recons = y[:,:10],y[:,10:]
        # Create internal encoders to ensure labels are 0, 1, 2...
        # self.caps_encoder_ = OneHotEncoder()
        # self.recons_encoder_ = FunctionTransformer(func=lambda t: t)
        # # Fit them to the input data
        # self.caps_encoder_.fit(y_caps)
        # self.recons_encoder_.fit(y_recons)
        # Save the number of classes
        self.n_classes_ = [
            y_caps.shape[0],
            y_recons.shape[1],
        ]
        # Save number of expected outputs in the Keras model
        # SciKeras will automatically use this to do error-checking
        self.n_outputs_expected_ = 2
        return self

    def transform(self, y: np.ndarray) -> List[np.ndarray]:
        y_caps, y_recons = y[:,:10],y[:,10:]
        # Apply transformers to input array
        # y_caps = self.caps_encoder_.transform(y_caps)
        # y_recons = self.recons_encoder_.transform(y_recons)
        # Split the data into a list
        return [y_caps, y_recons]

    def inverse_transform(self, y: List[np.ndarray], return_proba: bool = False) -> np.ndarray:
        y_pred_proba = y  # rename for clarity, what Keras gives us are probs
        if return_proba:
            return np.column_stack(y_pred_proba, axis=1)
        # Get class predictions from probabilities
        y_pred_caps = to_categorical(np.argmax(y_pred_proba[0], axis=1), num_classes=y_pred_proba[0].shape[1])
        y_pred_recons = y_pred_proba[1]
        # y_pred_cat = np.argmax(y_pred_proba[1], axis=1)
        # Pass back through LabelEncoder
        # y_pred_caps = self.caps_encoder_.inverse_transform(y_pred_caps)
        # y_pred_recons = self.recons_encoder_.inverse_transform(y_pred_recons)
        return np.column_stack([y_pred_caps, y_pred_recons])
    
    def get_metadata(self):
        return {
            "n_classes_": self.n_classes_,
            "n_outputs_expected_": self.n_outputs_expected_,
        }

#### Compile a MIMOEstimator

In [23]:
def input_reshaper(X):
    return [X[:,:-10].reshape(X.shape[0],28,28,1), X[:,-10:]]

In [24]:
from sklearn.model_selection import cross_val_score
from sklearn.metrics import accuracy_score, make_scorer

# Sample Scorer if you want to define your own scorer via Grid-Search CV
def capsnet_scorer(estimator, X, y):
  y_pred = estimator.predict(X)
  y_pred_caps, y_pred_recons = y_pred[:,:10],y_pred[:,10:]
  y_caps, y_recons = y[:,:10],y[:,10:]

  return accuracy_score(y_caps, y_pred_caps)

In [25]:
class MIMOEstimator(BaseWrapper):

  @property
  def feature_encoder(self):
      return FunctionTransformer(
          func=input_reshaper,
      )

  @staticmethod
  def scorer(X, #should be y_true according to documentation but present API has this definition
             y, #should be y_pred according to documentation but present API has this definition
             **kwargs) -> float:
    y_pred_caps, y_pred_recons = y[:,:10],y[:,10:]
    y_caps, y_recons = X[:,:10],X[:,10:]

    return accuracy_score(y_caps, y_pred_caps)
  
  @property
  def target_encoder(self):
      return MultiOutputTransformer()

# Load MNIST Data

In [20]:
def load_mnist():
    # the data, shuffled and split between train and test sets
    from tensorflow.keras.datasets import mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
    x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.
    y_train = to_categorical(y_train.astype('float32'))
    y_test = to_categorical(y_test.astype('float32'))
    return (x_train, y_train), (x_test, y_test)

# Train the Model

#### Set Arguments/Parameters for the CV run

In [21]:
class arguments():
    def __init__(self, epochs = 50, batch_size=100, lr=0.001, lr_decay=0.9, lam_recon=0.392, r=3, 
                 shift_fraction=0.1, debug=True, save_dir="/content/gdrive/My Drive/EEG MLSP/Emotion Classification/outputprocesseddata", t=False, digit=5, weights=None):
        self.epochs=epochs
        self.batch_size=batch_size
        self.lr = lr
        self.lr_decay = lr_decay
        self.lam_recon = lam_recon
        self.routings = r
        self.shift_fraction = shift_fraction
        self.debug = debug
        self.save_dir = save_dir
        self.testing = t
        self.digit = digit
        self.weights = weights

args = arguments()

### Training and Teting Single Run
We are using Grid Search CV. So, we wont be using this training/testing in this iteration.

In [16]:
def train(model,  # type: models.Model
          data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc',
                                           save_best_only=True, verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})

    """
    # Training without data augmentation:
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    """

    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield (x_batch, y_batch), (y_batch, x_batch)

    # Training with data augmentation. If shift_fraction=0., no augmentation.
    model.fit(train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
              steps_per_epoch=int(y_train.shape[0] / args.batch_size),
              epochs=args.epochs,
              validation_data=((x_test, y_test), (y_test, x_test)), batch_size=args.batch_size,
                                callbacks=[log, checkpoint, lr_decay])
    # End: Training with data augmentation -----------------------------------------------------------------------#

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    
    plot_log(args.save_dir + '/log.csv', show=True)

    return model

In [17]:
def test(model, data, args):
    x_test, y_test = data
    y_pred, x_recon = model.predict(x_test, batch_size=100)
    print('-' * 30 + 'Begin: test' + '-' * 30)
    print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1)) / y_test.shape[0])

    img = combine_images(np.concatenate([x_test[:50], x_recon[:50]]))
    image = img * 255
    Image.fromarray(image.astype(np.uint8)).save(args.save_dir + "/real_and_recon.png")
    print()
    print('Reconstructed images are saved to %s/real_and_recon.png' % args.save_dir)
    print('-' * 30 + 'End: test' + '-' * 30)
    plt.imshow(plt.imread(args.save_dir + "/real_and_recon.png"))
    plt.show()

In [18]:
def manipulate_latent(model, data, args):
    print('-' * 30 + 'Begin: manipulate' + '-' * 30)
    x_test, y_test = data
    index = np.argmax(y_test, 1) == args.digit
    number = np.random.randint(low=0, high=sum(index) - 1)
    x, y = x_test[index][number], y_test[index][number]
    x, y = np.expand_dims(x, 0), np.expand_dims(y, 0)
    noise = np.zeros([1, 10, 16])
    x_recons = []
    for dim in range(16):
        for r in [-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]:
            tmp = np.copy(noise)
            tmp[:, :, dim] = r
            x_recon = model.predict([x, y, tmp])
            x_recons.append(x_recon)

    x_recons = np.concatenate(x_recons)

    img = combine_images(x_recons, height=16)
    image = img * 255
    Image.fromarray(image.astype(np.uint8)).save(args.save_dir + '/manipulate-%d.png' % args.digit)
    print('manipulated result saved to %s/manipulate-%d.png' % (args.save_dir, args.digit))
    print('-' * 30 + 'End: manipulate' + '-' * 30)

In [19]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(np.sqrt(num))
    height = int(np.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[:, :, 0]
    return image

def test(model, data):
    x_test, y_test = data
    y_pred, x_recon = model.predict([x_test, y_test], batch_size=100)
    print('-'*50)
    print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0])

    import matplotlib.pyplot as plt
    from PIL import Image

    img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))
    image = img * 255
    Image.fromarray(image.astype(np.uint8)).save("real_and_recon.png")
    print()
    print('Reconstructed images are saved to ./real_and_recon.png')
    print('-'*50)
    plt.imshow(plt.imread("real_and_recon.png", ))
    plt.show()

### Cross-Validated Training

#### Model Getter
Called from MIMOEstimator

In [26]:
def get_model(input_shape,
              n_class,
              routings,
              n_filters_c1,
              batch_size=args.batch_size,
              model_type='train'):
  
  model, eval_model, manipulate_model = CapsNet(input_shape=input_shape,
                                              n_class=n_class,
                                              routings=routings,
                                              n_filters_c1 = n_filters_c1,
                                              batch_size=batch_size)
  model.summary()

  # Plot model graph
  model_plotter(model, show_shapes=True, show_layer_names=True, to_file='model.png')
  from IPython.display import Image as ipy_img
  ipy_img(retina=True, filename='model.png')
  
  if model_type == 'train':
    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})
    return model
  elif model_type == 'test':
    return eval_model
  elif model_type == 'manipulate':
    return manipulate_model
  else:
    print ('Enter a Valid Model Type')

#### Change data into scikit format

In [27]:
def get_sciki_xy(X,y):
  X_sciki = np.column_stack([X.reshape((y.shape[0], np.prod(X.shape[1:]))), y])
  y_sciki = np.column_stack([y,X.reshape((y.shape[0], np.prod(X.shape[1:])))])

  return X_sciki,y_sciki

#### Define the function to setup cross validation and make a run

In [30]:
def get_cross_val():
  (x_train_, y_train_), (x_test_, y_test_) = load_mnist()

  x_train, y_train = get_sciki_xy(x_train_[:1000], y_train_[:1000])

  x_test, y_test = get_sciki_xy(x_test_[:100],y_test_[:100])

  # callbacks
  log = callbacks.CSVLogger(args.save_dir + '/log.csv')
  checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc',
                                          save_best_only=True, verbose=1)
  lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch))

  clf = MIMOEstimator(model = get_model,
                        model__input_shape=x_train_.shape[1:],
                        model__n_class=len(np.unique(np.argmax(y_train_, 1))),
                        model__routings=args.routings,
                        model__batch_size = args.batch_size,
                        model__n_filters_c1=256, 
                        # epochs=args.epochs, 
                        # callbacks=[log, checkpoint, lr_decay],
                        model__model_type = 'train')
  print("X input shape = ", x_train.shape)
  print("Y input shape = ", y_train.shape)
  # clf.fit(X=x_train, 
  #         y=y_train,
  #         )
  # print('Score = ', clf.score(x_test, y_test))

  params = {'model__n_filters_c1': [128,256],
            'model__routings': [4,5]}
  
  # no. of examples/cv should be completely divisible by batch_size
  gs = GridSearchCV(estimator=clf, param_grid=params, cv=5, scoring=capsnet_scorer, verbose=True)
  gs_res = gs.fit(X=x_train, 
                  y=y_train)
  print("Grid Search Results: ")
  print(gs_res)
  

In [31]:
get_cross_val()

X input shape =  (1000, 794)
Y input shape =  (1000, 794)
Fitting 5 folds for each of 4 candidates, totalling 20 fits
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


Model: "model_9"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_10 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_10[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Reshape)    (100, 1152, 8)       0           primarycap_conv2d[0][0]          
____________________________________________________________________________________________



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_12"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_13 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_13[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_15"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_16 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_16[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_18"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_19 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_19[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_21"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_22 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_22[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_24"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_25 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_25[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_27"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_28 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_28[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_30"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_31 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_31[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_33"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_34 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_34[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_36"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_37 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_37[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_39"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_40 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 256)   20992       input_40[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_42"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_43 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 256)   20992       input_43[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_45"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_46 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 256)   20992       input_46[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_48"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_49 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 256)   20992       input_49[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_51"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_52 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 256)   20992       input_52[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_54"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_55 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 256)   20992       input_55[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_57"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_58 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 256)   20992       input_58[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_60"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_61 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 256)   20992       input_61[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_63"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_64 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 256)   20992       input_64[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (800, 10)
Recons Y shape:  (800, 784)
Model: "model_66"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_67 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 256)   20992       input_67[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Resh



Shape of y-pred:  (200, 10) (200, 784)
Caps Prediction Shape:  (200, 10)
Recons Prediction Shape:  (200, 784)
Capsule Y shape:  (1000, 10)
Recons Y shape:  (1000, 784)


[Parallel(n_jobs=1)]: Done  20 out of  20 | elapsed: 12.8min finished


Model: "model_69"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_70 (InputLayer)           [(100, 28, 28, 1)]   0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (100, 20, 20, 128)   10496       input_70[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (100, 6, 6, 256)     2654464     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Reshape)    (100, 1152, 8)       0           primarycap_conv2d[0][0]          
___________________________________________________________________________________________