Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for bitsandbytes #15622

Merged
merged 18 commits into from
Apr 19, 2022
Merged

Conversation

manuelciosici
Copy link
Contributor

@manuelciosici manuelciosici commented Feb 11, 2022

What does this PR do?

Fixes #14819

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@stas00 @sgugger @TimDettmers

Status

  • Need to instrument CI to install bnb (binary package so a bit trickier than normal dependency)
  • I have implemented a CLI parameter to support bitsandbytes
  • I did not write any documentation yet
  • I followed @TimDettmers 's suggestion to override the embedding layers. However, I am unsure about a couple of things:
    • Does the override need to happen before the model is loaded onto the GPU as the official documentation describes for other overrides?
    • Are there any pitfalls to my current approach to identifying Embedding layers? It seems to work fine for RoBERTa and for GPT-2.
  • So far, I've used run_mlm.py and run_clm.py from the examples directory to check that the code runs. Using RTX A6000 GPUs, I see
Model visible devices optimizer per device batch size GPU memory
gpt2-large 0 adamw_torch 2 48638MiB / 49140MiB
gpt2-large 0 adamw_bnb 2 42412MiB / 49140MiB
gpt2-large 0,1 adamw_torch 1 30724MiB / 49140MiB
21040MiB / 49140MiB
gpt2-large 0,1 adamw_torch 2 OOM
gpt2-large 0,1 adamw_bnb 1 26820MiB / 49140MiB
21042MiB / 49140MiB
gpt2-large 0,1 adamw_bnb 2 44458MiB / 49140MiB
36906MiB / 49140MiB

@HuggingFaceDocBuilder
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this! I've left a couple of comments.

src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
tests/extended/test_trainer_ext.py Outdated Show resolved Hide resolved
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, @manuelciosici!

Let's add an actual test to it before merging this.

src/transformers/trainer.py Outdated Show resolved Hide resolved
tests/extended/test_trainer_ext.py Outdated Show resolved Hide resolved
tests/extended/test_trainer_ext.py Outdated Show resolved Hide resolved
@stas00
Copy link
Contributor

stas00 commented Feb 11, 2022

I followed @TimDettmers 's #14819 (comment) to override the embedding layers. However, I am unsure about a couple of things:

Does the override need to happen before the model is loaded onto the GPU as the official documentation describes for other overrides?

This requirement would be a problem in some use-cases e.g. with Deepspeed ZeRO-3 which pre-loads the model directly on GPU during from_pretrained. But given that Deepspeed already shards the optim states over multiple-gpus it doesn't really need BNB. So we might need to instrument a counter-indication of BNB X DS-ZeRO3.

I'm not sure about other cases where a model ends up on GPU - if I'm not mistaken DS is the only one, @sgugger ?

Are there any pitfalls to my current approach to identifying Embedding layers? It seems to work fine for RoBERTa and for GPT-2.

Commented here:
#15622 (comment)

@sgugger
Copy link
Collaborator

sgugger commented Feb 11, 2022

Normally, when creating the optimizer, the model has been moved to the proper device already, except in the following cases:

  • model parallelism (it has been split across devices already)
  • deepseed and fairscale
  • evaluation-only in fp16/bf16 full eval.

@stas00
Copy link
Contributor

stas00 commented Feb 11, 2022

Normally, when creating the optimizer, the model has been moved to the proper device already, except in the following cases:

* model parallelism (it has been split across devices already)

so this is not an exception, as it's not on cpu, we just don't do it in the Trainer, but the modeling code does.

* deepseed and fairscale

yes, except deepspeed zero3 where it's already moved to gpu - we just don't do it in trainer.

* evaluation-only in fp16/bf16 full eval.

Check - but it's irrelevant to the optimizer.

So to summarize Sylvain's list of exceptions - in the general case the model should be already on GPU.

So we need to wait for Tim to let us know if that's a problem or whether it has a work around.

@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Mar 13, 2022
@huggingface huggingface deleted a comment from github-actions bot Mar 13, 2022
@stas00
Copy link
Contributor

stas00 commented Mar 13, 2022

@TimDettmers, if you get a chance could you please address some of the questions to you so that this PR can be unblocked and BNB integration added to the HF Trainer? Thank you!

manuelciosici and others added 2 commits March 30, 2022 13:53
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
@TimDettmers
Copy link
Contributor

I followed @TimDettmers 's #14819 (comment) to override the embedding layers. However, I am unsure about a couple of things:
Does the override need to happen before the model is loaded onto the GPU as the official documentation describes for other overrides?

This requirement would be a problem in some use-cases e.g. with Deepspeed ZeRO-3 which pre-loads the model directly on GPU during from_pretrained. But given that Deepspeed already shards the optim states over multiple-gpus it doesn't really need BNB. So we might need to instrument a counter-indication of BNB X DS-ZeRO3.

I'm not sure about other cases where a model ends up on GPU - if I'm not mistaken DS is the only one, @sgugger ?

Are there any pitfalls to my current approach to identifying Embedding layers? It seems to work fine for RoBERTa and for GPT-2.

Commented here: #15622 (comment)

The new implementation of the override no longer depends on when the model is transferred to the GPU or when the override is registered. It takes the following signature:

GlobalOptimManager.get_instance().register_module_override(module, 'weight', {'optim_bits': 32})

where weight is the parameter name to override. In this case, one can use this on the embedding layer (skipping the positional embedding).

@TimDettmers
Copy link
Contributor

Normally, when creating the optimizer, the model has been moved to the proper device already, except in the following cases:

  • model parallelism (it has been split across devices already)
  • deepseed and fairscale
  • evaluation-only in fp16/bf16 full eval.

These issues should be resolved with the new parameter override which is independent of when the parameters are transferred to the device.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 15, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great now.

Thank you for working on this, @manuelciosici!

And thank you @TimDettmers for supporting the sorting out process!

Let's ask @sgugger to have another look before we merge this.

@stas00 stas00 requested a review from sgugger April 15, 2022 03:52
@stas00 stas00 marked this pull request as ready for review April 15, 2022 03:55
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work on this. It's almost ready to be merged, I just have a small request to replace everywhere is_bnb_available by is_bitsandbytes_available. Since we have a lot of thos eis_xxx_available and not all contributors might know this library, it will make it clearer to everyone what this is :-)

src/transformers/utils/import_utils.py Outdated Show resolved Hide resolved
tests/extended/test_trainer_ext.py Outdated Show resolved Hide resolved
tests/extended/test_trainer_ext.py Show resolved Hide resolved
@manuelciosici
Copy link
Contributor Author

@stas00 I caught up with your work on testing and with @sgugger's subsequent requests. Is there anything else I should do on this PR for it to be ready to merge?

@stas00
Copy link
Contributor

stas00 commented Apr 19, 2022

Just waiting for @sgugger to have one last look after moving the require_* decorators to testing_utils.py and I think this is good to be merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good, thanks again for all the work on this!

@sgugger sgugger merged commit 3104036 into huggingface:main Apr 19, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* Add initial BNB integration

* fixup! Add initial BNB integration

* Add bnb test decorator

* Update Adamw8bit option name

* Use the full bnb package name

* Overide bnb for all embedding layers

* Fix package name

* Formatting

* Remove unnecessary import

* Update src/transformers/trainer.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Rename AdamwBNB optimizer option

* Add training test checking that bnb memory utilization is lower

* fix merge

* fix merge; fix + extend new test

* cleanup

* expand bnb

* move all require_* candidates to testing_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
@younesbelkada
Copy link
Contributor

younesbelkada commented Aug 11, 2022

Hi there!
Jumping up on this PR after it has been merged
It appears that 2 tests of this PR are not passing:


     def test_bnb_adam8bit_no_bnb(self):
          args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
           # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
           # bnb will fail even if bnb is installed.
           with patch.dict("sys.modules", {"bnb.optim": None}):
              with self.assertRaises(ValueError):
                  Trainer.get_optimizer_cls_and_kwargs(args)

This test should be fixed in #18584 because of a very small typo, for the second test

test_run_seq2seq_bnb

I suspected it has been never run on our side since it is the only test that requires bitsandbytes on transformers have been always skipped because we used to not install bitsandbytes (check for eg this page tests/extended/test_trainer_ext.py::TestTrainerExt::test_run_seq2seq_bnb SKIPPED [ 21%] on our Docker image until the PR #17901
But I did not managed to reproduce the failing test (it is passing on my testing VM with the latest bitsandbytes + the one we use on the Docker image (bitsandbytes==0.31.5).

cc @TimDettmers @manuelciosici @stas00

@stas00
Copy link
Contributor

stas00 commented Aug 11, 2022

Thank you, @younesbelkada

@ydshieh, do you think it'd be OK to add bitsandbytes to the nightly tests workflow so that bnb tests run?

the installation is just cuda-version specific:

pip install bitsandbytes-cudaXXX

https://github.com/facebookresearch/bitsandbytes#requirements--installation

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 11, 2022

@stas00 I could add it and see how things go.

But @younesbelkada added it to the scheduled CI (which means to run on GPU) with

RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes==0.31.5

I am a bit confused by why there was no cuda there.

@younesbelkada
Copy link
Contributor

Thanks @stas00 and @ydshieh !
Adding the cuda version is not needed anymore and the facebook repo is not the one we have to refer to from now according to @TimDettmers

pip install bitsandbytes should be sufficient for now (I have to update the Dockerfile though)

@younesbelkada
Copy link
Contributor

Here is the repo we have to refer to: https://github.com/TimDettmers/bitsandbytes

@stas00
Copy link
Contributor

stas00 commented Aug 11, 2022

oh, ok, I missed that you already added it, nothing to do then.

@TimDettmers, would it be possible to archive the original repo and post a link to the new repo on top of its README, since otherwise users will have no idea to use the new repo instead. thank you!

@stas00
Copy link
Contributor

stas00 commented Aug 11, 2022

Also note that we are linking to the old repo:

examples/research_projects/robust-speech-event/README.md:[bitsandbytes](https://github.com/facebookresearch/bitsandbytes) to replace the
docs/source/en/perf_train_gpu_one.mdx:- 2 bytes * number of parameters for 8-bit AdamW optimizers like [bitsandbytes](https://github.com/facebookresearch/bitsandbytes)
docs/source/en/perf_train_gpu_one.mdx:On the other hand [8bit BNB optimizer](https://github.com/facebookresearch/bitsandbytes) can save 3/4 of memory normally used by a typical AdamW optimizer if it is configured to quantize all optimizer states, but in some situations only some optimizer states are quintized and then more memory is used. XXX: update once  https://github.com/huggingface/transformers/pull/15622 is merged.
docs/source/en/perf_train_gpu_one.mdx:In contrast to the previous approaches is this one not integrated into the [`Trainer`] as a simple flag. We need to install the 8-bit optimizer and then pass it as a custom optimizer to the [`Trainer`]. Follow the installation guide in the Github [repo](https://github.com/facebookresearch/bitsandbytes) to install the `bitsandbytes` library that implements the 8-bit Adam optimizer.

@TimDettmers, should we fix those to point to the new repo instead?

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 18, 2022

But I did not managed to reproduce the failing test (it is passing on my testing VM with the latest bitsandbytes + the one we use on the Docker image (bitsandbytes==0.31.5).

Hi @younesbelkada Are you running inside docker container on a VM similar to CI runners (Nivida T4)?
Could you try to get the values gpu_peak_mem_orig and gpu_peak_mem_bnb?

@younesbelkada
Copy link
Contributor

younesbelkada commented Aug 18, 2022

Hi @ydshieh ! I am running on a VM similar to CI runners, let me re try to reproduce as you suggested

@younesbelkada
Copy link
Contributor

younesbelkada commented Aug 18, 2022

The test is passing on my VM.. for the VM I get:

gpu_peak_mem_orig
243490816
gpu_peak_mem_bnb
8053248

But I re-ran the test with CUDA_VISIBLE_DEVICES=0 (running the test on a single GPU) and the test failed. Maybe the test was designed on a multi-GPU setup. Do you think that we should just add the decorator @require_torch_multigpu for this tests?

EDIT: I saw that even on multi-gpu the test was failing on the docker container

@younesbelkada
Copy link
Contributor

In a single GPU setup:

gpu_peak_mem_bnb
510833664
gpu_peak_mem_orig
509707264

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RFC: Integrating bitsandbytes 8-bit optimizer / adding Embedding Norm
8 participants