# Push SWAG

## Introduction

In this notebook, we run SWAG using PusH. SWAG stands for Stochastic Weight Averaging Gaussian, and it is used to build a distribution of parameters for a pre trained network by averaging parameter values over a set number of swag epochs. We begin by training a standard neural network for some number of epochs, then we start tracking and calculating the first and second moments of our model's parameters. The first moment is a running average of our parameters, updated after each swag epoch. The second moment is simply the first moment squared. After training for swag epochs, we use the first and second moments to sample parameter states.

To get an inference result, we choose to sample say 20 parameter states, and average the predictions from these 20. The first and second moment defines our distribution to sample from, and can be a simple gaussian.

In [1]:
import experiments.nns.simplenet.simplenet
from experiments.nns.simplenet.simplenet import SimpleNet
import torch.optim as optim
import torch.nn as nn


model = SimpleNet(num_classes=10, in_chans= 1, scale=1, network_idx=1, mode=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

### Dataset

In [2]:
import torch
import os
from torch.utils.data import DataLoader
from torchvision import datasets, transforms



notebook_directory = os.path.dirname(os.path.abspath("deep_ensemble_mnist.ipynb"))
# Navigate to the parent folder (assuming "usr" and "home" are at the same level)
parent_directory = os.path.abspath(os.path.join(notebook_directory, "..","..","..","..","..","..",".."))

# Construct the path to the ImageNet directory
mnist_directory = os.path.abspath(os.path.join(parent_directory, "/usr/data1/vision/data/"))

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load the MNIST training dataset
train_dataset = datasets.MNIST(root=mnist_directory, train=True, download=False, transform=transform)

# Load the MNIST test dataset
test_dataset = datasets.MNIST(root=mnist_directory, train=False, download=False, transform=transform)

# Save the subset indices inside mnist_directory
subset_save_path = os.path.join(mnist_directory, "subset_indices.pth")

# Load the subset indices from mnist_directory
loaded_subset_indices = torch.load(subset_save_path)

# Create the subset using the loaded indices
loaded_train_subset = torch.utils.data.Subset(train_dataset, loaded_subset_indices)

In [3]:
# Create data loaders
batch_size = 128
train_loader = DataLoader(loaded_train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [4]:
import torch
import push.bayes.swag

num_ensembles = 2
swag_epochs = 20
pretrain_epochs = 20
lr = 1e-3
loss_fn = torch.nn.CrossEntropyLoss()

four_particle_mswag = push.bayes.swag.train_mswag(
    train_loader,
    loss_fn,
    pretrain_epochs,
    swag_epochs,
    SimpleNet, 10, 1, 1, 1, 2,
    num_devices=2,
    num_models = 4,
    lr = lr
)


  5%|▌         | 1/20 [00:05<01:52,  5.94s/it]

Average epoch loss tensor(2.2297)


 10%|█         | 2/20 [00:12<01:49,  6.08s/it]

Average epoch loss tensor(1.8378)


 15%|█▌        | 3/20 [00:18<01:46,  6.24s/it]

Average epoch loss tensor(1.5450)


 20%|██        | 4/20 [00:25<01:43,  6.46s/it]

Average epoch loss tensor(1.3046)


 25%|██▌       | 5/20 [00:32<01:38,  6.60s/it]

Average epoch loss tensor(1.1109)


 30%|███       | 6/20 [00:38<01:32,  6.61s/it]

Average epoch loss tensor(0.9448)


 35%|███▌      | 7/20 [00:45<01:27,  6.71s/it]

Average epoch loss tensor(0.8101)


 40%|████      | 8/20 [00:52<01:21,  6.77s/it]

Average epoch loss tensor(0.6985)


 45%|████▌     | 9/20 [00:59<01:15,  6.86s/it]

Average epoch loss tensor(0.6001)


 50%|█████     | 10/20 [01:06<01:08,  6.88s/it]

Average epoch loss tensor(0.5144)


 55%|█████▌    | 11/20 [01:13<01:01,  6.83s/it]

Average epoch loss tensor(0.4404)


 60%|██████    | 12/20 [01:20<00:54,  6.79s/it]

Average epoch loss tensor(0.3828)


 65%|██████▌   | 13/20 [01:26<00:47,  6.74s/it]

Average epoch loss tensor(0.3311)


 70%|███████   | 14/20 [01:33<00:40,  6.74s/it]

Average epoch loss tensor(0.2864)


 75%|███████▌  | 15/20 [01:39<00:33,  6.69s/it]

Average epoch loss tensor(0.2515)


 80%|████████  | 16/20 [01:46<00:26,  6.73s/it]

Average epoch loss tensor(0.2186)


 85%|████████▌ | 17/20 [01:53<00:20,  6.77s/it]

Average epoch loss tensor(0.1929)


 90%|█████████ | 18/20 [02:00<00:13,  6.81s/it]

Average epoch loss tensor(0.1730)


 95%|█████████▌| 19/20 [02:07<00:06,  6.80s/it]

Average epoch loss tensor(0.1532)


100%|██████████| 20/20 [02:14<00:00,  6.71s/it]


Average epoch loss tensor(0.1344)


100%|██████████| 20/20 [02:30<00:00,  7.54s/it]


In [5]:
median_four_model_outputs = four_particle_mswag.posterior_pred(test_loader, loss_fn, num_samples=2, mode="median", f_reg=False)
mode_four_model_outputs = four_particle_mswag.posterior_pred(test_loader, loss_fn, num_samples=2, mode="mode", f_reg=False)
mean_four_model_outputs = four_particle_mswag.posterior_pred(test_loader, loss_fn, num_samples=2, mode="mean", f_reg=False)



In [6]:
print("median_four_model_outputs: ", median_four_model_outputs)
print("mean_four_model_outputs: ", mean_four_model_outputs)
print("mode_four_model_outputs: ", mode_four_model_outputs)

median_four_model_outputs:  tensor([7, 2, 1,  ..., 4, 5, 6])
mean_four_model_outputs:  tensor([7, 2, 1,  ..., 4, 5, 6])
mode_four_model_outputs:  tensor([7, 2, 1,  ..., 4, 5, 6])
