# SWAG-Laplace Example

In [1]:
import glob
import os
import pickle
import torch
import torch.nn as nn

from datetime import datetime
from torchvision import transforms, datasets
from torchvision.models import resnet18
from torch.utils.data import DataLoader, random_split
from laplace.curvature.asdl import AsdlGGN
from laplace.marglik_training import marglik_training
from laplace.swag_laplace import SWAGLaplace

In [2]:
DATA_ROOT = './data'
BATCH_SIZE = 128
LIKELIHOOD = 'classification'
EPOCHS = 1
MARGLIK_FREQUENCY = 1
N_MODELS = 20

# Create a directory to save models if it doesn't exist
save_dir = './models'
os.makedirs(save_dir, exist_ok=True)

# Add this to the beginning of your notebook after imports
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


### Step 1: Prepare dataset.

In [3]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

full_dataset = datasets.CIFAR10(root=DATA_ROOT, train=True, download=True, transform=transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

for inputs, targets in train_loader:
    inputs, targets = inputs.to(device), targets.to(device)

### Step 2: Initialize and train the model.

In [4]:
model = resnet18(weights='IMAGENET1K_V1')

# Replace the final layer with one that outputs 10 classes for CIFAR-10
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)

model = model.to(device)

# lap, trained_model, train_loss, val_loss = marglik_training(
#     model=model,
#     train_loader=train_loader,
#     likelihood=LIKELIHOOD,
#     n_epochs=EPOCHS,
#     marglik_frequency=MARGLIK_FREQUENCY,
#     hessian_structure='diag',
#     backend=AsdlGGN,
#     progress_bar=True,
#     device=device
# )

# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# model_path = f"{save_dir}/marglik_trained_model_{timestamp}.pkl"

# with open(model_path, 'wb') as f:
#     pickle.dump({
#         'model_state': trained_model.state_dict(),
#         'training_params': {
#             'epochs': EPOCHS,
#             'likelihood': LIKELIHOOD,
#             'marglik_frequency': MARGLIK_FREQUENCY
#         }
#     }, f)

# print(f"Marglik trained model saved to {model_path}")

### Step 3: Initialize and train SWAG Laplace.

In [5]:
swag_laplace = SWAGLaplace(
    model=model,
    likelihood=LIKELIHOOD,
    n_models=N_MODELS,
    start_epoch=0,
    swa_freq=1,
    device=device,
)

criterion = nn.CrossEntropyLoss().to(device)

swag_laplace.fit(
    train_loader=train_loader,
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
    criterion=criterion,
    epochs=EPOCHS,
    progress_bar=True,
)

train_accuracy = swag_laplace.evaluate(train_loader)
val_accuracy = swag_laplace.evaluate(val_loader)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = f"{save_dir}/swag_laplace_model_{timestamp}.pkl"

with open(model_path, 'wb') as f:
    pickle.dump({
        'swag_laplace_state': {
            'model_state': model.state_dict(),
            'swag_mean': swag_laplace.swag_mean,
            'swag_covariance': swag_laplace.swag_covariance,
        },
        'training_params': {
            'epochs': EPOCHS,
            'likelihood': LIKELIHOOD,
            'n_models': N_MODELS,
            'accuracies': {
                'train': train_accuracy,
                'validation': val_accuracy
            }
        }
    }, f)

print(f"SWAG-Laplace model saved to {model_path}")

Epochs: 100%|██████████| 1/1 [02:29<00:00, 149.91s/it, loss=1.0837]


SWAG-Laplace model saved to ./models/swag_laplace_model_20250603_231150.pkl


### Step 4: Get predictions using SWAG-Laplace.

In [6]:
saved_models = glob.glob(f"{save_dir}/swag_laplace_model_*.pkl")

if not saved_models:
    print("No saved models found!")
else:
    model_path = max(saved_models, key=os.path.getctime)  # Get the most recent file
    print(f"Loading model from: {model_path}")

    # Load the saved model
    with open(model_path, 'rb') as f:
        saved_data = pickle.load(f)

    # Create a new instance of your model
    loaded_model = resnet18(weights='IMAGENET1K_V1')
    num_features = loaded_model.fc.in_features
    loaded_model.fc = nn.Linear(num_features, 10)
    loaded_model.load_state_dict(saved_data['swag_laplace_state']['model_state'])
    loaded_model = loaded_model.to(device)

    # Create a new SWAGLaplace instance
    loaded_swag_laplace = SWAGLaplace(
        model=loaded_model,
        likelihood=saved_data['training_params']['likelihood'],
        n_models=saved_data['training_params']['n_models'],
        device=device,
    )

    # Restore SWAG statistics
    loaded_swag_laplace.swag_mean = saved_data['swag_laplace_state']['swag_mean']
    loaded_swag_laplace.swag_covariance = saved_data['swag_laplace_state']['swag_covariance']

    # Get predictions using loaded model
    test_inputs, test_targets = next(iter(val_loader))
    test_inputs = test_inputs.to(device)

    # Retrieve stored accuracies
    train_accuracy = saved_data['training_params']['accuracies']['train']
    val_accuracy = saved_data['training_params']['accuracies']['validation']
    
    print(f'Loaded model train accuracy: {train_accuracy:.2f}%')
    print(f'Loaded model validation accuracy: {val_accuracy:.2f}%')
    
    # Optional: Verify the loaded model's performance
    print("Evaluating loaded model on validation data...")
    loaded_val_accuracy = loaded_swag_laplace.evaluate(val_loader)
    print(f'Re-evaluated validation accuracy: {loaded_val_accuracy:.2f}%')

Loading model from: ./models/swag_laplace_model_20250603_231150.pkl
Loaded model train accuracy: 69.66%
Loaded model validation accuracy: 68.62%
Evaluating loaded model on validation data...
Re-evaluated validation accuracy: 69.07%
