# Import and pre-definitions

In [None]:
import os
MAIN_PATH = r'/home/luis-felipe'
DATA_PATH = os.path.join(MAIN_PATH,'data')
PATH_MODELS = os.path.join(MAIN_PATH,'torch_models')
FIGS_PATH = os.path.join(MAIN_PATH,'results','figs')

In [None]:
import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Define o computador utilizado como cuda (gpu) se existir ou cpu caso contrário
print(torch.cuda.is_available())
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.set_default_dtype(torch.float64)
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
import sys
sys.path.insert(1, '..')
sys.path.insert(1, '../..')

import models
from utils import measures,metrics
from data_utils import upload_logits,split_data
import post_hoc

# Evaluate logits

In [None]:
MODEL_ARC = 'wide_resnet50_2'
DATASET = 'ImageNet'

In [None]:
logits,labels = upload_logits(MODEL_ARC,DATASET,PATH_MODELS, 
                            split = 'test', device = dev,data_dir = DATA_PATH)
risk = measures.wrong_class(logits,labels).float()

In [None]:
naurc_i = metrics.N_AURC(risk,measures.MSP(logits))
naurc_p = metrics.N_AURC(risk,post_hoc.MaxLogit_p(logits,risk = risk))
naurc_pT = metrics.N_AURC(risk,post_hoc.MSP_p(logits,risk = risk))

# Histograms

## Baseline

In [None]:
plt.figure(figsize=(8,6))
plt.hist(measures.MSP(logits).cpu().numpy(),bins = 'auto', density = True)
plt.xlabel('MSP')
plt.ylabel('Density')
plt.title(f'NAURC gain = {naurc_i-naurc_p:.4f}')
plt.show()

## Optimized

### MSP-p

In [None]:
plt.figure(figsize=(8,6))
plt.hist(post_hoc.MSP_p(logits,risk = risk).cpu().numpy(),bins = 'auto', density = True)
plt.xlabel('MSP')
plt.ylabel('Density')
plt.title(f'NAURC gain = {naurc_i-naurc_pT:.4f}')
plt.show()

### MaxLogit-p

In [None]:
plt.figure(figsize=(8,6))
plt.hist(post_hoc.MaxLogit_p(logits,risk = risk).cpu().numpy(),bins = 'auto', density = True)
plt.xlabel('MSP')
plt.ylabel('Density')
plt.title(f'NAURC gain = {naurc_i-naurc_p:.4f}')
plt.show()