In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Lambda, Compose, Normalize
from collections import defaultdict
from torch.utils.data import random_split
import copy

from NN_utils import *
from NN_utils.train_and_eval import *
from uncertainty import train_NN_with_g
from uncertainty.losses import penalized_uncertainty
import uncertainty.comparison as unc_comp
import uncertainty.quantifications as unc


from torch.utils.data import random_split

In [None]:
print(torch.cuda.is_available())
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
transforms_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
transforms_ = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])

In [None]:
training_data = datasets.CIFAR100(
root="data",
 train=True,
 download=True,
transform=transforms_train)

test_data = datasets.CIFAR100(
root="data",
train=False,
download=True,
transform=transforms_)

#train_size = int(0.5*len(training_data))
#val_size = len(training_data) - train_size
#training_data, validation_data = random_split(training_data, [train_size, val_size])

In [None]:
training_data_90 = dataset_cut_classes(training_data,indices = range(90))
test_data_90 = dataset_cut_classes(test_data,indices = range(90))
test_data_extra = dataset_cut_classes(test_data,indices = range(90,100))

In [None]:
batch_size = 12
train_dataloader = DataLoader(training_data, batch_size=batch_size,shuffle = True)
train_dataloader_90 = DataLoader(training_data_90, batch_size=batch_size,shuffle = True)
#validation_dataloader = DataLoader(validation_data, batch_size=batch_size,shuffle = True)
test_dataloader = DataLoader(test_data, batch_size=100)

test_dataloader_90 = DataLoader(test_data_90, batch_size=100)
test_dataloader_extra = DataLoader(test_data_extra, batch_size=100)

In [None]:
class hist_train():
    def __init__(self,model,loss_criterion,data):
        
        self.model = model
        self.loss_criterion = loss_criterion
        self.data = data
        
        self.acc_list = []
        self.loss_list = []
        self.bce_iter = []
    
        
    def update_hist(self,data = None):
        
        if data is None:
            data = self.data
        
        with torch.no_grad():
            for image,label in data:
                image,label = image.to(dev),label.to(dev)
                output = self.model(image)
                loss = self.loss_criterion(output,label).item()
                acc = correct_total(output,label)/label.size(0)

                self.acc_list.append(acc)
                self.loss_list.append(loss)
                
    def __str__(self):
        pass
            
                
class hist_train_g(hist_train):
    def __init__(self,model,loss_criterion,data,c = 0):
        super().__init__(model,loss_criterion,data)
        
        self.c = c
        self.g_list = []
        if c>0:
            self.acc_c_g = []
            self.acc_c_mcp = []
            
    def update_hist(self,data = None):
        
        if data is None:
            data = self.data
            
        with torch.no_grad():

            label, output, g = accumulate_results(self.model,data)
            loss = self.loss_criterion(output,label).item()
            output = torch.exp(output)
            acc = correct_total(output,label)/label.size(0)

            self.g_list.append(torch.mean(g).item())
            self.acc_list.append(acc)
            self.loss_list.append(loss)

            mcp = unc.MCP_unc(output)

            if self.c>0:
                self.acc_c_g.append(unc_comp.acc_coverage(output,label,1-g,self.c))
                self.acc_c_mcp.append(unc_comp.acc_coverage(output,label,mcp,self.c))
            
        
    
class Trainer():
    def __init__(self,model,optimizer,loss_criterion,training_data,validation_data):
        super().__init__()
        
        loss_crit = copy.deepcopy(loss_criterion.criterion)
        loss_crit.reduction = 'mean'
        
        self.hist_train = hist_train(model,loss_crit,training_data)
        self.hist_val = hist_train(model,loss_crit,validation_data)
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_criterion

    def fit(self,data,n_epochs):
        loss_criterion = copy.deepcopy(self.loss_fn.criterion)
        loss_criterion.reduction = 'mean'
        for epoch in range(1,n_epochs+1):
            train_NN(self.model,self.optimizer,data,loss_criterion,1, set_train_mode = True)
            self.hist_train.update_hist()
            self.hist_val.update_hist()
            
            
            
class Trainer_with_g(Trainer):
    def __init__(self,model,optimizer,loss_fn,training_data,validation_data, c = 0.2):
        super().__init__(model,optimizer,loss_fn,training_data,validation_data)
        
        loss_criterion = copy.deepcopy(loss_fn.criterion)
        loss_criterion.reduction = 'mean'
        
        self.hist_train = hist_train_g(model,loss_criterion,training_data)
        self.hist_val = hist_train_g(model,loss_criterion,validation_data,c=c)
    
    def fit_all(self,data,n_epochs):
        unfreeze_params(self.model)
        for epoch in range(1,n_epochs+1):
            train_NN_with_g(self.model,self.optimizer,data,self.loss_fn,n_epochs=1, print_loss = True,set_train_mode = True)
            self.hist_train.update_hist()
            self.hist_val.update_hist()
            self.loss_fn.update_L0(self.hist_train.loss_list[-1])

    def fit_g(self,data,n_epochs,ignored_layers = ['main_layer','classifier_layer']):
        
        
        loss = self.hist_val.loss_list[-1]
        self.loss_fn.update_L0(loss)
        ignore_layers(self.model,ignored_layers, reset = True)
        for epoch in range(1,n_epochs+1):
            train_NN_with_g(self.model,self.optimizer,data,self.loss_fn,n_epochs=1, print_loss = True,set_train_mode = False)
            self.hist_train.update_hist()
            self.hist_val.update_hist()
            ignore_layers(self.model,ignored_layers, reset = False)
            loss = self.hist_val.loss_list[-1]
            self.loss_fn.update_L0(loss)

In [None]:

from NN_models.CIFAR100 import Model_CNN_100_with_g
model_100 = Model_CNN_100_with_g()
model_100.to(dev)

from NN_models import Model_CNN_with_g
model_90 = Model_CNN_with_g(90)
model_90.to(dev);

In [None]:
model_100 = Model_CNN_100_with_g()
model_100.to(dev)
optimizer = torch.optim.SGD(model_100.parameters(), lr=1e-3)

loss_criterion = nn.NLLLoss(reduction = 'none')
loss_fn = penalized_uncertainty(loss_criterion,np.log(10))

model_trainer_100 = Trainer_with_g(model_100,optimizer,loss_fn, train_dataloader,validation_dataloader,c = 0.2)
model_trainer_100.fit_all(train_dataloader,80)
acc, g, bce = model_metrics(model_100,loss_criterion,train_dataloader)
print('Conjunto de treinamento: acc = ', acc, 'média de g = ', g, 'média de bce = ', bce, '\n')
acc, g, bce = model_metrics(model_100,loss_criterion,test_dataloader)
print('Conjunto de teste: acc = ', acc, 'média de g = ', g, 'média de bce = ', bce, '\n')

In [None]:
model_90 = Model_CNN_with_g(90)
model_90.to(dev);
optimizer = torch.optim.SGD(model_90.parameters(), lr=1e-3)

loss_criterion = nn.NLLLoss(reduction = 'none')
loss_fn = penalized_uncertainty(loss_criterion,np.log(10))

model_trainer_90 = Trainer_with_g(model_90,optimizer,loss_fn, train_dataloader,validation_dataloader,c = 0.2)
model_trainer_90.fit_all(train_dataloader,80)
acc, g, bce = model_metrics(model_90,loss_criterion,train_dataloader)
print('Conjunto de treinamento: acc = ', acc, 'média de g = ', g, 'média de bce = ', bce, '\n')
acc, g, bce = model_metrics(model_90,loss_criterion,test_dataloader)
print('Conjunto de teste: acc = ', acc, 'média de g = ', g, 'média de bce = ', bce, '\n')