Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT / Optim: Add GaLore optimizer #29588

Merged
merged 44 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b31ce79
add galore v1
younesbelkada Mar 11, 2024
58169f1
add import
younesbelkada Mar 11, 2024
9032635
add tests and doc
younesbelkada Mar 11, 2024
136f104
fix doctest
younesbelkada Mar 11, 2024
a5483b3
forward contrib credits from discussions
Mar 11, 2024
887d3ad
forward contrib credits from discussions
Mar 11, 2024
d6f119f
Apply suggestions from code review
younesbelkada Mar 11, 2024
3fae229
Merge remote-tracking branch 'upstream/main' into HEAD
younesbelkada Mar 11, 2024
c8c50f8
fix failing tests'
younesbelkada Mar 11, 2024
2bdda68
Merge remote-tracking branch 'upstream/main' into add-galore-optimizer
younesbelkada Mar 13, 2024
630bd13
switch to `optim_target_modules` and clarify docs
younesbelkada Mar 13, 2024
a871b75
more clarification
younesbelkada Mar 13, 2024
cb6cd7e
Merge remote-tracking branch 'upstream/main' into add-galore-optimizer
younesbelkada Mar 13, 2024
51b7b29
enhance lookup logic
younesbelkada Mar 13, 2024
3da3b90
update a test to add peak memory
younesbelkada Mar 13, 2024
9115c94
add regex, all-linear and single string support
younesbelkada Mar 13, 2024
0b4ba83
add layer-wise optimization through DummyOptimizers and LRSchedulers
younesbelkada Mar 13, 2024
3e5930e
forward contrib credits from discussions and original idea
hiyouga Mar 13, 2024
a16d3a8
add a section about DDP not supported in layerwise
younesbelkada Mar 13, 2024
29e7e94
Update src/transformers/trainer.py
younesbelkada Mar 13, 2024
18ea144
fix self
younesbelkada Mar 13, 2024
7800bf1
check only if layer_wise
younesbelkada Mar 13, 2024
e022bdd
Update src/transformers/training_args.py
younesbelkada Mar 14, 2024
830c68d
oops
younesbelkada Mar 14, 2024
b640e98
make use of intervals
younesbelkada Mar 14, 2024
14a89b2
clarify comment
younesbelkada Mar 14, 2024
6f7102d
add matching tests
younesbelkada Mar 14, 2024
c11cb63
GaLoRe -> GaLore
younesbelkada Mar 14, 2024
3678201
move to `get_scheduler`
younesbelkada Mar 14, 2024
fdc4b2a
add note on docs
younesbelkada Mar 14, 2024
e7ce9b7
add a warning
younesbelkada Mar 14, 2024
91d6436
adapt a bit the docs
younesbelkada Mar 15, 2024
b9e338a
update docstring
younesbelkada Mar 15, 2024
6ff3762
support original API
younesbelkada Mar 17, 2024
0d0440a
Update docs/source/en/trainer.md
younesbelkada Mar 17, 2024
832f2be
slightly refactor
younesbelkada Mar 18, 2024
898a3c5
Update docs/source/en/trainer.md
younesbelkada Mar 18, 2024
ed3ad4a
Update src/transformers/training_args.py
younesbelkada Mar 19, 2024
57e7096
fix args parsing and add tests
younesbelkada Mar 19, 2024
64ccfa6
remove warning for regex
younesbelkada Mar 19, 2024
4413f07
Merge remote-tracking branch 'upstream/main' into add-galore-optimizer
younesbelkada Mar 19, 2024
73dcabb
fix type hint
younesbelkada Mar 19, 2024
1987b7a
add note about extra args
younesbelkada Mar 19, 2024
db2bf21
make `is_regex` return optional
younesbelkada Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,98 @@ trainer = Trainer(..., args=training_args)

NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.

## GaLore

Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.

First make sure to install GaLore official repository:

```bash
pip install galore-torch
```

Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"]
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```

You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507).

Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner.

Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

You can also perform layer-wise optimization by post-pending the optimizer name with `layerwise` like below:

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw_layerwise",
optim_target_modules=["attn", "mlp"]
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```

Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice note :)


## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau

from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
from .trainer_utils import SchedulerType
from .utils import logging
from .utils.versions import require_version
Expand Down Expand Up @@ -362,6 +363,32 @@ def get_scheduler(
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]

# If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
# recursively call `get_scheduler` to get the proper schedulers on each parameter
if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict = {}

for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
name,
optimizer=optimizer_dict[param],
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)

def scheduler_hook(param):
# Since the optimizer hook has been already attached we only need to
# attach the scheduler hook
if param.grad is not None:
scheduler_dict[param].step()

for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)

return LayerWiseDummyScheduler()

if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)

Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
Expand Down Expand Up @@ -324,6 +325,14 @@ def require_bs4(test_case):
return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)


def require_galore_torch(test_case):
"""
Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
https://github.com/jiaweizzhao/GaLore
"""
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
141 changes: 139 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LayerWiseDummyOptimizer,
LengthGroupedSampler,
SequentialDistributedSampler,
distributed_broadcast_scalars,
Expand Down Expand Up @@ -111,6 +112,7 @@
RemoveColumnsCollator,
TrainerMemoryTracker,
TrainOutput,
check_target_module_exists,
default_compute_objective,
denumpify_detensorize,
enable_full_determinism,
Expand Down Expand Up @@ -140,6 +142,7 @@
is_apex_available,
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
is_in_notebook,
is_ipex_available,
is_peft_available,
Expand Down Expand Up @@ -1009,7 +1012,17 @@ def create_optimizer(self):
},
]

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)

# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")

# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")

self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
Expand All @@ -1032,7 +1045,9 @@ def create_optimizer(self):
return self.optimizer

@staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
def get_optimizer_cls_and_kwargs(
args: TrainingArguments, model: Optional[PreTrainedModel] = None
) -> Tuple[Any, Any]:
"""
Returns the optimizer class and optimizer parameters based on the training arguments.

Expand Down Expand Up @@ -1170,6 +1185,128 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_cls = torch.optim.Adagrad
elif args.optim == OptimizerNames.RMSPROP:
optimizer_cls = torch.optim.RMSprop
elif args.optim in [
OptimizerNames.GALORE_ADAMW,
OptimizerNames.GALORE_ADAMW_8BIT,
OptimizerNames.GALORE_ADAFACTOR,
OptimizerNames.GALORE_ADAMW_LAYERWISE,
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
]:
if not is_galore_torch_available():
raise ImportError(
"You need to install `galore_torch` in order to use GaLore optimizers"
" install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
)
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an import check here, no? 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah indeed, many things can be optimized, for now it's a really rough draft, will focus on polishing everything next!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, ping me when you're ready for a full review :)


is_layerwise = args.optim.lower().endswith("layerwise")
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")

optimizer_mapping = {
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
}

optimizer_cls = optimizer_mapping[args.optim]

if args.optim_target_modules is None:
raise ValueError(
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
)

if not isinstance(args.optim_target_modules, (list, str)):
raise ValueError(
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
)

if model is None:
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")

logger.warning(
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
)

all_linear = (
isinstance(args.optim_target_modules, str)
and args.optim_target_modules.replace("_", "-") == "all-linear"
)

galore_params = []
galore_params_names = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be more efficient as a set but shouldn't matter much in the grand scheme of things.

for module_name, module in model.named_modules():
if not isinstance(module, nn.Linear):
Copy link
Member

@BenjaminBossan BenjaminBossan Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should raise an error if the target module name matches but the layer type is not Linear. Let's assume that a user matches N linear layers and accidentally 1 other type like Embedding, currently the embedding layer would be ignored but the user doesn't get any error message or warning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like a good idea!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! I think we should pass a warning to be able to pass simple regex such as .*.attn.*. - lmk wdyt !

# Warn in case we match but it's not a linear layer
if check_target_module_exists(args.optim_target_modules, module_name):
logger.warning(
f"{module_name} has been matched but ignored as GaLore only supports linear layers. If you passed a regex `.*.attn.*` this is expected, otherwise please double check your `optim_target_modules`!"
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
)

continue

if not check_target_module_exists(args.optim_target_modules, module_name) and not all_linear:
continue

galore_params.append(module.weight)
galore_params_names.append(module_name + ".weight")

if len(galore_params) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada But it is not allowed for pure layered_adamw optimizers, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm in that case users should just pass adamw instead of galore_adamw_layerwise I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh I think i got your point, we could indeed extend the optimizers and enable layer-wise optimizations ! This can be done in a scope of another follow up PR !

Copy link
Contributor Author

@younesbelkada younesbelkada Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @muellerzr @amyeroberts @BenjaminBossan @pacman100
The idea here would be to leverage optim_target_modules and think of a more general API to enable layer-wise optimization within Trainer, using the approach presented here with DummyOptimizer / DummyLRScheduler and post gradient hooks. I propose to do that in a separate PR but I can also do that here if you think it makes more sense to introduce both GaLoRe + per-layer optimization for all optimizers in the same PR
I think it's wiser to do it in a separate PR as layer-wise optimization might not be supported OTB for many scenarios such as DS / DDP etc;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada Yeah, that sounds good for me. I guess I can comment out this check locally and wait for your pr.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it makes more sense to a) have generic API for layerwise optimization using the approach presented here b) have GaLoRE c) use generic API to instantiate layerwise GaLoRE; otherwise, after you add a generic API (which it seems like you already have a lot of code to do that, I don't see anything that's GaLoRE-specific), you would have to go back and do another PR just to refactor the layerwise GaLoRE?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiddyboots216 I think it's fine to add Galore first. We're just merging into the dev branch atm, so there aren't guarantees about API stability until we have a release. Adding it in a more general sense is more involved and will require more tests / hitting possible blockers. In this order, we can release the feature without being blocked by the development of the more general API.

raise ValueError(
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
)

non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]

galore_optim_kwargs = {
"rank": optim_args.pop("rank", 128),
"update_proj_gap": optim_args.pop("update_proj_gap", 200),
"scale": optim_args.pop("scale", 0.25),
"proj_type": optim_args.pop("proj_type", "std"),
}
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved

# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
param_groups = [
{"params": non_galore_params},
{"params": galore_params, **galore_optim_kwargs},
]

if is_layerwise:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about layerwise optimizers without galore? Can we have layerwise_adamw_8bit?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, we proposed the same thing at almost the same time! I think if this layerwise thing works then we should have an option for all optimizers, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great timing indeed !

# For layer-wise optimizers, the optimization step is done through post accumulation
# gradient hooks. The trick is to first attach these hooks to the model parameters then
# create a dummy optimizer that will perform no-ops in the Trainer.
# See the original implementation or the nice implementation from @hiyouga
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
if args.gradient_accumulation_steps != 1:
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")

optimizer_dict = {}
for param in non_galore_params:
param_groups = [{"params": [param]}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
for param in galore_params:
param_groups = [{"params": [param], **galore_optim_kwargs}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)

def optimizer_hook(param):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()

for param in model.parameters():
param.register_post_accumulate_grad_hook(optimizer_hook)

optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this looks super hacky. At first glance, it looks okay, but I could imagine that some assumptions will fail in certain edge cases. Not sure what can really be done, except to try to cover as many use cases as possible (lr scheduler, checkpointing, gradient clipping, ...) in tests. Those tests would have to be a bit more involved than just testing that there is no error in training. AFAICT, the referenced repo does not perform these tests.

I guess this would be a lot of extra work, so feel free to ignore my comment and then we just cross our fingers that it works :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree ! For the other solution I guess we have to change the training logic as you need to somehow guard optimizer.step() and call it only if you don't have hooks attached :/
@hiyouga has perhaps more experience with this approach, as you say it's probably fine to keep it like this as layer-wise optimization is experimental and we can iterate in the future to see if it's stable with other features (thinking of DS, etc). I'll add a note on the docs about it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what @BenjaminBossan was concerned about is right. We have had some discussions with the authors of GaLore, and the current utilization of hooks is just a temporary solution. Actually, GaLore does not rely on hooks to save memory. We can expect that the PyTorch team will provide more flexible APIs to support GaLore in DDP or DS environments in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thanks for explaining ! so I think we should just keep in mind this is temporary and we might adopt our approach according to how things will move in the future


optimizer_kwargs.update({"params": param_groups})

if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
Loading
Loading