# Libs and data

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Lambda, Compose, Normalize
from collections import defaultdict
from torch.utils.data import random_split
import copy

In [None]:
# Define o computador utilizado como cuda (gpu) se existir ou cpu caso contrário
print(torch.cuda.is_available())
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
import NN_models as models
import uncertainty.comparison as unc_comp
import uncertainty.quantifications as unc
import uncertainty.losses as losses
import uncertainty.train_and_eval_with_g as TE_g
import NN_utils as utils
import NN_utils.train_and_eval as TE

In [None]:
transforms_train = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
transforms_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])

In [None]:
training_data = datasets.CIFAR10(
root="data",
 train=True,
 download=True,
transform=transforms_train)

test_data = datasets.CIFAR10(
root="data",
train=False,
download=True,
transform=transforms_test)

train_size = int(0.95*len(training_data))
val_size = len(training_data) - train_size
training_data, validation_data = random_split(training_data, [train_size, val_size])

validation_data = copy.deepcopy(validation_data)
validation_data.dataset.transform = transforms_test

In [None]:
batch_size = 100
train_dataloader = DataLoader(training_data, batch_size=batch_size,shuffle = True)
validation_dataloader = DataLoader(validation_data, batch_size=batch_size,shuffle = False)
test_dataloader = DataLoader(test_data, batch_size=100)

# Definitions

In [None]:
def dot_grads(grads_1,grads_2):
    loss = -torch.dot(grads_1,grads_2)
    return loss

class loss_grads(torch.nn.Module):
    '''Defines LCE loss - Devries(2018)'''
    def __init__(self,params,criterion = dot_grads):
        super().__init__()
        self.criterion = criterion
        self.params = params
 
    def forward(self, grads_1,grads_2):
        loss_grad = torch.tensor(0.,requires_grad = True)
        for i,(n,p) in enumerate(self.params()):
            if 'g_layer' in n:
                continue
            loss_grad = loss_grad + dot_grads(grads_1[i].view(-1),grads_2[i].view(-1))
        return loss_grad

In [None]:
model = models.Model_CNN_with_g(10).to(dev)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3,momentum = 0.9)

loss_criterion_0 = nn.NLLLoss()
loss_criterion_1 = lambda x,label: torch.mean(loss_criterion_0(x[0],label))
loss_criterion_g = lambda x,label: torch.mean(x[1].view(-1)*loss_criterion_0(x[0],label))

loss_fn = loss_grads(model.named_parameters)

# Train

In [None]:
#analisar se os retain_graph são necessários e como colocar loss_v e grad_v para fora do loop
utils.unfreeze_params(model)
model.train()
n_epochs = 100
for epoch in range(n_epochs):
    
    for image,label in train_dataloader:
        image,label = image.to(dev), label.to(dev)
        optimizer.zero_grad()
        output = model(image)
        loss_t = loss_criterion_1(output,label)
        loss_t.backward()
        optimizer.step()
        

    for image,label in train_dataloader:
        image,label = image.to(dev), label.to(dev)
        optimizer.zero_grad()
        output = model(image)

        loss_t = loss_criterion_g(output,label)
        grads_t = torch.autograd.grad(loss_t, model.parameters(), retain_graph=True, create_graph=True,allow_unused=True)
        
        loss_v = TE.calc_loss_batch(model,loss_criterion_1,validation_dataloader)
        grads_v = torch.autograd.grad(loss_v, model.parameters(), retain_graph=True, create_graph=True,allow_unused=True)
        
        
        utils.ignore_layers(model,['main_layer','classifier_layer'], reset = True)
        loss = loss_fn(grads_t,grads_v)
        loss.backward(retain_graph=True)
        optimizer.step()
        utils.unfreeze_params(model)
        model.train()
        
    print(f'Epoch = {epoch}, main_loss = {loss_t}, grad_loss = {loss}')

In [None]:
torch.autograd.set_detect_anomaly(True)

In [None]:
torch.cuda.empty_cache()

In [None]:
for n,p in model.named_parameters():
    print(n)
    print(p.grad)

In [None]:
loss_fn(grads_t,grads_v)