Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring Trainer, adds save_only_model arg and simplifying FSDP integration #27652

Merged
merged 13 commits into from
Nov 24, 2023

Conversation

pacman100
Copy link
Contributor

What does this PR do?

  1. Bumps up the minimum Accelerate version to 0.21.0
  2. Add save_only_model arg - This enables the feature request Add an option to decide whether to store the checkpoint and rng_state. #26706
  3. Simplifies a lot of logic in FSDP:
    a. Currently, FSDP-XLA logic is custom in Trainer and normal FSDP is using the Accelerate's integration. There were many zombie code snippets related to normal FSDP. Cleaned those.
    b. Made it easier to train with FSDP. When using FULL_STATE_DICT setting, it should now save the model in transformers format using the default safetensors sharded format. This reduces the burden on users to later load, shard and save in safetensors format.
    c. Should fix Fine tuning with examples/pytorch/language-modeling/run_clm.py on torch/XLA + FSDP produce abnormal models #27432 but don't have access to TPUs to test this.
    d. Fixes NotImplementedError: Cannot copy out of meta tensor; no data! #27166
    e. This is built upon the PR in Accelerate to simplify FSDP integration fsdp refactoring accelerate#2177. It should be merged first.

1. Refactor FSDP
2. Add `--save_only_model` option: When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state.
3. Bump up the minimum `accelerate` version to `0.21.0`
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 22, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a million! Looks great to me, just one tiny doc nit for consistency

src/transformers/training_args.py Outdated Show resolved Hide resolved
pacman100 and others added 2 commits November 22, 2023 19:18
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactor and simplification of the code. I have a few comments, please take a look.

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@@ -2462,12 +2396,49 @@ def _save_checkpoint(self, model, trial, metrics=None):
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))

if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
def _save_optimizer_and_scheduler(self, output_dir):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it would be great if we could reuse accelerate code here, but it seems that there is no directly corresponding code. I wonder if we could achieve that with a bit of refactoring on the accelerate side @muellerzr (potentially in a future PR).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into this the last time during migration to Accelerate backend as well and recently, I couldn't find a clean way of doing it on the Accelerate side. Maybe a discussion on this offline would help.

src/transformers/training_args.py Show resolved Hide resolved
pacman100 and others added 2 commits November 22, 2023 22:24
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the refactor - looks a lot cleaner!

Regarding the accelerate bump up we should make sure everything's compatible in the library with this change. @ydshieh I believe the CI images should have the version reflected in the setup.py and modifying that file triggers a run on everything - is that right? Is there anything else we should check before merging?

@@ -452,14 +459,18 @@ class TrainingArguments:
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
all-gathers.
- use_orig_params (`bool`, *optional*, defaults to `False`)
- use_orig_params (`bool`, *optional*, defaults to `True`)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this set to False by default before - is there an advantage to not having this enabled? Asking in case this introduces a degraded experience for some users

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, this is required now for simplifying the FSDP integration. Please find the explanation in the corresponding Accelerate PR: huggingface/accelerate#2177 (comment)

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Show resolved Hide resolved
src/transformers/trainer.py Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry - meant to hit approve. Happy to marge after @BenjaminBossan's approval and @ydshieh confirms checking the library versions.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. LGTM

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 23, 2023

Hi,

The (Circle)CI did run all the tests: for example, in torch_job, we can see

python -m pytest --junitxml=test-results/junit.xml -n 6 --max-worker-restart=0 --dist=loadfile --make-reports=tests_torch tests/benchmark tests/bettertransformer tests/deepspeed tests/extended tests/fixtures tests/fsdp tests/generation tests/models tests/optimization tests/peft_integration tests/quantization tests/sagemaker tests/test_backbone_common.py tests/test_configuration_common.py tests/test_configuration_utils.py tests/test_feature_extraction_common.py tests/test_feature_extraction_utils.py tests/test_image_processing_common.py tests/test_image_processing_utils.py tests/test_image_transforms.py tests/test_modeling_common.py tests/test_modeling_flax_common.py tests/test_modeling_flax_utils.py tests/test_modeling_tf_common.py tests/test_modeling_tf_utils.py tests/test_modeling_utils.py tests/test_pipeline_mixin.py tests/test_sequence_feature_extraction_common.py tests/test_tokenization_common.py tests/test_tokenization_utils.py tests/tokenization tests/tools tests/trainer tests/utils || true

The CI uses the latest main branch of accelerate (so newer than the latest released version)

"pip install -U --upgrade-strategy eager -e git+https://github.com/huggingface/accelerate@main#egg=accelerate"

Nothing extra to check.

@pacman100 pacman100 merged commit a761d6e into main Nov 24, 2023
22 checks passed
@pacman100 pacman100 deleted the smangrul/FSDP-refactor branch November 24, 2023 06:10
@welsh01
Copy link

welsh01 commented Nov 28, 2023

save_only_model is a nice feature indeed, but it does not work together with load_best_model_at_end (at least with deepspeed enabled), since the final model cannot be loaded from the checkpoint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
7 participants