In [19]:
import sys
sys.path.insert(0, '..')
import argparse
from typing import Dict, Any
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from data_utils import DATASIZE_DICT, FIGURE_DIR, RESULTS_DIR
from data_utils import DATASET_NAMES, TOPK_DICT
import seaborn as sns
import pandas as pd
import pickle

RUNS = 100
LOG_FREQ = 100
METHOD_NAME_DICT = {#'non-active_no_prior': 'Non-active',
                    'non-active_uniform': 'non-active_uniform',
                    'non-active_informed': 'non-active_informed',
                    #'ts_uniform': 'TS',
                    'ts_informed': 'TS (informative)',
                    #'epsilon_greedy_no_prior': 'Epsilon greedy',
                    #'bayesian_ucb_no_prior': 'Bayesian UCB',
                    }
COLUMN_WIDTH = 3.25  # Inches
TEXT_WIDTH = 6.299213  # Inches
GOLDEN_RATIO = 1.61803398875
RESULTS_DIR = '../output/'
dataset_names = TOPK_DICT.keys()
group_method = 'predicted_class'

In [31]:
pseudocount = 2


def compute(METRIC, MODE, TOPK, eval_metric):
    
    if METRIC == 'accuracy':
        METHOD_NAME_LIST = ['random_arm', 'random_data', 'random_arm_informed', \
                            'random_data_informed', 'ts_uniform', 'ts_informed']
        if MODE == 'min':
            task = 'least_accurate/'
        else:
            task = 'most_accurate/'
    elif METRIC == 'calibration_error':
        METHOD_NAME_DICT = {'non-active': 'Non-active',
                            'ts': 'TS'}
        if MODE == 'min':
            task = 'least_biased/'
        else:
            task = 'most_biased/'
    counts = np.zeros((len(dataset_names), len(METHOD_NAME_LIST)))
    for i, dataset_name in enumerate(dataset_names):
        if TOPK:
            topk = TOPK_DICT[dataset_name]
        else:
            topk = 1
        experiment_name = '%s_groupby_%s_top%d_pseudocount%.2f/' % (dataset_name, group_method, topk, pseudocount)
        mrr_dict = pickle.load(open(RESULTS_DIR + task + experiment_name + ('mrr.pkl'), "rb" ))
        # method: num_runs, num_samples // LOG_FREQ
        for j, method_name in enumerate(METHOD_NAME_LIST):
            metric_eval = np.mean(mrr_dict[method_name], axis=1)
            metric_eval = np.argmax(metric_eval > min(0.99, metric_eval.max()*0.99)) + 5
            #metric_eval[metric_eval==0] = DATASIZE_DICT[dataset_name] / LOG_FREQ
            counts[i][j] = int(metric_eval * LOG_FREQ + LOG_FREQ) * 1.0 / DATASIZE_DICT[dataset_name]
    df = pd.DataFrame(np.round(counts.T*100, 1), 
                      index=METHOD_NAME_LIST, 
                      columns=dataset_names)
    return df

In [32]:
results = {}
results['accuracy_min_top1'] = compute('accuracy', 'min', False, 'mrr')
results['accuracy_min_topm'] = compute('accuracy', 'min', True, 'mrr')
# results['ece_max_top1'] = compute('calibration_error', 'max', False, 'mrr')
# results['ece_max_topm'] = compute('calibration_error', 'max', True, 'mrr')

In [30]:
dataset_print= {
    'cifar100': 'CIFAR-100',
    'imagenet': 'ImageNet',
    'svhn': 'SVHN',
    '20newsgroup': '20 Newsgroups',
    'dbpedia': 'DBpedia'
}
tasklist = ['accuracy_min_top1','accuracy_min_topm','ece_max_top1','ece_max_topm']

print('\\begin{tabular}{@{}rrrccccccccccccc@{}}')
print('\\toprule ')
print('& ')
print('& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top 1$}')
print('& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top m$}')
print('& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top 1$}')
print('& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top m$}\\\ ')
print('\cmidrule{4-6} \cmidrule{8-10} \cmidrule{12-13} \cmidrule{15-16}')
print('\multicolumn{2}{c}{Dataset} && R  &RI &TSI  && R &RI &TSI && R &TS && R &TS \\\ \midrule')
for i in dataset_print.keys():
    print('\multicolumn{2}{c}{%14s}  ' % dataset_print[i], end = '')
    vals = (results['accuracy_min_top1'][i]['random_data'],
            results['accuracy_min_top1'][i]['random_data_informed'],
            results['accuracy_min_top1'][i]['ts_informed'],
            results['accuracy_min_topm'][i]['random_data'],
            results['accuracy_min_topm'][i]['random_data_informed'],
            results['accuracy_min_topm'][i]['ts_informed'],
#             results['ece_max_top1'][i]['non-active'],
#             results['ece_max_top1'][i]['ts'],
#             results['ece_max_topm'][i]['non-active'],
#             results['ece_max_topm'][i]['ts'])
            results['accuracy_min_top1'][i]['ts_informed'],
            results['accuracy_min_topm'][i]['random_data'],
            results['accuracy_min_topm'][i]['random_data_informed'],
            results['accuracy_min_topm'][i]['ts_informed'])
    print('&&%4.1f &%4.1f &%4.1f  &&%4.1f &%4.1f &%4.1f  &&%4.1f &%4.1f &&%4.1f &%4.1f\\' % vals, end = '')
    print('\\ \n', end = '');
print('\\bottomrule')
print('\\end{tabular}')

\begin{tabular}{@{}rrrccccccccccccc@{}}
\toprule 
& 
& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top 1$}
& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top m$}
& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top 1$}
& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top m$}\\ 
\cmidrule{4-6} \cmidrule{8-10} \cmidrule{12-13} \cmidrule{15-16}
\multicolumn{2}{c}{Dataset} && R  &RI &TSI  && R &RI &TSI && R &TS && R &TS \\ \midrule
\multicolumn{2}{c}{     CIFAR-100}  &&816.0 &839.0 &254.0  &&1003.0 &1003.0 &556.0  &&254.0 &1003.0 &&1003.0 &556.0\\ 
\multicolumn{2}{c}{      ImageNet}  &&970.0 &948.2 &94.4  &&997.2 &985.8 &172.2  &&94.4 &997.2 &&985.8 &172.2\\ 
\multicolumn{2}{c}{          SVHN}  &&906.6 &900.0 &830.1  &&1001.8 &1001.5 &962.3  &&830.1 &1001.8 &&1001.5 &962.3\\ 
\multicolumn{2}{c}{ 20 Newsgroups}  &&545.7 &560.3 &175.3  &&926.7 &932.0 &431.5  &&175.3 &926.7 &&932.0 &431.5\\ 
\multicolumn{2}{c}{       DBpedia}  &&80.7 &76.3 &117.1  &&920.1 &902.3 &572.0  &&117.1 &920.1 &&902.3 &572.0\\ 
\bottomrule
