In [4]:
# This notebook computes and plots the results for Table 2 in the paper, including also the EC values
# which are not included in the table to avoid clutter.

import matplotlib.pyplot as plt
import numpy as np
from expected_cost import ec, utils
from expected_cost.data import get_llks_for_multi_classif_task
from sklearn.metrics import roc_curve, roc_auc_score, f1_score, precision_recall_curve, accuracy_score, recall_score, precision_score
import re

outdir = "outputs/"
utils.mkdir_p(outdir)

In [5]:
p_agnews = 0.01
p_iemocap = 0.05
data_dir = '../data/'

datasets = {'SST2 GPT2-4sh':         'sst2_gpt2_4shot',
            'SST2 GPT2-0sh':         'sst2_gpt2',
            'SST2-Res1 GPT2-0sh':   ('sst2_gpt2', np.array([0.7, 0.3]), None),
            'SST2-Res2 GPT2-0sh':   ('sst2_gpt2', np.array([0.8, 0.2]), None),
            'SITW XvPLDA':           'sitw_plda',
            'SITW-Res1 XvPLDA':     ('sitw_plda', np.array([0.90, 0.10]), None),
            'SITW-Res2 XvPLDA':     ('sitw_plda', np.array([0.80, 0.20]), None),
            'FVCAUS XvPLDA':         'fvcaus_plda',
            'CIFAR-1vsO Resnet-20': ('cifar100_resnet-20/', None, 1),
            'CIFAR-2vsO Resnet-20': ('cifar100_resnet-20/', None, 2),
            'IEMOCAP W2V2':          'iemocap_wav2vec_pt',
            'IEMOCAP-Res1 W2V2':    ('iemocap_wav2vec_pt', np.array([(1-p_iemocap)/3] * 3 + [p_iemocap] ), None),
            'AGNEWS GPT2-0sh':       'agnews_gpt2',
            'AGNEWS-Res1 GPT2-0sh': ('agnews_gpt2', np.array([(1-p_agnews)/3] * 3 + [p_agnews] ), None),
            'CIFAR10 Resnet-20':     'cifar10_resnet-20/',
            'CIFAR10 Vgg19':         'cifar10_vgg19_bn/',
            'CIFAR10 RepVgg-a2':     'cifar10_repvgg_a2/',
            'CIFAR100 Resnet-20':    'cifar100_resnet-20/',
            'CIFAR100 Vgg19':        'cifar100_vgg19_bn/',
            'CIFAR100 RepVgg-a2':    'cifar100_repvgg_a2/'}

# Field separators for printing to screen or in latex format for the paper
print_style = ''

if print_style == 'latex':
    sep = ' & '
    sep2 = ''
    newline = '\\\\'
else:
    sep = '  ' 
    sep2 = '|' 
    newline = ''

first = True
targets_dict = {}

for dname_raw, dinfo in datasets.items():
    
    data_name = re.sub(' .*','',dname_raw)
    system_name = re.sub('.* ','',dname_raw)

    if type(dinfo) == tuple:
        (dpath, priors, one_vs_other) = dinfo
        targets, logpost_raw, logpost_cal = get_llks_for_multi_classif_task(data_dir+dpath, logpost=True, train_cal_on_test=False, priors=priors, one_vs_other=one_vs_other)
    else:
        dpath = dinfo
        targets, logpost_raw, logpost_cal = get_llks_for_multi_classif_task(data_dir+dpath, logpost=True, train_cal_on_test=False)
    
    targets_dict[dname_raw] = targets

    K = logpost_cal.shape[1]
    counts = np.bincount(targets)
    priors = counts/len(targets)
        
    # Define various costs matrices
    costs = {}

    # Standard 0-1 cost
    costs['C01'] = ec.CostMatrix.zero_one_costs(K)

    # Balanced error rate. The costs are inversely proportional to the priors.
    costm = (1 - np.eye(K))/np.atleast_2d(priors).T/K
    costs['CinvP']  = ec.CostMatrix(costm)

    # A generalization of the EC4 for the binary case
    costm = 1 - np.eye(K)
    costm[-1,:] *= 100
    costs['Cimb']  = ec.CostMatrix(costm)

    if print_style != 'latex':
        if first is True:
            print(f'\nSystem                {sep} MinPrior {sep}', end='')
            for costn, cost in costs.items():
                print(f" {sep2}  {costn:11s}{sep}", end='')
                if 'abs' in costn:
                    print(f"  Abs% ", end='')
            print(f'{sep2} 1-ACC   {sep}{sep2} 1-AveF1 {sep}{sep2} 1-F1AvePR {sep}{sep2} 1-BinF1 {sep}{sep2} 1-AUC  {sep}{sep2}   EER {newline}')
            if K == 2:
                f1_naive = 2 * priors[1] / (priors[1]+1)
            #    print(f'F1 naive: {f1_naive}')

            first = False
    else:
        print('\\multirow{2}{*}{%s} %s \\multirow{2}{*}{%s}'%(data_name, sep, system_name), end='')

    for logpostname, logpost in {'raw': logpost_raw, 'cal': logpost_cal}.items():

        dname = f'{dname_raw} {logpostname}'

        if print_style != 'latex':
            print(f'{dname:25s} {sep} {np.min(priors):4.2f} {sep}', end='')
        else:
            if logpostname == 'cal':
                print(f' {sep} ', end='')
            print(f' {sep} {logpostname} {sep}', end='')

        for costn, cost in costs.items():

            decisions, _ = ec.bayes_decisions(logpost, cost, score_type='log_posteriors')

            ecval  = ec.average_cost(targets, decisions, cost, adjusted=False)
            ecvaln = ec.average_cost(targets, decisions, cost, adjusted=True)

            if print_style != 'latex': 
                print(f" {sep2} {ecval:5.3f}{sep}{ecvaln:5.3f}{sep}", end='')
            else:
                # For the paper, print only the NEC values
                print(f" {sep2} {ecvaln:5.3f}{sep}", end='')

        acc = accuracy_score(targets, np.argmax(logpost, axis=1))

        argmax_decision = np.argmax(logpost, axis=1)
        # This F1 is simply the average F1 over all classes
        macro_f1 = f1_score(targets, decisions, average='macro')
    
        # This is the F1 of average precision and recalls
        r = recall_score(targets, decisions, average='macro')
        p = precision_score(targets, decisions, average='macro', zero_division=0)
        f1_of_macroPR = 2*r*p/(p+r)

        print(f"{sep2}  {1-acc:5.3f}{sep}  {sep2}  {1-macro_f1:5.3f}{sep}  {sep2}   {1-f1_of_macroPR:5.3f}{sep}   ", end='')


        if K == 2:
            auc = roc_auc_score(targets, logpost[:,1])
            fpr, tpr, threshold = roc_curve(targets, logpost[:,1], pos_label=1)
            eer = fpr[np.nanargmin(np.absolute((1-tpr - fpr)))]
            f1 = f1_score(targets, np.argmax(logpost, axis=1))
            
            print(f"{sep2}  {1-f1:5.3f}{sep}  {sep2}  {1-auc:5.3f}{sep} {sep2}  {eer:5.3f} {newline} ", end='')

        else:
            print(f"{sep2}    -   {sep} {sep2}   -   {sep} {sep2}  -  {newline} ", end='')


        print("")
    
    if print_style == 'latex':
        print("\\midrule")





System                   MinPrior    |  C01           |  CinvP         |  Cimb         | 1-ACC     | 1-AveF1   | 1-F1AvePR   | 1-BinF1   | 1-AUC    |   EER 
SST2 GPT2-4sh raw            0.50    | 0.497  0.996   | 0.496  0.992   | 0.501  1.000  |  0.497    |  0.667    |   0.667     |  0.332    |  0.048   |  0.116  
SST2 GPT2-4sh cal            0.50    | 0.113  0.226   | 0.113  0.225   | 0.461  0.921  |  0.113    |  0.486    |   0.329     |  0.115    |  0.048   |  0.116  
SST2 GPT2-0sh raw            0.50    | 0.414  0.828   | 0.413  0.826   | 0.501  1.000  |  0.414    |  0.667    |   0.667     |  0.294    |  0.072   |  0.152  
SST2 GPT2-0sh cal            0.50    | 0.155  0.310   | 0.154  0.308   | 0.454  0.907  |  0.155    |  0.571    |   0.363     |  0.156    |  0.072   |  0.152  
SST2-Res1 GPT2-0sh raw       0.30    | 0.579  1.931   | 0.496  0.991   | 0.700  1.000  |  0.579    |  0.769    |   0.769     |  0.493    |  0.074   |  0.159  
SST2-Res1 GPT2-0sh cal       0.30    | 0.135  0

In [6]:
# Print data statistics

for dname_raw, dinfo in datasets.items():

    data_name = re.sub(' .*','',dname_raw)

    targets = targets_dict[dname_raw]
    counts = np.bincount(targets)
    num_samples = len(targets)
    priors = counts/num_samples
    num_classes = len(priors)

    if num_classes>7:
        priors = np.unique(priors)

    prior_str = ' '.join([f'{p:.2f}' for p in priors])
    print(f'{data_name:20s}   {sep} {num_classes:4d}  {sep}  {num_samples:7d}  {sep} {prior_str} {newline}\n', end='')
    



SST2                         2         1821     0.50 0.50 
SST2                         2         1821     0.50 0.50 
SST2-Res1                    2         1301     0.70 0.30 
SST2-Res2                    2         1140     0.80 0.20 
SITW                         2       721788     0.99 0.01 
SITW-Res1                    2        36580     0.90 0.10 
SITW-Res2                    2        18290     0.80 0.20 
FVCAUS                       2       114072     0.98 0.02 
CIFAR-1vsO                   2        10000     0.99 0.01 
CIFAR-2vsO                   2        10000     0.99 0.01 
IEMOCAP                      4         5473     0.20 0.29 0.31 0.20 
IEMOCAP-Res1                 4         3483     0.32 0.32 0.32 0.05 
AGNEWS                       4         7600     0.25 0.25 0.25 0.25 
AGNEWS-Res1                  4         5757     0.33 0.33 0.33 0.01 
CIFAR10                     10        10000     0.10 
CIFAR10                     10        10000     0.10 
CIFAR10                   