Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory consumption in batched_forward_pass #234

Merged
merged 2 commits into from
Mar 22, 2023

Conversation

ohashi56225
Copy link
Contributor

This PR reduces memory consumption in batched_forward_pass of PPOTrainer, by avoiding the storage of logits when they are not necessary.

Before this PR, batched_forward_pass stored all of the model's logits all the time like other tensors such as values and logprobs. Here, logits tensors have a much larger size (batch_size * tokens * vocabulary_size) compared to logprobs and values tensors (batch_size × tokens), consuming a significant amount of cuda memory.

I have modified batched_forward_pass to avoid unnecessary storage of logits, which is only required when calculating entropy in the loss method.

@ohashi56225 ohashi56225 changed the title Reduce memory consumption by avoiding logits storage in forward_pass Reduce memory consumption in batched_forward_pass Mar 21, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 21, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing and for taking care of the memory consumption
This looks very good to me!
Would love to hear @lvwerra 's thoughts here

trl/trainer/ppo_trainer.py Show resolved Hide resolved
Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@lvwerra lvwerra merged commit a6ebdb6 into huggingface:main Mar 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants