In [1]:
# Import needed files and basic setup
import numpy as np
from sklearn.cluster import MiniBatchKMeans
import sklearn

from ipywidgets import Output
from IPython.display import display, Markdown, Latex, Math, clear_output
%matplotlib inline

from IPython.display import HTML, display
import tabulate

In [2]:
networks = { 'Dense': 'dense',
             'Conv1': 'conv1',
             'Conv2': 'conv2',
             'Inception': 'inception',
             'ResNet': 'resnet'}

# Modify to match the outputs of running the code in the Training and Linearization Calculation folder
startPath = '/s/red/b/nobackup/data/bsattelb/linearRegions/'
endPath = 'MNIST{}Region{}.npy'

for key in networks:
    networks[key] = startPath + networks[key] + endPath

In [3]:
def calcAccuracy(key, n_clusters):
    print(key)
    
    out = np.zeros((10000, 10))
    scores = np.zeros(10)
    
    testInputs = np.load('MNISTTestInputs.npy')
    testInputs = np.hstack((testInputs, np.ones((testInputs.shape[0], 1))))
    trueLabels = np.load('MNISTTestLabels.npy')
    
    for i in range(10):
        kmeans = MiniBatchKMeans(n_clusters=n_clusters, init_size=max(300, 3*n_clusters))
        netTrain = np.load(networks[key].format('Train', i))
        kmeans.fit(netTrain)
        
        del netTrain
        
        netTest = np.load(networks[key].format('Test', i))
        scores[i] = kmeans.score(netTest)
        print('\t', i, scores[i])
        
        predRegions = kmeans.cluster_centers_[kmeans.predict(netTest), :]
        
        out[:, i] = np.sum(np.multiply(testInputs, predRegions), axis=1)
        
    preds = np.argmax(out, axis=1)

    accuracy = np.sum(preds == trueLabels)
    print('\t', accuracy)
        
    return scores, accuracy

In [4]:
keyLabels = list(networks.keys())

In [5]:
denseAccuracy = 9603
conv1Accuracy = 9807
conv2Accuracy = 9804
inceptionAccuracy = 9908
resnetAccuracy = 9892

In [6]:
numClusters = [1, 10, 100, 1000, 10000]
accuracies = np.zeros([len(numClusters)+1, len(keyLabels)])

In [7]:
accuracies[0, 0] = denseAccuracy
accuracies[0, 1] = conv1Accuracy
accuracies[0, 2] = conv2Accuracy
accuracies[0, 3] = inceptionAccuracy
accuracies[0, 4] = resnetAccuracy

In [8]:
for i in range(len(numClusters)):
    for j in range(len(keyLabels)):
        if accuracies[i+1, j] != 0:
            continue
        _, accuracy = calcAccuracy(keyLabels[j], numClusters[i])
        accuracies[i+1, j] = accuracy

Dense
	 0 -685.8788207864054
	 1 -760.883156602938
	 2 -1189.3688533780264
	 3 -1142.0002064680657
	 4 -930.1911276524731
	 5 -1460.9690306197724
	 6 -866.934980888483
	 7 -1095.522193930247
	 8 -899.014979021417
	 9 -1328.9388930310583
	 8679
Conv1
	 0 -2295.916537667999
	 1 -2554.9582611247497
	 2 -3406.3900860847193
	 3 -3464.105606671022
	 4 -3696.280831277792
	 5 -5384.397233507858
	 6 -4542.342485523613
	 7 -3785.1781170726663
	 8 -2269.2680896135084
	 9 -3348.746806223222
	 6766
Conv2
	 0 -2356.3644018775594
	 1 -2602.0675384495307
	 2 -3980.843072937315
	 3 -3668.640541475965
	 4 -3599.281813890825
	 5 -4964.166876078076
	 6 -4766.231801999631
	 7 -3541.441408003224
	 8 -2186.114345386185
	 9 -3419.4968742538354
	 6366
Inception
	 0 -184833.87277436478
	 1 -125320.30443050737
	 2 -141907.34988082355
	 3 -143376.05834946263
	 4 -150579.99366836238
	 5 -90337.27143157921
	 6 -146713.2538905528
	 7 -135421.5544962114
	 8 -165855.78575492027
	 9 -92079.9239587177
	 974
ResNet
	 0 -

In [11]:
table = list(accuracies)
table[0] = [''] + list(table[0])
for i in range(1, len(numClusters)+1):
    table[i] = [numClusters[i-1]] + list(table[i])
    
table = [[''] + keyLabels] + table

display(HTML(tabulate.tabulate(table, tablefmt='html')))

0,1,2,3,4,5
,Dense,Conv1,Conv2,Inception,ResNet
,9603.0,9807.0,9804.0,9908.0,9892.0
1.0,8679.0,6766.0,6366.0,974.0,1432.0
10.0,9231.0,8639.0,8672.0,9660.0,8166.0
100.0,9434.0,9382.0,9421.0,9689.0,8458.0
1000.0,9508.0,9586.0,9603.0,9695.0,8982.0
10000.0,9554.0,9696.0,9673.0,9752.0,9381.0
