Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AdaFactor optimizer from fairseq #6722

Merged
merged 17 commits into from
Aug 27, 2020

Conversation

moscow25
Copy link
Contributor

@moscow25 moscow25 commented Aug 25, 2020

Tested for T5 finetuning and MLM -- reduced memory consumption compared to ADAM.

Fixes #1256

… MLM -- reduced memory consumption compared to ADAM.
@moscow25
Copy link
Contributor Author

Hey @sshleifer -- here is belated PR for AdaFactor. Please let me know how to edit this properly, and what tests or examples we should add. Thanks!

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is gunna be awesome!

Want to add a test similar to test_adamw here ?

Also I can take over whenever!

@sshleifer
Copy link
Contributor

We will integrate into examples/ in a separate PR I think.

@moscow25
Copy link
Contributor Author

Thanks @sshleifer -- let me try to make those changes.

Agree that I should be able to add a single test -- appreciate the link -- and you can add examples in separate PR.

If I don't get this figure out soon, yes happy for you to make the changes yourself :-)

@moscow25
Copy link
Contributor Author

Hey @sshleifer -- think I got a test working finally. We can squash the commits.

Still not sure what I need to clean up for the code standards/linter.

Please advise, thanks!

@sshleifer
Copy link
Contributor

For local style checking, you need: pip install isort --upgrade
Then make style and make quality to both suggest you have no errors.
They should autofix things or at least give error messages. My workflow is to define

sty () {
	make style
	flake8 examples templates tests src utils
}

and then run sty a lot.

@sshleifer sshleifer changed the title AdaFactor optimizer ported from fairseq. Add AdaFactor optimizer from fairseq Aug 25, 2020
@sshleifer
Copy link
Contributor

Also squashing happens automatically at merge time, don't worry about that.

@moscow25
Copy link
Contributor Author

For local style checking, you need: pip install isort --upgrade
Then make style and make quality to both suggest you have no errors.
They should autofix things or at least give error messages. My workflow is to define

sty () {
	make style
	flake8 examples templates tests src utils
}

and then run sty a lot.

Hmm. Is there a way for style to tell me the location in offending file? Output seems pretty minimal.

@sshleifer
Copy link
Contributor

if you also run the flake8 command it should just fix it.

@moscow25
Copy link
Contributor Author

I think I fixed the formatting, as requested. Took a sec to figure that all out...

src/transformers/optimization.py Outdated Show resolved Hide resolved
#
# Alternatively, relative_step with warmup_init can also be used.
# Training without LR warmup or clip threshold, is not recommended. Additional optimizer operations
# like gradient clipping, should not be used.
Copy link
Contributor

@sshleifer sshleifer Aug 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit)
This "second docstring" breaks style convention, I am OK to leave it here, because it is very useful, but would prefer to consolidate with the class docstring below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. It's up to you. Happy to move it, or if you want to consolidate the docstring in a future PR.

Let me try to make the change and see if you like it.

moscow25 and others added 2 commits August 26, 2020 13:25
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
@moscow25
Copy link
Contributor Author

@sshleifer -- any idea what happened with the black / code quality changes overnite? I'm very confused. Seems as if the standard changed from yesterday...

@sshleifer
Copy link
Contributor

sshleifer commented Aug 26, 2020

Yes they did, sorry about that. I did some cleanup on this branch.
If you are curious about the style change: I tried to future proof it here #6748

@codecov
Copy link

codecov bot commented Aug 26, 2020

Codecov Report

Merging #6722 into master will decrease coverage by 0.02%.
The diff coverage is 68.23%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6722      +/-   ##
==========================================
- Coverage   78.96%   78.94%   -0.03%     
==========================================
  Files         157      157              
  Lines       28486    28571      +85     
==========================================
+ Hits        22495    22555      +60     
- Misses       5991     6016      +25     
Impacted Files Coverage Δ
src/transformers/__init__.py 99.28% <ø> (ø)
src/transformers/optimization.py 82.28% <68.23%> (-13.27%) ⬇️
src/transformers/file_utils.py 82.41% <0.00%> (-0.26%) ⬇️
src/transformers/generation_tf_utils.py 83.70% <0.00%> (+0.75%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a75c64d...8958b9f. Read the comment docs.

@moscow25
Copy link
Contributor Author

Awesome. Thanks @sshleifer. I'll start working more on the other less mature PRs we discussed. And please ping me if/when you write tests or examples for this. Happy to contribute to that as well if you need.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks a lot! Cool test as well.

@LysandreJik LysandreJik merged commit 971d180 into huggingface:master Aug 27, 2020
@LysandreJik
Copy link
Member

I've added Adafactor to the docs and slightly changed the style of the docstrings in #6765

@sshleifer
Copy link
Contributor

Thanks! I'll add a --adafactor option lightning_base and trainer in 2 prs.

Zigur pushed a commit to Zigur/transformers that referenced this pull request Oct 26, 2020
* AdaFactor optimizer ported from fairseq. Tested for T5 finetuning and MLM -- reduced memory consumption compared to ADAM.

* update PR fixes, add basic test

* bug -- incorrect params in test

* bugfix -- import Adafactor into test

* bugfix -- removed accidental T5 include

* resetting T5 to master

* bugfix -- include Adafactor in __init__

* longer loop for adafactor test

* remove double error class declare

* lint

* black

* isort

* Update src/transformers/optimization.py

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>

* single docstring

* Cleanup docstring

Co-authored-by: Nikolai Y <nikolai.yakovenko@point72.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
* AdaFactor optimizer ported from fairseq. Tested for T5 finetuning and MLM -- reduced memory consumption compared to ADAM.

* update PR fixes, add basic test

* bug -- incorrect params in test

* bugfix -- import Adafactor into test

* bugfix -- removed accidental T5 include

* resetting T5 to master

* bugfix -- include Adafactor in __init__

* longer loop for adafactor test

* remove double error class declare

* lint

* black

* isort

* Update src/transformers/optimization.py

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>

* single docstring

* Cleanup docstring

Co-authored-by: Nikolai Y <nikolai.yakovenko@point72.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Could you please implement a Adafactor optimizer? :)
3 participants