In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
from torch.nn import functional as F
from torchvision.utils import make_grid
from torch.optim.swa_utils import SWALR

from swadist.data.loader import get_dataloaders
from swadist.utils import Trainer, show_imgs, LinearPolyLR
from swadist.models.resnet import ResNet

plt.rcParams["savefig.bbox"] = 'tight'
torch.multiprocessing.set_sharing_strategy('file_system')

torch.cuda.device_count()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device}')

# load CIFAR-10
train_loader, valid_loader = get_dataloaders(
    dataset="cifar10", num_workers=4, batch_size=16, cuda=device=='cuda', test=False
)

Using cpu
Files already downloaded and verified
Files already downloaded and verified


### Two-phase training: SGD + SWA

We train ResNet-8, using the optimal hyperparameters given in [Shallue et al. 2019](http://arxiv.org/abs/1811.03600) for phase 1 (SGD w/ Nesterov momentum). 

For the SWA phase, we follow [Izmailov et al.](http://arxiv.org/abs/1803.05407) and use a constant learning rate that is 10x smaller than the initial learning rate in the first phase.

In [2]:
# initial lr, scaling factor, momentum, SWA lr
lr0, alpha, gamma = 2**-8.5, 0.25, 0.97
swa_lr = lr0 / 10

# epochs, lr scaling epochs
epochs, decay_epochs, epochs_swa = 10, 8, 5

# model
resnet8 = ResNet(in_kernel_size=3, stack_sizes=[1, 1, 1], n_classes=10, batch_norm=False)

optimizer = torch.optim.SGD(resnet8.parameters(), lr=lr0, momentum=gamma, nesterov=True)
scheduler = LinearPolyLR(optimizer, alpha=alpha, decay_epochs=decay_epochs, verbose=True)
swa_scheduler = SWALR(optimizer, anneal_strategy='linear', anneal_epochs=0, swa_lr=swa_lr)

trainer = Trainer(resnet8, train_loader, valid_loader, F.cross_entropy, optimizer, scheduler=scheduler, swa_scheduler=swa_scheduler, log=True, name='swa')

# begin training
trainer(epochs=epochs, epochs_swa=epochs_swa, validations_per_epoch=4)

Adjusting learning rate of group 0 to 2.7621e-03.
Starting 15-epoch training loop...

Train epoch: 1 -- Accuracy: 0.290082 -- Avg. loss (cross_entropy): 1.872467 -- Batch: 2813/2813 (100%) -- Total steps: 2813
Validation accuracy: 0.389577 -- Avg. loss (cross_entropy): 1.656444 -- Batch: 313/313 (100%)

Adjusting learning rate of group 0 to 2.5550e-03.
Train epoch: 2 -- Accuracy: 0.440833 -- Avg. loss (cross_entropy): 1.517869 -- Batch: 2813/2813 (100%) -- Total steps: 5626
Validation accuracy: 0.515375 -- Avg. loss (cross_entropy): 1.380072 -- Batch: 313/313 (100%)

Adjusting learning rate of group 0 to 2.3478e-03.
Train epoch: 3 -- Accuracy: 0.578564 -- Avg. loss (cross_entropy): 1.175578 -- Batch: 2813/2813 (100%) -- Total steps: 8439
Validation accuracy: 0.613019 -- Avg. loss (cross_entropy): 1.102383 -- Batch: 313/313 (100%)

Adjusting learning rate of group 0 to 2.1407e-03.
Train epoch: 4 -- Accuracy: 0.638398 -- Avg. loss (cross_entropy): 1.020614 -- Batch: 2813/2813 (100%) -- T

In [3]:
# difference in validation accuracy between end of SWA and SGD phases
print(f'SWA validation accuracy gain: {trainer.valid_acc - trainer.valid_accs[epochs - 1]:.4f}')

SWA validation accuracy gain: 0.0160
