# SWAG Laplace

In [1]:
import torch
import torch.nn as nn

from torchvision import transforms, datasets
from torchvision.models import resnet18
from torch.utils.data import DataLoader, random_split
from laplace.curvature.asdl import AsdlGGN
from laplace.marglik_training import marglik_training
from laplace.swag_laplace import SWAGLaplace

In [2]:
DATA_ROOT = './data'
BATCH_SIZE = 64
LIKELIHOOD = 'classification'
EPOCHS = 100
MARGLIK_FREQUENCY = 1
N_MODELS = 20

### Step 1: Prepare dataset.

In [3]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

full_dataset = datasets.CIFAR10(root=DATA_ROOT, train=True, download=True, transform=transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

### Step 2: Initialize and train the model.

In [None]:
model = resnet18(pretrained=False)

lap, trained_model, train_loss, val_loss = marglik_training(
    model=model,
    train_loader=train_loader,
    likelihood=LIKELIHOOD,
    n_epochs=EPOCHS,
    marglik_frequency=MARGLIK_FREQUENCY,
    hessian_structure='diag',
    backend=AsdlGGN,
    progress_bar=True
)

### Step 3: Initialize and train SWAG Laplace.

In [None]:
swag_laplace = SWAGLaplace(
    model=trained_model,
    likelihood=LIKELIHOOD,
    n_models=N_MODELS,
    start_epoch=0,
    swa_freq=1,
)
swag_laplace.fit(train_loader)

train_accuracy = swag_laplace.evaluate(train_loader)
val_accuracy = swag_laplace.evaluate(val_loader)
test_inputs, _ = next(iter(val_loader))
predictions, uncertainties = swag_laplace(test_inputs, pred_type='glm')

print(f'Train Accuracy: {train_accuracy:.2f}%')
print(f'Validation Accuracy: {val_accuracy:.2f}%')
print(f'Predictions shape: {predictions.shape}')
print(f'Uncertainties shape: {uncertainties.shape}')