# Import and pre-definitions

In [None]:
import os
MAIN_PATH = r'/home/luis-felipe'
DATA_PATH = r'/data'
PATH_MODELS = os.path.join(MAIN_PATH,'torch_models')
FIGS_PATH = os.path.join(MAIN_PATH,'results','figs')

In [None]:
import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from collections import defaultdict

In [None]:
# Define o computador utilizado como cuda (gpu) se existir ou cpu caso contrário
print(torch.cuda.is_available())
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.set_default_dtype(torch.float64)
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
import sys
sys.path.insert(1, '..')
sys.path.insert(1, '../..')

import models
from utils import measures,metrics
from data_utils import upload_logits,split_data
import post_hoc

# Definitions

In [None]:
DATASET = 'ImageNet'
VAL_SIZE = 0.1 #5000
SUB_VAL_SIZE = 1
METRIC = metrics.AURC
NUM_EXPERIMENTS = 10

In [None]:
methods = {'MSP':measures.MSP,
           'SoftmaxMargin': measures.margin_softmax,
           #'Energy': lambda x:torch.logsumexp(x,-1),
           'MaxLogit':measures.max_logit,
           'LogitsMargin':measures.margin_logits,
           'NegativeEntropy':measures.negative_entropy,
           'NegativeGini': measures.negative_gini}

optm_metrics = {'naurc': metrics.N_AURC,'aurc':metrics.AURC,'auroc':metrics.AUROC,'sac':lambda x,y: metrics.SAC(x,y,0.98), 'ece': metrics.ECE(15)}
transforms = ['raw','T_nll','T','p']

p_range = torch.arange(10)

# Evaluate

In [None]:
results = {m:{method:{t:defaultdict(list) for t in transforms} for method in methods.keys()} for m in optm_metrics}
results_fallback = {m:{method:{t:defaultdict(list) for t in transforms} for method in methods.keys()} for m in optm_metrics}
p_list = {m:defaultdict(list) for m in methods.keys()}
acc = defaultdict(list)
msps = defaultdict(list)

In [None]:
seed = SEED
for i in range(NUM_EXPERIMENTS):
    print(i+1)
    for model_arc in models.list_models(DATASET):
        #print(model_arc)
        with torch.no_grad():
            logits_val,labels_val,logits_test,labels_test = split_data.split_logits(*upload_logits(model_arc,DATASET,PATH_MODELS, 
                                split = 'test', device = dev),VAL_SIZE,seed = seed)
            logits_val,labels_val = logits_val[:int(SUB_VAL_SIZE*labels_val.size(0))],labels_val[:int(SUB_VAL_SIZE*labels_val.size(0))]
            risk_val = measures.wrong_class(logits_val,labels_val).float()
            risk_test = measures.wrong_class(logits_test,labels_test).float()
        acc[model_arc].append(1-risk_test.mean().item())
        msps[model_arc].append(measures.MSP(logits_test).mean().item())
        T_nll = post_hoc.optimize.T(logits_val,labels_val,method = lambda x:x,metric = torch.nn.CrossEntropyLoss())
        
        for m,method in methods.items():
            if m == 'MaxLogit' or m == 'LogitsMargin': T_range = [1]
            else: T_range = torch.arange(0.01,2,0.01)
            pT = post_hoc.optimize.p_and_T(logits_val,risk_val,method,METRIC,p_range=p_range,T_range=T_range)
            if METRIC(method(post_hoc.normalize(logits_val,pT[0]).div(pT[1])),risk_val) > METRIC(measures.MSP(logits_val),risk_val):
                p_list[m][model_arc].append('MSP')
            else: p_list[m][model_arc].append(pT[0].item())
            T = post_hoc.optimize.T(logits_val,risk_val,method,METRIC,T_range = T_range)

            for t in transforms:
                if t == 'T_nll': fn = lambda z: z.div(T_nll)
                elif t == 'T': fn = lambda z: z.div(T)
                elif t == 'p': fn = lambda z: post_hoc.normalize(z,pT[0]).div(pT[1])
                elif t == 'raw': fn = lambda z: z
                Z = fn(logits_test)
                fallback = METRIC(method(fn(logits_val)),risk_val) > METRIC(measures.MSP(logits_val),risk_val)
                for metric in optm_metrics:
                    metric_value = optm_metrics[metric](method(Z),risk_test)
                    results[metric][m][t][model_arc].append(metric_value)
                    if fallback:
                        results_fallback[metric][m][t][model_arc].append(optm_metrics[metric](measures.MSP(logits_test),risk_test)) 
                    else:
                        results_fallback[metric][m][t][model_arc].append(metric_value)
    seed = seed+10

models_list = list(acc.keys())

# Plot

In [None]:
means = {}
std = {}
acc_mean = {}
for metric, d_metric in results.items():
    means[metric] = {}
    std[metric] = {}
    for method, d_method in d_metric.items():
        means[metric][method] = {}
        std[metric][method] = {}
        for transform, d_t in d_method.items():
            means[metric][method][transform] = {}
            std[metric][method][transform] = {}
            for model_arc,v in d_t.items():
                means[metric][method][transform][model_arc] = np.mean(v)
                std[metric][method][transform][model_arc] = np.std(v)
                acc_mean[model_arc] = np.mean(acc[model_arc])

In [None]:
baseline = np.array(list(results['naurc']['MSP']['raw'].values()))

## Figure 2

In [None]:
idx = np.argsort(baseline.mean(-1)-np.array(list(means['naurc']['MaxLogit']['p'].values())))[::-1]

methods_plot = {'MSP-TS-AURC':results['naurc']['MSP']['T'],
                'MSP-TS-NLL':results['naurc']['MSP']['T_nll'],
                'MSP-pNorm':results['naurc']['MSP']['p'],
                'MaxLogit-pNorm':results['naurc']['MaxLogit']['p'],
                'LogitsMargin':results['naurc']['LogitsMargin']['raw'],
                'LogitsMargin-pNorm':results['naurc']['LogitsMargin']['p'],
                'NegativeGini-pNorm':results['naurc']['NegativeGini']['p'],}

colors = iter(['blue','gray','green','red','lime','y','violet','pink'])
plt.figure(figsize = (8,5))
for name,values in methods_plot.items():
    values = np.array(list(values.values()))
    gains = baseline-values
    gains_mean = gains.mean(-1)[idx]
    gains_std = gains.std(-1)[idx]
    plot = plt.plot(range(1,len(models_list)+1),gains_mean,label = name,color = next(colors))
    plt.fill_between(range(1,len(models_list)+1),gains_mean-gains_std,gains_mean+gains_std,alpha = 0.4,color = plot[0].get_color())
plt.axhline(0.01,color = 'k',linestyle = '--')
plt.axhline(0,color = 'k',linestyle = '--',alpha = 0.5, label = 'MSP')
plt.legend(prop={'size': 12})
plt.text(5, 0.03, r'$\epsilon = 0.01$', fontsize=8,
        verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
plt.xlim(1,len(models_list))
plt.xlabel('Model')
plt.ylabel('NAURC gain over MSP',fontsize=13)
plt.grid()
plt.tick_params(axis='both',  labelsize=12)
plt.savefig(os.path.join(FIGS_PATH,f'gains_methods_{DATASET}.pdf'), transparent = True, format = 'pdf',bbox_inches = 'tight')
plt.show()

### Zoom

In [None]:
idx = np.argsort(baseline.mean(-1)-np.array(list(means['naurc']['MaxLogit']['p'].values())))[::-1]

methods_plot = {'MSP-TS-AURC':results['naurc']['MSP']['T'],
                'MSP-TS-NLL':results['naurc']['MSP']['T_nll'],
                'MSP-pNorm':results['naurc']['MSP']['p'],
                'MaxLogit-pNorm':results['naurc']['MaxLogit']['p'],
                'LogitsMargin':results['naurc']['LogitsMargin']['raw'],
                'LogitsMargin-pNorm':results['naurc']['LogitsMargin']['p'],
                'NegativeGini-pNorm':results['naurc']['NegativeGini']['p'],}

colors = iter(['blue','gray','green','red','lime','y','violet','pink'])
plt.figure(figsize = (8,3))
for name,values in methods_plot.items():
    values = np.array(list(values.values()))
    gains = baseline-values
    gains_mean = gains.mean(-1)[idx]
    gains_std = gains.std(-1)[idx]
    plot = plt.plot(gains_mean,label = name,color = next(colors))
    plt.fill_between(range(len(models_list)),gains_mean-gains_std,gains_mean+gains_std,alpha = 0.4,color = plot[0].get_color())
plt.axhline(0.01,color = 'k',linestyle = '--')
#plt.axhline(0,color = 'k',linestyle = '--',alpha = 0.5, label = 'MSP')
plt.ylabel('NAURC gain over MSP',fontsize=13)
#plt.xlabel('Model')
plt.legend()
plt.xlim(5,50)
plt.ylim(0.05,0.15)
plt.tick_params(axis='both',  labelsize=12)
plt.savefig(os.path.join(FIGS_PATH,f'NAURC_gains_methods_{DATASET}_zoom.pdf'), transparent = True, format = 'pdf',bbox_inches = 'tight')
plt.show()

## Figure 3

In [None]:
from scipy.stats import spearmanr,mode
optimal_naurc = np.array(list(results_fallback['naurc']['MaxLogit']['p'].values())).mean(-1)
p_list_mode = mode(np.array(list(p_list['MaxLogit'].values())),-1,keepdims = False).mode

### NAURC

In [None]:
fig,axes = plt.subplots(1,2,figsize = (10,4), sharey = True,sharex = True)
fig.tight_layout()

scatter = axes[0].scatter(np.array(list(acc.values())).mean(-1),np.array(list(results_fallback['naurc']['MSP']['raw'].values())).mean(-1),c = [-1 if x=='MSP' else x for x in p_list_mode])
axes[1].scatter(np.array(list(acc.values())).mean(-1),optimal_naurc,c = [-1 if x=='MSP' else x for x in p_list_mode])
axes[0].set_title(fr"Baseline - $\rho$={spearmanr(np.array(list(acc.values())).mean(-1),np.array(list(results_fallback['naurc']['MSP']['raw'].values())).mean(-1)).correlation:.4f}")
axes[1].set_title(fr"Optimized- $\rho$={spearmanr(np.array(list(acc.values())).mean(-1),optimal_naurc).correlation:.4f}")
axes[0].set_ylabel('NAURC')
for ax in axes:
    ax.grid()
    ax.set_xlabel('Accuracy')
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(15)
axes[0].set_xticks(axes[0].get_xticks()[1:-2])
#axes[0].legend()
l = scatter.legend_elements()[1]
for n,i in enumerate(l):
    l[n] = 'p = '+i 
l[0] = 'MSP'

legend1 = axes[1].legend(scatter.legend_elements()[0], l, prop={'size': 13})

plt.savefig(os.path.join(FIGS_PATH,'NAURC.pdf'), transparent = True, format = 'pdf',bbox_inches = 'tight')
plt.show()

### AURC

In [None]:
optimal_aurc = np.array(list(results_fallback['aurc']['MaxLogit']['p'].values())).mean(-1)
fig,axes = plt.subplots(1,2,figsize = (10,4), sharey = True,sharex = True)
fig.tight_layout()

scatter = axes[0].scatter(np.array(list(acc.values())).mean(-1),np.array(list(results_fallback['aurc']['MSP']['raw'].values())).mean(-1),c = [-1 if x=='MSP' else x for x in p_list_mode])
axes[1].scatter(np.array(list(acc.values())).mean(-1),optimal_aurc,c = [-1 if x=='MSP' else x for x in p_list_mode])
axes[0].set_title(fr"Baseline - $\rho$={spearmanr(np.array(list(acc.values())).mean(-1),np.array(list(results_fallback['aurc']['MSP']['raw'].values())).mean(-1)).correlation:.4f}")
axes[1].set_title(fr"Optimized- $\rho$={spearmanr(np.array(list(acc.values())).mean(-1),optimal_aurc).correlation:.4f}")
axes[0].set_ylabel('AURC')
for ax in axes:
    ax.grid()
    ax.set_xlabel('Accuracy')
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(15)
axes[0].set_xticks(axes[0].get_xticks()[1:-2])
#axes[0].legend()
l = scatter.legend_elements()[1]
for n,i in enumerate(l):
    l[n] = 'p = '+i 
l[0] = 'MSP'

legend1 = axes[1].legend(scatter.legend_elements()[0], l, prop={'size': 13})

plt.savefig(os.path.join(FIGS_PATH,'NAURC.pdf'), transparent = True, format = 'pdf',bbox_inches = 'tight')
plt.show()

### AUROC

In [None]:
optimal_auroc = np.array(list(results_fallback['auroc']['MaxLogit']['p'].values())).mean(-1)
fig,axes = plt.subplots(1,2,figsize = (10,4), sharey = True,sharex = True)
fig.tight_layout()

scatter = axes[0].scatter(np.array(list(acc.values())).mean(-1),np.array(list(results_fallback['auroc']['MSP']['raw'].values())).mean(-1),c = [-1 if x=='MSP' else x for x in p_list_mode])
axes[1].scatter(np.array(list(acc.values())).mean(-1),optimal_auroc,c = [-1 if x=='MSP' else x for x in p_list_mode])
axes[0].set_title(fr"Baseline - $\rho$={spearmanr(np.array(list(acc.values())).mean(-1),np.array(list(results_fallback['auroc']['MSP']['raw'].values())).mean(-1)).correlation:.4f}")
axes[1].set_title(fr"Optimized- $\rho$={spearmanr(np.array(list(acc.values())).mean(-1),optimal_auroc).correlation:.4f}")
axes[0].set_ylabel('AUROC')
for ax in axes:
    ax.grid()
    ax.set_xlabel('Accuracy')
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(15)
axes[0].set_xticks(axes[0].get_xticks()[1:-2])
#axes[0].legend()
l = scatter.legend_elements()[1]
for n,i in enumerate(l):
    l[n] = 'p = '+i 
l[0] = 'MSP'

legend1 = axes[1].legend(scatter.legend_elements()[0], l, prop={'size': 13})

plt.savefig(os.path.join(FIGS_PATH,'NAURC.pdf'), transparent = True, format = 'pdf',bbox_inches = 'tight')
plt.show()

### SAC

In [None]:
optimal_sac = np.array(list(results_fallback['sac']['MaxLogit']['p'].values())).mean(-1)
fig,axes = plt.subplots(1,2,figsize = (10,4), sharey = True,sharex = True)
fig.tight_layout()

scatter = axes[0].scatter(np.array(list(acc.values())).mean(-1),np.array(list(results_fallback['sac']['MSP']['raw'].values())).mean(-1),c = [-1 if x=='MSP' else x for x in p_list_mode])
axes[1].scatter(np.array(list(acc.values())).mean(-1),optimal_sac,c = [-1 if x=='MSP' else x for x in p_list_mode])
axes[0].set_title(fr"Baseline - $\rho$={spearmanr(np.array(list(acc.values())).mean(-1),np.array(list(results_fallback['sac']['MSP']['raw'].values())).mean(-1)).correlation:.4f}")
axes[1].set_title(fr"Optimized- $\rho$={spearmanr(np.array(list(acc.values())).mean(-1),optimal_sac).correlation:.4f}")
axes[0].set_ylabel('AUROC')
for ax in axes:
    ax.grid()
    ax.set_xlabel('Accuracy')
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(15)
axes[0].set_xticks(axes[0].get_xticks()[1:-2])
#axes[0].legend()
l = scatter.legend_elements()[1]
for n,i in enumerate(l):
    l[n] = 'p = '+i 
l[0] = 'MSP'

legend1 = axes[1].legend(scatter.legend_elements()[0], l, prop={'size': 13})

plt.savefig(os.path.join(FIGS_PATH,'NAURC.pdf'), transparent = True, format = 'pdf',bbox_inches = 'tight')
plt.show()

## Table 1

In [None]:
for model_arc,model_name in {'efficientnetv2_xl': 'EfficientNet-V2-XL', 'vgg16':'VGG16'}.items():
    print(r'\midrule \multirow{6}{*}{' + model_name + '}' )
    for method, d_method in results['naurc'].items():
        string = f'& {method}'
        for transform, d_t in d_method.items():
            if 'T' in transform and 'Logit' in method:
                string += f" & -"
            else: string += f" & {np.mean(d_t[model_arc],-1):.4f}"+r' {\footnotesize $\pm$'+f"{np.std(d_t[model_arc],-1):.4f}" + "}"
        print(string + r' \\')
            

## Table 2

In [None]:
for method, d_method in results_fallback['naurc'].items():
    string = f'{method}'
    for transform, d_t in d_method.items():
        if 'T' in transform and 'Logit' in method:
            string += f" & -"
        #elif post_hoc.significant(baseline - np.array(list(d_t.values()))).mean() < 0.00001 and post_hoc.significant(baseline - np.array(list(d_t.values()))).mean(0).std()<0.0001:
        #    string += r" & 0.0 {\footnotesize $\pm$ 0.0}"
        else: string += f" & {post_hoc.significant(baseline - np.array(list(d_t.values()))).mean():.5f}"+r' {\footnotesize $\pm$'+f"{post_hoc.significant(baseline - np.array(list(d_t.values()))).mean(0).std():.5f}"+"}"
    print(string + r' \\')

## Epsilon ablation - Figure 8

In [None]:
epsilon_list = np.arange(0,0.051,0.001)
from collections import defaultdict
apgs_mean = defaultdict(list)
apgs_std = defaultdict(list)
for m in methods:
    for epsilon in epsilon_list:
        v = results['naurc'][m]['p']
        apgs_mean[m].append(np.mean(post_hoc.significant(baseline - np.array(list(results['naurc'][m]['p'].values())),epsilon).mean()))
        apgs_std[m].append(np.mean(post_hoc.significant(baseline - np.array(list(results['naurc'][m]['p'].values())),epsilon).mean(0).std()))
    apgs_mean[m] = np.asarray(apgs_mean[m])
    apgs_std[m] = np.asarray(apgs_std[m])

In [None]:
plt.figure(figsize = (8,6))
for m,apg in apgs_mean.items():
    plot = plt.plot(epsilon_list,apg, label = m+'-pNorm')
    plt.fill_between(epsilon_list,apg-apgs_std[m],apg+apgs_std[m],alpha = 0.2,color = plot[0].get_color())
plt.xlabel(r'$\epsilon$',fontsize=15)
plt.ylabel('APG-NAURC',fontsize=15)
plt.grid()
plt.tick_params(axis='both',  labelsize=12)
plt.legend(prop={'size': 12})
plt.xlim(0,0.05)
plt.savefig(os.path.join(FIGS_PATH,'epsilon.pdf'),format = 'pdf',transparent = True,bbox_inches = 'tight')
plt.show()

## ECE - Table 12

In [None]:
for method, d_method in results['ece'].items():
    
    string = f'{method}'
    for transform, d_t in d_method.items():
        if isinstance(transform,int): continue
        string += f" & {np.mean(list(d_t.values())):.5f}"+r' {\footnotesize $\pm$'+f"{np.array(list(d_t.values())).mean(0).std():.5f}"+"}"
    print(string + r' \\')
    break

## Gain x MSP - Figure 15

In [None]:
gains_mlp = (baseline - np.array(list(d_t.values()))).mean()

In [None]:
plt.scatter(np.mean(np.array(msps)>0.999,axis=-1)[gains_mlp>0.01],gains_mlp[gains_mlp>0.01], color = 'r')
plt.scatter(np.mean(np.array(msps)>0.999,axis=-1)[gains_mlp<=0.01],gains_mlp[gains_mlp<=0.01], color = 'b')
plt.axhline(0.01,linestyle='--',color='k')
plt.grid()
plt.xlabel('Proportion of samples with MSP>0.999')
plt.ylabel('NAURC gain of MaxLogit-pNorm over MSP')
plt.savefig(os.path.join(FIGS_PATH,'msp_proportion_imagenet.pdf'),format = 'pdf')