# Libs and pre-definitions

In [None]:
import os
REPOSITORY_PATH = r'/home/luis-felipe/UncEst'
DATA_PATH = os.path.join(REPOSITORY_PATH,'data')
#CORRUPTED_DATA_PATH = os.path.join(DATA_PATH,'corrupted')

PATH_MODELS = os.path.join(REPOSITORY_PATH,'torch_models')
PATH_TRAINER = os.path.join(PATH_MODELS,'trainer')

### Bibliotecas padrões python e utils pytorch

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
from torch import nn

In [None]:
# Define o computador utilizado como cuda (gpu) se existir ou cpu caso contrário
print(torch.cuda.is_available())
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### Bibliotecas desenvolvidas

https://github.com/lfpc/Uncertainty_Estimation

In [None]:
import NN_models as models
import NN_utils as utils
import NN_utils.train_and_eval as TE
import torch_data
from uncertainty import metrics

# Train classifier

## Definitions

In [None]:
CREATE_DIR = True #If true, creates directories to save model (weights_path)
LIVE_PLOT = True #If True, plot* loss while training. If 'print', print loss per epoch
SAVE_CHECKPOINT = True #If True, save (and update) model weights for the best epoch (smallest validation loss)
SAVE_ALL = False #If True, saves weights and trainer at the end of training

In [None]:
MODEL_ARC = 'ResNet50'#WideResNet ResNet101,ResNet18,ResNet34,ResNet50 CNN8
DATA = 'Cifar100'
NAME = f'{MODEL_ARC}_{DATA}_g'

In [None]:
VAL_SIZE = 0.1

loss_criterion = nn.CrossEntropyLoss()
risk_dict = None#{'selective_risk_mcp':  lambda x,label: unc_comp.selective_risk(x,label,unc_type = unc.MCP_unc)}

data_params = {'train_batch_size': 128, 'validation_size': VAL_SIZE, 'test_batch_size': 100}

In [None]:
data = torch_data.__dict__[DATA](data_dir = DATA_PATH,validation_as_train = True,params = data_params)
num_classes = data.n_classes
model_class = models.__dict__[MODEL_ARC]

weights_path = os.path.join(PATH_MODELS,MODEL_ARC,DATA)

if CREATE_DIR and not os.path.isdir(weights_path):
    os.makedirs(weights_path)

## Base Model

In [None]:
TRAIN = True

### Upload Base Model

In [None]:
if not TRAIN:
    if DATA == 'ImageNet':
        weights = models.pretrained_models[model_class]
        pre_model = model_class(weights = weights).to(dev)
        data.transforms_test = weights.transforms()
    else:
        pre_model = model_class(num_classes = data.n_classes).to(dev)
        pre_model.load_state_dict(utils.upload_weights(weights_path))

### Train Base Model

In [None]:
N_EPOCHS_0 = 200
if TRAIN:
    loss_criterion = nn.CrossEntropyLoss()
    pre_model = model_class(num_classes = data.n_classes).to(dev)
    optimizer = torch.optim.SGD(pre_model.parameters(), lr =0.1,momentum = 0.9,weight_decay = 5e-4,nesterov = True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS_0)
    model_trainer = TE.Trainer(pre_model,optimizer,loss_criterion, None,data.test_dataloader,lr_scheduler = scheduler)
    model_trainer.fit(data.train_dataloader,N_EPOCHS_0, live_plot = LIVE_PLOT,save_checkpoint = SAVE_CHECKPOINT,PATH = weights_path)

### Base Model

In [None]:
pre_model.eval()
acc = TE.model_acc(pre_model,data.train_dataloader)
print('Conjunto de treinamento: acc = ', acc)
acc = TE.model_acc(pre_model,data.validation_dataloader)
print('Conjunto de validação: acc = ', acc)
acc = TE.model_acc(pre_model,data.test_dataloader)
print('Conjunto de teste: acc = ', acc)

## Uncertainty Estimator

### Model Definition

In [None]:
from torch.nn.functional import one_hot
class Uncertainty_Estimator(nn.Module):
    def __init__(self,model) -> None:
        super().__init__()
        self.Unc_Regressor = nn.Sequential(nn.Linear(data.n_classes,200),
                                        nn.ReLU(),
                                        #nn.Dropout(0.3),
                                        nn.Linear(200,100),
                                        nn.ReLU(),
                                        nn.Linear(100,100),
                                        nn.ReLU(),
                                        #nn.Dropout(0.3),
                                        nn.Linear(100,1), #globalpooling-max
                                        nn.Sigmoid())
        self.name = 'Unc_Estimator'
        self.base_model = model
        self.base_model.eval()
        utils.freeze_params(self.base_model)
    def forward(self,x):
        with torch.no_grad():
            logits = self.base_model(x)
            y_pred = TE.predicted_class(logits).view(-1)
            y = one_hot(y_pred,logits.size(-1))
        g = self.Unc_Regressor(logits)
        return y*g

### Model Training

In [None]:
from uncertainty.calibration import Platt_Model

In [None]:
def hits_labels(data,model,transforms = None):
    with torch.no_grad():

        hits = torch.tensor([],device = dev)
        for im,label in data:
            im,label=im.to(dev),label.to(dev)
            output = model(im)
            hits = torch.cat((hits,TE.correct_class(output,label)))
    return hits.to(torch.int64)

from uncertainty.utils import dontknow_mask

class Acc_Coverage(nn.Module):
    def __init__(self,coverage_list):
        super().__init__()
        self.coverage = coverage_list
    def forward(self,model,dataloader):
        with torch.no_grad():
            accs = []
            y_pred,labels = TE.accumulate_results(model,dataloader)
            g = torch.max(y_pred,dim=-1).values
            hits = TE.correct_class(y_pred,labels).float()
            for c in self.coverage:
                dk = torch.logical_not(dontknow_mask(g, c).bool())
                accs.append(torch.mean(hits[dk]).item())
            return accs

In [None]:
risk_dict = {'RC_curve':Acc_Coverage(np.arange(0.1,1,0.1))}
N_EPOCHS= 50
optim_params = {'lr':0.1,'momentum':0.9,'weight_decay':5e-4,'nesterov':True}

#### BCE Loss

##### Loss definition

In [None]:
class BCELoss(nn.BCELoss):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
    def forward(self,y_pred,y_true):
        with torch.no_grad():
            hits = TE.correct_class(y_pred,y_true).float()
        y_pred = torch.max(y_pred,-1).values
        loss = super().forward(y_pred,hits)
        return loss

loss_criterion = BCELoss()

##### Temperature Test

In [None]:
model = Uncertainty_Estimator(pre_model)
model.name = 'Temperature_BCE'
optimizer = torch.optim.SGD(model.parameters(), **optim_params)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS)

##### Train

In [None]:
model = Uncertainty_Estimator(pre_model)
model.name = 'Unc_Estimator_BCE'
optimizer = torch.optim.SGD(model.Unc_Regressor.parameters(), **optim_params)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS)

In [None]:
model_trainer = TE.Trainer(model,optimizer,loss_criterion, data.validation_dataloader,data.test_dataloader,lr_scheduler = scheduler, risk_dict = risk_dict)
model_trainer.fit(data.validation_dataloader,N_EPOCHS, live_plot = LIVE_PLOT,save_checkpoint = SAVE_CHECKPOINT,PATH = weights_path)

In [None]:
fig, (ax1,ax2) = plt.subplots(nrows=1,ncols=2,sharey = True,figsize=(16, 6))
for n,risk in model_trainer.hist_train.risk.items():
    c = n[-2:]
    ax1.plot(risk,label = f'{c=}%')
ax1.legend()
ax1.grid()
ax1.set_title('Coverage accuracy - Train (val) set')
ax1.set_ylabel('Error rate')
ax1.set_xlabel('Epoch')
for n,risk in model_trainer.hist_val.risk.items():
    c = n[-2:]
    ax2.plot(risk,label = f'{c=}%')
ax2.legend()
ax2.grid()
ax2.set_title('Coverage accuracy - Test set')
ax2.set_xlabel('Epoch')
plt.subplots_adjust(wspace=0.05)
plt.show()

##### Test

In [None]:
#Validation (Train2) data
RC = metrics.selective_metrics(pre_model,data.validation_dataloader)
RC.add_uncs({'Trained Activation':model.Unc_Regressor})
RC.d_uncs['Trained Activation'] = -RC.d_uncs['Trained Activation'].view(-1) #Confidence to uncertainty
RC.plot_RC(optimum = True)

In [None]:
#Test Data
RC = metrics.selective_metrics(pre_model,data.test_dataloader)
RC.add_uncs({'Trained Activation':model.Unc_Regressor})
RC.d_uncs['Trained Activation'] = -RC.d_uncs['Trained Activation'].view(-1) #Confidence to uncertainty
RC.plot_RC(optimum = True)