-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT / Optim: Add GaLore optimizer #29588
Changes from 22 commits
b31ce79
58169f1
9032635
136f104
a5483b3
887d3ad
d6f119f
3fae229
c8c50f8
2bdda68
630bd13
a871b75
cb6cd7e
51b7b29
3da3b90
9115c94
0b4ba83
3e5930e
a16d3a8
29e7e94
18ea144
7800bf1
e022bdd
830c68d
b640e98
14a89b2
6f7102d
c11cb63
3678201
fdc4b2a
e7ce9b7
91d6436
b9e338a
6ff3762
0d0440a
832f2be
898a3c5
ed3ad4a
57e7096
64ccfa6
4413f07
73dcabb
1987b7a
db2bf21
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -252,6 +252,98 @@ trainer = Trainer(..., args=training_args) | |
|
||
NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior. | ||
|
||
## GaLore | ||
|
||
Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA. | ||
|
||
First make sure to install GaLore official repository: | ||
|
||
```bash | ||
pip install git+https://github.com/jiaweizzhao/GaLore | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
|
||
Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which should be a list of strings, corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Let's add here, or further below, that |
||
|
||
```python | ||
import torch | ||
import datasets | ||
import trl | ||
|
||
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM | ||
|
||
train_dataset = datasets.load_dataset('imdb', split='train') | ||
|
||
args = TrainingArguments( | ||
output_dir="./test-galore", | ||
max_steps=100, | ||
per_device_train_batch_size=2, | ||
optim="galore_adamw", | ||
optim_target_modules=["attn", "mlp"] | ||
) | ||
|
||
model_id = "google/gemma-2b" | ||
|
||
config = AutoConfig.from_pretrained(model_id) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
model = AutoModelForCausalLM.from_config(config).to(0) | ||
|
||
trainer = trl.SFTTrainer( | ||
model=model, | ||
args=args, | ||
train_dataset=train_dataset, | ||
dataset_text_field='text', | ||
max_seq_length=512, | ||
) | ||
|
||
trainer.train() | ||
``` | ||
|
||
You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507). | ||
|
||
Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner. | ||
|
||
Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards. | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
You can also perform layer-wise optimization by post-pending the optimizer name with `layerwise` like below: | ||
|
||
```python | ||
import torch | ||
import datasets | ||
import trl | ||
|
||
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM | ||
|
||
train_dataset = datasets.load_dataset('imdb', split='train') | ||
|
||
args = TrainingArguments( | ||
output_dir="./test-galore", | ||
max_steps=100, | ||
per_device_train_batch_size=2, | ||
optim="galore_adamw_layerwise", | ||
optim_target_modules=["attn", "mlp"] | ||
) | ||
|
||
model_id = "google/gemma-2b" | ||
|
||
config = AutoConfig.from_pretrained(model_id) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
model = AutoModelForCausalLM.from_config(config).to(0) | ||
|
||
trainer = trl.SFTTrainer( | ||
model=model, | ||
args=args, | ||
train_dataset=train_dataset, | ||
dataset_text_field='text', | ||
max_seq_length=512, | ||
) | ||
|
||
trainer.train() | ||
``` | ||
|
||
Note layerwise optimization does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. | ||
|
||
## Accelerate and Trainer | ||
|
||
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/). | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,6 +83,8 @@ | |
DistributedTensorGatherer, | ||
IterableDatasetShard, | ||
LabelSmoother, | ||
LayerWiseDummyOptimizer, | ||
LayerWiseDummyScheduler, | ||
LengthGroupedSampler, | ||
SequentialDistributedSampler, | ||
distributed_broadcast_scalars, | ||
|
@@ -111,6 +113,7 @@ | |
RemoveColumnsCollator, | ||
TrainerMemoryTracker, | ||
TrainOutput, | ||
check_target_module_exists, | ||
default_compute_objective, | ||
denumpify_detensorize, | ||
enable_full_determinism, | ||
|
@@ -140,6 +143,7 @@ | |
is_apex_available, | ||
is_bitsandbytes_available, | ||
is_datasets_available, | ||
is_galore_torch_available, | ||
is_in_notebook, | ||
is_ipex_available, | ||
is_peft_available, | ||
|
@@ -1009,7 +1013,17 @@ def create_optimizer(self): | |
}, | ||
] | ||
|
||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) | ||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model) | ||
|
||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` | ||
# e.g. for GaLoRe optimizer. | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if "params" in optimizer_kwargs: | ||
optimizer_grouped_parameters = optimizer_kwargs.pop("params") | ||
|
||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` | ||
# to avoid arguments conflicts. | ||
if "optimizer_dict" in optimizer_kwargs: | ||
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") | ||
|
||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | ||
if optimizer_cls.__name__ == "Adam8bit": | ||
|
@@ -1032,7 +1046,9 @@ def create_optimizer(self): | |
return self.optimizer | ||
|
||
@staticmethod | ||
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: | ||
def get_optimizer_cls_and_kwargs( | ||
args: TrainingArguments, model: Optional[PreTrainedModel] = None | ||
) -> Tuple[Any, Any]: | ||
""" | ||
Returns the optimizer class and optimizer parameters based on the training arguments. | ||
|
||
|
@@ -1170,6 +1186,122 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: | |
optimizer_cls = torch.optim.Adagrad | ||
elif args.optim == OptimizerNames.RMSPROP: | ||
optimizer_cls = torch.optim.RMSprop | ||
elif args.optim in [ | ||
OptimizerNames.GALORE_ADAMW, | ||
OptimizerNames.GALORE_ADAMW_8BIT, | ||
OptimizerNames.GALORE_ADAFACTOR, | ||
OptimizerNames.GALORE_ADAMW_LAYERWISE, | ||
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE, | ||
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE, | ||
]: | ||
if not is_galore_torch_available(): | ||
raise ImportError( | ||
"You need to install `galore_torch` in order to use GaLore optimizers" | ||
" install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" | ||
) | ||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need an import check here, no? 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah indeed, many things can be optimized, for now it's a really rough draft, will focus on polishing everything next! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, ping me when you're ready for a full review :) |
||
|
||
is_layerwise = args.optim.lower().endswith("layerwise") | ||
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED: | ||
raise NotImplementedError("GaLore does not support DDP at this time") | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
optimizer_mapping = { | ||
OptimizerNames.GALORE_ADAMW: GaLoreAdamW, | ||
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit, | ||
OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor, | ||
OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW, | ||
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit, | ||
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor, | ||
} | ||
|
||
optimizer_cls = optimizer_mapping[args.optim] | ||
|
||
if args.optim_target_modules is None: | ||
raise ValueError( | ||
"You need to define a `optim_target_modules` in order to properly use GaLoRe optimizers" | ||
) | ||
|
||
if not isinstance(args.optim_target_modules, (list, str)): | ||
raise ValueError( | ||
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}" | ||
) | ||
|
||
if model is None: | ||
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") | ||
|
||
logger.warning( | ||
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !" | ||
) | ||
|
||
all_linear = ( | ||
isinstance(args.optim_target_modules, str) | ||
and args.optim_target_modules.replace("_", "-") == "all-linear" | ||
) | ||
|
||
galore_params = [] | ||
galore_params_names = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be more efficient as a |
||
for module_name, module in model.named_modules(): | ||
if not isinstance(module, nn.Linear): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should raise an error if the target module name matches but the layer type is not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds like a good idea! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! I think we should pass a warning to be able to pass simple regex such as |
||
continue | ||
|
||
if not check_target_module_exists(args.optim_target_modules, module_name) and not all_linear: | ||
continue | ||
|
||
galore_params.append(module.weight) | ||
galore_params_names.append(module_name + ".weight") | ||
|
||
if len(galore_params) == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @younesbelkada But it is not allowed for pure layered_adamw optimizers, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm in that case users should just pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ahh I think i got your point, we could indeed extend the optimizers and enable layer-wise optimizations ! This can be done in a scope of another follow up PR ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @muellerzr @amyeroberts @BenjaminBossan @pacman100 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @younesbelkada Yeah, that sounds good for me. I guess I can comment out this check locally and wait for your pr. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it makes more sense to a) have generic API for layerwise optimization using the approach presented here b) have GaLoRE c) use generic API to instantiate layerwise GaLoRE; otherwise, after you add a generic API (which it seems like you already have a lot of code to do that, I don't see anything that's GaLoRE-specific), you would have to go back and do another PR just to refactor the layerwise GaLoRE? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kiddyboots216 I think it's fine to add Galore first. We're just merging into the dev branch atm, so there aren't guarantees about API stability until we have a release. Adding it in a more general sense is more involved and will require more tests / hitting possible blockers. In this order, we can release the feature without being blocked by the development of the more general API. |
||
raise ValueError( | ||
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`." | ||
) | ||
|
||
non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names] | ||
|
||
galore_optim_kwargs = { | ||
"rank": optim_args.pop("rank", 128), | ||
"update_proj_gap": optim_args.pop("update_proj_gap", 200), | ||
"scale": optim_args.pop("scale", 0.25), | ||
"proj_type": optim_args.pop("proj_type", "std"), | ||
} | ||
amyeroberts marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore | ||
param_groups = [ | ||
{"params": non_galore_params}, | ||
{"params": galore_params, **galore_optim_kwargs}, | ||
] | ||
|
||
if is_layerwise: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about layerwise optimizers without galore? Can we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hah, we proposed the same thing at almost the same time! I think if this layerwise thing works then we should have an option for all optimizers, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great timing indeed ! |
||
# For layer-wise optimizers, the optimization step is done through post accumulation | ||
# gradient hooks. The trick is to first attach these hooks to the model parameters then | ||
# create a dummy optimizer that will perform no-ops in the Trainer. | ||
# See the original implementation or the nice implementation from @hiyouga | ||
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba | ||
if args.gradient_accumulation_steps != 1: | ||
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !") | ||
|
||
optimizer_dict = {} | ||
for param in non_galore_params: | ||
param_groups = [{"params": [param]}] | ||
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) | ||
for param in galore_params: | ||
param_groups = [{"params": [param], **galore_optim_kwargs}] | ||
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) | ||
|
||
def optimizer_hook(param): | ||
if param.grad is not None: | ||
optimizer_dict[param].step() | ||
optimizer_dict[param].zero_grad() | ||
|
||
for param in model.parameters(): | ||
param.register_post_accumulate_grad_hook(optimizer_hook) | ||
|
||
optimizer_cls = LayerWiseDummyOptimizer | ||
optimizer_kwargs.update({"optimizer_dict": optimizer_dict}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of this looks super hacky. At first glance, it looks okay, but I could imagine that some assumptions will fail in certain edge cases. Not sure what can really be done, except to try to cover as many use cases as possible (lr scheduler, checkpointing, gradient clipping, ...) in tests. Those tests would have to be a bit more involved than just testing that there is no error in training. AFAICT, the referenced repo does not perform these tests. I guess this would be a lot of extra work, so feel free to ignore my comment and then we just cross our fingers that it works :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I agree ! For the other solution I guess we have to change the training logic as you need to somehow guard There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think what @BenjaminBossan was concerned about is right. We have had some discussions with the authors of GaLore, and the current utilization of hooks is just a temporary solution. Actually, GaLore does not rely on hooks to save memory. We can expect that the PyTorch team will provide more flexible APIs to support GaLore in DDP or DS environments in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see thanks for explaining ! so I think we should just keep in mind this is temporary and we might adopt our approach according to how things will move in the future |
||
|
||
optimizer_kwargs.update({"params": param_groups}) | ||
|
||
if args.optim == OptimizerNames.GALORE_ADAFACTOR: | ||
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) | ||
else: | ||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") | ||
return optimizer_cls, optimizer_kwargs | ||
|
@@ -1182,6 +1314,30 @@ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optim | |
Args: | ||
num_training_steps (int): The number of training steps to do. | ||
""" | ||
if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer): | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
optimizer_dict = optimizer.optimizer_dict | ||
scheduler_dict = {} | ||
|
||
for param in optimizer_dict.keys(): | ||
scheduler_dict[param] = get_scheduler( | ||
self.args.lr_scheduler_type, | ||
optimizer=optimizer_dict[param], | ||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps) * 2, | ||
num_training_steps=num_training_steps * 2, | ||
) | ||
|
||
def scheduler_hook(param): | ||
# Since the optimizer hook has been already attached we only need to | ||
# attach the scheduler hook | ||
if param.grad is not None: | ||
scheduler_dict[param].step() | ||
|
||
for param in optimizer_dict.keys(): | ||
param.register_post_accumulate_grad_hook(scheduler_hook) | ||
|
||
self._created_lr_scheduler = True | ||
self.lr_scheduler = LayerWiseDummyScheduler() | ||
|
||
if self.lr_scheduler is None: | ||
self.lr_scheduler = get_scheduler( | ||
self.args.lr_scheduler_type, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GaLore has released an official package:
pip install galore-torch
https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#install-galore-optimizer