In [45]:
import sys
sys.path.insert(0, '..')
import argparse
from typing import Dict, Any
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from data_utils import DATASIZE_DICT, FIGURE_DIR, RESULTS_DIR
from data_utils import DATASET_NAMES, TOPK_DICT
import seaborn as sns
import pandas as pd
import pickle

RUNS = 100
LOG_FREQ = 10
COLUMN_WIDTH = 3.25  # Inches
TEXT_WIDTH = 6.299213  # Inches
GOLDEN_RATIO = 1.61803398875
RESULTS_DIR = '../output/'
dataset_names = TOPK_DICT.keys()
group_method = 'predicted_class'

In [46]:
pseudocount = 2


def compute(METRIC, MODE, TOPK, eval_metric):
    METHOD_NAME_LIST = ['random_arm', 'random_data', 'random_arm_informed', \
                            'random_data_informed', 'ts_uniform', 'ts_informed']
    task =  {('accuracy', 'min'): 'least_accurate/',
             ('accuracy', 'max'): 'most_accurate/',
             ('ece', 'min'): 'most_biased/',
             ('ece', 'min'): 'least_biased/',
            }[(METRIC, MODE)]
    counts = np.zeros((len(dataset_names), len(METHOD_NAME_LIST)))
    for i, dataset_name in enumerate(dataset_names):
        if TOPK:
            topk = TOPK_DICT[dataset_name]
        else:
            topk = 1
        experiment_name = '%s_groupby_%s_top%d_pseudocount%.2f/' % (dataset_name, group_method, topk, pseudocount)
        mrr_dict = pickle.load(open(RESULTS_DIR + task + experiment_name + ('mrr.pkl'), "rb" ))
        # method: num_runs, num_samples // LOG_FREQ
        for j, method_name in enumerate(METHOD_NAME_LIST):
            metric_eval = np.mean(mrr_dict[method_name], axis=1)
            metric_eval = np.argmax(metric_eval > min(0.99, metric_eval.max()*0.99))
            print(metric_eval)
            counts[i][j] = int(metric_eval * LOG_FREQ + LOG_FREQ) * 1.0 / DATASIZE_DICT[dataset_name]
    df = pd.DataFrame(np.round(counts.T*100, 1), 
                      index=METHOD_NAME_LIST, 
                      columns=dataset_names)
    return df

In [47]:
DATASIZE_DICT

{'cifar100': 10000,
 'imagenet': 50000,
 'imagenet2_topimages': 10000,
 '20newsgroup': 7532,
 'svhn': 26032,
 'dbpedia': 70000}

In [48]:
results = {}
results['accuracy_min_top1'] = compute('accuracy', 'min', False, 'mrr')
results['accuracy_min_topm'] = compute('accuracy', 'min', True, 'mrr')
# results['ece_max_top1'] = compute('calibration_error', 'max', False, 'mrr')
# results['ece_max_topm'] = compute('calibration_error', 'max', True, 'mrr')

894
810
941
833
261
248
4560
4844
4888
4735
680
466
1798
2354
1845
2337
736
2155
292
405
294
416
97
126
596
559
564
528
320
814
999
997
999
997
541
550
4969
4980
4954
4923
1238
855
2353
2602
2355
2601
2047
2499
735
692
733
696
235
319
6449
6435
6330
6310
3487
3998


In [49]:
dataset_print= {
    'cifar100': 'CIFAR-100',
    'imagenet': 'ImageNet',
    'svhn': 'SVHN',
    '20newsgroup': '20 Newsgroups',
    'dbpedia': 'DBpedia'
}
tasklist = ['accuracy_min_top1','accuracy_min_topm','ece_max_top1','ece_max_topm']

print('\\begin{tabular}{@{}rrrccccccccccccc@{}}')
print('\\toprule ')
print('& ')
print('& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top 1$}')
print('& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top m$}')
print('& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top 1$}')
print('& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top m$}\\\ ')
print('\cmidrule{4-6} \cmidrule{8-10} \cmidrule{12-13} \cmidrule{15-16}')
print('\multicolumn{2}{c}{Dataset} && R  &RI &TSI  && R &RI &TSI && R &TS && R &TS \\\ \midrule')
for i in dataset_print.keys():
    print('\multicolumn{2}{c}{%14s}  ' % dataset_print[i], end = '')
    vals = (results['accuracy_min_top1'][i]['random_data'],
            results['accuracy_min_top1'][i]['random_data_informed'],
            results['accuracy_min_top1'][i]['ts_informed'],
            results['accuracy_min_topm'][i]['random_data'],
            results['accuracy_min_topm'][i]['random_data_informed'],
            results['accuracy_min_topm'][i]['ts_informed'],
#             results['ece_max_top1'][i]['non-active'],
#             results['ece_max_top1'][i]['ts'],
#             results['ece_max_topm'][i]['non-active'],
#             results['ece_max_topm'][i]['ts'])
            results['accuracy_min_top1'][i]['ts_informed'],
            results['accuracy_min_topm'][i]['random_data'],
            results['accuracy_min_topm'][i]['random_data_informed'],
            results['accuracy_min_topm'][i]['ts_informed'])
    print('&&%4.1f &%4.1f &%4.1f  &&%4.1f &%4.1f &%4.1f  &&%4.1f &%4.1f &&%4.1f &%4.1f\\' % vals, end = '')
    print('\\ \n', end = '');
print('\\bottomrule')
print('\\end{tabular}')

\begin{tabular}{@{}rrrccccccccccccc@{}}
\toprule 
& 
& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top 1$}
& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top m$}
& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top 1$}
& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top m$}\\ 
\cmidrule{4-6} \cmidrule{8-10} \cmidrule{12-13} \cmidrule{15-16}
\multicolumn{2}{c}{Dataset} && R  &RI &TSI  && R &RI &TSI && R &TS && R &TS \\ \midrule
\multicolumn{2}{c}{     CIFAR-100}  &&81.1 &83.4 &24.9  &&99.8 &99.8 &55.1  &&24.9 &99.8 &&99.8 &55.1\\ 
\multicolumn{2}{c}{      ImageNet}  &&96.9 &94.7 & 9.3  &&99.6 &98.5 &17.1  && 9.3 &99.6 &&98.5 &17.1\\ 
\multicolumn{2}{c}{          SVHN}  &&90.5 &89.8 &82.8  &&100.0 &100.0 &96.0  &&82.8 &100.0 &&100.0 &96.0\\ 
\multicolumn{2}{c}{ 20 Newsgroups}  &&53.9 &55.4 &16.9  &&92.0 &92.5 &42.5  &&16.9 &92.0 &&92.5 &42.5\\ 
\multicolumn{2}{c}{       DBpedia}  && 8.0 & 7.6 &11.6  &&91.9 &90.2 &57.1  &&11.6 &91.9 &&90.2 &57.1\\ 
\bottomrule
\end{tabular}


In [53]:
dataset_print= {
    'cifar100': 'CIFAR-100',
    'imagenet': 'ImageNet',
    'svhn': 'SVHN',
    '20newsgroup': '20 Newsgroups',
    'dbpedia': 'DBpedia'
}
tasklist = ['accuracy_min_top1','accuracy_min_topm']

print('\\begin{tabular}{@{}rrrccccccc@{}}')
print('\\toprule ')
print('& ')
print('& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top 1$}')
print('& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top m$}\\\ ')
print('\cmidrule{4-6} \cmidrule{8-10}')
print('\multicolumn{2}{c}{Dataset} && R  &RI &TSI  && R &RI &TSI  \\\ \midrule')
for i in dataset_print.keys():
    print('\multicolumn{2}{c}{%14s}  ' % dataset_print[i], end = '')
    vals =(results['accuracy_min_top1'][i]['random_data'],
            results['accuracy_min_top1'][i]['random_data_informed'],
            results['accuracy_min_top1'][i]['ts_informed'],
            results['accuracy_min_topm'][i]['random_data'],
            results['accuracy_min_topm'][i]['random_data_informed'],
            results['accuracy_min_topm'][i]['ts_informed'],)
    print('&&%4.1f &%4.1f &%4.1f  &&%4.1f &%4.1f &%4.1f \\' % vals, end = '')
    print('\\ \n', end = '');
print('\\bottomrule')
print('\\end{tabular}')

\begin{tabular}{@{}rrrccccccc@{}}
\toprule 
& 
& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top 1$}
& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top m$}\\ 
\cmidrule{4-6} \cmidrule{8-10}
\multicolumn{2}{c}{Dataset} && R  &RI &TSI  && R &RI &TSI  \\ \midrule
\multicolumn{2}{c}{     CIFAR-100}  &&81.1 &83.4 &24.9  &&99.8 &99.8 &55.1 \\ 
\multicolumn{2}{c}{      ImageNet}  &&96.9 &94.7 & 9.3  &&99.6 &98.5 &17.1 \\ 
\multicolumn{2}{c}{          SVHN}  &&90.5 &89.8 &82.8  &&100.0 &100.0 &96.0 \\ 
\multicolumn{2}{c}{ 20 Newsgroups}  &&53.9 &55.4 &16.9  &&92.0 &92.5 &42.5 \\ 
\multicolumn{2}{c}{       DBpedia}  && 8.0 & 7.6 &11.6  &&91.9 &90.2 &57.1 \\ 
\bottomrule
\end{tabular}
