# Training a Laplace BNN

### Setup

Load imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from laplace import Laplace
from BNN.laplaceBNN import MLP, BayesianMLP, train_mlp, get_device


Set device

In [None]:
device = get_device()
print(f"Using device: {device}")

Load MNIST

In [None]:
# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)

### Training the model

Create and train base model

In [None]:
# Create and train base model
mlp = MLP().to(device)
mlp = train_mlp(mlp, train_loader, test_loader, device=device)

Create Bayesian version and fit LA

In [None]:
# Create Bayesian version and fit LA
bayes_mlp = BayesianMLP(mlp)
bayes_mlp.fit(train_loader)

### Testing the model

In [None]:
# Test predictions with uncertainty
x_test, y_test = next(iter(test_loader))
x_test = x_test.to(device)

# Get predictions with uncertainty
pred_probs = bayes_mlp.predict(x_test)
print("\nPredictive distribution shape:", pred_probs.shape)

# Move to CPU for printing
max_probs = pred_probs.max(dim=1)[0][:5].cpu()
print("Max probability:", max_probs)  # Show first 5 confidence scores