# Kronecker-Factored Laplace Approximation

_Kronecker-factored Laplace approximation (KFLA)_ is a post hoc method for model weight posterior approximation based on the Laplace approximation of the negative log-posterior and Kronecker factoring of this approximation for tractability. In Bensemble, we implement this powerful post hoc tool in the `LaplaceApproximation` class. In this notebook, we'll take you through basic usage of this class for probabilistic machine learning tasks and model ensembling.

## Prerequisites

In [None]:
!pip install torchvision tqdm

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np

from tqdm import tqdm

from bensemble.methods.laplace_approximation import LaplaceApproximation

## Example model and training

Below we provide an example of a small neural network model and a training loop to train it on the MNIST dataset.

In [None]:
# Neural Network Model

class SimpleNN(nn.Module):
    """Simple feedforward neural network for MNIST"""

    def __init__(
        self, input_dim=784, hidden_dims=[1200, 1200], output_dim=10, dropout_rate=0.0
    ):
        super(SimpleNN, self).__init__()
        layers = []
        prev_dim = input_dim

        for i, hidden_dim in enumerate(hidden_dims):
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, output_dim))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.network(x)

In [None]:
"""Training Function"""

def train_model(
    model, train_loader, val_loader, num_epochs=50, lr=1e-3, weight_decay=1e-4
):
    """Train a neural network model"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    print("Starting training...")
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = output.max(1)
            train_total += target.size(0)
            train_correct += predicted.eq(target).sum().item()

            if batch_idx % 100 == 0:
                print(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}")

        train_accuracy = 100.0 * train_correct / train_total
        train_losses.append(train_loss / len(train_loader))
        train_accuracies.append(train_accuracy)

        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)

                val_loss += loss.item()
                _, predicted = output.max(1)
                val_total += target.size(0)
                val_correct += predicted.eq(target).sum().item()

        val_accuracy = 100.0 * val_correct / val_total
        val_losses.append(val_loss / len(val_loader))
        val_accuracies.append(val_accuracy)

        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracy:.2f}%")
        print(f"  Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_accuracy:.2f}%")
        print("-" * 50)

    return {
        "train_losses": train_losses,
        "val_losses": val_losses,
        "train_accuracies": train_accuracies,
        "val_accuracies": val_accuracies,
    }

In [None]:
"""Training pipeline"""
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Data preparation
print("Loading and preparing data...")
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

train_dataset = datasets.MNIST(
    "../data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST("../data", train=False, transform=transform)

train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = torch.utils.data.random_split(
    train_dataset, [train_size, val_size]
)

train_loader = DataLoader(train_subset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Model training
print("\nTraining model...")
model = SimpleNN(input_dim=784, hidden_dims=[1200, 1200], output_dim=10)

training_history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=2,
    lr=1e-3,
    weight_decay=1e-4,
)

## Creating Laplace approximation class

The first thing you would want to do with `LaplaceApproximation` is create an instance of the class and then run either the `compute_posterior` or the `fit` method to compute the Kronecker factors used to sample models for prediction later. Both `compute_posterior` and `fit` take a PyTorch DataLoader instance, prior precision and sample count as input and internally compute Kronecker factors for the approximation as well as their square roots for sampling.

Since KFLA's main strength is its post hoc computation, the standard usage scenario for `LaplaceApproximation` is creation of the approximation for a pretrained model. Thus, the `pretrained` parameter of `LaplaceApproximation` is set to `True` by default, and in this case the `fit` method does exactly the same thing as the `compute_posterior` one. In case you really need it, you also have the option to set `pretrained=True` when creating the Laplace approximation instance. In this case the `fit` method will train the model with a simple built-in training loop before computing the posterior approximation.

In [None]:
print("Creating KFLA...")

# Make sure the model is on the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

### Create the Laplace approximation. ###
# Option to set likelihood to 'regression' for regression tasks with an MSE loss
# Verbose is off by default, but here we turn it on for demonstration purposes.
# You can toggle verbose on and off with the .toggle_verbose() method
laplace = LaplaceApproximation(model, likelihood="classification", verbose=True)

# Compute the posterior
print("Computing posterior...")
laplace.compute_posterior(
    train_loader=train_loader, prior_precision=1.0, num_samples=len(train_dataset)
)

## KFLA evaluation

In the `predict` method, `LaplaceApproximation` samples `n_samples` models from the posterior using precomputed sampling factors and makes a prediction with each one, after which the results are aggregated to produce a final answer as well as the uncertainty of the answer.

In [None]:
# Evaluation functions


def evaluate_laplace_model(laplace, test_loader, num_samples=10, temperature=1.0):
    """Evaluate the model with Laplace approximation"""
    laplace.model.eval()
    correct = 0
    total = 0
    uncertainties = []

    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data, target = data.to(device), target.to(device)

            mean_probs, uncertainty = laplace.predict(
                data, n_samples=num_samples, temperature=temperature
            )

            _, predicted = mean_probs.max(1)

            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            uncertainties.extend(uncertainty.cpu().numpy())

    accuracy = 100.0 * correct / total
    avg_uncertainty = np.mean(uncertainties)

    return accuracy, avg_uncertainty, uncertainties


def evaluate_standard_model(model, test_loader):
    """Evaluate the standard model"""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    accuracy = 100.0 * correct / total
    return accuracy

In [None]:
# Test standard model first
print("\nEvaluating standard model...")
standard_accuracy = evaluate_standard_model(model, test_loader)
print(f"Standard Model Accuracy: {standard_accuracy:.2f}%")

In [None]:
accuracy, uncertainty, _ = evaluate_laplace_model(
    laplace, test_loader, num_samples=10, temperature=1.0
)
print(
    f"10 samples, T = 1.0: Accuracy = {accuracy:.2f}%, Uncertainty = {uncertainty:.4f}"
)

While we have optimized model sampling as much as could be optimized, model inference is certainly not the fastest due to random sampling and additional matrix multiplication for each sampled model. A way to make predictions faster is to sample an ensemble of models beforehand and then make predictions with the ensemble afterwards. We'll show an example of model sampling later in this demo.

In [None]:
print("\nTesting Laplace with different sample counts:")
for n_samples in [10, 20, 50, 100]:
    accuracy, uncertainty, _ = evaluate_laplace_model(
        laplace, test_loader, num_samples=n_samples, temperature=1.0
    )
    print(
        f"{n_samples} samples, T = 1.0: Accuracy = {accuracy:.2f}%, Uncertainty = {uncertainty:.4f}"
    )

You can also change the posterior temperature in the `predict` method to control the model's certainty about its predictions. 

In [None]:
print("\nTesting Laplace with different temperatures:")
temperatures = [0.1, 0.5, 1, 2]

for temp in temperatures:
    accuracy, uncertainty, _ = evaluate_laplace_model(
        laplace, test_loader, num_samples=10, temperature=temp
    )
    print(
        f"Temperature {temp}: Accuracy = {accuracy:.2f}%, Uncertainty = {uncertainty:.4f}"
    )

Here, we analyze average uncertainty for correct and incorrect predictions. Note that uncertainty is on average slightly larger for incorrect answers, which is expected behavior.

In [None]:
def analyze_uncertainty_by_accuracy(laplace, test_loader, num_samples=10):
    """Analyze uncertainty for correct vs incorrect predictions"""
    laplace.model.eval()
    correct_uncertainties = []
    incorrect_uncertainties = []

    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data, target = data.to(device), target.to(device)

            mean_probs, uncertainty = laplace.predict(data, n_samples=num_samples)

            _, predicted = mean_probs.max(1)

            # Separate uncertainties for correct and incorrect predictions
            correct_mask = predicted.eq(target)
            incorrect_mask = ~correct_mask

            if correct_mask.any():
                correct_uncertainties.extend(uncertainty[correct_mask].cpu().numpy())
            if incorrect_mask.any():
                incorrect_uncertainties.extend(
                    uncertainty[incorrect_mask].cpu().numpy()
                )

    return correct_uncertainties, incorrect_uncertainties

In [None]:
correct_unc, incorrect_unc = analyze_uncertainty_by_accuracy(laplace, test_loader)

print("\nUncertainty Analysis:")
print(f"Average uncertainty for correct predictions: {np.mean(correct_unc):.4f}")
print(f"Average uncertainty for incorrect predictions: {np.mean(incorrect_unc):.4f}")
print(f"Number of correct predictions: {len(correct_unc)}")
print(f"Number of incorrect predictions: {len(incorrect_unc)}")

Let's also take a look at single predictions and their uncertainties:

In [None]:
# Example of getting predictions with uncertainty for a single batch
print("\nTesting on a single batch...")
data_iter = iter(test_loader)
test_data, test_target = next(data_iter)
test_data, test_target = test_data.to(device), test_target.to(device)

mean_probs, uncertainty = laplace.predict(test_data, n_samples=10)
_, predicted = mean_probs.max(1)

print("Predictions for first 10 test samples:")
for i in range(min(10, len(test_data))):
    print(
        f"  Sample {i+1}: True={test_target[i].item()}, Pred={predicted[i].item()}, "
        f"Uncertainty={uncertainty[i].item():.4f}, "
        f"Correct={predicted[i].eq(test_target[i]).item()}"
    )

## Model sampling

The sampling pipeline is pretty straight-forward: just call the `sample_models` method with your desired `n_models` amount of models to sample. The method returns a Python list of `nn.Module` instances which then can be used for sampling. Keep in mind that for ensembles of large models you'll likely need a decent amount of memory to store them.

In [None]:
model_samples = laplace.sample_models(n_models=10)
model_samples[0]