In [8]:
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
import numpy as np
import time
from itertools import product
import chainer.functions as F
from branchynet.augment import augmentation
from scipy.stats import entropy


def test_suite_B(branchyNet,x_test,y_test,batchsize=10000,ps=np.linspace(0.1,2.0,10)):
    accs = []
    diffs = []
    num_exits = []
    for p in ps:
        branchyNet.thresholdExits = p
        acc,diff,num_exit,_ = test1( branchyNet,x_test,y_test,batchsize=batchsize)
        accs.append(acc)
        diffs.append(diff)
        num_exits.append(num_exit)
        
    return ps, np.array(accs), np.array(diffs)/float(len(y_test)), num_exit

def test2(branchynet, x,t=None):

    numexits = []
    accuracies = []

    remainingXVar = x
    remainingTVar = t

    nummodels = len(branchynet.models)
    numsamples = x.data.shape[0]

    totaltime = 0

    #add peng
    find_exit_num = True
    for i,model in enumerate(branchynet.models):
        print(model, model.starti, model.endi)
        if i == 0:
            continue
        # modified by peng
        if remainingXVar is None or remainingTVar is None:
            numexits.append(0)
            accuracies.append(0)
            continue

        # if self.gpu:
        #     # Faster on GPU, less transfer
        #     smh = model.test(remainingXVar,None)
        # else:
        #     h = model.test(remainingXVar,None,model.starti,model.endi)
        #     smh = model.test(h,None,model.endi)
        start_time = time.time()
        print(remainingXVar)
        h = model.test(remainingXVar,model.starti,model.endi)
        endtime = time.time()
        totaltime += endtime - start_time

        smh = model.test(h,model.endi)

        softmax = F.softmax(smh)

        if self.gpu:
            entropy_value = entropy_gpu(softmax).get()
        else:
            entropy_value = np.array([entropy(s) for s in softmax.data])

        idx = np.zeros(entropy_value.shape[0],dtype=bool)
        if i == nummodels-1:
            idx = np.ones(entropy_value.shape[0],dtype=bool)
            numexit = sum(idx)
        else:
            if branchynet.thresholdExits is not None:
                #min_ent = min(entropy_value)
                min_ent = 0
                if isinstance(branchynet.thresholdExits,list):
                    idx[entropy_value < min_ent+self.thresholdExits[i]] = True
                    numexit = sum(idx)
                else:
                    idx[entropy_value < min_ent+self.thresholdExits] = True
                    numexit = sum(idx)
            else:
                if isinstance(branchynet.percentTestExits,list):
                    numexit = int((branchynet.percentTestExits[i])*numsamples)
                else:
                    numexit = int(branchynet.percentTestExits*entropy_value.shape[0])
                esorted = entropy_value.argsort()
                idx[esorted[:numexit]] = True

        total = entropy_value.shape[0]
        numkeep = total-numexit
        numexits.append(numexit)

        if branchynet.gpu:
            xdata = h.data.get()
            # xdata = remainingXVar.data.get()
            tdata = remainingTVar.data.get()
        else:
            xdata = h.data
#             xdata = remainingXVar.data
            tdata = remainingTVar.data

        if numkeep > 0:
            xdata_keep = xdata[~idx]
            tdata_keep = tdata[~idx]
            remainingXVar = Variable(branchynet.xp.array(xdata_keep,dtype=x.data.dtype),volatile=x.volatile)
            remainingTVar = Variable(branchynet.xp.array(tdata_keep,dtype=t.data.dtype),volatile=t.volatile)
        else:
            remainingXVar = None
            remainingTVar = None

        if numexit > 0:
            xdata_exit = xdata[idx]
            tdata_exit = tdata[idx]
            exitXVar = Variable(branchynet.xp.array(xdata_exit,dtype=x.data.dtype),volatile=x.volatile)
            exitTVar = Variable(branchynet.xp.array(tdata_exit,dtype=t.data.dtype),volatile=t.volatile)

            # if self.gpu:
            #     exitH = model.test(exitXVar,None)
            # else:
            #     exitH = model.test(exitXVar,None,model.endi)
            exitH = model.test(exitXVar,model.endi)

            accuracy = F.accuracy(exitH,exitTVar)

            if self.gpu:
                accuracies.append(accuracy.data.get())
            else:
                accuracies.append(accuracy.data)
        else:
            accuracies.append(0.)

    overall = 0
    for i,accuracy in enumerate(accuracies):
        overall += accuracy*numexits[i]
    overall /= np.sum(numexits)

    if self.verbose:
        print "numexits", numexits
        print "accuracies", accuracies
        print "overall accuracy", overall

    return overall, accuracies, numexits, totaltime  

def test(branchyNet,x_test,y_test=None,batchsize=10000,main=False):
    datasize = x_test.shape[0]
    
    overall = 0.
    totaltime = 0.
    nsamples = 0
    num_exits = np.zeros(branchyNet.numexits()).astype(int)
    # finals = []
    accbreakdowns = np.zeros(branchyNet.numexits())
    
    for i in range(0, datasize, batchsize):
        input_data = x_test[i : i + batchsize]
        label_data = y_test[i : i + batchsize]

        input_data = branchyNet.xp.asarray(input_data, dtype=branchyNet.xp.float32)
        label_data = branchyNet.xp.asarray(label_data, dtype=branchyNet.xp.int32)

        x = Variable(input_data, volatile=True)
        t = Variable(label_data, volatile=True)
        if main:
            acc, diff = branchyNet.test_main(x,t)
            #if hasattr(h.data,'get'):
            #    finals.append(h.data.get())
            #else:
            #    finals.append(h.data)
            #accuracies = [acc]
        else:
            #acc, accuracies, test_exits, diff = branchyNet.test(x,t)
            acc, accuracies, test_exits, diff = test2(branchyNet, x,t)

            for i, exits in enumerate(test_exits):
                num_exits[i] += exits
            for i in range(branchyNet.numexits()):
                accbreakdowns[i] += accuracies[i]*test_exits[i]
        
        # end_time = time.time()    
        # diff = end_time-start_time
                
        totaltime += diff
        overall += input_data.shape[0]*acc
        nsamples += input_data.shape[0]
        
    overall /= nsamples
    
    for i in range(branchyNet.numexits()):
        if num_exits[i] > 0:
            accbreakdowns[i]/=num_exits[i]
    #if len(finals) > 0:
    #    hh = Variable(np.vstack(finals),volatile=True)
    #    tt = Variable(y_test, volatile=True)
    #    overall = F.accuracy(hh,tt).data
    
    return overall, totaltime, num_exits, accbreakdowns


def test1(branchyNet,x_test,y_test=None,batchsize=10000,main=False):
    datasize = x_test.shape[0]
    
    entropy_list = []
    exit_result = []
    
    overall = 0.
    totaltime = 0.
    nsamples = 0
    num_exits = np.zeros(branchyNet.numexits()).astype(int)
    # finals = []
    accbreakdowns = np.zeros(branchyNet.numexits())
    
    for i in range(0, datasize, batchsize):
        input_data = x_test[i : i + batchsize]
        label_data = y_test[i : i + batchsize]

        input_data = branchyNet.xp.asarray(input_data, dtype=branchyNet.xp.float32)
        label_data = branchyNet.xp.asarray(label_data, dtype=branchyNet.xp.int32)

        x = Variable(input_data, volatile=True)
        t = Variable(label_data, volatile=True)
        if main:
            acc, diff = branchyNet.test_main(x,t)
        else:
            acc, accuracies, test_exits, diff = branchyNet.test(x,t)
            #acc, accuracies, test_exits, diff = [],[],[],[]
            #acc, accuracies, test_exits, diff = test1(branchyNet, x,t)

            remainingXVar = x
            remainingTVar = t
            
            for i,model in enumerate(branchyNet.models):
                
                # modified by peng
                if remainingXVar is None or remainingTVar is None:
                    numexits.append(0)
                    accuracies.append(0)
                    continue
                h = model.test(remainingXVar,model.starti,model.endi)
            
                smh = model.test(h,model.endi)
            
                softmax = F.softmax(smh)
                entropy_value = np.array([entropy(s) for s in softmax.data]) 
                entropy_list.append(entropy_value.tolist())
                            
                total = entropy_value.shape[0]
                idx = np.zeros(total,dtype=bool)
                if entropy_value < branchyNet.thresholdExits:
                    pass
                break     
            for i, exits in enumerate(test_exits):
                num_exits[i] += exits
            for i in range(branchyNet.numexits()):
                accbreakdowns[i] += accuracies[i]*test_exits[i]
        exit_result.append(test_exits[1])
        
        
        totaltime += diff
        overall += input_data.shape[0]*acc
        nsamples += input_data.shape[0]
        
    overall /= nsamples
    
    for i in range(branchyNet.numexits()):
        if num_exits[i] > 0:
            accbreakdowns[i] /= num_exits[i]
    
    #print(entropy_list)
    #print(exit_result)
    min_val = 10000000
    index = -1
    ret = []
    for i, val in enumerate(entropy_list):
        if exit_result[i] == 0 and min_val > val[0]:

            min_val = val[0]
            index = i
            #print(i, val[0])
            ret.append(index)
        if exit_result[i] == 1:
            #print(val[0])
            pass
    print("minimum entropy is ", index, entropy_list[index], exit_result[index])
    return overall, totaltime, num_exits, accbreakdowns

    

In [10]:
import dill

from datasets import mnist
#The MNIST database of handwritten digits, has a training set of 60,000 examples, and a test set of 10,000 examples.

X_train, Y_train, X_test, Y_test = mnist.get_data()
print(X_train.shape)

branchyNet = None
with open("_models/lenet_mnist.bn", "rb") as f:
    branchyNet = dill.load(f)
    #set network to inference mode, this is fob_test_data_yr measuring baseline function. 
branchyNet.testing()
branchyNet.verbose = False

#branchyNet.to_cpu()
TEST_BATCHSIZE = 10    
thresholds = [0.1]
    
branchyNet.to_cpu()

for i in range(10):
    ts, accs, diffs, exits = test_suite_B( branchyNet, X_test[Y_test==i], Y_test[Y_test==i], batchsize=1, ps=thresholds)
    print(ts, accs, diffs, exits)



(60000, 1, 28, 28)
('minimum entropy is ', 579, [2.0858755261227158e-12], 0)
[0, 1, 2, 5, 7, 14, 20, 35, 58, 577, 579]
([0.1], array([0.99591837]), array([0.00061922]), array([949,  31]))
('minimum entropy is ', 274, [1.7575597643926244e-09], 0)
[0, 1, 2, 8, 15, 20, 35, 73, 274]
([0.1], array([0.99823789]), array([0.00047021]), array([1118,   17]))
('minimum entropy is ', 564, [1.1183943333792052e-21], 0)
[0, 4, 5, 7, 40, 55, 120, 206, 412, 564]
([0.1], array([0.99224806]), array([0.00056826]), array([995,  37]))
('minimum entropy is ', 823, [4.557630708686806e-16], 0)
[1, 2, 6, 27, 44, 45, 75, 79, 95, 707, 713, 823]
([0.1], array([0.99306931]), array([0.000519]), array([985,  25]))
('minimum entropy is ', 408, [6.919652025851042e-14], 0)
[0, 2, 4, 8, 9, 14, 46, 150, 247, 342, 408]
([0.1], array([0.99185336]), array([0.00090247]), array([924,  58]))
('minimum entropy is ', 760, [2.8431236856994724e-15], 0)
[0, 2, 4, 15, 58, 381, 575, 577, 760]
([0.1], array([0.98878924]), array([0.0005

In [49]:
[815, 441, 840, 481, 95, 203, 333, 150, 795, 343]
[577, 1011, 564, 823, 408, 577, 568, 923, 571, 614]

[577, 1011, 564, 823, 408, 577, 568, 923, 571, 614]

In [25]:
import dill

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Load the fashion-mnist pre-shuffled train data and test data
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.fashion_mnist.load_data()
print("x_train shape:", X_train.shape, "y_train shape:", Y_train.shape)
X_train, X_test = X_train , X_test

X_train = X_train.reshape(-1, 1, 28, 28)
X_test = X_test.reshape(-1, 1, 28, 28) 


branchyNet = None
with open("_models/lenet_fashion_mnist.bn", "rb") as f:
    branchyNet = dill.load(f)
    #set network to inference mode, this is fob_test_data_yr measuring baseline function. 
branchyNet.testing()
branchyNet.verbose = False

#branchyNet.to_cpu()
TEST_BATCHSIZE = 10    
thresholds = [0.5]
    
branchyNet.to_cpu()
    
for i in range(10):
    print("number is ", i)
    ts, accs, diffs, exits = test_suite_B(branchyNet, X_test[Y_test==i], Y_test[Y_test==i], batchsize=1, ps=thresholds)
    print(ts, accs, diffs, exits)

('x_train shape:', (60000, 28, 28), 'y_train shape:', (60000,))
('number is ', 0)
('minimum entropy is ', 259, [3.577666575438343e-05], 0)
[0, 5, 10, 29, 51, 144, 259]
([0.5], array([0.826]), array([0.00299554]), array([671, 329]))
('number is ', 1)
('minimum entropy is ', 137, [1.790095442754172e-13], 0)
[0, 2, 16, 48, 137]
([0.5], array([0.978]), array([0.00082396]), array([976,  24]))
('number is ', 2)
('minimum entropy is ', 24, [0.0029745555948466063], 0)
[0, 2, 9, 24]
([0.5], array([0.775]), array([0.0040914]), array([514, 486]))
('number is ', 3)
('minimum entropy is ', 617, [7.744661729702784e-07], 0)
[0, 9, 21, 24, 34, 59, 127, 361, 617]
([0.5], array([0.882]), array([0.00242143]), array([760, 240]))
('number is ', 4)
('minimum entropy is ', 885, [0.0022552437148988247], 0)
[0, 19, 45, 73, 278, 427, 885]
([0.5], array([0.762]), array([0.00329175]), array([537, 463]))
('number is ', 5)
('minimum entropy is ', 513, [2.1382970224697387e-21], 0)
[0, 1, 6, 18, 21, 30, 38, 60, 190, 

In [12]:
import dill

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Load the k-mnist pre-shuffled train data and test data
import numpy as np
#https://www.kaggle.com/aakashnain/kmnist-mnist-replacement
# Let us define some paths first
input_path = "datasets/data/kmnist/"

# Path to training images and corresponding labels provided as numpy arrays
kmnist_train_images_path = input_path+"kmnist-train-imgs.npz"
kmnist_train_labels_path = input_path+"kmnist-train-labels.npz"

# Path to the test images and corresponding labels
kmnist_test_images_path = input_path+"kmnist-test-imgs.npz"
kmnist_test_labels_path = input_path+"kmnist-test-labels.npz"

# Load the training data from the corresponding npz files
kmnist_train_images = np.load(kmnist_train_images_path)['arr_0']
kmnist_train_labels = np.load(kmnist_train_labels_path)['arr_0']

# Load the test data from the corresponding npz files
kmnist_test_images = np.load(kmnist_test_images_path)['arr_0']
kmnist_test_labels = np.load(kmnist_test_labels_path)['arr_0']

print("Number of training samples: {} where each sample is of size: {}".format(
    len(kmnist_train_images), kmnist_train_images.shape[1:] ))
print("Number of test samples: {} where each sample is of size: {}".format(
    len(kmnist_test_images), kmnist_test_images.shape[1:]))

X_train = kmnist_train_images.reshape(60000, 1, 28, 28)
X_test = kmnist_test_images.reshape(10000, 1, 28, 28)
Y_train = kmnist_train_labels
Y_test = kmnist_test_labels

branchyNet = None
with open("_models/lenet_k_mnist.bn", "rb") as f:
    branchyNet = dill.load(f)
    #set network to inference mode, this is fob_test_data_yr measuring baseline function. 
branchyNet.testing()
branchyNet.verbose = False

#branchyNet.to_cpu()
TEST_BATCHSIZE = 10    
thresholds = [0.05]
    
branchyNet.to_cpu()
    
for i in range(10):
    print("number is ", i)
    ts, accs, diffs, exits = test_suite_B(branchyNet, X_test[Y_test==i], Y_test[Y_test==i], batchsize=1, ps=thresholds)
    print(ts, accs, diffs, exits)

Number of training samples: 60000 where each sample is of size: (28, 28)
Number of test samples: 10000 where each sample is of size: (28, 28)
('number is ', 0)
('minimum entropy is ', 396, [1.4808608506111653e-10], 0)
[1, 2, 5, 6, 8, 9, 16, 57, 91, 396]
([0.05], array([0.928]), array([0.00144799]), array([686, 314]))
('number is ', 1)
('minimum entropy is ', 909, [4.9983203340153715e-12], 0)
[0, 1, 18, 111, 299, 879, 909]
([0.05], array([0.91]), array([0.00127394]), array([719, 281]))
('number is ', 2)
('minimum entropy is ', 471, [2.260422071742596e-08], 0)
[0, 4, 6, 10, 13, 25, 49, 91, 224, 471]
([0.05], array([0.867]), array([0.00150766]), array([617, 383]))
('number is ', 3)
('minimum entropy is ', 937, [1.478038108571056e-10], 0)
[0, 1, 7, 15, 937]
([0.05], array([0.953]), array([0.00138324]), array([733, 267]))
('number is ', 4)
('minimum entropy is ', 799, [4.124938612903861e-09], 0)
[0, 3, 31, 43, 114, 629, 799]
([0.05], array([0.895]), array([0.00208745]), array([631, 369]))
(

In [None]:
MNIST: [579, 274, 564, 823, 408, 760, 568, 923, 571, 614]
FASHION-MINIST: [350, 601, 401, 145, 906, 814, 962, 253, 35, 906]
[259, 137, 24, 617, 885, 513, 372, 765, 293, 534]
K-MNIST: [396, 909, 471, 937, 799, 437, 645, 847, 772, 227]