# Refinement and G-Prior

In [1]:
%load_ext autoreload
%autoreload 2

## Imports

In [2]:
import numpy as np
import torch
import copy
import plotly.graph_objects as go
import math

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import ConcatDataset, DataLoader
from torch.distributions.normal import Normal

from data.uci_datasets import UCIData
from main import set_seed, get_device
from util.plots import plot_data, plot_regression, plot_bayesian_regression
from models.nets import create_mlp
from trainer import ModelTrainer, NegativeLogLikelihood

from laplace import Laplace

from backpack import backpack, extend
from backpack.extensions import BatchGrad

In [3]:
from hydra import initialize, compose
from omegaconf import OmegaConf
try:
    initialize(version_base=None, config_path="configuration")
except Exception as e:
    print(e)
config = compose(config_name="uci.yaml")
set_seed(config.seed)

## Refined Laplace

### Helper functions

In [4]:
def flatten_nn(model):
    weights = []
    for name, param in model.named_parameters():
        weights.append(param.detach().flatten())
    return torch.cat(weights, dim=0)

### RefinedLaplace Module

In [5]:
class RefinedLaplace(nn.Module):
    def __init__(self, model, output_dim, posterior_covariance):
        super(RefinedLaplace, self).__init__()
        self.model = model
        self.output_dim = output_dim
        self.weights_map = flatten_nn(self.model)
        self.weights = torch.nn.Parameter(torch.zeros_like(self.weights_map), requires_grad=True)
        self.posterior_covariance = posterior_covariance

    def forward(self, X):
        with torch.no_grad():
            f = self.model(X)
        J = self._jacobian(X) 
        out = torch.einsum("ijk,k->ij", J, (self.weights - self.weights_map))
        return out
    
    def predict(self, X):
        with torch.no_grad():
            f = self.model(X)
        J = self._jacobian(X)
        mean = torch.einsum("ijk,k->ij", J, (self.weights - self.weights_map)) + f
        return mean, self._functional_variance(J)
        
    def _functional_variance(self, Js):
        return torch.einsum('ncp,pq,nkq->nck', Js, self.posterior_covariance, Js)
 

    def _jacobian(self, X):
        """
        Compute the jacobian of the model with respect to the input X
        Args:
            X: input tensor
        Returns:
            J: jacobian of the model with respect to X
        """
        model = copy.deepcopy(self.model)
        model.eval()
        model = extend(model)
        Js = []
        for o in range(self.output_dim):
            f = model(X)
            f_o = f.sum(dim=0)[o]

            with backpack(BatchGrad()):
                f_o.backward()
            Jo = []
            for name, param in model.named_parameters():    
                batch_size = param.grad_batch.size(0)
                Jo.append(param.grad_batch.reshape(batch_size, -1))
            Jo = torch.cat(Jo, dim=1)
            Js.append(Jo)
        return torch.stack(Js, dim=1)


### Trainer for refinement

In [6]:
def log_likelihood(y, mu, std):
    dist = Normal(mu.squeeze(), std.squeeze())
    log_probs = dist.log_prob(y.squeeze())
    return log_probs.squeeze().sum().item()

def evaluate(model, sigma, dataloader, device):
    ll = 0.0
    count = 0
    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)
        f_mu, f_var = model.predict(X)
        f_sigma = torch.sqrt(f_var)
        pred_std = torch.sqrt(f_sigma**2 + sigma**2)
        ll += log_likelihood(y, f_mu, pred_std)
        count += X.shape[0]
    return -ll / count

def train(model, sigma, train_dataloader, val_dataloader, epochs, lr, device):
    criteria = NegativeLogLikelihood(sigma=sigma).to(device)
    optimizer = torch.optim.SGD(
        model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-3
    )
    best_val_nll = math.inf
    best_model = copy.deepcopy(model)
    for i in range(epochs):
        epoch_err = 0.0
        epoch_nll = 0.0
        count = 0
        model.train()
        for X, y in train_dataloader:
            X = X.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            out = model(X)
            loss = criteria(out, y).mean()    
            loss.backward()
            optimizer.step()
            batch_size = X.shape[0]
            epoch_err += (
                F.mse_loss(out, y, reduction="mean").sqrt().item() * batch_size
            )
            epoch_nll += loss * batch_size
            count += batch_size

        epoch_nll = epoch_nll / count
        epoch_err = epoch_err / count
        val_nll = evaluate(model, sigma, val_dataloader, device)
        print(f"Epoch {i} | Train NLL {epoch_nll} | Val NLL {val_nll}")
        if val_nll < best_val_nll:
            best_val_nll = val_nll
            best_model = copy.deepcopy(model)
            
    return best_model

### Train a MAP model

In [7]:
data = UCIData(config.data.path)
meta_data = data.get_metadata()
device = get_device()
train_dataloader, val_dataloader, test_dataloader = data.get_dataloaders(
        dataset=config.data.name,
        batch_size=config.trainer.batch_size,
        seed=config.data.seed,
        val_size=config.data.val_size,
        split_index=config.data.split_index,
        gap=(config.data.split == "GAP"),
    )
trainer = ModelTrainer(config.trainer, device=device)
  

model = create_mlp(
        input_size=meta_data[config.data.name]["input_dim"],
        hidden_sizes=config.model.hidden_sizes,
        output_size=meta_data[config.data.name]["output_dim"],
    )
model = model.to(device=device, dtype=torch.float64)
map_model, sigma = trainer.train(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
    )

In [8]:
print(f"Sigma: {sigma}")

Sigma: 0.6844667196273804


### Train Full Laplace

In [9]:
prior_precisions = np.logspace(0.1, 1, num=5, base=10).tolist()[:-1]  + np.logspace(1, 2, num=10, base=10).tolist()
print(prior_precisions)

[1.2589254117941673, 2.113489039836647, 3.548133892335755, 5.956621435290105, 10.0, 12.91549665014884, 16.68100537200059, 21.544346900318832, 27.825594022071243, 35.938136638046274, 46.41588833612777, 59.94842503189409, 77.4263682681127, 100.0]


In [10]:
model_copy = copy.deepcopy(map_model)
la, prior_precision = trainer.train_la_posthoc(
                model=model_copy,
                dataloader=train_dataloader,
                subset_of_weights="all",
                hessian_structure="full",
                sigma_noise=sigma,
                prior_mean=config.trainer.la.prior_mean,
                val_dataloader=val_dataloader,
                prior_precisions=prior_precisions
            )


In [11]:
print(f"Prior precision: {prior_precision}")

Prior precision: 12.91549665014884


In [12]:
posterior_covariance = la.posterior_covariance

In [13]:
trainer.evaluate_la(la, test_dataloader)

1.1229294272116894

In [14]:
refined_model = RefinedLaplace(model=map_model,
                               output_dim=meta_data[config.data.name]["output_dim"],
                               posterior_covariance=posterior_covariance)

In [15]:
refined_model = train(model=refined_model, sigma=sigma, train_dataloader=train_dataloader, val_dataloader=val_dataloader, epochs=100, lr=config.trainer.lr, device=device)

Epoch 0 | Train NLL 3.8988862386182075 | Val NLL 1.5421363804779908
Epoch 1 | Train NLL 1.8634767322968617 | Val NLL 1.8851955220958183
Epoch 2 | Train NLL 1.7045278392984051 | Val NLL 1.7735799267865997
Epoch 3 | Train NLL 1.625040727521054 | Val NLL 1.6767984284934538
Epoch 4 | Train NLL 1.5712423662593502 | Val NLL 1.7112473170944171
Epoch 5 | Train NLL 1.5218078680311922 | Val NLL 1.673452442461878
Epoch 6 | Train NLL 1.4857665014502666 | Val NLL 1.6566689417755698
Epoch 7 | Train NLL 1.4564540302919082 | Val NLL 1.6879857617765508
Epoch 8 | Train NLL 1.4335647523985535 | Val NLL 1.7116341245895563
Epoch 9 | Train NLL 1.4140337012072983 | Val NLL 1.6887787293293717
Epoch 10 | Train NLL 1.3996436983239993 | Val NLL 1.6709366781102506
Epoch 11 | Train NLL 1.3772540326105247 | Val NLL 1.6424996599877384
Epoch 12 | Train NLL 1.3624814909216192 | Val NLL 1.676260540699261
Epoch 13 | Train NLL 1.348009006260022 | Val NLL 1.6911219140112466
Epoch 14 | Train NLL 1.334487074329372 | Val NLL

In [16]:
evaluate(model=refined_model, sigma=sigma, dataloader=test_dataloader, device=device)

1.3661867485730306