# ProtoNN in Tensorflow

This is a simple notebook that illustrates the usage of Tensorflow implementation of ProtoNN. We are using the USPS dataset. Please refer to `fetch_usps.py` for more details of downloading the dataset.

In [1]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

from __future__ import print_function
import sys
import os
import numpy as np
import tensorflow as tf

sys.path.insert(0, '../../')
from edgeml.trainer.protoNNTrainer import ProtoNNTrainer
from edgeml.graph.protoNN import ProtoNN
import edgeml.utils as utils
import helpermethods as helper

# USPS Data

It is assumed that the USPS data has already been downloaded and set up with the help of [fetch_usps.py](fetch_usps.py) and is placed in the `./usps10` subdirectory.

In [4]:
# Load data
out = helper.loadData(DATA_DIR)
dataDimension = out[0]
numClasses = out[1]
x_train, y_train = out[2], out[3]
x_test, y_test = out[4], out[5]

# Model Parameters

In [3]:
DATA_DIR = './curet'
PROJECTION_DIM = 60
NUM_PROTOTYPES = 80
REG_W = 0.000005
REG_B = 0.0
REG_Z = 1.0
SPAR_W = 0.8
SPAR_B = 1.0
SPAR_Z = 1.0
LEARNING_RATE = 0.05
NUM_EPOCHS = 800
GAMMA = None

In [5]:
W, B, gamma = helper.getGamma(GAMMA, PROJECTION_DIM, dataDimension,
                       NUM_PROTOTYPES, x_train)

Using median heuristic to estimate gamma.
Gamma estimate is: 0.001576




In [None]:
# Setup input and train protoNN
X = tf.placeholder(tf.float32, [None, dataDimension], name='X')
Y = tf.placeholder(tf.float32, [None, numClasses], name='Y')
protoNN = ProtoNN(dataDimension, PROJECTION_DIM,
                  NUM_PROTOTYPES, numClasses,
                  gamma, W=W, B=B)
trainer = ProtoNNTrainer(protoNN, REG_W, REG_B, REG_Z,
                         SPAR_W, SPAR_B, SPAR_Z,
                         LEARNING_RATE, X, Y, lossType='xentropy')
sess = tf.Session()
trainer.train(16, NUM_EPOCHS, sess, x_train, x_test, y_train, y_test,
              printStep=200)



Epoch:   0 Batch:   0 Loss: 2302.15259 Accuracy: 0.00000
Epoch:   0 Batch: 200 Loss: 4.23385 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 4.27443 Accuracy: 0.00000
Epoch:   1 Batch: 200 Loss: 4.22948 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 4.28290 Accuracy: 0.00000
Epoch:   2 Batch: 200 Loss: 4.23013 Accuracy: 0.00000
Test Loss: 4.31272 Accuracy: 0.01643
Epoch:   3 Batch:   0 Loss: 4.27636 Accuracy: 0.00000
Epoch:   3 Batch: 200 Loss: 4.24300 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 4.26667 Accuracy: 0.00000
Epoch:   4 Batch: 200 Loss: 4.25892 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 4.25636 Accuracy: 0.00000
Epoch:   5 Batch: 200 Loss: 4.27118 Accuracy: 0.00000
Test Loss: 4.31867 Accuracy: 0.01643
Epoch:   6 Batch:   0 Loss: 4.25208 Accuracy: 0.00000
Epoch:   6 Batch: 200 Loss: 4.27776 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 4.25267 Accuracy: 0.00000
Epoch:   7 Batch: 200 Loss: 4.28119 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 4.25397 Accuracy: 0.00000
Epoch

Epoch:  68 Batch: 200 Loss: 4.16465 Accuracy: 0.31250
Test Loss: 4.23129 Accuracy: 0.03854
Epoch:  69 Batch:   0 Loss: 4.19383 Accuracy: 0.06250
Epoch:  69 Batch: 200 Loss: 4.16498 Accuracy: 0.31250
Epoch:  70 Batch:   0 Loss: 4.19334 Accuracy: 0.06250
Epoch:  70 Batch: 200 Loss: 4.16435 Accuracy: 0.25000
Epoch:  71 Batch:   0 Loss: 4.19335 Accuracy: 0.06250
Epoch:  71 Batch: 200 Loss: 4.16346 Accuracy: 0.25000
Test Loss: 4.22754 Accuracy: 0.03423
Epoch:  72 Batch:   0 Loss: 4.19308 Accuracy: 0.06250
Epoch:  72 Batch: 200 Loss: 4.16202 Accuracy: 0.25000
Epoch:  73 Batch:   0 Loss: 4.19306 Accuracy: 0.06250
Epoch:  73 Batch: 200 Loss: 4.16138 Accuracy: 0.25000
Epoch:  74 Batch:   0 Loss: 4.19310 Accuracy: 0.06250
Epoch:  74 Batch: 200 Loss: 4.16081 Accuracy: 0.25000
Test Loss: 4.22419 Accuracy: 0.03423
Epoch:  75 Batch:   0 Loss: 4.19346 Accuracy: 0.06250
Epoch:  75 Batch: 200 Loss: 4.16034 Accuracy: 0.25000
Epoch:  76 Batch:   0 Loss: 4.19466 Accuracy: 0.06250
Epoch:  76 Batch: 200 Los

Epoch: 136 Batch: 200 Loss: 4.16408 Accuracy: 0.12500
Epoch: 137 Batch:   0 Loss: 4.15613 Accuracy: 0.12500
Epoch: 137 Batch: 200 Loss: 4.17396 Accuracy: 0.12500
Test Loss: 4.20137 Accuracy: 0.03068
Epoch: 138 Batch:   0 Loss: 4.15084 Accuracy: 0.06250
Epoch: 138 Batch: 200 Loss: 4.16869 Accuracy: 0.12500
Epoch: 139 Batch:   0 Loss: 4.15531 Accuracy: 0.12500
Epoch: 139 Batch: 200 Loss: 4.17448 Accuracy: 0.12500
Epoch: 140 Batch:   0 Loss: 4.16156 Accuracy: 0.12500
Epoch: 140 Batch: 200 Loss: 4.19304 Accuracy: 0.12500
Test Loss: 4.21350 Accuracy: 0.04148
Epoch: 141 Batch:   0 Loss: 4.14926 Accuracy: 0.06250
Epoch: 141 Batch: 200 Loss: 4.17832 Accuracy: 0.12500
Epoch: 142 Batch:   0 Loss: 4.14806 Accuracy: 0.06250
Epoch: 142 Batch: 200 Loss: 4.18121 Accuracy: 0.12500
Epoch: 143 Batch:   0 Loss: 4.15635 Accuracy: 0.06250
Epoch: 143 Batch: 200 Loss: 4.17002 Accuracy: 0.12500
Test Loss: 4.20967 Accuracy: 0.03987
Epoch: 144 Batch:   0 Loss: 4.15558 Accuracy: 0.18750
Epoch: 144 Batch: 200 Los

Epoch: 204 Batch: 200 Loss: 4.15973 Accuracy: 0.12500
Epoch: 205 Batch:   0 Loss: 4.15680 Accuracy: 0.06250
Epoch: 205 Batch: 200 Loss: 4.16918 Accuracy: 0.12500
Epoch: 206 Batch:   0 Loss: 4.16158 Accuracy: 0.06250
Epoch: 206 Batch: 200 Loss: 4.18637 Accuracy: 0.12500
Test Loss: 4.20473 Accuracy: 0.02633
Epoch: 207 Batch:   0 Loss: 4.16815 Accuracy: 0.06250
Epoch: 207 Batch: 200 Loss: 4.20155 Accuracy: 0.12500
Epoch: 208 Batch:   0 Loss: 4.15331 Accuracy: 0.06250
Epoch: 208 Batch: 200 Loss: 4.21725 Accuracy: 0.12500
Epoch: 209 Batch:   0 Loss: 4.15556 Accuracy: 0.06250
Epoch: 209 Batch: 200 Loss: 4.21017 Accuracy: 0.12500
Test Loss: 4.19638 Accuracy: 0.02992
Epoch: 210 Batch:   0 Loss: 4.15629 Accuracy: 0.06250
Epoch: 210 Batch: 200 Loss: 4.18241 Accuracy: 0.12500
Epoch: 211 Batch:   0 Loss: 4.17741 Accuracy: 0.06250
Epoch: 211 Batch: 200 Loss: 4.18889 Accuracy: 0.12500
Epoch: 212 Batch:   0 Loss: 4.16698 Accuracy: 0.00000
Epoch: 212 Batch: 200 Loss: 4.22029 Accuracy: 0.12500
Test Los

Epoch: 273 Batch: 200 Loss: 4.21385 Accuracy: 0.12500
Epoch: 274 Batch:   0 Loss: 4.15701 Accuracy: 0.06250
Epoch: 274 Batch: 200 Loss: 4.18336 Accuracy: 0.00000
Epoch: 275 Batch:   0 Loss: 4.14561 Accuracy: 0.00000
Epoch: 275 Batch: 200 Loss: 4.18469 Accuracy: 0.12500
Test Loss: 4.26565 Accuracy: 0.02206
Epoch: 276 Batch:   0 Loss: 4.24759 Accuracy: 0.06250
Epoch: 276 Batch: 200 Loss: 4.19521 Accuracy: 0.12500
Epoch: 277 Batch:   0 Loss: 4.16035 Accuracy: 0.06250
Epoch: 277 Batch: 200 Loss: 4.20492 Accuracy: 0.00000
Epoch: 278 Batch:   0 Loss: 4.16145 Accuracy: 0.00000
Epoch: 278 Batch: 200 Loss: 4.20184 Accuracy: 0.12500
Test Loss: 4.22822 Accuracy: 0.03703
Epoch: 279 Batch:   0 Loss: 4.20917 Accuracy: 0.06250
Epoch: 279 Batch: 200 Loss: 4.19517 Accuracy: 0.00000
Epoch: 280 Batch:   0 Loss: 4.16485 Accuracy: 0.06250
Epoch: 280 Batch: 200 Loss: 4.18876 Accuracy: 0.12500
Epoch: 281 Batch:   0 Loss: 4.15342 Accuracy: 0.00000
Epoch: 281 Batch: 200 Loss: 4.19337 Accuracy: 0.12500
Test Los

Epoch: 342 Batch: 200 Loss: 4.14822 Accuracy: 0.12500
Epoch: 343 Batch:   0 Loss: 4.15439 Accuracy: 0.06250
Epoch: 343 Batch: 200 Loss: 4.16589 Accuracy: 0.12500
Epoch: 344 Batch:   0 Loss: 4.16161 Accuracy: 0.06250
Epoch: 344 Batch: 200 Loss: 4.20064 Accuracy: 0.12500
Test Loss: 4.18929 Accuracy: 0.02775
Epoch: 345 Batch:   0 Loss: 4.15378 Accuracy: 0.06250
Epoch: 345 Batch: 200 Loss: 4.22324 Accuracy: 0.00000
Epoch: 346 Batch:   0 Loss: 4.15658 Accuracy: 0.00000
Epoch: 346 Batch: 200 Loss: 4.18437 Accuracy: 0.12500
Epoch: 347 Batch:   0 Loss: 4.16566 Accuracy: 0.06250
Epoch: 347 Batch: 200 Loss: 4.20185 Accuracy: 0.12500
Test Loss: 4.19846 Accuracy: 0.02846
Epoch: 348 Batch:   0 Loss: 4.15288 Accuracy: 0.00000
Epoch: 348 Batch: 200 Loss: 4.18111 Accuracy: 0.12500
Epoch: 349 Batch:   0 Loss: 4.14842 Accuracy: 0.00000
Epoch: 349 Batch: 200 Loss: 4.20119 Accuracy: 0.12500
Epoch: 350 Batch:   0 Loss: 4.15785 Accuracy: 0.00000
Epoch: 350 Batch: 200 Loss: 4.21108 Accuracy: 0.12500
Test Los

Epoch: 411 Batch: 200 Loss: 4.20120 Accuracy: 0.00000
Epoch: 412 Batch:   0 Loss: 4.16890 Accuracy: 0.00000
Epoch: 412 Batch: 200 Loss: 4.20125 Accuracy: 0.00000
Epoch: 413 Batch:   0 Loss: 4.17032 Accuracy: 0.00000
Epoch: 413 Batch: 200 Loss: 4.21191 Accuracy: 0.12500
Test Loss: 4.18195 Accuracy: 0.02633
Epoch: 414 Batch:   0 Loss: 4.15169 Accuracy: 0.00000
Epoch: 414 Batch: 200 Loss: 4.22261 Accuracy: 0.12500
Epoch: 415 Batch:   0 Loss: 4.15512 Accuracy: 0.00000
Epoch: 415 Batch: 200 Loss: 4.22287 Accuracy: 0.12500
Epoch: 416 Batch:   0 Loss: 4.15423 Accuracy: 0.00000
Epoch: 416 Batch: 200 Loss: 4.19651 Accuracy: 0.00000
Test Loss: 4.20906 Accuracy: 0.03916
Epoch: 417 Batch:   0 Loss: 4.15407 Accuracy: 0.00000
Epoch: 417 Batch: 200 Loss: 4.20123 Accuracy: 0.12500
Epoch: 418 Batch:   0 Loss: 4.16512 Accuracy: 0.06250
Epoch: 418 Batch: 200 Loss: 4.19428 Accuracy: 0.00000
Epoch: 419 Batch:   0 Loss: 4.16105 Accuracy: 0.00000
Epoch: 419 Batch: 200 Loss: 4.20189 Accuracy: 0.00000
Test Los

Epoch: 480 Batch: 200 Loss: 4.19919 Accuracy: 0.12500
Epoch: 481 Batch:   0 Loss: 4.16819 Accuracy: 0.00000
Epoch: 481 Batch: 200 Loss: 4.20642 Accuracy: 0.00000
Epoch: 482 Batch:   0 Loss: 4.14580 Accuracy: 0.00000
Epoch: 482 Batch: 200 Loss: 4.18494 Accuracy: 0.12500
Test Loss: 4.20064 Accuracy: 0.02846
Epoch: 483 Batch:   0 Loss: 4.14620 Accuracy: 0.06250
Epoch: 483 Batch: 200 Loss: 4.17041 Accuracy: 0.12500
Epoch: 484 Batch:   0 Loss: 4.16165 Accuracy: 0.00000
Epoch: 484 Batch: 200 Loss: 4.20608 Accuracy: 0.12500
Epoch: 485 Batch:   0 Loss: 4.14782 Accuracy: 0.06250
Epoch: 485 Batch: 200 Loss: 4.21873 Accuracy: 0.18750
Test Loss: 4.19338 Accuracy: 0.02921
Epoch: 486 Batch:   0 Loss: 4.15587 Accuracy: 0.00000
Epoch: 486 Batch: 200 Loss: 4.21017 Accuracy: 0.12500
Epoch: 487 Batch:   0 Loss: 4.15238 Accuracy: 0.00000
Epoch: 487 Batch: 200 Loss: 4.20181 Accuracy: 0.00000
Epoch: 488 Batch:   0 Loss: 4.15106 Accuracy: 0.00000
Epoch: 488 Batch: 200 Loss: 4.19859 Accuracy: 0.12500
Test Los

Epoch: 549 Batch: 200 Loss: 4.19964 Accuracy: 0.12500
Epoch: 550 Batch:   0 Loss: 4.17920 Accuracy: 0.00000
Epoch: 550 Batch: 200 Loss: 4.19198 Accuracy: 0.00000
Epoch: 551 Batch:   0 Loss: 4.16878 Accuracy: 0.00000
Epoch: 551 Batch: 200 Loss: 4.21746 Accuracy: 0.12500
Test Loss: 4.19258 Accuracy: 0.02992
Epoch: 552 Batch:   0 Loss: 4.16190 Accuracy: 0.00000
Epoch: 552 Batch: 200 Loss: 4.20192 Accuracy: 0.00000
Epoch: 553 Batch:   0 Loss: 4.14280 Accuracy: 0.00000
Epoch: 553 Batch: 200 Loss: 4.15487 Accuracy: 0.12500
Epoch: 554 Batch:   0 Loss: 4.16361 Accuracy: 0.06250
Epoch: 554 Batch: 200 Loss: 4.20296 Accuracy: 0.00000
Test Loss: 4.20176 Accuracy: 0.03987
Epoch: 555 Batch:   0 Loss: 4.17021 Accuracy: 0.06250
Epoch: 555 Batch: 200 Loss: 4.20877 Accuracy: 0.12500
Epoch: 556 Batch:   0 Loss: 4.15382 Accuracy: 0.06250
Epoch: 556 Batch: 200 Loss: 4.18987 Accuracy: 0.00000
Epoch: 557 Batch:   0 Loss: 4.14081 Accuracy: 0.00000
Epoch: 557 Batch: 200 Loss: 4.18020 Accuracy: 0.12500
Test Los

Epoch: 618 Batch: 200 Loss: 4.22094 Accuracy: 0.12500
Epoch: 619 Batch:   0 Loss: 4.17126 Accuracy: 0.06250
Epoch: 619 Batch: 200 Loss: 4.22616 Accuracy: 0.00000
Epoch: 620 Batch:   0 Loss: 4.15017 Accuracy: 0.00000
Epoch: 620 Batch: 200 Loss: 4.19900 Accuracy: 0.12500
Test Loss: 4.20684 Accuracy: 0.02140
Epoch: 621 Batch:   0 Loss: 4.17351 Accuracy: 0.06250
Epoch: 621 Batch: 200 Loss: 4.21771 Accuracy: 0.12500
Epoch: 622 Batch:   0 Loss: 4.15215 Accuracy: 0.00000
Epoch: 622 Batch: 200 Loss: 4.21499 Accuracy: 0.12500
Epoch: 623 Batch:   0 Loss: 4.14585 Accuracy: 0.06250
Epoch: 623 Batch: 200 Loss: 4.22682 Accuracy: 0.12500
Test Loss: 4.19945 Accuracy: 0.02704
Epoch: 624 Batch:   0 Loss: 4.16639 Accuracy: 0.06250
Epoch: 624 Batch: 200 Loss: 4.22619 Accuracy: 0.00000
Epoch: 625 Batch:   0 Loss: 4.15023 Accuracy: 0.00000
Epoch: 625 Batch: 200 Loss: 4.22199 Accuracy: 0.12500
Epoch: 626 Batch:   0 Loss: 4.16877 Accuracy: 0.06250
Epoch: 626 Batch: 200 Loss: 4.22575 Accuracy: 0.00000
Test Los

# Model Evaluation

In [None]:
acc = sess.run(protoNN.accuracy, feed_dict={X: x_test, Y: y_test})
# W, B, Z are tensorflow graph nodes
W, B, Z, _ = protoNN.getModelMatrices()
matrixList = sess.run([W, B, Z])
sparcityList = [SPAR_W, SPAR_B, SPAR_Z]
nnz, size, sparse = helper.getModelSize(matrixList, sparcityList)
print("Final test accuracy", acc)
print("Model size constraint (Bytes): ", size)
print("Number of non-zeros: ", nnz)
nnz, size, sparse = helper.getModelSize(matrixList, sparcityList, expected=False)
print("Actual model size: ", size)
print("Actual non-zeros: ", nnz)