Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Stochastic Weight Averaging (SWA) #159

Open
kiudee opened this issue Mar 21, 2018 · 3 comments
Open

Adding Stochastic Weight Averaging (SWA) #159

kiudee opened this issue Mar 21, 2018 · 3 comments

Comments

@kiudee
Copy link
Contributor

kiudee commented Mar 21, 2018

Since SWA was successful for Leela Zero in producing stronger network weights (see leela-zero/leela-zero#814, leela-zero/leela-zero#1030), I want to record this as a possible improvement here.

What is Stochastic Weight Averaging?

Izmailov et al. (2018) discovered that SGD explores regions of the weight space where networks with good performance lie, but does not reach the central point. By tracking a running average of the mean weights, they were able to find better weights than those found by SGD alone.
They also demonstrate that SWA leads to solutions in wider optima, which is conjectured to be important for generalization.

Here is a comparison of SWA and SGD with a ResNet-110 on CIFAR-100:
screenshot-2018-3-21 1803 05407 pdf 3

Implementation

The implementation is trivially easy, because the only thing we need to do is to update a running average of the weights in addition to the current weight vector.
Since we use batch normalization, we also need to calculate the running means and standard deviations for the resulting network.

The algorithm can be seen here:
screenshot-2018-3-21 1803 05407 pdf 1

The authors recommend starting with a pretrained model, before starting to average the weights. This we get for free, since we always initialize with the last best network.

@Error323
Copy link
Collaborator

This makes a lot of sense to me. Nice! Let's do it!

@Error323
Copy link
Collaborator

Let's do it once our trainingwindow contains solely V2 chunks. Then I can train the same net in parallel.

@remdu
Copy link

remdu commented Mar 22, 2018

The paper says you may need to run the resulting net on the training set in order to get batchnorm weights too. In LZ it seems like it's not strictly necessary though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants