# Refinement and G-Prior

In [1]:
%load_ext autoreload
%autoreload 2

## Imports

In [2]:
import numpy as np
import torch
import copy
import plotly.graph_objects as go
import math

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import ConcatDataset, DataLoader
from torch.distributions.normal import Normal
from torch.nn.utils import parameters_to_vector

from data.snelson1d import Snelson1D
from main import set_seed, get_device
from util.plots import plot_data, plot_regression, plot_bayesian_regression
from models.nets import create_mlp
from trainer import ModelTrainer, NegativeLogLikelihood

from laplace import Laplace

from backpack import backpack, extend
from backpack.extensions import BatchGrad

In [3]:
from hydra import initialize, compose
from omegaconf import OmegaConf
try:
    initialize(version_base=None, config_path="configuration")
except Exception as e:
    print(e)
config = compose(config_name="snelson.yaml")
set_seed(config.seed)

## Refined Laplace

### Helper functions

In [4]:
def flatten_nn(model):
    weights = []
    for name, param in model.named_parameters():
        weights.append(param.detach().flatten())
    return torch.cat(weights, dim=0)

### RefinedLaplace Module

In [5]:
class RefinedLaplace(nn.Module):
    def __init__(self, model, output_dim, posterior_covariance):
        super(RefinedLaplace, self).__init__()
        self.model = model
        self.output_dim = output_dim
        self.weights_map = flatten_nn(self.model)
        self.weights = torch.nn.Parameter(torch.ones_like(self.weights_map).detach(), requires_grad=True)
        self.posterior_covariance = posterior_covariance

    def get_parameters(self):
        return self.weights

    def forward(self, X):
        with torch.no_grad():
            f = self.model(X)
        J = self._jacobian(X) 
        out = torch.einsum("ijk,k->ij", J, (self.weights - self.weights_map)) + f
        return out
    
    def predict(self, X):
        with torch.no_grad():
            f = self.model(X)
        J = self._jacobian(X)
        mean = torch.einsum("ijk,k->ij", J, (self.weights - self.weights_map)) + f
        return mean, self._functional_variance(J)
        
    def _functional_variance(self, Js):
        return torch.einsum('ncp,pq,nkq->nck', Js, self.posterior_covariance, Js)
 

    def _jacobian(self, X):
        """
        Compute the jacobian of the model with respect to the input X
        Args:
            X: input tensor
        Returns:
            J: jacobian of the model with respect to X
        """
        model = copy.deepcopy(self.model)
        model.eval()
        model = extend(model)
        Js = []
        for o in range(self.output_dim):
            f = model(X)
            f_o = f.sum(dim=0)[o]

            with backpack(BatchGrad()):
                f_o.backward()
            Jo = []
            for name, param in model.named_parameters():    
                batch_size = param.grad_batch.size(0)
                Jo.append(param.grad_batch.reshape(batch_size, -1))
            Jo = torch.cat(Jo, dim=1)
            Js.append(Jo)
        return torch.stack(Js, dim=1)


### Trainer for refinement

In [6]:
def log_likelihood(y, mu, std):
    dist = Normal(mu.squeeze(), std.squeeze())
    log_probs = dist.log_prob(y.squeeze())
    return log_probs.squeeze().sum().item()

def evaluate_predictive(model, sigma, dataloader, device):
    ll = 0.0
    count = 0
    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)
        f_mu, f_var = model.predict(X)
        f_sigma = torch.sqrt(f_var)
        pred_std = torch.sqrt(f_sigma**2 + sigma**2)
        ll += log_likelihood(y, f_mu, pred_std)
        count += X.shape[0]
    return -ll / count

def evaluate(model, sigma, dataloader, device):
    model.eval()
    criteria = torch.nn.MSELoss(reduction='sum')
    err = 0.0
    nll = 0.0
    count = 0
    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)
        batch_size = X.shape[0]
        out = model(X)
        loss = criteria(out, y).mean()
        err += F.mse_loss(out, y, reduction="mean").sqrt().item() * batch_size
        nll += loss.item() * batch_size
        count += batch_size

    nll = nll / count
    err = err / count

    return nll, err

def train(model, sigma, delta, train_dataloader, val_dataloader, epochs, lr, device):
    criteria =  NegativeLogLikelihood(sigma=sigma).to(device)
    optimizer = torch.optim.Adam(
        [model.get_parameters()], lr=lr
    )
    theta = model.get_parameters()
    best_val_nll = math.inf
    best_model = copy.deepcopy(model)
    best_epoch = 0
    for i in range(epochs):
        epoch_err = 0.0
        epoch_nll = 0.0
        count = 0
        model.train()
        for X, y in train_dataloader:
            X = X.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            out = model(X)
            loss =  criteria(out, y).sum() + (0.5 * (delta * theta) @ theta)
            loss.backward()
            optimizer.step()
            batch_size = X.shape[0]
            epoch_err += (
                F.mse_loss(out, y, reduction="mean").sqrt().item() * batch_size
            )
            epoch_nll += loss * batch_size
            count += batch_size

        epoch_nll = epoch_nll / count
        epoch_err = epoch_err / count
        val_nll, val_err = evaluate(model, sigma, val_dataloader, device)
        if (i + 1) % 100 == 0:
            print(f"Epoch {i} | Train NLL {epoch_nll} | Val NLL {val_nll} | Train Err {epoch_err} | Val Err {val_err}")
        if val_nll < best_val_nll:
            best_val_nll = val_nll
            best_model = copy.deepcopy(model)
            best_epoch = i
    print(f"Best epoch {best_epoch} | Val NLL {best_val_nll}")
    return best_model

### Train a MAP model

In [7]:
snelson1d = Snelson1D(config.data.path)
train_dataloader, val_dataloader, test_dataloader = snelson1d.get_dataloaders(batch_size=config.trainer.batch_size, val_size=config.data.val_size, random_state=config.data.seed)
device = get_device()
trainer = ModelTrainer(config.trainer, device=device)
  
input_size = 1
output_size = 1

model = create_mlp(
        input_size=input_size,
        hidden_sizes=config.model.hidden_sizes,
        output_size=output_size,
    )
model = model.to(device=device, dtype=torch.float64)
map_model, sigma = trainer.train(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
    )
print(f"Sigma: {sigma}")

Sigma: 0.08778076618909836


### Train Full Laplace

In [8]:
prior_precisions = np.logspace(0.1, 1, num=5, base=10).tolist()[:-1]  + np.logspace(1, 2, num=10, base=10).tolist()
model_copy = copy.deepcopy(map_model)
la, prior_precision = trainer.train_la_posthoc(
                model=model_copy,
                dataloader=train_dataloader,
                subset_of_weights="all",
                hessian_structure="full",
                sigma_noise=sigma,
                prior_mean=config.trainer.la.prior_mean,
                val_dataloader=val_dataloader,
                prior_precisions=prior_precisions
            )
posterior_covariance = la.posterior_covariance
print(f"Prior precision: {prior_precision}")


Prior precision: 100.0


### Evaluate Full Laplace

In [9]:
nll_la = trainer.evaluate_la(la, test_dataloader)
print(f"NLL of Full Laplace for the test dataset: {nll_la}")

NLL of Full Laplace for the test dataset: 0.20758638481663608


In [10]:
refined_model = RefinedLaplace(model=map_model,
                               output_dim=output_size,
                               posterior_covariance=posterior_covariance)
refined_model = train(model=refined_model,
                      sigma=sigma,
                      delta=prior_precision,
                      train_dataloader=train_dataloader,
                      val_dataloader=val_dataloader,
                      epochs=5000,
                      lr=1E-3,
                      device=device)

Epoch 99 | Train NLL 233414.0571754079 | Val NLL 1156.1304200179866 | Train Err 8.558508630166422 | Val Err 8.500479471837114
Epoch 199 | Train NLL 99652.39754992897 | Val NLL 285.1373845420672 | Train Err 4.1171036306663416 | Val Err 4.221502876213541
Epoch 299 | Train NLL 67391.62685267218 | Val NLL 145.77366069932168 | Train Err 3.2991782291709097 | Val Err 3.018419088481188
Epoch 399 | Train NLL 51265.571458328355 | Val NLL 106.7673282727636 | Train Err 2.858250334086449 | Val Err 2.583206924938017
Epoch 499 | Train NLL 40882.28696702542 | Val NLL 82.7769696721826 | Train Err 2.4823640441568346 | Val Err 2.274546241453757
Epoch 599 | Train NLL 33114.20258941136 | Val NLL 63.03082495407801 | Train Err 2.1486670959679546 | Val Err 1.9847988713292528
Epoch 699 | Train NLL 26479.66879691529 | Val NLL 46.829598021596055 | Train Err 1.854338488348938 | Val Err 1.7108038684635225
Epoch 799 | Train NLL 21290.906690468593 | Val NLL 33.954570330100466 | Train Err 1.5360553052158963 | Val Err

In [11]:
nll_refined = evaluate_predictive(model=refined_model, sigma=sigma, dataloader=test_dataloader, device=device)
print(f"NLL of Refined Laplace for the test dataset: {nll_refined}")

NLL of Refined Laplace for the test dataset: 2.708476848621137


### Plot Refined Laplace

In [12]:
def bayesian_regression(model, sigma, train_dataloader, test_dataloader, title):
    X_train = train_dataloader.dataset.X.numpy().squeeze()
    y_train = train_dataloader.dataset.y.numpy().squeeze()
    X_test = test_dataloader.dataset.X.numpy().squeeze()
    y_test = test_dataloader.dataset.y.numpy().squeeze()
    X_test = np.concatenate([X_train, X_test]).reshape(-1, 1)
    X = torch.from_numpy(X_test).to(device=device, dtype=torch.float64)
    f_mu, f_var = model.predict(X)
    f_mu = f_mu.detach().squeeze().cpu().numpy()
    pred_std = torch.sqrt(f_var.squeeze() + sigma**2).detach().cpu().numpy()
    return plot_bayesian_regression(X_train=X_train, y_train=y_train, X_test=X.squeeze().detach().cpu().numpy(), y_test=f_mu, y_std=pred_std, title=title)


In [13]:
fig = bayesian_regression(refined_model, sigma, train_dataloader, test_dataloader, title="Refined Laplace")
fig.show()

### Plot Full Laplace

In [14]:
def bayesian_regression(model, train_dataloader, test_dataloader, title):
    X_train = train_dataloader.dataset.X.numpy().squeeze()
    y_train = train_dataloader.dataset.y.numpy().squeeze()
    X_test = test_dataloader.dataset.X.numpy().squeeze()
    y_test = test_dataloader.dataset.y.numpy().squeeze()
    X_test = np.concatenate([X_train, X_test]).reshape(-1, 1)
    X = torch.from_numpy(X_test).to(device=device, dtype=torch.float64)
    f_mu, f_var = model(x=X)
    f_mu = f_mu.detach().squeeze().cpu().numpy()
    pred_std = torch.sqrt(f_var.squeeze() + model.sigma_noise**2).detach().cpu().numpy()
    return plot_bayesian_regression(X_train=X_train, y_train=y_train, X_test=X.squeeze().detach().cpu().numpy(), y_test=f_mu, y_std=pred_std, title=title)


In [15]:
fig = bayesian_regression(la, train_dataloader, test_dataloader, title="Laplace")
fig.show()