In [123]:
import numpy as np
import os
import pandas as pd
import keras
from keras.datasets import mnist
import keras.backend as K
import tensorflow as tf
from keras import initializers, layers, callbacks
import os

In [126]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.45
K.tensorflow_backend.set_session(tf.Session(config=config))

### Get MNIST

In [127]:
NCLASSES = 10
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes=NCLASSES)
y_test = keras.utils.to_categorical(y_test, num_classes=NCLASSES)
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
input_shape = (28, 28, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

### Helper functions

In [145]:
class Length(layers.Layer):
    def call(self, inputs, **kwargs):
        return K.sqrt(K.sum(K.square(inputs), -1))
    def compute_output_shape(self, input_shape):
        return input_shape[:-1]
def squash(inputs):
    s_squared_norm = K.sum(K.square(inputs), axis=-1, keepdims=True)
    scale = (s_squared_norm / (1 + s_squared_norm))
    return scale * inputs / K.sqrt(s_squared_norm)    
def margin_loss(y_true, y_pred):
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) +  0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
    return K.mean(K.sum(L, 1))

### Capsule network


#### capsule layer

In [159]:
class DigitCaps(layers.Layer):
    
    def __init__(self, num_capsule, dim_vector, num_routing=3,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 name='digitcaps'):
        super(DigitCaps, self).__init__(name=name) # só  pra  colocar nome nesta merda
        self.num_capsule = num_capsule
        self.dim_vector = dim_vector
        self.num_routing = num_routing
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
    def build(self, input_shape):
        # shoud be [None,input_num_capsule,input_dim_vector]
        self.input_num_capsule = input_shape[1]
        self.input_dim_vector = input_shape[2]
        # Create a trainable weight variable for this layer.
        self.W_ij = self.add_weight(shape=[self.input_num_capsule, self.num_capsule, self.input_dim_vector, self.dim_vector],
                                    initializer=self.kernel_initializer,name='W_ij')
        self.b_i = self.add_weight(shape=[1, self.input_num_capsule, self.num_capsule, 1, 1],
                                   initializer=self.bias_initializer,name='b_i', trainable=False)
    def call(self, inputs, training=None):
        # [None, input_num_capsule, input_dim_vector]
        u_i = K.expand_dims(K.expand_dims(inputs, 2), 2)# [None, input_num_capsule, 1, 1, input_dim_vector]
        u_i = K.tile(u_i, [1, 1, self.num_capsule, 1, 1]) # [None, input_num_capsule, num_capsule, 1, input_dim_vector]
        u_hat_ji = tf.scan(lambda ac, x: K.batch_dot(x, self.W_ij, [3, 2]), elems=u_i,
                             initializer=K.zeros([self.input_num_capsule, self.num_capsule, 1, self.dim_vector]))
        for i in range(self.num_routing):
            c_ij = tf.nn.softmax(self.b_i, dim=2)  
            v_j = squash(K.sum(c_ij * u_hat_ji, 1, keepdims=True))
            if i != self.num_routing - 1:
                self.b_i += K.sum(u_hat_ji * v_j, -1, keepdims=True)
        return K.reshape(v_j, [-1, self.num_capsule, self.dim_vector])
    def compute_output_shape(self, input_shape):
        return tuple([None, self.num_capsule, self.dim_vector])


#### primary caps


In [160]:
def PrimaryCap(inputs, dim_vector, n_channels, kernel_size, strides, padding):
    output = layers.Conv2D(filters=dim_vector*n_channels, kernel_size=kernel_size, strides=strides, padding=padding)(inputs)
    outputs = layers.Reshape(target_shape=[-1, dim_vector])(output)
    return layers.Lambda(squash)(outputs)

### CapsNet

In [161]:
input_shape=[28, 28, 1]
num_routing=3
x = layers.Input(shape=input_shape)
# Conv1
conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)
# primary caps
primarycaps = PrimaryCap(conv1, dim_vector=8, n_channels=32, kernel_size=9, strides=2, padding='valid')
# digit caps
digitcaps = DigitCaps(num_capsule=NCLASSES, dim_vector=16, num_routing=num_routing,name='digitcaps')(primarycaps)
# predicted label
#     v_j_abs = keras.layers.Lambda(Length)(digitcaps)
v_j_abs = Length(name='vj_abs')(digitcaps)
## Reconstruction for regulariation
y = layers.Input(shape=(NCLASSES,))

def Mask(inputs): 
    return K.batch_dot(inputs[0], inputs[1], [1, 1])

masked = layers.Lambda(Mask)([digitcaps, y])
# masked = Mask()([digitcaps, y])
l = layers.Dense(512, activation='relu')(masked)
l = layers.Dense(1024, activation='relu')(l)
l = layers.Dense(784, activation='sigmoid')(l)
print(l.shape)
decoded = layers.Reshape(target_shape=[28, 28, 1], name='dsf')(l)
print(decoded)
model = keras.models.Model([x, y], [v_j_abs, decoded])
model.summary()

(?, 784)
Tensor("dsf_7/Reshape:0", shape=(?, 28, 28, 1), dtype=float32)
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_72 (InputLayer)            (None, 28, 28, 1)     0                                            
____________________________________________________________________________________________________
conv1 (Conv2D)                   (None, 20, 20, 256)   20992       input_72[0][0]                   
____________________________________________________________________________________________________
conv2d_40 (Conv2D)               (None, 6, 6, 256)     5308672     conv1[0][0]                      
____________________________________________________________________________________________________
reshape_40 (Reshape)             (None, 1152, 8)       0           conv2d_40[0][0]                  
___________________

In [165]:
def train(model, data, epocs):
    (x_train, y_train), (x_test, y_test) = data
    model.compile(optimizer=keras.optimizers.Adam(lr=0.001, decay=0.0001),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., 28*28*0.0005],
                  metrics={'v_j_abs': 'accuracy'})    
    model.fit([x_train, y_train], [y_train, x_train], batch_size=256, epochs=epocs,
              validation_data=[[x_test, y_test], [y_test, x_test]])

In [None]:
train(model=model, data=((x_train, y_train), (x_test, y_test)), epocs=20)