Skip to content

Fix DeepSpeed ZeRO-3 in PPOTrainer#730

Merged
lewtun merged 18 commits intomainfrom
fix-ppo-ds3
Sep 5, 2023
Merged

Fix DeepSpeed ZeRO-3 in PPOTrainer#730
lewtun merged 18 commits intomainfrom
fix-ppo-ds3

Conversation

@lewtun
Copy link
Copy Markdown
Member

@lewtun lewtun commented Sep 3, 2023

This PR adds ZeRO-3 support for the PPOTrainer by ensuring that the active and reference model weights are sharded in the same manner. I ran a few sentiment tuning tests with GPT-2 and find that the general trend of the mean reward is similar both with / without ZeRO-3 and the KL divergence is 0 at step 0 (as it should be):

Screenshot 2023-09-04 at 11 13 41

Screenshot 2023-09-04 at 11 13 46

I've also tested that this works with larger models like llama-2-7b and it does (modulo a very small diff in the KL divergence at step 0 which is likely tied to needing bfloat16).

There are probably a few more optimisations one can do with the DeepSpeed config, but this seems like a good start for now.

I've also added some accelerate configs so it's a bit easier for people to run the examples.

Closes #600

Script commands for testing

# Baseline - no DeepSpeed
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml examples/scripts/sentiment_tuning.py --batch_size 32 --mini_batch_size 32 --log
_with wandb

# ZeRO-{1,2,3}
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml examples/scripts/sentiment_tuning.py --batch_size 32 --mini_batch_size 32 --log_with wandb

TODO

  • Make sure we get parity without DeepSpeed on sentiment tuning
  • Validate with large models
  • Test it works with offloading
  • Decide whether setting model.train() should be unique to ZeRO-3 or not in train loop
  • Add accelerate configs

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Sep 3, 2023

The documentation is not available anymore as the PR was closed or merged.

Comment thread trl/trainer/ppo_trainer.py Outdated
Comment thread trl/trainer/ppo_trainer.py Outdated
all_masks = []
all_values = []

model.eval()
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm following the same logic of transformers.Trainer to put the model in eval model during inference - this is needed to ensure the KL divergence is 0 at step 0 with ZeRO-3

Comment thread trl/trainer/ppo_trainer.py Outdated
text.append(" ")
print(text)

def _prepare_deepspeed_zero3(self, model):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move this to separate utils function that can also be used for e.g. sharding the reward model? In that case, the function signature would be something like _prepare_deepspeed_zero3(model, accelerator)

Comment thread trl/trainer/ppo_trainer.py Outdated

def _prepare_deepspeed_zero3(self, model):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
# TODO: figure out if any other parameters are needed for inference
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kwargs below are a best guess for what's minimally needed - we can tune them later if needed IMO

@lewtun lewtun changed the title [WIP] Fix DeepSpeed ZeRO-3 in PPOTrainer Fix DeepSpeed ZeRO-3 in PPOTrainer Sep 4, 2023
@lewtun lewtun marked this pull request as ready for review September 4, 2023 09:16

# this hack seems to be needed for DS stage 3 to work
if self.accelerator.state.deepspeed_plugin.zero_stage == 3:
self.model.train()
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been moved to the training loop where I think it should be done for all models (including DeepSpeed ones)

train_stats (dict[str, `torch.Tensor`]):
Dictionary of training statistics
"""
self.model.train()
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto as above

Copy link
Copy Markdown
Member Author

@lewtun lewtun Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added these configs to make it easier for users to run DeepSpeed in various settings (and also for dev testing)

Comment thread examples/scripts/README.md Outdated
Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
Copy link
Copy Markdown
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! It would be nice to add some info to the docs as well and link to the configs. E.g. here would be a good place:

https://huggingface.co/docs/trl/customization

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! @lewtun
I see you have already added instructions on how to use the yaml files directly with accelerate, I agree with @lvwerra we can also add it in an appropriate section in the documentation, no big deal though we can also do it in a follow up PR (to unlock the DS-zero 3 to users)

- `sentiment_tuning.py`: This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset
- `multi_adapter_rl.py`: This script shows how to use the `PPOTrainer` to train a single base model with multiple adapters. This scripts requires you to run the example script with the reward model training beforehand.
- `stable_diffusion_tuning_example.py`: This script shows to use DDPOTrainer to fine-tune a stable diffusion model using reinforcement learning.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!!

@lewtun
Copy link
Copy Markdown
Member Author

lewtun commented Sep 5, 2023

Good idea about adding a section to the docs, done in 8a48d12

I'll merge once the CI is green

@lewtun lewtun merged commit c04074e into main Sep 5, 2023
@lewtun lewtun deleted the fix-ppo-ds3 branch September 5, 2023 09:00
kushal-tri pushed a commit to kushalarora/trl that referenced this pull request Sep 19, 2023
* Initialise ref model with ZeRO-3

* Fix deadlock

* Refactor & fix KL div

* Refactor

* Refactor

* Fix imports

* Add types

* Add accelerate configs

* Add more DeepSpeed configs

* Fix types

* Disable debug

* Refactor

* Add docs

* Disable eval mode for peft

* Restore eval mode

* Revert ref model prep for peft

* Update examples/scripts/README.md

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>

* Add docs

---------

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
@andrew-zm-ml
Copy link
Copy Markdown

andrew-zm-ml commented Sep 27, 2023

@lewtun With these changes I can now get past the old RuntimeError: 'weight' must be 2-D issue, but training fails shortly thereafter with the following assertion error:

AssertionError: {'id': 163, 'status': 'NOT_AVAILABLE', 'numel': 0, 'ds_numel': 0, 'shape': (0,), 'ds_shape': (0,), 'requires_grad': True, 'grad_shape': None, 'persist': True, 'active_sub_modules': {207}, 'ds_tensor.shape': torch.Size([0])}

This stems from this statement

assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()

in /deepspeed/runtime/zero/partitioned_param_coordinator.py.

The first forward pass through the regular model succeeds, but this error occurs when we try to run a forward pass through the reference model in ppo_trainer::step:

        with torch.no_grad():
            all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
                self.model,
                queries,
                responses,
                model_inputs,
                response_masks=response_masks,
                return_logits=full_kl_penalty,
            )
            # for when the model is a peft model
            if self.is_peft_model and hasattr(
                self.accelerator.unwrap_model(self.model).pretrained_model,
                "disable_adapter",
            ):
                with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter():
                    ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
                        self.model, queries, responses, model_inputs, return_logits=full_kl_penalty
                    )
            elif self.is_peft_model and not hasattr(self.model.pretrained_model, "disable_adapter"):
                raise ValueError(
                    "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version."
                )

            else:
                ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass( # <--------- * HERE *
                    self.ref_model, queries, responses, model_inputs, return_logits=full_kl_penalty
                )

I initially thought this was a separate problem (there's a discussion about a very similar error in DeepSpeed deepspeedai/DeepSpeed#4229), but the suggested fix does not work in this situation, and the fact that this occurs when trying to use the reference model and that the tensor in this error has shape (0,) makes me wonder if it's actually related to the original 'weight' must be 2-D issue` issue.

Could you share the LLaMA run you mention in the description of this PR? Thank you!

For reference, I'm using

deepspeed==0.10.3
transformers==4.31.0
accelerate==0.23.0

lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* Initialise ref model with ZeRO-3

* Fix deadlock

* Refactor & fix KL div

* Refactor

* Refactor

* Fix imports

* Add types

* Add accelerate configs

* Add more DeepSpeed configs

* Fix types

* Disable debug

* Refactor

* Add docs

* Disable eval mode for peft

* Restore eval mode

* Revert ref model prep for peft

* Update examples/scripts/README.md

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>

* Add docs

---------

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
* Initialise ref model with ZeRO-3

* Fix deadlock

* Refactor & fix KL div

* Refactor

* Refactor

* Fix imports

* Add types

* Add accelerate configs

* Add more DeepSpeed configs

* Fix types

* Disable debug

* Refactor

* Add docs

* Disable eval mode for peft

* Restore eval mode

* Revert ref model prep for peft

* Update examples/scripts/README.md

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>

* Add docs

---------

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DeepSpeed ZeRO-3 throws RuntimeError: 'weight' must be 2-D for sentiment_tuning.py

6 participants