# Libs and pre-definitions

In [None]:
import os
REPOSITORY_PATH = r'/home/luis-felipe/UncEst'
DATA_PATH = os.path.join(REPOSITORY_PATH,'data')
#CORRUPTED_DATA_PATH = os.path.join(DATA_PATH,'corrupted')

PATH_MODELS = os.path.join(REPOSITORY_PATH,'torch_models')
PATH_TRAINER = os.path.join(PATH_MODELS,'trainer')

### Bibliotecas padrões python e utils pytorch

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
from torch import nn

In [None]:
# Define o computador utilizado como cuda (gpu) se existir ou cpu caso contrário
print(torch.cuda.is_available())
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### Bibliotecas desenvolvidas

https://github.com/lfpc/Uncertainty_Estimation

In [None]:
import NN_models as models
import NN_utils as utils
import NN_utils.train_and_eval as TE
import torch_data
from uncertainty import metrics

# Base Model

## Definitions

In [None]:
MODEL_ARC = 'ResNet50'#WideResNet ResNet101,ResNet18,ResNet34,ResNet50 CNN8
DATA = 'Cifar100'
NAME = f'{MODEL_ARC}_{DATA}_g'

In [None]:
VAL_SIZE = 0.1
data_params = {'train_batch_size': 128, 'validation_size': VAL_SIZE, 'test_batch_size': 100}

In [None]:
TRAIN = True

In [None]:
CREATE_DIR = True #If true, creates directories to save model (weights_path)
LIVE_PLOT = True #If True, plot* loss while training. If 'print', print loss per epoch
SAVE_CHECKPOINT = True #If True, save (and update) model weights for the best epoch (smallest validation loss)
SAVE_ALL = False #If True, saves weights and trainer at the end of training

In [None]:
data = torch_data.__dict__[DATA](data_dir = DATA_PATH,validation_as_train = True,params = data_params)
num_classes = data.n_classes
model_class = models.__dict__[MODEL_ARC]

weights_path = os.path.join(PATH_MODELS,MODEL_ARC,DATA, 'Uncertainty_Regressor')

if CREATE_DIR and not os.path.isdir(weights_path):
    os.makedirs(weights_path)

## Base Model

### Upload Base Model

In [None]:
if not TRAIN:
    if DATA == 'ImageNet':
        weights = models.pretrained_models[model_class]
        pre_model = model_class(weights = weights).to(dev)
        data.transforms_test = weights.transforms()
    else:
        pre_model = model_class(num_classes = data.n_classes).to(dev)
        pre_model.load_state_dict(utils.upload_weights(weights_path))

### Train Base Model

In [None]:
N_EPOCHS_BASE = 200
loss_criterion = nn.CrossEntropyLoss()
risk_dict = None#{'selective_risk_mcp':  lambda x,label: unc_comp.selective_risk(x,label,unc_type = unc.MCP_unc)}

In [None]:

if TRAIN:
    loss_criterion = nn.CrossEntropyLoss()
    pre_model = model_class(num_classes = data.n_classes).to(dev)
    optimizer = torch.optim.SGD(pre_model.parameters(), lr =0.1,momentum = 0.9,weight_decay = 5e-4,nesterov = True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS_BASE)
    model_trainer = TE.Trainer(pre_model,optimizer,loss_criterion, None,None,lr_scheduler = scheduler)
    model_trainer.fit(data.train_dataloader,N_EPOCHS_BASE, live_plot = LIVE_PLOT,save_checkpoint = SAVE_CHECKPOINT,PATH = weights_path)

### Base Model

In [None]:
pre_model.eval()
acc = TE.model_acc(pre_model,data.train_dataloader)
print('Conjunto de treinamento: acc = ', acc)
acc = TE.model_acc(pre_model,data.validation_dataloader)
print('Conjunto de validação: acc = ', acc)
acc = TE.model_acc(pre_model,data.test_dataloader)
print('Conjunto de teste: acc = ', acc)

# Uncertainty Estimator

In [None]:
beta_range = np.around(np.append(np.arange(0.1,1,0.05),np.arange(1,5.1,0.05)),2)

## Model Definition

In [None]:
from torch.nn.functional import one_hot
class Uncertainty_Estimator(nn.Module):
    def __init__(self,model) -> None:
        super().__init__()
        self.Unc_Regressor = nn.Sequential(nn.Linear(data.n_classes,200),
                                        nn.ReLU(),
                                        #nn.Dropout(0.3),
                                        nn.Linear(200,100),
                                        nn.ReLU(),
                                        nn.Linear(100,100),
                                        nn.ReLU(),
                                        #nn.Dropout(0.3),
                                        nn.Linear(100,1), #globalpooling-max
                                        nn.Sigmoid())
        self.name = 'Unc_Estimator'
        self.base_model = model
        self.base_model.eval()
        utils.freeze_params(self.base_model)
    def forward(self,x):
        with torch.no_grad():
            logits = self.base_model(x)
            y_pred = TE.predicted_class(logits).view(-1)
            y = one_hot(y_pred,logits.size(-1))
        g = self.Unc_Regressor(logits)
        return y*g

class Temp_Model(nn.Module):
    def __init__(self, base_model,beta = 1.0) -> None:
        super().__init__()
        self.beta = beta
        self.base_model = base_model
    def forward(self):
        logits = self.base_model(x)
        return self.beta*logits

## Model Training

In [None]:
from uncertainty.metrics import acc_coverage

class Acc_Coverage_AUX():
    def __init__(self):
        self.y_pred = 0
        self.labels = 0
    def update(self,model,dataloader):
        self.y_pred, self.labels = TE.accumulate_results(model,dataloader)
    def get(self):
        return self.y_pred,self.labels
aux = Acc_Coverage_AUX()

class Acc_Coverage(nn.Module):
    def __init__(self,coverage,aux):
        super().__init__()
        self.coverage = coverage
        self.aux = aux
    def forward(self,model,dataloader):
        with torch.no_grad():
            if self.coverage <0.15:
                self.aux.update(model,dataloader)
            y_pred,labels = self.aux.get()
            g = torch.max(y_pred,dim=-1).values
            acc = acc_coverage(y_pred,labels,g,self.coverage)
            return torch.tensor(acc)

In [None]:
c_list = np.around(np.arange(0.1,1,0.1),1)
risk_dict = {f'Acc_{int(c*100)}':Acc_Coverage(c,aux) for c in c_list}
N_EPOCHS= 50
optim_params = {'lr':0.1,'momentum':0.9,'weight_decay':5e-4,'nesterov':True}

### BCE Loss

#### Loss definition

In [None]:
class BCELoss(nn.BCELoss):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
    def forward(self,y_pred,y_true):
        with torch.no_grad():
            hits = TE.correct_class(y_pred,y_true).float()
        y_pred = torch.max(y_pred,-1).values
        loss = super().forward(y_pred,hits)
        return loss

loss_criterion = BCELoss()

#### Temperature Test

In [None]:
model = Temp_Model(pre_model).to(dev)
model.name = 'Temperature_BCE'