In [1]:
# Import needed files and basic setup
import numpy as np
from sklearn import linear_model
import sklearn

from ipywidgets import Output
from IPython.display import display, Markdown, Latex, Math, clear_output
%matplotlib inline

from IPython.display import HTML, display
import tabulate

In [2]:
networks = { 'Dense': 'dense',
             'Conv1': 'conv1',
             'Conv2': 'conv2',
             'Inception': 'inception',
             'ResNet': 'resnet'}

# Modify this location to be where networks are saved
startPath = '/s/red/b/nobackup/data/bsattelb/linearRegions/'
endPath = 'MNIST{}Region{}.npy'

for key in networks:
    networks[key] = startPath + networks[key] + endPath

In [3]:
keyLabels = list(networks.keys())

In [4]:
# Accuracies calculated in other notebooks
denseAccuracy = 9603
conv1Accuracy = 9807
conv2Accuracy = 9804
inceptionAccuracy = 9908
resnetAccuracy = 9892

In [15]:
# Calculate the accuracies arising from mapping between two networks linear regions
def calcAccuracies(keyFrom, keyTo, networks, use_bias=False, fit_intercept=False):         
    print('from', keyFrom, 'to', keyTo)

    scores = np.zeros(10)

    out = np.zeros((10000, 10))
    
    testInputs = np.load('MNISTTestInputs.npy')
    trueLabels = np.load('MNISTTestLabels.npy')
    
    if use_bias:
        testInputs = np.hstack((testInputs, np.ones((testInputs.shape[0], 1))))
    
    for i in range(10):
        origNetTrain = np.load(networks[keyFrom].format('Train', i))
        newNetTrain = np.load(networks[keyTo].format('Train', i))
        
        if not use_bias:
            origNetTrain = origNetTrain[:, :-1]
            newNetTrain = newNetTrain[:, :-1]

        reg = linear_model.LinearRegression(fit_intercept=fit_intercept)
        reg.fit(origNetTrain, newNetTrain)
        
        del origNetTrain
        del newNetTrain

        origNetTest = np.load(networks[keyFrom].format('Test', i))
        newNetTest = np.load(networks[keyTo].format('Test', i))
        
        if not use_bias:
            origNetTest = origNetTest[:, :-1]
            newNetTest = newNetTest[:, :-1]

        
        scores[i] = reg.score(origNetTest, newNetTest)
        print('\t', i, scores[i])

        predRegions = reg.predict(origNetTest)

        out[:, i] = np.sum(np.multiply(testInputs, predRegions), axis=1)
            
            
    preds = np.argmax(out, axis=1)

    accuracy = np.sum(preds == trueLabels)
    print('\t', accuracy)
        
    return scores, accuracy

In [7]:
# Calculate the accuracies for all mappings
def getAccuracies(greedy=False, repitition=False, use_bias=False, fit_intercept=False):
    accuracies = np.zeros((len(keyLabels), len(keyLabels))).astype(np.int)
    accuracies[0, 0] = denseAccuracy
    accuracies[1, 1] = conv1Accuracy
    accuracies[2, 2] = conv2Accuracy
    accuracies[3, 3] = inceptionAccuracy
    accuracies[4, 4] = resnetAccuracy
    fullScores = {}
    
    for i in range(len(keyLabels)):
        for j in range(len(keyLabels)):
            if accuracies[i, j] != 0:
                continue

            scores, accuracy = calcAccuracies(keyLabels[i], keyLabels[j], networks, greedy=greedy, repitition=repitition, use_bias=use_bias, fit_intercept=fit_intercept)
            accuracies[i, j] = accuracy
            fullScores[(keyLabels[i], keyLabels[j])] = scores
    
    return accuracies, fullScores

In [8]:
# Do not include the biases as elements to map
accuracies, _ = getAccuracies(use_bias=False, fit_intercept=False)

from Dense to Conv1
	 0 0.5218040495401673
	 1 0.5399280839436884
	 2 0.48055967802846056
	 3 0.5079179873155751
	 4 0.5437807991526866
	 5 0.5145375895568266
	 6 0.5108720374169542
	 7 0.5155813826138865
	 8 0.48575849414565214
	 9 0.5118821870123654
	 9527
from Dense to Conv2
	 0 0.4958516091568873
	 1 0.5320981744400942
	 2 0.47289181235661937
	 3 0.5086833739237702
	 4 0.551109989385051
	 5 0.494184827774396
	 6 0.5103991684261818
	 7 0.5151038675199342
	 8 0.47603509350782935
	 9 0.5048411746503333
	 9517
from Dense to Inception
	 0 0.08481188948905696
	 1 0.1266396392891022
	 2 0.12318482221059313
	 3 0.10491925172969527
	 4 0.12183428568667845
	 5 0.09732785996827667
	 6 0.09846037699802392
	 7 0.126230238716925
	 8 0.11760996170991463
	 9 0.12210644545848519
	 7437
from Dense to ResNet
	 0 0.18749705764152622
	 1 0.2913659851270455
	 2 0.19056603475939088
	 3 0.2615041739328434
	 4 0.23923684704031845
	 5 0.18154819726265628
	 6 0.2221711826503318
	 7 0.2463207838898032
	 8 0.1

In [9]:
print('Input matching, no bias')
table = list(accuracies)
for i in range(len(keyLabels)):
    table[i] = [keyLabels[i]] + list(table[i])
    
table = [[''] + keyLabels] + table

display(HTML(tabulate.tabulate(table, tablefmt='html')))

Input matching, no bias


0,1,2,3,4,5
,Dense,Conv1,Conv2,Inception,ResNet
Dense,9603,9527,9517,7437,6927
Conv1,9572,9807,9774,7979,7240
Conv2,9580,9783,9804,8105,7179
Inception,9032,9414,9452,9908,7028
ResNet,9375,9724,9720,7635,9892


In [16]:
# Do include the biases as elements to map
accuracies, _ = getAccuracies(use_bias=True, fit_intercept=False)

from Dense to Conv1
	 0 0.5222996713111914
	 1 0.5404961864700313
	 2 0.4815911154618716
	 3 0.5085001675616007
	 4 0.5439206439234548
	 5 0.5152989209591956
	 6 0.5111021158426358
	 7 0.5164359873448261
	 8 0.4863966170023193
	 9 0.512402004487399
	 9536
from Dense to Conv2
	 0 0.49702766669449283
	 1 0.5336192309652515
	 2 0.4768072954500311
	 3 0.5103158125622598
	 4 0.5519339766325063
	 5 0.49728858479788174
	 6 0.5116287433849783
	 7 0.5166430129773149
	 8 0.47711514189168247
	 9 0.5064392519902755
	 9519
from Dense to Inception
	 0 0.8842880719490174
	 1 0.8483860128594392
	 2 0.7890161952122041
	 3 0.7587503734650197
	 4 0.7801529222951854
	 5 0.7589316823389204
	 6 0.8551243342455553
	 7 0.7866316659404753
	 8 0.6967819818804166
	 9 0.7428354109849855
	 9290
from Dense to ResNet
	 0 0.7958437705074496
	 1 0.7382309526439395
	 2 0.5268988061190881
	 3 0.6276407142314155
	 4 0.7682251919527844
	 5 0.68114454285617
	 6 0.8086969715483111
	 7 0.6482507823471708
	 8 0.68965731184263

In [17]:
print('Input matching, bias')
table = list(accuracies)
for i in range(len(keyLabels)):
    table[i] = [keyLabels[i]] + list(table[i])
    
table = [[''] + keyLabels] + table

display(HTML(tabulate.tabulate(table, tablefmt='html')))

Input matching, bias


0,1,2,3,4,5
,Dense,Conv1,Conv2,Inception,ResNet
Dense,9603,9536,9519,9290,9068
Conv1,9567,9807,9776,9662,9588
Conv2,9562,9786,9804,9644,9579
Inception,8868,9488,9511,9908,9536
ResNet,9320,9738,9739,9838,9892
