# Libs and pre-definitions

In [None]:
import os
REPOSITORY_PATH = r'/home/luis-felipe/UncEst'
DATA_PATH = os.path.join(REPOSITORY_PATH,'data')
#CORRUPTED_DATA_PATH = os.path.join(DATA_PATH,'corrupted')

PATH_MODELS = os.path.join(REPOSITORY_PATH,'torch_models')
PATH_TRAINER = os.path.join(PATH_MODELS,'trainer')

### Bibliotecas padrões python e utils pytorch

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
from torch import nn

In [None]:
# Define o computador utilizado como cuda (gpu) se existir ou cpu caso contrário
print(torch.cuda.is_available())
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### Bibliotecas desenvolvidas

https://github.com/lfpc/Uncertainty_Estimation

In [None]:
import NN_models as models
import NN_utils as utils
import NN_utils.train_and_eval as TE
import torch_data
from uncertainty import MCP_unc, entropy
from uncertainty import metrics

# Train classifier

## Definitions

In [None]:
CREATE_DIR = True #If true, creates directories to save model (weights_path)
LIVE_PLOT = True #If True, plot* loss while training. If 'print', print loss per epoch
SAVE_CHECKPOINT = True #If True, save (and update) model weights for the best epoch (smallest validation loss)
SAVE_ALL = False #If True, saves weights and trainer at the end of training

In [None]:
MODEL_ARC = 'ResNet50'#WideResNet ResNet101,ResNet18,ResNet34,ResNet50 CNN8
DATA = 'Cifar100'
NAME = f'{MODEL_ARC}_{DATA}_g'

In [None]:
VAL_SIZE = 0.1

loss_criterion = nn.CrossEntropyLoss()
risk_dict = None#{'selective_risk_mcp':  lambda x,label: unc_comp.selective_risk(x,label,unc_type = unc.MCP_unc)}

data_params = {'train_batch_size': 128, 'validation_size': VAL_SIZE, 'test_batch_size': 100}

In [None]:
data = torch_data.__dict__[DATA](data_dir = DATA_PATH,validation_as_train = True,params = data_params)
num_classes = data.n_classes
model_class = models.__dict__[MODEL_ARC]

weights_path = os.path.join(PATH_MODELS,MODEL_ARC,DATA)

if CREATE_DIR and not os.path.isdir(weights_path):
    os.makedirs(weights_path)

## Base Model

In [None]:
TRAIN = True

### Upload Base Model

In [None]:
if not TRAIN:
    if DATA == 'ImageNet':
        weights = models.pretrained_models[model_class]
        pre_model = model_class(weights = weights).to(dev)
        data.transforms_test = weights.transforms()
    else:
        pre_model = model_class(num_classes = data.n_classes).to(dev)
        pre_model.load_state_dict(utils.upload_weights(weights_path))

### Train Base Model

In [None]:
N_EPOCHS_0 = 200
if TRAIN:
    loss_criterion = nn.CrossEntropyLoss()
    pre_model = model_class(num_classes = data.n_classes).to(dev)
    pre_model.name = f'{MODEL_ARC}_{DATA}_Val{int(VAL_SIZE*100)}'
    optimizer = torch.optim.SGD(pre_model.parameters(), lr =0.1,momentum = 0.9,weight_decay = 5e-4,nesterov = True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS_0)
    model_trainer = TE.Trainer(pre_model,optimizer,loss_criterion, None,data.test_dataloader,lr_scheduler = scheduler)
    model_trainer.fit(data.train_dataloader,N_EPOCHS_0, live_plot = LIVE_PLOT,save_checkpoint = SAVE_CHECKPOINT,PATH = weights_path)

### Base Model

In [None]:
pre_model.eval()
acc = TE.model_acc(pre_model,data.train_dataloader)
print('Conjunto de treinamento: acc = ', acc)
acc = TE.model_acc(pre_model,data.validation_dataloader)
print('Conjunto de validação: acc = ', acc)
acc = TE.model_acc(pre_model,data.test_dataloader)
print('Conjunto de teste: acc = ', acc)

## Temperature Analysis

### Model Definition

In [None]:
from uncertainty.calibration import Platt_Model, _ECELoss

model = Platt_Model(pre_model)

In [None]:
c_list = np.arange(0.05,1.05,0.05)
output,label = TE.accumulate_results(pre_model,data.test_dataloader)

### Results

In [None]:
fig,axs = plt.subplots(1,2,figsize=(12, 8))
beta_range = np.append(np.arange(0.1,1,0.05),np.arange(1,5,0.5))
ECE = _ECELoss(n_bins = 10)
RC_mcps = dict.fromkeys(beta_range)
RC_entropys = dict.fromkeys(beta_range)
eces = []
for beta in beta_range:
    y_pred = torch.nn.functional.softmax(output*beta,dim=-1)
    RC_mcp = metrics.RC_curve(y_pred,label,MCP_unc(y_pred),c_list = c_list)
    RC_mcps[beta] = RC_mcp
    RC_entropy = metrics.RC_curve(y_pred,label,entropy(y_pred),c_list = c_list)
    RC_entropys[beta] = RC_entropy
    axs[0].plot(c_list,RC_mcp,label = f'{beta:.1f}')
    axs[1].plot(c_list,RC_entropy,label = f'{beta:.1f}')
    eces.append(ECE(output*beta,label).item())
    
axs[0].set_title('MCP')
axs[1].set_title('Entropy')
axs[0].grid()
axs[1].grid()
plt.legend()
plt.show()

In [None]:
from ipywidgets import interactive
%matplotlib inline
fig = plt.figure()

x = np.linspace(0, 2 * np.pi)
def plot_RC_widget(beta=1.0):
    plt.plot(c_list,RC_mcps[beta],label = 'MCP')
    plt.plot(c_list,RC_entropys[beta], label = 'Entropy')
    
    fig.canvas.draw_idle()
    plt.grid()
    plt.legend()
    plt.show()

interactive_plot = interactive(plot_RC_widget)
interactive_plot

In [None]:
difs = []
for beta in beta_range:
    dif = torch.abs(RC_mcps[beta]-RC_entropys[beta]).sum().item()
    difs.append(dif)
plt.plot(beta_range,difs)
plt.grid()
plt.xlabel('Beta (coolness)')
plt.title(r'$\int$ |RC(MCP)-RC(Entropy)|')
plt.show()


### Calibration

In [None]:
fig,(ax0,ax1) = plt.subplots(1,2,figsize=(10, 6))
ax0.plot(beta_range,eces)
ax0.set_ylabel('ECE_Loss')
ax0.set_title(r'Coolness ($\beta$)')
ax0.grid()
ax1.plot(1/beta_range,eces)
ax1.set_ylabel('ECE_Loss')
ax1.set_title('Temperature (T)')
ax1.grid()
plt.show()

In [1]:
calibrated_beta = beta_range[np.argmin(eces)]
print(f'Best (calibration) empirical T = {1/calibrated_beta}')
print(r'Best (calibration) empirical $\beta$ = ', str(calibrated_beta))

Best (calibration) empirical T = 1.0
Best (calibration) empirical $\beta$ =  1


### Temperature training

In [None]:
N_EPOCHS_0 = 50

loss_criterion = nn.CrossEntropyLoss()

In [None]:
model_temperature = Platt_Model(pre_model)
utils.freeze_params(pre_model)
model_temperature.B.requires_grad = False

In [None]:
def get_T(model, *args):
    return model.A

In [None]:
optimizer = torch.optim.SGD(model_temperature.parameters(), lr =0.001)#,momentum = 0.9,weight_decay = 5e-4,nesterov = True)
scheduler = None#torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS_0)
model_trainer = TE.Trainer(model_temperature,optimizer,loss_criterion, None,None,risk_dict=get_T)
model_trainer.fit(data.validation_dataloader,N_EPOCHS_0, live_plot = LIVE_PLOT,save_checkpoint = SAVE_CHECKPOINT,PATH = weights_path)