In [17]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from capsnet import CapsNet

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data/')
batch_size = 100

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [18]:
tf.reset_default_graph()

tf.random.set_random_seed(0)
np.random.seed(0)

checkpoint_file = './tmp/model.ckpt'

In [19]:
import functools

def count_params():
    size = lambda v: functools.reduce(lambda x, y: x*y, v.get_shape().as_list())
    n_trainable = sum(size(v) for v in tf.trainable_variables())
    print("Model size (Trainable): {:.1f}M\n".format(n_trainable/1000000.0))

In [20]:
def train(model, restore = False, n_epochs = 50):
    init = tf.global_variables_initializer()	

    n_iter_train_per_epoch = mnist.train.num_examples // batch_size
    n_iter_valid_per_epoch = mnist.validation.num_examples // batch_size

    best_loss_val = np.infty
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        writer = tf.summary.FileWriter("output", sess.graph)

        if restore and tf.train.checkpoint_exists('checkpoint_file'):
            saver.restore(sess, checkpoint_file)
        else:
            init.run()

        print('\n\nRunning CapsNet ...\n')
        count_params()

        print("\nt_loss : training loss")
        print("t_ml   : training margin loss")
        print("t_rl   : training reconstruction loss")
        print("v_loss : validation loss")
        print("t_acc  : training accuracy")
        print("v_acc  : validation accuracy\n")
        
        for epoch in range(n_epochs):
            loss_train_ep = []
            margin_loss_train_ep = []
            recnst_loss_train_ep = []
            acc_train_ep  = []
            for it in range(1, n_iter_train_per_epoch + 1):
                X_batch, y_batch = mnist.train.next_batch(batch_size)

                _, loss_batch_train, margin_loss_train, recnst_loss_train, acc_batch_train = sess.run(
                                [model.train_op, 
                                 model.batch_loss, 
                                 model.margn_loss,
                                 model.recnst_loss_scale,
                                 model.accuracy],
                                feed_dict = {model.X: X_batch.reshape([-1, 28, 28, 1]),
                                                model.y: y_batch,
                                                model.reconstruction: True})

                print("\rIter: {}/{} [{:.1f}%] loss : {:.5f}".format(
                    it, n_iter_train_per_epoch, 100.0 * it / n_iter_train_per_epoch, loss_batch_train), end="")

                loss_train_ep.append(loss_batch_train)
                margin_loss_train_ep.append(margin_loss_train)
                recnst_loss_train_ep.append(recnst_loss_train)
                acc_train_ep.append(acc_batch_train)
                
            loss_train = np.mean(loss_train_ep)
            margin_loss_train = np.mean(margin_loss_train_ep)
            recnst_loss_train = np.mean(recnst_loss_train_ep)
            acc_train = np.mean(acc_train_ep)
            
            loss_val_ep = []
            acc_val_ep  = []

            for it in range(1, n_iter_valid_per_epoch + 1):
                X_batch, y_batch = mnist.validation.next_batch(batch_size)
                loss_batch_val, acc_batch_val = sess.run(
                                [model.batch_loss, model.accuracy],
                                feed_dict = {model.X_cropped: X_batch.reshape([-1, 28, 28, 1]),
                                                model.y: y_batch})

                loss_val_ep.append(loss_batch_val)
                acc_val_ep.append(acc_batch_val)

                print("\rValidation {}/{} {:.1f}%".format(it, n_iter_valid_per_epoch, 100.0 * it / n_iter_valid_per_epoch), end=" "*30)

            loss_val = np.mean(loss_val_ep)
            acc_val  = np.mean(acc_val_ep)
            
            print("\rep: {3d} t_loss: {:.5f}, t_ml: {:.5f}, t_rl: {:.5f}, v_loss: {:.5f}, t_acc: {:.4f}%, v_acc: {:.4f}% {}".format(
                epoch + 1, 
                loss_train, 
                margin_loss_train,
                recnst_loss_train,
                loss_val, 
                acc_train * 100.0, 
                acc_val * 100.0, 
                "(imp)" if loss_val < best_loss_val else ""))

            if loss_val < best_loss_val:
                best_loss_val = loss_val
            saver.save(sess, checkpoint_file)
            

        writer.close()

In [21]:
model = CapsNet(rounds = 3)

In [22]:
train(model, False, 30)



Running CapsNet ...

Model size (Trainable): 8.2M

epoch: 1 loss_train: 0.15213, margin_loss: 0.12496, recnst_loss: 0.02717, loss_val: 0.04223, train_acc: 86.3327%, valid_acc: 98.4600% (imp)
epoch: 2 loss_train: 0.04533, margin_loss: 0.02245, recnst_loss: 0.02288, loss_val: 0.03299, train_acc: 98.2255%, valid_acc: 98.8200% (imp)
epoch: 3 loss_train: 0.03630, margin_loss: 0.01568, recnst_loss: 0.02062, loss_val: 0.02831, train_acc: 98.7236%, valid_acc: 99.0800% (imp)
epoch: 4 loss_train: 0.03059, margin_loss: 0.01190, recnst_loss: 0.01869, loss_val: 0.02663, train_acc: 99.0455%, valid_acc: 99.1600% (imp)
epoch: 5 loss_train: 0.02768, margin_loss: 0.01030, recnst_loss: 0.01738, loss_val: 0.02410, train_acc: 99.2127%, valid_acc: 99.1600% (imp)
epoch: 6 loss_train: 0.02453, margin_loss: 0.00854, recnst_loss: 0.01599, loss_val: 0.02095, train_acc: 99.3491%, valid_acc: 99.3600% (imp)
epoch: 7 loss_train: 0.02260, margin_loss: 0.00761, recnst_loss: 0.01499, loss_val: 0.02062, train_acc: 99.

In [23]:
def test(model):
    n_iter_test_per_epoch = mnist.test.num_examples // batch_size

    loss_test_ep = []
    acc_test_ep  = []
    #init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        #init.run()
        #saver = tf.train.import_meta_graph(checkpoint_file +'.meta')
        saver.restore(sess, tf.train.latest_checkpoint('tmp/'))

        #init.run()
        print('\n\nTest\n')
        for it in range(1, n_iter_test_per_epoch + 1):
            X_batch, y_batch = mnist.test.next_batch(batch_size)
            loss_batch_test, acc_batch_test = sess.run(
                                [model.batch_loss, model.accuracy],
                                feed_dict = { model.X_cropped: X_batch.reshape([-1, 28, 28, 1]),
                                    model.y: y_batch,
                                    model.reconstruction: False})

            loss_test_ep.append(loss_batch_test)
            acc_test_ep.append(acc_batch_test)
            print("\rTesting {}/{} {:.1f}%".format(it,
                                            n_iter_test_per_epoch,
                                            100.0 * it / n_iter_test_per_epoch),
                                            end=" "*30)

        loss_test = np.mean(loss_test_ep)
        acc_test  = np.mean(acc_test_ep)

        print("\r(Testing) accuracy: {:.3f}%, loss: {:.4f}".format(acc_test*100.0, loss_test))

In [24]:
test(model)

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from tmp/model.ckpt


Test

(Testing) accuracy: 99.640%, loss: 0.0116           


In [25]:
!git status

On branch master
Your branch and 'origin/master' have diverged,
and have 1 and 1 different commits each, respectively.
  (use "git pull" to merge the remote branch into yours)
You have unmerged paths.
  (fix conflicts and run "git commit")
  (use "git merge --abort" to abort the merge)

Changes to be committed:

	[32mmodified:   capsules.py[m
	[32mmodified:   main.py[m

Unmerged paths:
  (use "git add <file>..." to mark resolution)

	[31mboth modified:   capsnet-nb.ipynb[m
	[31mboth modified:   capsnet.py[m

Untracked files:
  (use "git add <file>..." to include in what will be committed)

	[31mtest-nb.ipynb[m



In [None]:
!git add .