# Refinement and G-Prior

In [1]:
%load_ext autoreload
%autoreload 2

## Imports

In [2]:
import numpy as np
import torch
import copy
import plotly.graph_objects as go
import math

import torch.nn as nn
import torch.nn.functional as F

from torch.distributions.normal import Normal
from torch.nn.utils import parameters_to_vector

from data.uci_datasets import UCIData
from main import set_seed, get_device
from models.nets import create_mlp
from trainer import ModelTrainer, NegativeLogLikelihood

from laplace import Laplace

from backpack import backpack, extend
from backpack.extensions import BatchGrad

In [3]:
from hydra import initialize, compose
from omegaconf import OmegaConf
try:
    initialize(version_base=None, config_path="configuration")
except Exception as e:
    print(e)
config = compose(config_name="uci.yaml")
set_seed(config.seed)

## Refined Laplace

### Helper functions

In [4]:
def flatten_nn(model):
    weights = []
    for name, param in model.named_parameters():
        weights.append(param.detach().flatten())
    return torch.cat(weights, dim=0)

### RefinedLaplace Module

In [5]:
class RefinedLaplace(nn.Module):
    def __init__(self, model, output_dim, posterior_covariance):
        super(RefinedLaplace, self).__init__()
        self.model = model
        self.output_dim = output_dim
        self.weights_map = flatten_nn(self.model)
        self.weights = torch.nn.Parameter(torch.clone(self.weights_map), requires_grad=True)
        self.posterior_covariance = posterior_covariance

    def get_parameters(self):
        return self.weights

    def forward(self, X):
        with torch.no_grad():
            f = self.model(X)
        J = self._jacobian(X) 
        out = torch.einsum("ijk,k->ij", J, (self.weights - self.weights_map)) + f
        return out
    
    def predict(self, X):
        with torch.no_grad():
            f = self.model(X)
        J = self._jacobian(X)
        mean = torch.einsum("ijk,k->ij", J, (self.weights - self.weights_map)) + f
        return mean, self._functional_variance(J)
        
    def _functional_variance(self, Js):
        return torch.einsum('ncp,pq,nkq->nck', Js, self.posterior_covariance, Js)
 

    def _jacobian(self, X):
        """
        Compute the jacobian of the model with respect to the input X
        Args:
            X: input tensor
        Returns:
            J: jacobian of the model with respect to X
        """
        model = copy.deepcopy(self.model)
        model.eval()
        model = extend(model)
        Js = []
        for o in range(self.output_dim):
            f = model(X)
            f_o = f.sum(dim=0)[o]

            with backpack(BatchGrad()):
                f_o.backward()
            Jo = []
            for name, param in model.named_parameters():    
                batch_size = param.grad_batch.size(0)
                Jo.append(param.grad_batch.reshape(batch_size, -1))
            Jo = torch.cat(Jo, dim=1)
            Js.append(Jo)
        return torch.stack(Js, dim=1)


### Trainer for refinement

In [6]:
def log_likelihood(y, mu, std):
    dist = Normal(mu.squeeze(), std.squeeze())
    log_probs = dist.log_prob(y.squeeze())
    return log_probs.squeeze().sum().item()

def evaluate_predictive(model, sigma, dataloader, device):
    ll = 0.0
    count = 0
    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)
        f_mu, f_var = model.predict(X)
        f_sigma = torch.sqrt(f_var)
        pred_std = torch.sqrt(f_sigma**2 + sigma**2)
        ll += log_likelihood(y, f_mu, pred_std)
        count += X.shape[0]
    return -ll / count

def evaluate(model, sigma, dataloader, device):
    model.eval()
    criteria = NegativeLogLikelihood(sigma=sigma).to(device)
    err = 0.0
    nll = 0.0
    count = 0
    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)
        batch_size = X.shape[0]
        out = model(X)
        loss = criteria(out, y).mean()
        err += F.mse_loss(out, y, reduction="mean").sqrt().item() * batch_size
        nll += loss.item() * batch_size
        count += batch_size

    nll = nll / count
    err = err / count

    return nll, err

def train(model, sigma, delta, train_dataloader, val_dataloader, epochs, lr, device):
    criteria = NegativeLogLikelihood(sigma=sigma).to(device)
    delta = delta
    theta = model.get_parameters()
    optimizer = torch.optim.Adam(
        [model.get_parameters()], lr=lr
    )

    best_val_nll = math.inf
    best_model = copy.deepcopy(model)
    best_epoch = 0
    for i in range(epochs):
        epoch_err = 0.0
        epoch_nll = 0.0
        count = 0
        model.train()
        for X, y in train_dataloader:
            X = X.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            out = model(X)
            loss = criteria(out, y).sum() + (0.5 * (delta * theta) @ theta) 
            loss.backward()
            optimizer.step()
            batch_size = X.shape[0]
            epoch_err += (
                F.mse_loss(out, y, reduction="mean").sqrt().item() * batch_size
            )
            epoch_nll += loss 
            count += batch_size

        epoch_nll = epoch_nll / count
        epoch_err = epoch_err / count
        val_nll, val_err = evaluate(model, sigma, val_dataloader, device)
        if (i + 1) % 100 == 0:
            print(f"Epoch {i} | Train NLL {epoch_nll} | Val NLL {val_nll} | Train Err {epoch_err} | Val Err {val_err}")
        if val_nll < best_val_nll:
            best_val_nll = val_nll
            best_model = copy.deepcopy(model)
            best_epoch = i
    
    print(f"Best epoch {best_epoch} | Val NLL {best_val_nll}")       
    return best_model

### Train a MAP model

In [7]:
data = UCIData(config.data.path)
meta_data = data.get_metadata()
device = get_device()
train_dataloader, val_dataloader, test_dataloader = data.get_dataloaders(
        dataset=config.data.name,
        batch_size=config.trainer.batch_size,
        seed=config.data.seed,
        val_size=config.data.val_size,
        split_index=config.data.split_index,
        gap=(config.data.split == "GAP"),
    )
trainer = ModelTrainer(config.trainer, device=device)
  

model = create_mlp(
        input_size=meta_data[config.data.name]["input_dim"],
        hidden_sizes=config.model.hidden_sizes,
        output_size=meta_data[config.data.name]["output_dim"],
    )
model = model.to(device=device, dtype=torch.float64)
map_model, sigma = trainer.train(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
    )
print(f"Sigma: {sigma}")

Sigma: 0.6844667196273804


### Train Full Laplace

In [8]:
prior_precisions = np.logspace(0.1, 1, num=5, base=10).tolist()[:-1]  + np.logspace(1, 2, num=10, base=10).tolist()
model_copy = copy.deepcopy(map_model)
la, prior_precision = trainer.train_la_posthoc(
                model=model_copy,
                dataloader=train_dataloader,
                subset_of_weights="all",
                hessian_structure="full",
                sigma_noise=sigma,
                prior_mean=config.trainer.la.prior_mean,
                val_dataloader=val_dataloader,
                prior_precisions=prior_precisions
            )
posterior_covariance = la.posterior_covariance
print(f"Prior precision: {prior_precision}")
test_nll = trainer.evaluate_la(la, test_dataloader)
print(f"Test NLL of Full Laplace Model: {test_nll}")


Prior precision: 12.91549665014884
Test NLL of Full Laplace Model: 1.1229294272116894


In [9]:
refined_model = RefinedLaplace(model=map_model,
                               output_dim=meta_data[config.data.name]["output_dim"],
                               posterior_covariance=posterior_covariance)
refined_model = train(model=refined_model,
                      sigma=sigma,
                      delta=prior_precision,
                      train_dataloader=train_dataloader,
                      val_dataloader=val_dataloader,
                      epochs=1000,
                      lr=1E-3,
                      device=device)

Epoch 99 | Train NLL 1.4743169806260639 | Val NLL 1.244679601421234 | Train Err 0.8205121488803585 | Val Err 0.8112251779832637
Epoch 199 | Train NLL 1.4773469885959392 | Val NLL 1.2425784576949006 | Train Err 0.8255731062484511 | Val Err 0.8050657749520898
Epoch 299 | Train NLL 1.470458523124189 | Val NLL 1.2437502532741185 | Train Err 0.8200930783918375 | Val Err 0.8036910961462111
Epoch 399 | Train NLL 1.4733108158127943 | Val NLL 1.2444934330380664 | Train Err 0.8228016941348018 | Val Err 0.8084724552324211
Epoch 499 | Train NLL 1.4762731391387598 | Val NLL 1.2419099218647964 | Train Err 0.8270678294756338 | Val Err 0.7857737089917228
Epoch 599 | Train NLL 1.475428097998921 | Val NLL 1.2451238848282793 | Train Err 0.8244963341500053 | Val Err 0.8009532876734328
Epoch 699 | Train NLL 1.4767225985579688 | Val NLL 1.2473597660305986 | Train Err 0.8299873416660049 | Val Err 0.8048524221706379
Epoch 799 | Train NLL 1.4722817645542527 | Val NLL 1.2429215914473866 | Train Err 0.8218057859

In [10]:
test_nll = evaluate_predictive(model=refined_model, sigma=sigma, dataloader=test_dataloader, device=device)
print(f"Test NLL of Refined Laplace Model: {test_nll}")

Test NLL of Refined Laplace Model: 1.1130633958749314
