In [1]:
import sys
sys.path.insert(0, '..')
import argparse
from typing import Dict, Any
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from data_utils import DATASIZE_DICT, FIGURE_DIR, RESULTS_DIR
from data_utils import DATASET_NAMES, TOPK_DICT
import seaborn as sns
import pandas as pd

RUNS = 100
LOG_FREQ = 100
METHOD_NAME_DICT = {#'non-active_no_prior': 'Non-active',
                    'non-active_uniform': 'non-active_uniform',
                    'non-active_informed': 'non-active_informed',
                    #'ts_uniform': 'TS',
                    'ts_informed': 'TS (informative)',
                    #'epsilon_greedy_no_prior': 'Epsilon greedy',
                    #'bayesian_ucb_no_prior': 'Bayesian UCB',
                    }
COLUMN_WIDTH = 3.25  # Inches
TEXT_WIDTH = 6.299213  # Inches
GOLDEN_RATIO = 1.61803398875
RESULTS_DIR = RESULTS_DIR + 'active_learning_topk/'
dataset_names = TOPK_DICT.keys()

In [2]:
pseudocount = 2


def compute(METRIC, MODE, TOPK, eval_metric):

    if METRIC == 'accuracy':
        METHOD_NAME_DICT = {'non-active_uniform': 'non-active_uniform',
                        'non-active_informed': 'non-active_informed',
                        'ts_uniform': 'TS',
                        'ts_informed': 'TS (informative)',
                        }
    elif METRIC == 'calibration_error':
        METHOD_NAME_DICT = {'non-active': 'Non-active',
                            'ts': 'TS'}

    counts = np.zeros((len(dataset_names), len(METHOD_NAME_DICT)))
    for i, dataset_name in enumerate(dataset_names):
        if TOPK:
            topk = TOPK_DICT[dataset_name]
        else:
            topk = 1
        experiment_name = '%s_%s_%s_top%d_runs%d_pseudocount%.2f/' % (dataset_name, METRIC, MODE, topk, RUNS, pseudocount)
        for j, method_name in enumerate(METHOD_NAME_DICT):
            metric_eval = np.load(RESULTS_DIR + experiment_name + ('%s_%s.npy' % (eval_metric, method_name)))
            metric_eval = np.mean(metric_eval, axis=0)
#             if dataset_name == 'imagenet':
#                 print(metric_eval)
#                 break
            metric_eval = np.argmax(metric_eval[5:] > min(0.99, metric_eval.max()*0.99)) + 5
            #metric_eval[metric_eval==0] = DATASIZE_DICT[dataset_name] / LOG_FREQ
            counts[i][j] = int(metric_eval * LOG_FREQ + LOG_FREQ) * 1.0 / DATASIZE_DICT[dataset_name]
    df = pd.DataFrame(np.round(counts.T*100, 1), 
                      index=METHOD_NAME_DICT.keys(), 
                      columns=dataset_names)
    return df

In [3]:
results = {}
results['accuracy_min_top1'] = compute('accuracy', 'min', False, 'mrr')
results['accuracy_min_topm'] = compute('accuracy', 'min', True, 'mrr')
results['ece_max_top1'] = compute('calibration_error', 'max', False, 'mrr')
results['ece_max_topm'] = compute('calibration_error', 'max', True, 'mrr')

In [12]:
dataset_print= {
    'cifar100': 'CIFAR-100',
    'imagenet': 'ImageNet',
    'svhn': 'SVHN',
    '20newsgroup': '20 Newsgroups',
    'dbpedia': 'DBpedia'
}
tasklist = ['accuracy_min_top1','accuracy_min_topm','ece_max_top1','ece_max_topm']

print('\\begin{tabular}{@{}rrrccccccccccccc@{}}')
print('\\toprule ')
print('& ')
print('& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top 1$}')
print('& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top m$}')
print('& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top 1$}')
print('& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top m$}\\\ ')
print('\cmidrule{4-6} \cmidrule{8-10} \cmidrule{12-13} \cmidrule{15-16}')
print('\multicolumn{2}{c}{Dataset} && R  &RI &TSI  && R &RI &TSI && R &TS && R &TS \\\ \midrule')
for i in dataset_print.keys():
    print('\multicolumn{2}{c}{%14s}  ' % dataset_print[i], end = '')
    vals = (results['accuracy_min_top1'][i]['non-active_uniform'],
            results['accuracy_min_top1'][i]['non-active_informed'],
            results['accuracy_min_top1'][i]['ts_informed'],
            results['accuracy_min_topm'][i]['non-active_uniform'],
            results['accuracy_min_topm'][i]['non-active_informed'],
            results['accuracy_min_topm'][i]['ts_informed'],
            results['ece_max_top1'][i]['non-active'],
            results['ece_max_top1'][i]['ts'],
            results['ece_max_topm'][i]['non-active'],
            results['ece_max_topm'][i]['ts'])
    print('&&%4.1f &%4.1f &%4.1f  &&%4.1f &%4.1f &%4.1f  &&%4.1f &%4.1f &&%4.1f &%4.1f\\' % vals, end = '')
    print('\\ \n', end = '');
print('\\bottomrule')
print('\\end{tabular}')

\begin{tabular}{@{}rrrccccccccccccc@{}}
\toprule 
& 
& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top 1$}
& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top m$}
& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top 1$}
& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top m$}\\ 
\cmidrule{4-6} \cmidrule{8-10} \cmidrule{12-13} \cmidrule{15-16}
\multicolumn{2}{c}{Dataset} && R  &RI &TSI  && R &RI &TSI && R &TS && R &TS \\ \midrule
\multicolumn{2}{c}{     CIFAR-100}  &&82.0 &82.0 &29.0  &&100.0 &100.0 &63.0  &&88.0 &43.0 &&90.0 &59.0\\ 
\multicolumn{2}{c}{      ImageNet}  &&99.4 &99.4 &12.8  &&100.0 &100.0 &20.2  &&89.6 &31.0 &&90.0 &41.2\\ 
\multicolumn{2}{c}{          SVHN}  &&65.7 &65.7 &66.5  &&94.1 &95.7 &94.9  &&58.8 &40.7 &&88.4 &77.6\\ 
\multicolumn{2}{c}{ 20 Newsgroups}  &&45.1 &45.1 &18.6  &&98.2 &98.2 &41.2  &&69.0 &27.9 &&90.3 &50.5\\ 
\multicolumn{2}{c}{       DBpedia}  &&14.4 &14.4 &11.7  &&89.6 &89.6 &52.9  &&27.9 & 8.1 &&89.1 &55.6\\ 
\bottomrule
\end{tabular}


In [5]:
dataset_print= {
    'cifar100': 'CIFAR-100',
    'imagenet': 'ImageNet',
    'svhn': 'SVHN',
    '20newsgroup': '20 Newsgroups',
    'dbpedia': 'DBpedia'
}
tasklist = ['accuracy_min_top1','accuracy_min_topm','ece_max_top1','ece_max_topm']

print('\\begin{tabular}{@{}rrrccccccccccc@{}}')
print('\\toprule ')
print('& ')
print('& \phantom{a} &  \multicolumn{2}{c}{$ACC, Top 1$}')
print('& \phantom{a} &  \multicolumn{2}{c}{$ACC, Top m$}')
print('& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top 1$}')
print('& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top m$}\\\ ')
print('\cmidrule{4-5} \cmidrule{7-8} \cmidrule{10-11} \cmidrule{13-14}')
print('\multicolumn{2}{c}{Dataset}  &&Random&TS  &&Random&TS &&Random&TS &&Random&TS \\\ \midrule')
for i in dataset_print.keys():
    print('\multicolumn{2}{c}{%12s}\t' % dataset_print[i], end = '')
    vals = (results['accuracy_min_top1'][i]['non-active_uniform'],
            results['accuracy_min_top1'][i]['ts_uniform'],
            results['accuracy_min_topm'][i]['non-active_uniform'],
            results['accuracy_min_topm'][i]['ts_uniform'],
            results['ece_max_top1'][i]['non-active'],
            results['ece_max_top1'][i]['ts'],
            results['ece_max_topm'][i]['non-active'],
            results['ece_max_topm'][i]['ts'])
    print('&&%4.1f &\\textbf{%4.1f} &&%4.1f &\\textbf{%4.1f} &&%4.1f &\\textbf{%4.1f} &&%4.1f &\\textbf{%4.1f}\\ ' % vals, end = '')
    print('\\\\ \n', end = '');
print('\\bottomrule')
print('\\end{tabular}')

\begin{tabular}{@{}rrrccccccccccc@{}}
\toprule 
& 
& \phantom{a} &  \multicolumn{2}{c}{$ACC, Top 1$}
& \phantom{a} &  \multicolumn{2}{c}{$ACC, Top m$}
& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top 1$}
& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top m$}\\ 
\cmidrule{4-5} \cmidrule{7-8} \cmidrule{10-11} \cmidrule{13-14}
\multicolumn{2}{c}{Dataset}  &&Random&TS  &&Random&TS &&Random&TS &&Random&TS \\ \midrule
\multicolumn{2}{c}{   CIFAR-100}	&&82.0 &\textbf{27.0} &&100.0 &\textbf{57.0} &&88.0 &\textbf{43.0} &&90.0 &\textbf{59.0}\ \\ 
\multicolumn{2}{c}{    ImageNet}	&&99.4 &\textbf{20.2} &&100.0 &\textbf{27.2} &&89.6 &\textbf{31.0} &&90.0 &\textbf{41.2}\ \\ 
\multicolumn{2}{c}{        SVHN}	&&65.7 &\textbf{35.0} &&94.1 &\textbf{81.4} &&58.8 &\textbf{40.7} &&88.4 &\textbf{77.6}\ \\ 
\multicolumn{2}{c}{20 Newsgroups}	&&45.1 &\textbf{14.6} &&98.2 &\textbf{33.2} &&69.0 &\textbf{27.9} &&90.3 &\textbf{50.5}\ \\ 
\multicolumn{2}{c}{     DBpedia}	&&14.4 &\textbf{ 7.7} &&89.6 &\textbf{50.7} &&27.9 &

In [6]:
dataset_names

dict_keys(['cifar100', 'imagenet', 'svhn', '20newsgroup', 'dbpedia'])

In [23]:
dataset_print= {
    'cifar100': 'CIFAR-100',
    'imagenet': 'ImageNet',
    'svhn': 'SVHN',
    '20newsgroup': '20 Newsgroups',
    'dbpedia': 'DBpedia'
}
tasklist = ['accuracy_min_top1','accuracy_min_topm']

print('\\begin{tabular}{@{}rrrccccccc@{}}')
print('\\toprule ')
print('& ')
print('& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top 1$}')
print('& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top m$}\\\ ')
print('\cmidrule{4-6} \cmidrule{8-10}')
print('\multicolumn{2}{c}{Dataset} && R  &RI &TSI  && R &RI &TSI  \\\ \midrule')
for i in dataset_print.keys():
    print('\multicolumn{2}{c}{%14s}  ' % dataset_print[i], end = '')
    vals = (results['accuracy_min_top1'][i]['non-active_uniform'],
            results['accuracy_min_top1'][i]['non-active_informed'],
            results['accuracy_min_top1'][i]['ts_informed'],
            results['accuracy_min_topm'][i]['non-active_uniform'],
            results['accuracy_min_topm'][i]['non-active_informed'],
            results['accuracy_min_topm'][i]['ts_informed'])
    print('&&%4.1f &%4.1f &%4.1f  &&%4.1f &%4.1f &%4.1f \\' % vals, end = '')
    print('\\ \n', end = '');
print('\\bottomrule')
print('\\end{tabular}')

\begin{tabular}{@{}rrrccccccc@{}}
\toprule 
& 
& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top 1$}
& \phantom{a} &  \multicolumn{3}{c}{$ACC, Top m$}\\ 
\cmidrule{4-6} \cmidrule{8-10}
\multicolumn{2}{c}{Dataset} && R  &RI &TSI  && R &RI &TSI  \\ \midrule
\multicolumn{2}{c}{     CIFAR-100}  &&82.0 &82.0 &29.0  &&100.0 &100.0 &63.0 \\ 
\multicolumn{2}{c}{      ImageNet}  &&99.4 &99.4 &12.8  &&100.0 &100.0 &20.2 \\ 
\multicolumn{2}{c}{          SVHN}  &&65.7 &65.7 &66.5  &&94.1 &95.7 &94.9 \\ 
\multicolumn{2}{c}{ 20 Newsgroups}  &&45.1 &45.1 &18.6  &&98.2 &98.2 &41.2 \\ 
\multicolumn{2}{c}{       DBpedia}  &&14.4 &14.4 &11.7  &&89.6 &89.6 &52.9 \\ 
\bottomrule
\end{tabular}


In [26]:
dataset_print= {
    'cifar100': 'CIFAR-100',
    'imagenet': 'ImageNet',
    'svhn': 'SVHN',
    '20newsgroup': '20 Newsgroups',
    'dbpedia': 'DBpedia'
}
tasklist = ['ece_max_top1','ece_max_topm']

print('\\begin{tabular}{@{}rrrccccc@{}}')
print('\\toprule ')
print('& ')
print('& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top 1$}')
print('& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top m$}\\\ ')
print('\cmidrule{4-5} \cmidrule{7-8}')
print('\multicolumn{2}{c}{Dataset} && R &TS && R &TS \\\ \midrule')
for i in dataset_print.keys():
    print('\multicolumn{2}{c}{%14s}  ' % dataset_print[i], end = '')
    vals = (results['ece_max_top1'][i]['non-active'],
            results['ece_max_top1'][i]['ts'],
            results['ece_max_topm'][i]['non-active'],
            results['ece_max_topm'][i]['ts'])
    print('&&%4.1f &%4.1f &&%4.1f &%4.1f\\' % vals, end = '')
    print('\\ \n', end = '');
print('\\bottomrule')
print('\\end{tabular}')

\begin{tabular}{@{}rrrccccc@{}}
\toprule 
& 
& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top 1$}
& \phantom{a} &  \multicolumn{2}{c}{$ECE, Top m$}\\ 
\cmidrule{4-5} \cmidrule{7-8}
\multicolumn{2}{c}{Dataset} && R &TS && R &TS \\ \midrule
\multicolumn{2}{c}{     CIFAR-100}  &&88.0 &43.0 &&90.0 &59.0\\ 
\multicolumn{2}{c}{      ImageNet}  &&89.6 &31.0 &&90.0 &41.2\\ 
\multicolumn{2}{c}{          SVHN}  &&58.8 &40.7 &&88.4 &77.6\\ 
\multicolumn{2}{c}{ 20 Newsgroups}  &&69.0 &27.9 &&90.3 &50.5\\ 
\multicolumn{2}{c}{       DBpedia}  &&27.9 & 8.1 &&89.1 &55.6\\ 
\bottomrule
\end{tabular}
