In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from capsnet import CapsNet

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data/')
batch_size = 125

In [None]:
tf.reset_default_graph()

tf.random.set_random_seed(0)
np.random.seed(0)

checkpoint_file = './tmp/model.ckpt'

In [None]:
import functools

def count_params():
    size = lambda v: functools.reduce(lambda x, y: x*y, v.get_shape().as_list())
    n_trainable = sum(size(v) for v in tf.trainable_variables())
    print("Model size (Trainable): {:.1f}M\n".format(n_trainable/1000000.0))

In [None]:
def train(model, restore = False, n_epochs = 50):
    init = tf.global_variables_initializer()

    n_iter_train_per_epoch = mnist.train.num_examples // batch_size
    n_iter_valid_per_epoch = mnist.validation.num_examples // batch_size
    n_iter_test_per_epoch  = mnist.test.num_examples // batch_size

    best_loss_val = np.infty
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        writer = tf.summary.FileWriter("output", sess.graph)

        if restore and tf.train.checkpoint_exists('checkpoint_file'):
            saver.restore(sess, checkpoint_file)
        else:
            init.run()

        print('\n\nRunning CapsNet ...\n')
        count_params()

        print("\ntr_loss\t: training loss")
        print("tr_ml       : training margin loss")
        print("tr_rl       : training reconstruction loss")
        print("v_loss      : validation loss")
        print("tr_acc      : training accuracy(%)")
        print("v_acc       : validation accuracy(%)")
        print("te_acc      : test accuracy(%)\n")
        
        for epoch in range(n_epochs):
            loss_train_ep = []
            margin_loss_train_ep = []
            recnst_loss_train_ep = []
            acc_train_ep  = []
            for it in range(1, n_iter_train_per_epoch + 1):
                X_batch, y_batch = mnist.train.next_batch(batch_size)
                
                _, loss_batch_train, margin_loss_train, recnst_loss_train, acc_batch_train = sess.run(
                                [model.train_op, 
                                 model.batch_loss, 
                                 model.margn_loss,
                                 model.recnst_loss_scale,
                                 model.accuracy],
                                feed_dict = {model.X: X_batch.reshape([-1, 28, 28, 1]),
                                             model.y: y_batch,
                                             model.reconstruction: True})

                print("\rIter: {}/{} [{:.1f}%] loss : {:.5f}".format(
                    it, n_iter_train_per_epoch, 100.0 * it / n_iter_train_per_epoch, loss_batch_train), end="")

                loss_train_ep.append(loss_batch_train)
                margin_loss_train_ep.append(margin_loss_train)
                recnst_loss_train_ep.append(recnst_loss_train)
                acc_train_ep.append(acc_batch_train)
                
            loss_train = np.mean(loss_train_ep)
            margin_loss_train = np.mean(margin_loss_train_ep)
            recnst_loss_train = np.mean(recnst_loss_train_ep)
            acc_train = np.mean(acc_train_ep)
            
            loss_val_ep = []
            acc_val_ep  = []

            for it in range(1, n_iter_valid_per_epoch + 1):
                X_batch, y_batch = mnist.validation.next_batch(batch_size)
                loss_batch_val, acc_batch_val = sess.run(
                                [model.batch_loss, model.accuracy],
                                feed_dict = {model.X_cropped: X_batch.reshape([-1, 28, 28, 1]),
                                             model.y: y_batch,
                                             model.reconstruction: False})

                loss_val_ep.append(loss_batch_val)
                acc_val_ep.append(acc_batch_val)

                print("\rValidation {}/{} {:.1f}%".format(it, n_iter_valid_per_epoch, 100.0 * it / n_iter_valid_per_epoch), end=" "*30)

            loss_val = np.mean(loss_val_ep)
            acc_val  = np.mean(acc_val_ep)
            
            acc_test_ep  = []
            
            for it in range(1, n_iter_test_per_epoch + 1):
                X_batch, y_batch = mnist.test.next_batch(batch_size)
                acc_batch_test = sess.run(model.accuracy,
                                    feed_dict = { model.X_cropped: X_batch.reshape([-1, 28, 28, 1]),
                                        model.y: y_batch,
                                        model.reconstruction: False})
                acc_test_ep.append(acc_batch_test)
                
                print("\rTesting {}/{} {:.1f}%".format(it, n_iter_test_per_epoch, 100.0 * it / n_iter_test_per_epoch), end=" "*30)
                
            acc_test  = np.mean(acc_test_ep)
            
            print("\rEp {:2d}: tr_loss:{:.4f}, tr_ml:{:.4f}, tr_rl:{:.4f}, v_loss:{:.4f}, tr_acc:{:.3f}, v_acc:{:.2f}, te_acc: {:.2f}".format(
                epoch + 1, 
                loss_train, 
                margin_loss_train,
                recnst_loss_train,
                loss_val, 
                acc_train * 100.0, 
                acc_val * 100.0, 
                acc_test * 100.0))

            saver.save(sess, checkpoint_file)
            

        writer.close()

In [None]:
tf.reset_default_graph()
batch_size = 125
model = CapsNet(rounds = 3, alpha = 0.0001, batch_size=batch_size, reconstruction_net = True)

In [None]:
train(model, False, 50)

In [None]:
test(model)