# Results analysis for ANN recognition with ANN masks

This notebook analyzes results for for the ANN recognition experiments with ANN masks done for the NeurIPS rebuttal.

To begin, follow the instructions in the notebook `create_ANN_recognition_data_ANN_masks.ipynb` to generate the data, then run the script `run_nn_recognition.sh` to produce outputs.

In [1]:
import numpy as np
from os.path import dirname, join as pjoin
import scipy.io as sio
import matplotlib.pyplot as plt
import fnmatch
import os
import itertools
import pandas as pd
import math

Read in the data.

In [2]:
results_dir = dirname('NN_recognition_outputs/')
results_files = os.listdir(results_dir)
assert results_files
results_files

['ImageNet_attention-branch-network_ResNet-101.txt',
 'ImageNet_baseline-cnns_AlexNet.txt',
 'CIFAR-100_baseline-cnns_VGG-19-BN.txt',
 'CIFAR-100_attention-branch-network_DenseNet-BC.txt',
 'CIFAR-100_baseline-cnns_ResNet-110.txt',
 'ImageNet_baseline-cnns_VGG-16-BN.txt',
 'ImageNet_baseline-cnns_EfficientNet-B0.txt',
 'CIFAR-100_attention-branch-network_ResNet-110.txt',
 'ImageNet_baseline-cnns_ResNet-101.txt',
 'CIFAR-100_baseline-cnns_AlexNet.txt',
 'CIFAR-100_learn-to-pay-attention_VGG.txt']

In [3]:
def combine_dicts(d1, d2):
    return {**d1, **d2}

        
def confs_to_ranking(confs):
    order = confs.argsort()[::-1]
    ranks = order.argsort()
    return ranks


def read_outputs(fname, fdir):
    with open(os.path.join(fdir, fname), 'r') as f:
        lines = f.readlines()

    def parse_line(line):
        line = line.strip('\n').split('\t')
        (_, condition, img_mask), confs = line[0].split('/'), line[1:]
        
        img, mask = img_mask.split('_')
        img = int(img[3:])
        if 'nomask' in mask:
            mask = -1
        else:
            mask = int(mask[:-4][4:])
        
        hparams = dict(h.split('=') for h in condition.split('_'))
        hparams['rotation'] = int(hparams['rotation'])
        hparams['exp'] = float(hparams['exp'])
        
        confs = np.array(confs, dtype=float)
             
        return combine_dicts({
            "img_id": img,
            "mask_id": mask,
            "confs": confs,
            "ranks": confs_to_ranking(confs),     
        }, hparams)

    parsed_lines = list(map(parse_line, lines))
    
    dataset, attention_type, model = fname[:-4].split('_')
    run_metadata = {
        'dataset': dataset, 
        'attention_type': attention_type, 
        'model': model,
    }
    parsed_lines = [combine_dicts(run_metadata, s) for s in parsed_lines]
    
    return parsed_lines

In [4]:
all_results = []
for results_file in results_files:
    all_results += read_outputs(results_file, results_dir)
data = pd.DataFrame(all_results)
data

Unnamed: 0,dataset,attention_type,model,img_id,mask_id,confs,ranks,mask,null,exp,threshold,rotation
0,ImageNet,attention-branch-network,ResNet-101,10,10,"[6.2831114e-06, 0.00017957437, 1.5464173e-05, ...","[849, 200, 654, 774, 502, 745, 982, 196, 802, ...",alexnet-gradcam11,black,1.0,,0
1,ImageNet,attention-branch-network,ResNet-101,10,115,"[3.6679044e-06, 0.000112102236, 1.8855877e-05,...","[987, 400, 766, 852, 522, 933, 745, 342, 814, ...",alexnet-gradcam11,black,1.0,,0
2,ImageNet,attention-branch-network,ResNet-101,10,125,"[2.37085e-06, 0.0002139487, 3.4340806e-06, 5.0...","[969, 241, 942, 904, 595, 931, 959, 152, 615, ...",alexnet-gradcam11,black,1.0,,0
3,ImageNet,attention-branch-network,ResNet-101,10,126,"[1.044508e-05, 9.635016e-05, 1.0928367e-05, 9....","[803, 394, 796, 816, 493, 867, 878, 330, 897, ...",alexnet-gradcam11,black,1.0,,0
4,ImageNet,attention-branch-network,ResNet-101,10,134,"[3.081696e-05, 0.0023515741, 7.4542455e-05, 6....","[898, 81, 748, 769, 534, 495, 825, 190, 557, 7...",alexnet-gradcam11,black,1.0,,0
...,...,...,...,...,...,...,...,...,...,...,...,...
14295,CIFAR-100,learn-to-pay-attention,VGG,95,790,"[0.0070521804, 0.006282379, 0.0071532615, 0.01...","[62, 72, 60, 10, 28, 83, 54, 30, 85, 5, 49, 74...",alexnet-sggbp,black,1.0,,0
14296,CIFAR-100,learn-to-pay-attention,VGG,95,81,"[0.002426063, 0.002831061, 0.003621754, 0.0210...","[86, 73, 51, 2, 32, 72, 77, 19, 85, 35, 78, 95...",alexnet-sggbp,black,1.0,,0
14297,CIFAR-100,learn-to-pay-attention,VGG,95,90,"[0.005813017, 0.0035832888, 0.003362819, 0.003...","[43, 78, 84, 89, 52, 80, 90, 41, 5, 1, 50, 92,...",alexnet-sggbp,black,1.0,,0
14298,CIFAR-100,learn-to-pay-attention,VGG,95,95,"[0.0020077196, 0.00403631, 0.001990376, 0.0034...","[92, 51, 93, 64, 10, 78, 37, 21, 79, 50, 94, 9...",alexnet-sggbp,black,1.0,,0


Analyze the results.

In [5]:
def ranking_distance(unmasked_ranks, correctly_masked_ranks, incorrectly_masked_ranks, k=10):
    # Get (unsorted) indices of top k. 
    idx = np.argpartition(unmasked_ranks, k)[:k]

    # Get the rankings at those locations.
    unmasked_top_k_ranks = unmasked_ranks[idx]
    correct_top_k_ranks = correctly_masked_ranks[idx]
    incorrect_top_k_ranks = incorrectly_masked_ranks[idx]

    # Do not normalize because we report fixed k.
    # binomial_coeff = math.comb(k, 2)
    
    return (
        sum(correct_top_k_ranks - unmasked_top_k_ranks),
        sum(incorrect_top_k_ranks - unmasked_top_k_ranks),
    )

In [6]:
# Unmasked data is duplicated across hparams and masks, but this makes it easier to iterate over.
unmasked_data = data[(data['mask_id'] == -1) & (data['rotation'] == 0)]
unmasked_data

Unnamed: 0,dataset,attention_type,model,img_id,mask_id,confs,ranks,mask,null,exp,threshold,rotation
25,ImageNet,attention-branch-network,ResNet-101,10,-1,"[5.0242652e-05, 2.3426071e-05, 3.1531763e-05, ...","[401, 605, 529, 968, 577, 896, 997, 432, 917, ...",alexnet-gradcam11,black,1.0,,0
51,ImageNet,attention-branch-network,ResNet-101,115,-1,"[3.0585632e-06, 7.050344e-07, 9.1366826e-07, 1...","[84, 275, 223, 38, 122, 24, 37, 906, 984, 652,...",alexnet-gradcam11,black,1.0,,0
77,ImageNet,attention-branch-network,ResNet-101,125,-1,"[1.4614342e-07, 2.5749837e-06, 2.9733073e-07, ...","[965, 599, 912, 836, 738, 556, 929, 734, 881, ...",alexnet-gradcam11,black,1.0,,0
103,ImageNet,attention-branch-network,ResNet-101,126,-1,"[3.1096151e-06, 2.2654494e-06, 6.363943e-08, 2...","[365, 422, 974, 840, 957, 669, 983, 560, 905, ...",alexnet-gradcam11,black,1.0,,0
129,ImageNet,attention-branch-network,ResNet-101,134,-1,"[2.2633388e-09, 1.5657108e-08, 6.1017985e-10, ...","[846, 575, 962, 944, 950, 347, 983, 760, 866, ...",alexnet-gradcam11,black,1.0,,0
...,...,...,...,...,...,...,...,...,...,...,...,...
14195,CIFAR-100,learn-to-pay-attention,VGG,76,-1,"[0.006354826, 0.013585589, 0.0051908703, 0.005...","[53, 14, 67, 64, 22, 7, 79, 48, 11, 18, 84, 59...",alexnet-sggbp,black,1.0,,0
14221,CIFAR-100,learn-to-pay-attention,VGG,790,-1,"[0.00040577984, 0.00068223744, 0.0022077095, 0...","[87, 49, 15, 44, 98, 4, 51, 65, 25, 32, 93, 50...",alexnet-sggbp,black,1.0,,0
14247,CIFAR-100,learn-to-pay-attention,VGG,81,-1,"[0.0002182761, 0.00020017404, 5.1607694e-05, 0...","[39, 42, 92, 14, 22, 2, 67, 87, 33, 4, 93, 7, ...",alexnet-sggbp,black,1.0,,0
14273,CIFAR-100,learn-to-pay-attention,VGG,90,-1,"[0.0011341472, 0.0017544234, 0.0013007402, 0.0...","[85, 49, 74, 68, 35, 15, 79, 51, 48, 30, 50, 7...",alexnet-sggbp,black,1.0,,0


In [7]:
correctly_masked_idx = (
    (data['img_id'] == data['mask_id']) & 
    (data['rotation'] == 0)
)
data[correctly_masked_idx]

Unnamed: 0,dataset,attention_type,model,img_id,mask_id,confs,ranks,mask,null,exp,threshold,rotation
0,ImageNet,attention-branch-network,ResNet-101,10,10,"[6.2831114e-06, 0.00017957437, 1.5464173e-05, ...","[849, 200, 654, 774, 502, 745, 982, 196, 802, ...",alexnet-gradcam11,black,1.0,,0
27,ImageNet,attention-branch-network,ResNet-101,115,115,"[1.2779348e-05, 0.0016233105, 0.00011320456, 3...","[996, 127, 699, 935, 354, 881, 934, 137, 285, ...",alexnet-gradcam11,black,1.0,,0
54,ImageNet,attention-branch-network,ResNet-101,125,125,"[3.4509048e-06, 0.00027638735, 2.8200178e-05, ...","[995, 329, 798, 922, 784, 845, 993, 347, 549, ...",alexnet-gradcam11,black,1.0,,0
81,ImageNet,attention-branch-network,ResNet-101,126,126,"[9.792601e-05, 0.00157715, 0.000108135624, 5.1...","[735, 122, 714, 865, 421, 842, 750, 312, 611, ...",alexnet-gradcam11,black,1.0,,0
108,ImageNet,attention-branch-network,ResNet-101,134,134,"[8.651271e-08, 1.1304604e-05, 6.064055e-06, 1....","[999, 597, 722, 919, 786, 932, 993, 412, 445, ...",alexnet-gradcam11,black,1.0,,0
...,...,...,...,...,...,...,...,...,...,...,...,...
14190,CIFAR-100,learn-to-pay-attention,VGG,76,76,"[0.00567973, 0.0058642714, 0.0056418753, 0.006...","[76, 72, 77, 68, 25, 37, 96, 82, 69, 35, 88, 8...",alexnet-sggbp,black,1.0,,0
14217,CIFAR-100,learn-to-pay-attention,VGG,790,790,"[0.007526817, 0.0069096256, 0.007149, 0.009399...","[58, 68, 64, 32, 42, 43, 26, 65, 97, 6, 27, 40...",alexnet-sggbp,black,1.0,,0
14244,CIFAR-100,learn-to-pay-attention,VGG,81,81,"[0.006737825, 0.0073460834, 0.007913602, 0.022...","[84, 75, 62, 1, 38, 60, 85, 33, 70, 35, 71, 98...",alexnet-sggbp,black,1.0,,0
14271,CIFAR-100,learn-to-pay-attention,VGG,90,90,"[0.0056537213, 0.0042599514, 0.0033437524, 0.0...","[47, 75, 93, 63, 36, 76, 97, 54, 7, 2, 58, 89,...",alexnet-sggbp,black,1.0,,0


In [8]:
incorrectly_masked_idx = (
    (data['mask_id'] != -1) &
    (data['img_id'] != data['mask_id'])
)
data[incorrectly_masked_idx]

Unnamed: 0,dataset,attention_type,model,img_id,mask_id,confs,ranks,mask,null,exp,threshold,rotation
1,ImageNet,attention-branch-network,ResNet-101,10,115,"[3.6679044e-06, 0.000112102236, 1.8855877e-05,...","[987, 400, 766, 852, 522, 933, 745, 342, 814, ...",alexnet-gradcam11,black,1.0,,0
2,ImageNet,attention-branch-network,ResNet-101,10,125,"[2.37085e-06, 0.0002139487, 3.4340806e-06, 5.0...","[969, 241, 942, 904, 595, 931, 959, 152, 615, ...",alexnet-gradcam11,black,1.0,,0
3,ImageNet,attention-branch-network,ResNet-101,10,126,"[1.044508e-05, 9.635016e-05, 1.0928367e-05, 9....","[803, 394, 796, 816, 493, 867, 878, 330, 897, ...",alexnet-gradcam11,black,1.0,,0
4,ImageNet,attention-branch-network,ResNet-101,10,134,"[3.081696e-05, 0.0023515741, 7.4542455e-05, 6....","[898, 81, 748, 769, 534, 495, 825, 190, 557, 7...",alexnet-gradcam11,black,1.0,,0
5,ImageNet,attention-branch-network,ResNet-101,10,146,"[4.83232e-06, 4.7614165e-05, 6.239172e-06, 2.1...","[746, 316, 698, 891, 420, 971, 932, 174, 768, ...",alexnet-gradcam11,black,1.0,,0
...,...,...,...,...,...,...,...,...,...,...,...,...
14293,CIFAR-100,learn-to-pay-attention,VGG,95,695,"[0.0065383264, 0.0054073436, 0.0072237127, 0.0...","[68, 90, 56, 84, 54, 97, 49, 43, 29, 41, 60, 3...",alexnet-sggbp,black,1.0,,0
14294,CIFAR-100,learn-to-pay-attention,VGG,95,76,"[0.005660284, 0.008285698, 0.0057829027, 0.007...","[78, 37, 76, 53, 18, 69, 84, 85, 65, 61, 93, 8...",alexnet-sggbp,black,1.0,,0
14295,CIFAR-100,learn-to-pay-attention,VGG,95,790,"[0.0070521804, 0.006282379, 0.0071532615, 0.01...","[62, 72, 60, 10, 28, 83, 54, 30, 85, 5, 49, 74...",alexnet-sggbp,black,1.0,,0
14296,CIFAR-100,learn-to-pay-attention,VGG,95,81,"[0.002426063, 0.002831061, 0.003621754, 0.0210...","[86, 73, 51, 2, 32, 72, 77, 19, 85, 35, 78, 95...",alexnet-sggbp,black,1.0,,0


In [9]:
results = []

for unmasked_row in unmasked_data.itertuples():
    
    img_id = getattr(unmasked_row, 'img_id')
    dataset = getattr(unmasked_row, 'dataset')
    attention_type = getattr(unmasked_row, 'attention_type')
    model = getattr(unmasked_row, 'model')
    mask = getattr(unmasked_row, 'mask')

    hparam_match_idx = (
        (data['dataset'] == dataset) &
        (data['attention_type'] == attention_type) &
        (data['model'] == model) &
        (data['mask'] == mask) &
        (data['null'] == getattr(unmasked_row, 'null')) &
        (data['exp'] == getattr(unmasked_row, 'exp')) &
        (data['threshold'] == getattr(unmasked_row, 'threshold')) &     
        # Average only over the null rotation.
        (data['rotation'] == 0)
    )
    
    correctly_masked_row = data[(data['img_id'] == img_id) & correctly_masked_idx & hparam_match_idx]
    assert len(correctly_masked_row) == 1
    correctly_masked_row = correctly_masked_row.iloc[0]
    
    incorrectly_masked_rows = data[(data['img_id'] == img_id) & incorrectly_masked_idx & hparam_match_idx]

    k_dists = {}
    for k in range(1, 5+1):
        cumu_dist = np.array([0. ,0.])
        count = 0
        for incorrectly_masked_row in incorrectly_masked_rows.itertuples():
            dist = ranking_distance(
                getattr(unmasked_row, 'ranks'),
                getattr(correctly_masked_row, 'ranks'), 
                getattr(incorrectly_masked_row, 'ranks'),
                k=k)
            cumu_dist += np.array(dist)
            count += 1
        k_dists[k] = cumu_dist/count

    results += [{**{
        'dataset': dataset,
        'attention_type': attention_type,
        'model': model,
        'img_id': img_id,
        'mask': mask,
        }, **{'{}-rank distance (correct mask)'.format(k): k_dists[k][0] for k in k_dists}, 
           **{'{}-rank distance (incorrect mask)'.format(k): k_dists[k][1] for k in k_dists}
    }]

results = pd.DataFrame(results)
results.to_csv("NN_recognition_results.csv", index=False)
results

Unnamed: 0,dataset,attention_type,model,img_id,mask,1-rank distance (correct mask),2-rank distance (correct mask),3-rank distance (correct mask),4-rank distance (correct mask),5-rank distance (correct mask),1-rank distance (incorrect mask),2-rank distance (incorrect mask),3-rank distance (incorrect mask),4-rank distance (incorrect mask),5-rank distance (incorrect mask)
0,ImageNet,attention-branch-network,ResNet-101,10,alexnet-gradcam11,2.0,2.0,0.0,50.0,76.0,19.416667,68.541667,71.208333,156.458333,229.125000
1,ImageNet,attention-branch-network,ResNet-101,115,alexnet-gradcam11,399.0,1232.0,2101.0,2582.0,3550.0,486.041667,1275.791667,2035.083333,2459.000000,3255.958333
2,ImageNet,attention-branch-network,ResNet-101,125,alexnet-gradcam11,423.0,434.0,452.0,453.0,599.0,375.541667,417.625000,600.750000,627.750000,755.666667
3,ImageNet,attention-branch-network,ResNet-101,126,alexnet-gradcam11,445.0,794.0,810.0,951.0,1198.0,338.250000,454.250000,458.875000,545.208333,723.333333
4,ImageNet,attention-branch-network,ResNet-101,134,alexnet-gradcam11,1.0,3.0,19.0,115.0,127.0,24.083333,86.083333,163.875000,528.000000,869.916667
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
545,CIFAR-100,learn-to-pay-attention,VGG,76,alexnet-sggbp,38.0,79.0,100.0,101.0,130.0,40.625000,57.708333,100.458333,103.625000,146.375000
546,CIFAR-100,learn-to-pay-attention,VGG,790,alexnet-sggbp,10.0,65.0,135.0,133.0,172.0,42.125000,73.250000,150.541667,162.875000,211.125000
547,CIFAR-100,learn-to-pay-attention,VGG,81,alexnet-sggbp,19.0,21.0,79.0,150.0,181.0,5.458333,21.500000,78.166667,139.416667,153.541667
548,CIFAR-100,learn-to-pay-attention,VGG,90,alexnet-sggbp,79.0,151.0,166.0,198.0,260.0,66.708333,148.250000,191.541667,245.375000,308.833333
