## Using stochastic weight averaging

Stochastic weight averaging with gaussian modelling (SWAG) is a method to approximate posteriors by sampling
the weights of a neural network from a gaussian distribution that is fitted to samples from the stochastic
gradient descent iterates.

We implement this in `baal` as an optimiser, since the samples are taken during the optimisation steps.

This optimiser class is in the `swag` submodule. It accepts all parameters as the standard SGD optimiser
in pytorch, and in addition accepts three parameters to determine the SWAG behaviour:

- `swa_burn_in`, or how many steps of SGD optimisation are taken before SWAG samples are collected
- `swa_steps`, or how many steps (i.e. mini-batches) to take in between SWAG samples
- `n_deviations`, or how many of the most recent deviations to use to fit the gaussian

In [None]:
from baal import swag

optimiser = swag.StochasticWeightAveraging(
    standard_model.parameters(),
    lr=1e-3,
    weight_decay=1e-5,
    swa_burn_in=100,
    swa_steps=20,
    n_deviations=20,
)

When you use this optimiser to fit your neural network, the optimiser will collect the mean and
variation of each of the network's weights every 20 SGD steps, following the first 100 SGD steps.

_Usually_, you should set `swa_burn_in` to upwards of tens of epochs, and swa_steps to on the
order of one epoch. Note that the optimiser is unaware of the length of an epoch, and so you
need to specify the amounts in terms of SGD steps, or mini-batches.

For example:

In [None]:
from baal import swag
import torch
import torch.utils.data

# the data shape:
batch_size = 8
dataset_size = 20 * 8
# create some dummy data:
x = torch.randn(dataset_size, 10)
y = torch.randint(low=0, high=2, size=(dataset_size,)).long()
# write loaders for these:
dummy_dataset = torch.utils.data.TensorDataset(x, y)
dummy_loader = torch.utils.data.DataLoader(dummy_dataset, batch_size=batch_size)
# write a simple model:
model = torch.nn.Sequential(
    torch.nn.Linear(10, 5),
    torch.nn.ReLU(),
    torch.nn.Dropout(p=0.5),
    torch.nn.Linear(5, 2),
)
criterion = torch.nn.NLLLoss()
# create the SWAG optimiser:
optimiser = swag.StochasticWeightAveraging(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-5,
    # burn in for 50 epochs:
    swa_burn_in=50*len(dummy_loader),
    # then collect samples every epoch:
    swa_steps=len(dummy_loader),
    n_deviations=20,
)

We can then train our model as normal:

In [None]:
for epoch in range(75):

    for x, y in dummy_loader:

        loss = criterion(model(x), y)
        loss.backward()
        optimiser.step()

In order to obtain uncertainty estimates for our predictions, we need to "sample" models
multiple times and make predictions with each model separately.

For example, if we want to sample from the approximate posterior 100 times we do:

In [None]:
eval_batch = torch.randn(8, 10)

model.eval()

predictions = []

with torch.no_grad():
    for n in range(1000):
        optimiser.sample()
        prediction = model(eval_batch)
        predictions.append(prediction)

predictions = torch.stack(predictions, dim=-1)

Again, this can be simplified by using the `swag.ModelWrapper` class:

In [None]:
model_wrapper = swag.SwagModelWrapper(
    model,
    criterion
)

with torch.no_grad():
    predictions = model_wrapper.predict_on_batch(
        eval_batch,
        optimiser,
        iterations=1000
    )

We now have 1000 predictions for every data point in our mini-batch:

In [None]:
predictions.shape

And we can visualise the posterior:

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots()
ax.hist(predictions[0, 0, :].numpy(), bins=50);
plt.show()