In [1]:
from __future__ import print_function
import sys
import os
import numpy as np
import torch

from edgeml_pytorch.graph.protoNN import ProtoNN
from edgeml_pytorch.trainer.protoNNTrainer import ProtoNNTrainer
import edgeml_pytorch.utils as utils
import helpermethods as helper

## USPS Data
It is assumed that the USPS data has already been downloaded and set up with the help of `fetch_usps.py` and is placed in the `./usps10` subdirectory.

In [2]:
# Load data
DATA_DIR = './usps10'
train, test = np.load(DATA_DIR + '/train.npy'), np.load(DATA_DIR + '/test.npy')
x_train, y_train = train[:, 1:], train[:, 0]
x_test, y_test = test[:, 1:], test[:, 0]

numClasses = max(y_train) - min(y_train) + 1
numClasses = max(numClasses, max(y_test) - min(y_test) + 1)
numClasses = int(numClasses)

y_train = helper.to_onehot(y_train, numClasses)
y_test = helper.to_onehot(y_test, numClasses)
dataDimension = x_train.shape[1]
numClasses = y_train.shape[1]

## Model Parameters

Note that ProtoNN is very sensitive to the value of the hyperparameter $\gamma$, here stored in valiable GAMMA. If GAMMA is set to None, median heuristic will be used to estimate a good value of $\gamma$ through the helper.getGamma() method. This method also returns the corresponding W and B matrices which should be used to initialize ProtoNN (as is done here).

In [3]:
PROJECTION_DIM = 60
NUM_PROTOTYPES = 60
REG_W = 0.000005
REG_B = 0.0
REG_Z = 0.00005
SPAR_W = 0.8
SPAR_B = 1.0
SPAR_Z = 1.0
LEARNING_RATE = 0.05
NUM_EPOCHS = 200
BATCH_SIZE = 32
GAMMA = 0.0015

In [4]:
W, B, gamma = helper.getGamma(GAMMA, PROJECTION_DIM, dataDimension,
                       NUM_PROTOTYPES, x_train)

In [5]:
protoNNObj = ProtoNN(dataDimension, PROJECTION_DIM, NUM_PROTOTYPES, numClasses,
                     gamma, W=W, B=B)
protoNNTrainer = ProtoNNTrainer(protoNNObj, REG_W, REG_B, REG_Z, SPAR_W, SPAR_B, SPAR_W,
                                LEARNING_RATE, lossType='xentropy')

Using x-entropy loss


In [6]:
protoNNTrainer.train(BATCH_SIZE, NUM_EPOCHS, x_train, x_test, y_train, y_test, printStep=600, valStep=10)

Epoch 0 batch 0 loss 19.695423 acc 0.187500
Epoch 1 batch 0 loss 2.124979 acc 0.406250
Epoch 2 batch 0 loss 1.176522 acc 0.625000
Epoch 3 batch 0 loss 0.852282 acc 0.781250
Epoch 4 batch 0 loss 0.707985 acc 0.843750
Epoch 5 batch 0 loss 0.620102 acc 0.906250
Epoch 6 batch 0 loss 0.536484 acc 0.906250
Epoch 7 batch 0 loss 0.506774 acc 0.906250
Epoch 8 batch 0 loss 0.469522 acc 0.906250
Epoch 9 batch 0 loss 0.443830 acc 0.937500
Validation accuracy: 0.866467
Epoch 10 batch 0 loss 0.418728 acc 0.937500
Epoch 11 batch 0 loss 0.400192 acc 0.937500
Epoch 12 batch 0 loss 0.391262 acc 0.937500
Epoch 13 batch 0 loss 0.377898 acc 0.937500
Epoch 14 batch 0 loss 0.374987 acc 0.937500
Epoch 15 batch 0 loss 0.374003 acc 0.937500
Epoch 16 batch 0 loss 0.372961 acc 0.937500
Epoch 17 batch 0 loss 0.372122 acc 0.937500
Epoch 18 batch 0 loss 0.374060 acc 0.937500
Epoch 19 batch 0 loss 0.372302 acc 0.937500
Validation accuracy: 0.890882
Epoch 20 batch 0 loss 0.370008 acc 0.937500
Epoch 21 batch 0 loss 0.3

Epoch 174 batch 0 loss 0.297396 acc 0.937500
Epoch 175 batch 0 loss 0.297138 acc 0.937500
Epoch 176 batch 0 loss 0.296805 acc 0.937500
Epoch 177 batch 0 loss 0.296536 acc 0.937500
Epoch 178 batch 0 loss 0.296280 acc 0.937500
Epoch 179 batch 0 loss 0.295992 acc 0.937500
Validation accuracy: 0.910314
Epoch 180 batch 0 loss 0.295781 acc 0.937500
Epoch 181 batch 0 loss 0.295575 acc 0.937500
Epoch 182 batch 0 loss 0.294182 acc 0.937500
Epoch 183 batch 0 loss 0.294994 acc 0.937500
Epoch 184 batch 0 loss 0.294937 acc 0.937500
Epoch 185 batch 0 loss 0.294700 acc 0.937500
Epoch 186 batch 0 loss 0.294113 acc 0.937500
Epoch 187 batch 0 loss 0.292420 acc 0.937500
Epoch 188 batch 0 loss 0.291916 acc 0.937500
Epoch 189 batch 0 loss 0.291724 acc 0.937500
Validation accuracy: 0.911310
Epoch 190 batch 0 loss 0.291277 acc 0.937500
Epoch 191 batch 0 loss 0.291130 acc 0.937500
Epoch 192 batch 0 loss 0.291667 acc 0.937500
Epoch 193 batch 0 loss 0.291187 acc 0.937500
Epoch 194 batch 0 loss 0.290773 acc 0.93

## Evaluation

In [7]:
 x_, y_= torch.Tensor(x_test), torch.Tensor(y_test)
logits = protoNNObj.forward(x_)
_, predictions = torch.max(logits, dim=1)
_, target = torch.max(y_, dim=1)
acc, count = protoNNTrainer.accuracy(predictions, target)
W, B, Z, gamma  = protoNNObj.getModelMatrices()
matrixList = [W, B, Z]
matrixList = [x.detach().numpy() for x in matrixList]
sparcityList = [SPAR_W, SPAR_B, SPAR_Z]
nnz, size, sparse = helper.getModelSize(matrixList, sparcityList)
print("Final test accuracy", acc)
print("Model size constraint (Bytes): ", size)
print("Number of non-zeros: ", nnz)
nnz, size, sparse = helper.getModelSize(matrixList, sparcityList,
                                        expected=False)
print("Actual model size: ", size)
print("Actual non-zeros: ", nnz)

Final test accuracy tensor(0.9138, dtype=torch.float64)
Model size constraint (Bytes):  78240
Number of non-zeros:  19560
Actual model size:  78240
Actual non-zeros:  16368
