# Import and pre-definitions

In [None]:
import os
MAIN_PATH = r'/home/luis-felipe'
DATA_PATH = r'/data'
PATH_MODELS = os.path.join(MAIN_PATH,'torch_models')
FIGS_PATH = os.path.join(MAIN_PATH,'results','figs')

In [None]:
import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Define o computador utilizado como cuda (gpu) se existir ou cpu caso contrário
print(torch.cuda.is_available())
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
SEED = 42
torch.set_default_dtype(torch.float64)
torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
import sys
sys.path.insert(1, '..')
sys.path.insert(1, '../..')

import models
from utils import measures,metrics
from data_utils import upload_logits,split_data
import post_hoc
from scipy.stats import pearsonr

# Evaluate logits

In [None]:
METRIC = metrics.AURC#lambda x,y: 1-metrics.AUROC(x,y)#metrics.AURC
DATASET = 'ImageNet'
NUM_SPLITS = 10
VAL_SIZE = 0.1

In [None]:
shifts = ['test','v2','sketch']#corrupted

corruptions = os.listdir(os.path.join(DATA_PATH,DATASET,'corrupted'))
extras = ['speckle_noise', 'gaussian_blur','spatter','saturate']
for c in extras:
    corruptions.remove(c)
lvls = range(1,6)

In [None]:
from collections import defaultdict
acc = {s: defaultdict(list) for s in shifts}

sac_opt = {s: defaultdict(list) for s in shifts}
naurc_opt = {s: defaultdict(list) for s in shifts}
naurc_baseline = {s: defaultdict(list) for s in shifts}
sac_baseline = {s: defaultdict(list) for s in shifts}

p_list = defaultdict(list)

In [None]:
acc['corrupted'] = {c:{lvl:defaultdict(list) for lvl in lvls} for c in corruptions}
naurc_opt['corrupted'] = {c:{lvl:defaultdict(list) for lvl in lvls} for c in corruptions}
sac_opt['corrupted'] = {c:{lvl:defaultdict(list) for lvl in lvls} for c in corruptions}
naurc_baseline['corrupted'] = {c:{lvl:defaultdict(list) for lvl in lvls} for c in corruptions}
sac_baseline['corrupted'] = {c:{lvl:defaultdict(list) for lvl in lvls} for c in corruptions}
                

In [None]:
%%time
models_list = models.list_models()#['resnet50','vgg16','alexnet','efficientnetv2_xl','efficientnet_b3','convnext_base','resnet18','vit_l_16_384','vit_b_32_sam','wide_resnet50_2','maxvit_t']
seed = SEED
for i in range(NUM_SPLITS):
    print(i)
    for model_arc in models_list:
        print(model_arc)
        with torch.no_grad():
            logits_val,labels_val,logits_test,labels_test = split_data.split_logits(*upload_logits(model_arc,DATASET,PATH_MODELS, 
                                split = 'test', device = dev),VAL_SIZE,seed = seed)
            logits_val = post_hoc.centralize(logits_val)
            logits_test = post_hoc.centralize(logits_test)
            risk_val = measures.wrong_class(logits_val,labels_val).float()
            risk_test = measures.wrong_class(logits_test,labels_test).float()
        
        acc_iid= (1-risk_test.mean().item())
        p = post_hoc.optimize.p(logits_val,risk_val,measures.max_logit,METRIC)
        fallback = (metrics.N_AURC(risk_val,measures.MSP(logits_val)) < metrics.N_AURC(risk_val,post_hoc.MaxLogit_p(logits_val,p)))
        if fallback: p_list[model_arc].append('MSP')
        else: p_list[model_arc].append(p.item())
        
        for shift in shifts:          
            if shift =='test':
                logits_shift,labels_shift = logits_test,labels_test
            else:    
                logits_shift,labels_shift = upload_logits(model_arc,DATASET,PATH_MODELS, 
                                        split = shift, device = dev,data_dir = DATA_PATH)
            
            logits_shift = post_hoc.centralize(logits_shift)
            risk_shift = measures.wrong_class(logits_shift,labels_shift).float()
            acc[shift][model_arc].append((1-risk_shift.mean().item()))
            
            
            naurc_baseline[shift][model_arc].append(metrics.N_AURC(risk_shift,measures.MSP(logits_shift)))
            sac_baseline[shift][model_arc].append(metrics.SAC(risk_shift,measures.MSP(logits_shift),acc_iid))
            
            if fallback:
                naurc_opt[shift][model_arc].append(metrics.N_AURC(risk_shift,measures.MSP(logits_shift)))
                sac_opt[shift][model_arc].append(metrics.SAC(risk_shift,measures.MSP(logits_shift),acc_iid))
            else:
                naurc_opt[shift][model_arc].append(metrics.N_AURC(risk_shift,post_hoc.MaxLogit_p(logits_shift,p)))
                sac_opt[shift][model_arc].append(metrics.SAC(risk_shift,post_hoc.MaxLogit_p(logits_shift,p),acc_iid))
        if model_arc in ['resnet50','alexnet','wide_resnet50_2','convnext_large','vgg11','efficientnet_b3']: 
            for corruption in corruptions:
                for lvl in lvls:

                    logits_shift,labels_shift = split_data.split_logits(*upload_logits(model_arc,DATASET,PATH_MODELS, 
                            split = ('corrupted',corruption,str(lvl)), device = dev),VAL_SIZE,seed = seed)[-2:]
                    logits_shift = post_hoc.centralize(logits_shift)
                    risk_shift = measures.wrong_class(logits_shift,labels_shift).float()
                    acc['corrupted'][corruption][lvl][model_arc].append((1-risk_shift.mean().item()))

                    naurc_baseline['corrupted'][corruption][lvl][model_arc].append(metrics.N_AURC(risk_shift,measures.MSP(logits_shift)))
                    sac_baseline['corrupted'][corruption][lvl][model_arc].append(metrics.SAC(risk_shift,measures.MSP(logits_shift),acc_iid))

                    if fallback:
                        naurc_opt['corrupted'][corruption][lvl][model_arc].append(metrics.N_AURC(risk_shift,measures.MSP(logits_shift)))
                        sac_opt['corrupted'][corruption][lvl][model_arc].append(metrics.SAC(risk_shift,measures.MSP(logits_shift),acc_iid))
                    else:
                        naurc_opt['corrupted'][corruption][lvl][model_arc].append(metrics.N_AURC(risk_shift,post_hoc.MaxLogit_p(logits_shift,p)))
                        sac_opt['corrupted'][corruption][lvl][model_arc].append(metrics.SAC(risk_shift,post_hoc.MaxLogit_p(logits_shift,p),acc_iid))
                
            
            
    seed+=10
models_list = list(acc.keys())

In [None]:
naurc_c_opt = {lvl:[] for lvl in lvls}
naurc_c_baseline = {lvl:[] for lvl in lvls}
acc_c = {lvl:[] for lvl in lvls}
naurc_opt_mean = {lvl:[] for lvl in lvls}
naurc_baseline_mean = {lvl:[] for lvl in lvls}
acc_mean = {lvl:[] for lvl in lvls}
naurc_opt_std = {lvl:[] for lvl in lvls}
naurc_baseline_std = {lvl:[] for lvl in lvls}
for lvl in shifts+list(lvls):
    if isinstance(lvl,int):
        for c in corruptions:
            naurc_c_opt[lvl].append(list(naurc_opt['corrupted'][c][lvl].values()))
            naurc_c_baseline[lvl].append(list(naurc_baseline['corrupted'][c][lvl].values()))
            acc_c[lvl].append(list(acc['corrupted'][c][lvl].values()))
        naurc_opt[lvl] = {m:np.mean(naurc_c_opt[lvl],0)[i] for i,m in enumerate(naurc_opt['corrupted'][c][lvl])}
        naurc_baseline[lvl] = {m:np.mean(naurc_c_baseline[lvl],0)[i] for i,m in enumerate(naurc_baseline['corrupted'][c][lvl])}
        acc[lvl] = {m:np.mean(acc_c[lvl],0)[i] for i,m in enumerate(acc['corrupted'][c][lvl])}
    
    acc_mean[lvl] = np.mean(list(acc[lvl].values()),-1)
    naurc_opt_mean[lvl] = np.mean(list(naurc_opt[lvl].values()),-1)
    naurc_baseline_mean[lvl] = np.mean(list(naurc_baseline[lvl].values()),-1)
    naurc_opt_std[lvl] = np.std(list(naurc_opt[lvl].values()),-1)
    naurc_baseline_std[lvl] = np.std(list(naurc_baseline[lvl].values()),-1)

In [None]:
labels = {lvl: f"Imagenet-C - {lvl}" for lvl in lvls}
labels['v2'] = 'ImageNetV2'
labels['test'] = 'ImageNet (ID)'
labels['sketch'] = 'ImageNet Sketch'

marker = {lvl: '^' for lvl in lvls}
marker['v2'] = 'x'
marker['test'] = 'o'
marker['sketch'] = '*'

In [None]:
fig,axes = plt.subplots(1,2,figsize = (10,4), sharey = True, sharex = True)
fig.tight_layout(w_pad = -1)

for lvl in shifts+list(lvls):
    if lvl =='sketch': continue
    scatter = axes[0].scatter(acc_mean[lvl],
                          naurc_baseline_mean[lvl],marker = marker[lvl],label = labels[lvl])

    axes[1].scatter(acc_mean[lvl],
                    naurc_opt_mean[lvl],marker = marker[lvl],label = labels[lvl])
#for i in range(5):
#    axes[1].scatter(accs_c.mean(0)[i],naurc_p_c.mean(0)[i],marker = '^',label = i)
#    axes[0].scatter(accs_c.mean(0)[i],baseline_c.mean(0)[i],marker = '^',label = i)
#$\rho$={spearmanr(np.array(list(acc.values())).mean(-1),optimal_naurc).correlation:.4f}
axes[0].set_title(rf'Baseline',fontsize = 15)
axes[1].set_title(rf'Optimized',fontsize = 15)
axes[0].set_ylabel('NAURC',fontsize = 15)
for ax in axes:
    ax.grid()
    ax.set_xlabel('Accuracy',fontsize = 15)
    ax.tick_params(axis='both',  labelsize=13)
axes[1].legend(prop={'size': 13})


plt.savefig(os.path.join(FIGS_PATH,'NAURC_shift.pdf'),format = 'pdf', bbox_inches='tight', transparent = True)
plt.show()

In [None]:
plt.figure(figsize = (5,4))
scatter = plt.scatter(naurc_baseline_mean['test'],naurc_baseline_mean['v2'], c = 'red', marker = 'x', alpha = 0.5, label = 'Baseline')
scatter = plt.scatter(naurc_opt_mean['test'],naurc_opt_mean['v2'], c = 'blue', marker = 'o', label = 'Optimized')

naurc_opt_mean[lvl]
plt.plot([0.15,0.45],[0.15,0.45],'k--')
plt.xlabel('NAURC ImageNet (IID)',fontsize = 15)
plt.ylabel('NAURC ImageNetV2',fontsize = 15)
plt.grid()

legend1 = plt.legend(prop={'size': 10})
plt.tick_params(axis='both',  labelsize=13)
plt.savefig(os.path.join(FIGS_PATH,'NAURC_consistency.pdf'),format = 'pdf', bbox_inches='tight', transparent = True)
plt.show()

print('Pearson Correlation = ', pearsonr(np.r_[naurc_opt_mean['test'],naurc_baseline_mean['test']],np.r_[naurc_opt_mean['v2'],naurc_baseline_mean['v2']]).correlation)
#print('Pearson Correlation = ', pearsonr(naurc_opt_mean['test'],naurc_opt_mean['v2']).correlation)

In [None]:
gains = {}
for lvl,v in naurc_opt_mean.items():
    gains[lvl] = naurc_baseline_mean[lvl] - v

In [None]:
from collections import Counter
p_list_mode = {}
for m,p in p_list.items():
    p_list_mode[m] = Counter(p).most_common(1)[0][0]

In [None]:
plt.figure(figsize = (5,4))
scatter = plt.scatter(gains['test'],gains['v2'], c = [-1 if x=='MSP' else x for x in p_list_mode.values()])
plt.plot([0,0.3],[0,0.3],'k--')
plt.xlabel('Gain ImageNet (IID)',fontsize = 15)
plt.ylabel('Gain ImageNetV2',fontsize = 15)
plt.grid()

l = scatter.legend_elements()[1]
for n,i in enumerate(l):
    l[n] = 'p = '+i 
l[0] = 'MSP'

legend1 = plt.legend(scatter.legend_elements()[0], l, prop={'size': 10})
plt.tick_params(axis='both',  labelsize=13)
plt.savefig(os.path.join(FIGS_PATH,'gains_v2.pdf'),format = 'pdf', bbox_inches='tight', transparent = True)
plt.show()

print('Pearson Correlation = ', pearsonr(gains['test'],gains['v2']).correlation)