# Laplace

Adapted from [the laplace docs](https://aleximmer.github.io/Laplace/#full-example-post-hoc-laplace-on-a-large-image-classifier), for the MNIST dataset

## Simple MNIST dataset and model

In [None]:
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from torchvision import datasets, transforms

In [None]:
# define a flatten transform
flatten = transforms.Lambda(lambda x: x.view(-1))

# Load the MNIST dataset
train_dataset = datasets.MNIST(
    "data",
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), flatten]),
)
test_dataset = datasets.MNIST(
    "data",
    train=False,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), flatten]),
)

# Define the model
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.Dropout(0.2),
    nn.ReLU(),
    nn.Linear(256, 64),
    nn.Dropout(0.2),
    nn.ReLU(),
    nn.Linear(64, 10),
)

In [None]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [None]:
# Train the model
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

batch_size = 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

n_epochs = 5
losses = []
eval_losses = []
eval_accuracies = []
for epoch in range(n_epochs):
    model.train()
    progress_bar = tqdm(train_loader, total=len(train_loader))
    for imgs, labels in progress_bar:
        imgs = imgs.view(imgs.shape[0], -1)
        optimizer.zero_grad()
        output = model(imgs)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        rolling_loss = torch.tensor(losses[-100:]).mean()
        progress_bar.set_description(
            f"Epoch {epoch+1}/{n_epochs}, loss: {rolling_loss.item():.4f}"
        )

    # Evaluate the model
    model.eval()
    progress_bar = tqdm(test_loader, total=len(test_loader))
    for imgs, labels in progress_bar:
        output = model(imgs)
        loss = loss_fn(output, labels)
        acc = accuracy(output, labels)
        eval_losses.append(loss.item())
        eval_accuracies.append(acc.item())
        rolling_loss = torch.tensor(eval_losses[-100:]).mean()
        rolling_acc = torch.tensor(eval_accuracies[-100:]).mean()
        progress_bar.set_description(
            f"Epoch {epoch+1}/{n_epochs}, loss: {rolling_loss.item():.4f}, acc: {rolling_acc.item():.4f}"
        )

## Laplace


In [None]:
from laplace import Laplace

la = Laplace(model, "classification", subset_of_weights="all", hessian_structure="kron")
la.fit(train_loader)
la.optimize_prior_precision(method="marglik")

In [None]:
input_data = test_dataset.data.view(-1, 784).float()[0:25]

In [None]:
from tqdm.auto import tqdm
import numpy as np

In [None]:
input_data = input_data = test_dataset.data.view(-1, 784).float()[0:20]
samples = la.predictive_samples(input_data, n_samples=100)

In [None]:
sample_mean = samples.mean(axis=0)
sample_std = samples.std(axis=0)

In [None]:
most_uncertain = np.argsort(sample_std.mean(axis=1))[-5:]
most_uncertain

In [None]:
# print the mean and std of the most uncertain samples
for i in most_uncertain:
    print(
        f"Sample {i}: predicted class {sample_mean[i].argmax()} with a mean prob {sample_mean[i].max()} std {sample_std[i].max()}"
    )

In [None]:
example = 0
for i in range(10):
    print(
        f"Class {i}: {sample_mean[example, i].mean():.2f} ± {sample_std[example, i].mean():.2f}"
    )