In [1]:
import numpy as np
import tensorflow as tf
import os
import struct

  return f(*args, **kwds)


In [40]:
class MNIST():
    def __init__(self, directory):
        self._directory = directory
        
        self._training_data = self._load_binaries("./mnist_data/train-images.idx3-ubyte")
        self._training_labels = self._load_binaries("./mnist_data/train-labels.idx1-ubyte")
        self._test_data = self._load_binaries("./mnist_data/t10k-images.idx3-ubyte")
        self._test_labels = self._load_binaries("./mnist_data/t10k-labels.idx1-ubyte")
        
        np.random.seed(0)
        samples_n = self._training_labels.shape[0]
        random_indices = np.random.choice(samples_n, samples_n // 10, replace = False)
        np.random.seed()
        
        self._validation_data = self._training_data[random_indices]
        self._validation_labels = self._training_labels[random_indices]
        self._training_data = np.delete(self._training_data, random_indices, axis = 0)
        self._training_labels = np.delete(self._training_labels, random_indices)
    
    def _load_binaries(self, file_name):
        path = os.path.join(self._directory, file_name)
        
        with open(path, 'rb') as fd:
            check, items_n = struct.unpack(">ii", fd.read(8))

            if "images" in file_name and check == 2051:
                height, width = struct.unpack(">II", fd.read(8))
                images = np.fromfile(fd, dtype = 'uint8')
                return np.reshape(images, (items_n, height, width))
            elif "labels" in file_name and check == 2049:
                return np.fromfile(fd, dtype = 'uint8')
            else:
                raise ValueError("Not a MNIST file: " + path)
    
    
    def get_training_batch(self, batch_size):
        return self._get_batch(self._training_data, self._training_labels, batch_size)
    
    def get_validation_batch(self, batch_size):
        return self._get_batch(self._validation_data, self._validation_labels, batch_size)
    
    def get_test_batch(self, batch_size):
        return self._get_batch(self._test_data, self._test_labels, batch_size)
    
    def _get_batch(self, data, labels, batch_size):
        samples_n = labels.shape[0]
        if batch_size <= 0:
            batch_size = samples_n
        
        random_indices = np.random.choice(samples_n, samples_n, replace = False)
        data = data[random_indices]
        labels = labels[random_indices]
        for i in range(samples_n // batch_size):
            on = i * batch_size
            off = on + batch_size
            yield data[on:off], labels[on:off]
    
    
    def get_sizes(self):
        training_samples_n = self._training_labels.shape[0]
        validation_samples_n = self._validation_labels.shape[0]
        test_samples_n = self._test_labels.shape[0]
        return training_samples_n, validation_samples_n, test_samples_n

In [41]:
#STORING THE DATA
mnist_data = MNIST('.')

In [44]:
#INVESTIGATING THE DATA
for i, batch in enumerate(mnist_data.get_training_batch(1)):
    print(batch)
    if i > 1:
        break

(array([[[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         110, 138,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         178, 253,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           

In [65]:
#HYPERPARAMETERS

#TRAINING PARAMS
batch_size = 128
learning_rate = 0.0001

#INPUT
input_channels = 1 #mnist is grey value 
input_size = 28 #mnist is 28x28 pics

#CONVOLUTIONAL LAYER 1
kernel_size_conv1 = 9
stride_conv1 = 1
channels_conv1 = 256

#PRIMARY CAPSULES
kernel_size_conv2 = 9
dim_primary_caps = 8 #primary capsules are 8-D
channels_primary_caps = 32
no_neurons_primary_caps = 6*6*channels_primary_caps

#DIGIT CAPSULE
dim_digits_caps = 16 #capsules for digits are 16-D
no_output_classes = 10 #mnist depicts 10 numbers from 0 to 9

#LOSS
m_plus = 0.9
m_minus = 0.1
lamb = 0.5


In [7]:
#FUNCTIONS GET WEIGHTS AND BIASES
def get_weights(shape):
    return tf.Variable(tf.truncated_normal(shape, stddev=0.1))

def get_biases(shape):
    return tf.Variable(tf.constant(1.0, shape=shape))

In [8]:
#FUNCTIONS FOR BUILDING THE NETWORK

def conv_relu(x, weights, biases, stride):
    conv = tf.nn.conv2d(x, weights, strides=[1, stride, stride, 1], padding='VALID')
    return tf.nn.relu(conv + biases)

def conv(x, weights, biases, stride):
    conv = tf.nn.conv2d(x, weights, strides=[1, stride, stride, 1], padding='VALID')
    return conv + biases

In [80]:
#squashing function from paper
def squash(tensor, axis):
    
    #tensor with same dimensions as the tensor with the length of the 
    #vector along the specified axis stored in every component of this
    #vector, norm is the euclidean norm here
    
    #shape = tensor.get_shape()
    #ones_tensor = tf.ones(shape=shape, dtype=tf.float32)
    
    norm = tf.norm(tensor, keep_dims=True, axis=axis)
    normed_tensor = tensor/norm
    #norm_sq = tf.multiply(norm, norm)
    #denom = tf.add(ones_tensor, norm_sq)
    
    #factor = tf.divide(norm_sq, denom)
    #scaled_tensor = tf.divide(tensor, norm)
    
    squashing_factor = norm**2/(1+norm**2)
    
    return squashing_factor * normed_tensor
    
    
    
    
    
    

In [86]:
#routing algorithm 
def routing(prediction_vectors, iterations=3):
    
    #intitialize the logits b
    #log priors can be learned ?!?!?!?!?
    #for mnist they found its sufficient to set them to zero
    b = tf.zeros(shape=[batch_size,
                        no_neurons_primary_caps,
                        no_output_classes]
                )
    
    for i in range(iterations):
        
        #compute the softmax 
        c = tf.nn.softmax(b)
        c = tf.expand_dims(c, axis=-1)
        c = tf.tile(c, [1, 1, 1, dim_digits_caps])
        
        #compute the input
        #s = tf.multiply(c, prediction_vectors)
        #s = tf.reduce_sum(s, axis=1)
        s = tf.reduce_sum(c*prediction_vectors, axis=1)
        
        #compute the output
        v = squash(s, axis=2)
        
        #compute the agreement
        v_exp = tf.expand_dims(v, axis=1)
        v_exp = tf.tile(v_exp, [1, no_neurons_primary_caps, 1, 1])
        #a = tf.multiply(prediction_vectors, v_exp)
        #a = tf.reduce_sum(a, axis=-1)
        a = tf.reduce_sum(prediction_vectors*v_exp, axis=-1)
        
        #updating the logits
        b = b+a
        
    return v
        
    
        
    
    
    

In [87]:
#DESCIRBING THE DATAFLOW GRAPH
tf.reset_default_graph()

images = tf.placeholder(dtype=tf.float32, shape=[batch_size,
                                             input_size,
                                             input_size])
images_exp = tf.expand_dims(images, axis=-1)
labels = tf.placeholder(dtype=tf.int64, shape=[batch_size])


with tf.variable_scope('ReLU_Conv1'):
    
    weights = get_weights(shape = [kernel_size_conv1,
                                    kernel_size_conv1,
                                    input_channels,
                                    channels_conv1]
                            )
   
    biases = get_biases(shape = [channels_conv1])
    
    conv1 = conv_relu(images_exp, weights, biases, stride=1)
    
    

with tf.variable_scope('Primary_Caps'):
    
    weights = get_weights(shape = [kernel_size_conv2,
                                    kernel_size_conv2,
                                    channels_conv1,
                                    dim_primary_caps*channels_primary_caps]
                         )
    
    biases = get_biases(shape = [dim_primary_caps*channels_primary_caps])
    
    #biases? relu? try to exclude it?
    #paper doesnt state clear if relu is used. experiments online showed
    # better results with relu, but Id say paper says not to use relu
    primary_caps = conv(conv1, weights, biases, stride=2)
    primary_caps = tf.reshape(primary_caps, shape=[batch_size,6,6,
                                                   channels_primary_caps,
                                                   dim_primary_caps]
                             )
    primary_caps = squash(primary_caps, axis=4)
    
    
with tf.variable_scope('Digit_Caps'):
    
    weights = get_weights(shape = [1, no_neurons,
                                   no_output_classes,
                                   dim_primary_caps,
                                   dim_digits_caps]
                          )
    
    #get primary caps into the right dims for matrix multiplication
    primary_caps = tf.reshape(primary_caps, shape=[batch_size,
                                                  no_neurons_primary_caps,
                                                  1,
                                                  dim_primary_caps,
                                                  1])
    primary_caps = tf.tile(primary_caps, [1, 1, 10, 1, 1])
    weights = tf.tile(weights, [batch_size, 1, 1, 1, 1])
    
    prediction_vectors = tf.matmul(weights, primary_caps, transpose_a=True)
    prediction_vectors = tf.squeeze(prediction_vectors)
    
    digit_caps = routing(prediction_vectors)
    
    


with tf.variable_scope('Loss'):
    
    length_digit_caps = tf.norm(digit_caps, axis = 2)
    labels_one_hot = tf.one_hot(labels, depth=10)
    labels_one_hot_inv = tf.subtract(tf.ones(shape=[batch_size,10]),
                                      labels_one_hot)
    
    
    m_plus_tsr = tf.constant(m_plus, shape = [batch_size, 10])
    m_minus_tsr = tf.constant(m_minus, shape = [batch_size, 10])
    
    
    #relu for max(0, margin_loss)
    plus_loss = tf.nn.relu(tf.subtract(m_plus_tsr, length_digit_caps))
    minus_loss = tf.nn.relu(tf.subtract(length_digit_caps, m_minus_tsr))
    
    plus_loss = tf.multiply(labels_one_hot, plus_loss)
    minus_loss = tf.multiply(tf.scalar_mul(lamb, labels_one_hot_inv), minus_loss)
    
    loss = tf.add(plus_loss, minus_loss)
    loss = tf.reduce_sum(loss, axis=-1)
    loss = tf.reduce_mean(loss)
    
    correct_prediction = tf.equal(tf.argmax(length_digit_caps, 1), labels)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    


    
with tf.variable_scope('Optimizer'):
    optimizer = tf.train.AdamOptimizer(learning_rate)
    training_step = optimizer.minimize(loss)
    
    
#SUMMARIES
tf.summary.scalar('loss', loss)

merged_summaries = tf.summary.merge_all()
    
    
    
    
    
    
    
    
    
    
    
    
    
    

    
    
    
    

(128, 6, 6, 256)
(128, 6, 6, 32, 8)
(128, 6, 6, 32, 8)


In [88]:
train_writer = tf.summary.FileWriter("./summaries/train", tf.get_default_graph())

with tf.Session() as sess:
    step = 0
    sess.run(tf.global_variables_initializer())
    
    for epoch in range(10):
        batch_generator = mnist_data.get_training_batch(batch_size)
        
        for x, y in batch_generator:
            _loss, _accuracy, _summaries, _ = sess.run([loss,
                                    accuracy,
                                    merged_summaries, 
                                    training_step],
                                    feed_dict = {images: x,
                                                labels: y})
            train_writer.add_summary(_summaries, step)
            step += 1
            print("Loss: {}, Accuracy: {}".format(_loss, _accuracy))

Loss: 3.2933812141418457, Accuracy: 0.09375
Loss: 3.164635181427002, Accuracy: 0.125
Loss: 3.05485200881958, Accuracy: 0.1015625


KeyboardInterrupt: 

In [79]:
aa = tf.constant(1.0, shape=[3,3])
bb = 1.5+aa
with tf.Session() as sess:
    
    print(bb.eval())

[[ 2.5  2.5  2.5]
 [ 2.5  2.5  2.5]
 [ 2.5  2.5  2.5]]
