In [13]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from capsnet import CapsNet

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data/')
batch_size = 125

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [14]:
tf.reset_default_graph()

tf.random.set_random_seed(0)
np.random.seed(0)

checkpoint_file = './tmp/model.ckpt'

In [15]:
import functools

def count_params():
    size = lambda v: functools.reduce(lambda x, y: x*y, v.get_shape().as_list())
    n_trainable = sum(size(v) for v in tf.trainable_variables())
    print("Model size (Trainable): {:.1f}M\n".format(n_trainable/1000000.0))

In [19]:
def train(model, restore = False, n_epochs = 50):
    init = tf.global_variables_initializer()	

    n_iter_train_per_epoch = mnist.train.num_examples // batch_size
    n_iter_valid_per_epoch = mnist.validation.num_examples // batch_size
    n_iter_test_per_epoch  = mnist.test.num_examples // batch_size

    best_loss_val = np.infty
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        writer = tf.summary.FileWriter("output", sess.graph)

        if restore and tf.train.checkpoint_exists('checkpoint_file'):
            saver.restore(sess, checkpoint_file)
        else:
            init.run()

        print('\n\nRunning CapsNet ...\n')
        count_params()

        print("\ntr_loss   : training loss(margin loss)")
        print("val_loss  : validation loss")
        print("train_acc : training accuracy(%)")
        print("val_acc   : validation accuracy(%)\n")
        
        recnst_loss_train = 0.0
        
        for epoch in range(n_epochs):
            loss_train_ep = []
            acc_train_ep  = []
            for it in range(1, n_iter_train_per_epoch + 1):
                X_batch, y_batch = mnist.train.next_batch(batch_size)
                
                _, loss_batch_train, acc_batch_train = sess.run(
                                [model.train_op, 
                                 model.batch_loss,
                                 model.accuracy],
                                feed_dict = {model.X: X_batch.reshape([-1, 28, 28, 1]),
                                                model.y: y_batch})

                print("\rIter: {}/{} [{:.1f}%] loss : {:.5f}".format(
                    it, n_iter_train_per_epoch, 100.0 * it / n_iter_train_per_epoch, loss_batch_train), end="")

                loss_train_ep.append(loss_batch_train)
                acc_train_ep.append(acc_batch_train)
                
            loss_train = np.mean(loss_train_ep)
            acc_train = np.mean(acc_train_ep)
            
            loss_val_ep = []
            acc_val_ep  = []

            for it in range(1, n_iter_valid_per_epoch + 1):
                X_batch, y_batch = mnist.validation.next_batch(batch_size)
                loss_batch_val, acc_batch_val = sess.run(
                                [model.batch_loss, model.accuracy],
                                feed_dict = {model.X_cropped: X_batch.reshape([-1, 28, 28, 1]),
                                                model.y: y_batch})

                loss_val_ep.append(loss_batch_val)
                acc_val_ep.append(acc_batch_val)

                print("\rValidation {}/{} {:.1f}%".format(it, n_iter_valid_per_epoch, 100.0 * it / n_iter_valid_per_epoch), end=" "*30)

            loss_val = np.mean(loss_val_ep)
            acc_val  = np.mean(acc_val_ep)
            
            print("\rEp {:2d}: train_loss:{:.4f}, valid_loss:{:.4f}, train_acc:{:.3f}, val_acc:{:.2f}".format(
                epoch + 1, 
                loss_train,
                loss_val, 
                acc_train * 100.0, 
                acc_val * 100.0))

            saver.save(sess, checkpoint_file)
            

        writer.close()

In [17]:
model = CapsNet(rounds = 3, batch_size = batch_size,reconstruction_net = False)

In [20]:
train(model, False, 50)



Running CapsNet ...

Model size (Trainable): 6.8M


tr_loss   : training loss(margin loss)
val_loss  : validation loss
train_acc : training accuracy(%)
val_acc   : validation accuracy(%)

Ep  1: train_loss:0.1494, valid_loss:0.0253, train_acc:83.616, val_acc:97.92
Ep  2: train_loss:0.0270, valid_loss:0.0138, train_acc:97.867, val_acc:98.94
Ep  3: train_loss:0.0179, valid_loss:0.0120, train_acc:98.605, val_acc:99.10
Ep  4: train_loss:0.0136, valid_loss:0.0091, train_acc:98.951, val_acc:99.26
Ep  5: train_loss:0.0111, valid_loss:0.0081, train_acc:99.124, val_acc:99.38
Ep  6: train_loss:0.0098, valid_loss:0.0082, train_acc:99.227, val_acc:99.28
Ep  7: train_loss:0.0082, valid_loss:0.0070, train_acc:99.389, val_acc:99.38
Ep  8: train_loss:0.0073, valid_loss:0.0067, train_acc:99.444, val_acc:99.44
Ep  9: train_loss:0.0069, valid_loss:0.0062, train_acc:99.496, val_acc:99.46
Ep 10: train_loss:0.0061, valid_loss:0.0055, train_acc:99.551, val_acc:99.50
Ep 11: train_loss:0.0057, valid_loss:0.0

In [21]:
def test(model):
    n_iter_test_per_epoch = mnist.test.num_examples // batch_size

    loss_test_ep = []
    acc_test_ep  = []

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint('tmp/'))
        print('\n\nTest\n')
        
        for it in range(1, n_iter_test_per_epoch + 1):
            X_batch, y_batch = mnist.test.next_batch(batch_size)
            loss_batch_test, acc_batch_test = sess.run(
                                [model.batch_loss, model.accuracy],
                                feed_dict = { model.X_cropped: X_batch.reshape([-1, 28, 28, 1]),
                                    model.y: y_batch})
                                    #model.reconstruction: False})

            loss_test_ep.append(loss_batch_test)
            acc_test_ep.append(acc_batch_test)
            print("\rTesting {}/{} {:.1f}%".format(it,
                                            n_iter_test_per_epoch,
                                            100.0 * it / n_iter_test_per_epoch),
                                            end=" "*30)

        loss_test = np.mean(loss_test_ep)
        acc_test  = np.mean(acc_test_ep)

        print("\r(Testing) accuracy: {:.3f}%, loss: {:.4f}".format(acc_test*100.0, loss_test))

In [22]:
test(model)

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from tmp/model.ckpt


Test

(Testing) accuracy: 99.650%, loss: 0.0042         
