Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT / Optim: Add GaLore optimizer #29588

Merged
merged 44 commits into from
Mar 19, 2024

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Mar 11, 2024

What does this PR do?

As per title, adds the GaLore optimizer from https://github.com/jiaweizzhao/GaLore

Fixes: #29512

This is how I am currently testing the API:

import torch
import datasets
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
    output_dir="./test-galore",
    max_steps=100,
    per_device_train_batch_size=2,
    optim="galore_adamw",
    optim_target_modules=["attn", "mlp"]
)

model_id = "mistralai/Mistral-7B-v0.1"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
    model=model, 
    args=args,
    train_dataset=train_dataset,
    dataset_text_field='text',
    max_seq_length=512,
)

trainer.train()

cc @pacman100 @muellerzr

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@younesbelkada
Copy link
Contributor Author

see OpenAccess-AI-Collective/axolotl#1370 (comment) for intermediate results

OptimizerNames.GALORE_ADAMW_8BIT,
OptimizerNames.GALORE_ADAFACTOR,
]:
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an import check here, no? 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah indeed, many things can be optimized, for now it's a really rough draft, will focus on polishing everything next!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, ping me when you're ready for a full review :)

@younesbelkada younesbelkada marked this pull request as ready for review March 11, 2024 16:01
@younesbelkada younesbelkada changed the title DRAFT / Optim: Add GaLore optimizer FEAT / Optim: Add GaLore optimizer Mar 11, 2024
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall this looks good to me, nice! Excited to see this in transformers :)

Left some suggestions.

src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty cool work, thanks for adding this Younes. After this, we should also add this to PEFT ;)

@@ -696,6 +699,11 @@ class TrainingArguments:
for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
`PeftModel` from peft.
galore_target_modules (`List[str]`, *optional*):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be cool if we could use the same mechanism as for target_modules in PEFT:

  • allow str for regex match
  • if list of str, not only match exact names, but also if the module name ends with the target module name
  • passing all-linear to match all linear layers

But I understand that we don't want to copy the whole mechanism to transformers. Maybe we can think of a way to factor this code out in the future, so that we can re-use it in multiple places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be now supported !

src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
docs/source/en/trainer.md Show resolved Hide resolved

galore_params = []
for module_name, module in model.named_modules():
if not isinstance(module, nn.Linear):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should raise an error if the target module name matches but the layer type is not Linear. Let's assume that a user matches N linear layers and accidentally 1 other type like Embedding, currently the embedding layer would be ignored but the user doesn't get any error message or warning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like a good idea!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! I think we should pass a warning to be able to pass simple regex such as .*.attn.*. - lmk wdyt !

tests/trainer/test_trainer.py Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Just a comment on the training args

docs/source/en/trainer.md Outdated Show resolved Hide resolved
@PenutChen
Copy link

Is it possible to integrate per-layer weight updates as described in the galore paper?

image

It seems to reduce memory usage significantly:

image

But the original implementation is somewhat complex.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @younesbelkada for adding GaLore 🔥! Others have added most of the comments that I agree with.

docs/source/en/trainer.md Show resolved Hide resolved
tests/trainer/test_trainer.py Show resolved Hide resolved
@hiyouga
Copy link
Contributor

hiyouga commented Mar 13, 2024

Is it possible to integrate per-layer weight updates as described in the galore paper?

image

It seems to reduce memory usage significantly:

image

But the original implementation is somewhat complex.

Sure, per-layer weight update is crucial to the current implementation of GaLore (may not be in the future). Without that, it would require an additional 14GB GRAM for the gradients for a 7B model, making full-parameter fine-tuning infeasible with a 24GB GPU.

You may refer to our discussion jiaweizzhao/GaLore#6 for more empirical results.

First make sure to install GaLore official repository:

```bash
pip install git+https://github.com/jiaweizzhao/GaLore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GaLore has released an official package: pip install galore-torch

https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#install-galore-optimizer

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work adding this!

Just some small nits. It would be great if you could address the warning for the attention layers before merge

pip install galore-torch
```

Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regew or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regew or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):
Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arghf, sadly it has been taken care by 898a3c5 already

trainer.train()
```

Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice note :)

@@ -1354,6 +1366,13 @@ class TrainingArguments:
},
)

optim_target_modules: Optional[List[str]] = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
optim_target_modules: Optional[List[str]] = field(
optim_target_modules: Optional[Union[str, List[str]]] = field(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI that suggestion broke the CI for some reason; 73dcabb fixed it I think Optional and Union are not compatible somehow

src/transformers/trainer.py Outdated Show resolved Hide resolved
@winglian
Copy link
Contributor

Nice catch @winglian ! 6ff3762 should fix it, can you try now?

It progresses to start training now, but I've been getting an odd error and I'm not sure if it's a transformers, axolotl, or galore_torch issue.

  File "/workspace/transformers/src/transformers/trainer.py", line 1774, in train
    return inner_training_loop(
  File "/workspace/transformers/src/transformers/trainer.py", line 2170, in _inner_training_loop
    self.optimizer.step()
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/optimizer.py", line 145, in step
    self.optimizer.step(closure)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 68, in wrapper
    return wrapped(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/optim/optimizer.py", line 373, in wrapper
    out = func(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/galore_torch/adamw.py", line 96, in step
    grad = state["projector"].project(grad, state["step"])
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/galore_torch/galore_projector.py", line 17, in project
    self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/galore_torch/galore_projector.py", line 87, in get_orthogonal_matrix
    A = U[:, :rank] @ torch.diag(s[:rank])
TypeError: slice indices must be integers or None or have an __index__ method

@winglian
Copy link
Contributor

So setting optim_args doesn't work and results in that error. Am I setting it incorrectly?

args = TrainingArguments(
    output_dir="./test-galore",
    max_steps=100,
    per_device_train_batch_size=1,
    optim="galore_adamw",
    optim_args="rank=128",
    optim_target_modules=["attn", "mlp"]
)

I believe this is the correct way according to

optim_args = {}
if args.optim_args:
for mapping in args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value

@winglian
Copy link
Contributor

@younesbelkada the bug is in the optim_args string splitting. it casts the value to a string, and galore expects an int. I expect it's going to have to need some fancy checks to determine if it's an int type or float type, since rank expects an int (doesn't work with float), but scale is a float. An easier solution might be to have TrainingArguments.optim_args be a Union[str, Dict[str, Any]] instead.

younesbelkada and others added 2 commits March 19, 2024 09:21
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@younesbelkada
Copy link
Contributor Author

younesbelkada commented Mar 19, 2024

Nice catch @winglian , in 57e7096 i should properly take care of everything and I can confirm it works fine. The correct way of passing galore args is the following:

            args = TrainingArguments(
                tmpdir,
                learning_rate=1e-9,
                logging_steps=5,
                optim="galore_adamw",
                optim_args="rank=64, update_proj_gap=100, scale=0.10",
                optim_target_modules=[r".*attn.*", r".*mlp.*"],
            )

@younesbelkada younesbelkada merged commit f6261d7 into huggingface:main Mar 19, 2024
22 checks passed
@fakerybakery
Copy link

Hi,
According to the GaLore README:

We are working on the offical release of GaLore. In the meanwhile, please feel free to try the pre-release version and provide feedback to us. Currently, the pre-release version (e.g., GaLore optimizers) should provide a decent memory reduction and accurate simulation of GaLore algorithm.

When the official version of GaLore is released, will it be integrated into Transformers?

@NicolasMejiaPetit
Copy link

Is there anyway to get this optimizer fully offloaded to the cpu? Similar to how deepspeed does it with a large 32bit adam optimizer. That way the entire system can be working together with little interference through the pcei bus (compared to going into shared memory). Optimizing gpu memory to the max, allowing the gpu to handle the highest batch sizes possible, while the cpu deals with the optimizer.

@kiddyboots216
Copy link

Is there anyway to get this optimizer fully offloaded to the cpu? Similar to how deepspeed does it with a large 32bit adam optimizer. That way the entire system can be working together with little interference through the pcei bus (compared to going into shared memory). Optimizing gpu memory to the max, allowing the gpu to handle the highest batch sizes possible, while the cpu deals with the optimizer.

If you use ZeRO stage 2/3 (offload all optimizer state to CPU) as done in deepspeed/FSDP then you will not have any VRAM usage from optimizer state. There's not much point in using GaLoRE if you're already offloading the optimizer state to the cpu, because GaLoRE only reduces the optimizer state. The layerwise backprop isn't part of GaLoRE, it can be done by any optimizer.

@NicolasMejiaPetit
Copy link

NicolasMejiaPetit commented Mar 19, 2024

Is there anyway to get this optimizer fully offloaded to the cpu? Similar to how deepspeed does it with a large 32bit adam optimizer. That way the entire system can be working together with little interference through the pcei bus (compared to going into shared memory). Optimizing gpu memory to the max, allowing the gpu to handle the highest batch sizes possible, while the cpu deals with the optimizer.

If you use ZeRO stage 2/3 (offload all optimizer state to CPU) as done in deepspeed/FSDP then you will not have any VRAM usage from optimizer state. There's not much point in using GaLoRE if you're already offloading the optimizer state to the cpu, because GaLoRE only reduces the optimizer state. The layerwise backprop isn't part of GaLoRE, it can be done by any optimizer.

Wait you can do this without having to use adam 32bit cpu? So I could do like adam 8bit fully offloaded to cpu? Last I tried this it didn’t work but I must’ve not done it the way you are saying. cause I’m trying to get the cpu to do the work, not just hold the memory in system ram, (the way paged adam works).

Apologies for this being off topic. To make up for it, I started testing GaLoRE with Unsloth/hf trainer (on windows), I can confirm it works, and i’m getting great results. The loss is what I would normally get, with the added bonus that Its letting me do a batch size of 8 where previously with paged adam 8bit I was at batch size of 4 training a 7b model.

@kiddyboots216
Copy link

kiddyboots216 commented Mar 19, 2024

Is there anyway to get this optimizer fully offloaded to the cpu? Similar to how deepspeed does it with a large 32bit adam optimizer. That way the entire system can be working together with little interference through the pcei bus (compared to going into shared memory). Optimizing gpu memory to the max, allowing the gpu to handle the highest batch sizes possible, while the cpu deals with the optimizer.

If you use ZeRO stage 2/3 (offload all optimizer state to CPU) as done in deepspeed/FSDP then you will not have any VRAM usage from optimizer state. There's not much point in using GaLoRE if you're already offloading the optimizer state to the cpu, because GaLoRE only reduces the optimizer state. The layerwise backprop isn't part of GaLoRE, it can be done by any optimizer.

Wait you can do this without having to use adam 32bit cpu? So I could do like adam 8bit fully offloaded to cpu? Last I tried this it didn’t work but I must’ve not done it the way you are saying. cause I’m trying to get the cpu to do the work, not just hold the memory in system ram, (the way paged adam works).

Could you try doing this the way it's referenced here TimDettmers/bitsandbytes#89 (comment) I'm not sure why it wouldn't work

(sorry to go off-topic on a PR with many people and commits on it)

itazap pushed a commit that referenced this pull request May 14, 2024
* add galore v1

* add import

* add tests and doc

* fix doctest

* forward contrib credits from discussions

* forward contrib credits from discussions

* Apply suggestions from code review

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* fix failing tests'

* switch to `optim_target_modules` and clarify docs

* more clarification

* enhance lookup logic

* update a test to add peak memory

* add regex, all-linear and single string support

* add layer-wise optimization through DummyOptimizers and LRSchedulers

* forward contrib credits from discussions and original idea

* add a section about DDP not supported in layerwise

* Update src/transformers/trainer.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* fix self

* check only if layer_wise

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* oops

* make use of intervals

* clarify comment

* add matching tests

* GaLoRe -> GaLore

* move to `get_scheduler`

* add note on docs

* add a warning

* adapt a bit the docs

* update docstring

* support original API

* Update docs/source/en/trainer.md

* slightly refactor

* Update docs/source/en/trainer.md

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix args parsing and add tests

* remove warning for regex

* fix type hint

* add note about extra args

* make `is_regex` return optional

---------

Co-authored-by: Maxime <maximegmd @users.noreply.github.com>
Co-authored-by: Wing Lian <winglian @users.noreply.github.com>
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Co-authored-by: hiyouga <hiyouga@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

feat: GaLore support