Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AdamW optimizer #4050

Merged
merged 7 commits into from Dec 21, 2017
Merged

Add AdamW optimizer #4050

merged 7 commits into from Dec 21, 2017

Conversation

tkerola
Copy link
Contributor

@tkerola tkerola commented Dec 6, 2017

This PR implements AdamW, which was proposed in the following paper: https://openreview.net/forum?id=rk6qdGgCZ

As shown in the paper, the current way that weight decay is implemented in Chainer does not work properly with Adam. AdamW is a modified version of Adam that properly handles weight decay,
and was shown to improve results.

While the original paper calls the algorithm AdamW, I call it AdamWeightDecay in this implementation, since I thought that it better spells out the purpose of the algorithm. Please let me know if you think the name should be changed to AdamW.

Note that this modification of Adam is theoretically applicable to AMSGrad as well (#4032).

@hvy
Copy link
Member

hvy commented Dec 6, 2017

Thank you for this PR! It looks very interesting. I have only skimmed through the paper but do you think it would be possible to reproduce (i.e. make plots) the experiments in the paper for CIFAR-10?

@hvy hvy added the cat:feature Implementation that introduces new interfaces. label Dec 6, 2017
@tkerola
Copy link
Contributor Author

tkerola commented Dec 6, 2017

Sure, I will try and post the results here!

@tkerola
Copy link
Contributor Author

tkerola commented Dec 7, 2017

I modified the Chainer CIFAR10 example to compare SGDM, Adam and AdamW with VGG16 (so the experiment is different from the paper).
https://gist.github.com/tkerola/5f643a20c4dc3f2a2b9831bebff3af61

Loss
cifar10_sgdm_adam_adamw_loss

Accuracy
cifar10_sgdm_adam_adamw_accuracy

Like they say in the paper, AdamW seems to beat Adam in the latter half of the training in terms of accuracy (although the validation loss is higher, strange?) and becomes competitive with SGDM.
I think SGDM beats the adaptive methods since the hyperparameters are probably fine-tuned for this example, but like they say in the paper, AdamW has an end-result that is closer to SGDM, while using only default hyperparameters.

@niboshi
Copy link
Member

niboshi commented Dec 7, 2017

Thank you for PR!
We discussed internally about the implementation, and concluded that it would be simpler to add additional hyperparameters (e.g. eta and weight_decay) to Adam.
Do you think it is OK to implement that way?

@tkerola
Copy link
Contributor Author

tkerola commented Dec 7, 2017

Sure, that is a good idea, and much more maintainable! I will change the implementation in that manner.

@tkerola
Copy link
Contributor Author

tkerola commented Dec 7, 2017

I merged AdamW into Adam. I set _default_hyperparam.weight_decay_rate = 0 to keep it backwards-compatible.

@niboshi
Copy link
Member

niboshi commented Dec 8, 2017

Thank you for the fix!
LGTM about the implementation.

As for the documentation, users without knowledge of AdamW would think Chainer's implementation of Adam is different from original Adam. I think it's better to explicitly state it's an additional feature, mentioning eta and weight_decay_rate.

Also we should mention the name AdamW, so that users who knows about it will instantly understand what it means.

@tkerola
Copy link
Contributor Author

tkerola commented Dec 8, 2017

I updated the documentation. Did you have something like this in mind?

@kashif
Copy link
Contributor

kashif commented Dec 12, 2017

@tkerola can you kindly also test my branch to make the nice graphs? I am not on my linux box for a while and I cannot get it running on my mac's gpu for some reason...

@tkerola
Copy link
Contributor Author

tkerola commented Dec 13, 2017

Sure, I will test it with the same training script.

@tkerola
Copy link
Contributor Author

tkerola commented Dec 13, 2017

As requested by @kashif, I redid the experiment above with AMSGrad added, so all 4 methods are compared.

Loss:
compare4_loss

Accuracy:
compare4_accuracy

@kashif
Copy link
Contributor

kashif commented Dec 15, 2017

@tkerola also I wanted to ask, would it make sense to fix the weight decay in chainer using this method, rather than implementing new optimizers? I am also thinking along similar lines for the AMSGrad PR... what do you think?

@tkerola
Copy link
Contributor Author

tkerola commented Dec 17, 2017

Hmm, do you have any idea of how to implement that efficiently? WeightDecay would be needed to be modified to update param.data using a new variable param.prev_data (or something like that) instead of param.grad after calling param.update instead of before, like is being done now. This would not be backwards compatible and also require backprop to use twice the memory since we need to store the previous weights for all layers.
https://github.com/chainer/chainer/blob/master/chainer/optimizer.py#L693
Maybe by adding a "post-param-update" type of Optimizer extension would allow an implementation that only requires a constant amount of extra memory, so that call_hooks(). is called after param.update(): https://github.com/chainer/chainer/blob/master/chainer/optimizer.py#L594

@niboshi
Copy link
Member

niboshi commented Dec 21, 2017

jenkins, test this please

1 similar comment
@niboshi
Copy link
Member

niboshi commented Dec 21, 2017

jenkins, test this please

@niboshi
Copy link
Member

niboshi commented Dec 21, 2017

LGTM!

@niboshi niboshi merged commit 0a8059e into chainer:master Dec 21, 2017
@niboshi niboshi added this to the v4.0.0b3 milestone Dec 21, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:feature Implementation that introduces new interfaces.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants