-
Notifications
You must be signed in to change notification settings - Fork 878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow for dynamic batch padding #2352
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this ! Left a few comments. From ke's comments, it looks like the batch issue is linked to the tracing.
src/accelerate/inference.py
Outdated
process_index=0, | ||
num_processes=state.num_processes, | ||
) | ||
extra = concatenate([extra] * ((found_batch_size % state.num_processes) + 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see related comment in slack.
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are practically there. Left a few comments.
old_size = tensor.shape | ||
new_size = list(old_size) | ||
new_size[0] = batch_size + to_pad | ||
new_tensor = tensor.new_zeros(tuple(new_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's okay to just pad 0
, we can drop these afterwards and the user won't know that padded inputs were event sent ideally
* Broken version * Timing I would expect * Working version! * Use MethodType * working test * Tests * Use no split module classes explicitly * Put split_points in pipelien * Store split points in hf_split_points * fix case num_process=1 * Allow for dynamic batch padding (#2352) * Allow for dynamic batch paddign * Fix test * Update src/accelerate/inference.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Break early after the first valid bs is found * Less slicy-dicy * Test cv model * Start, need to test * Use dataloader-like logic * Refactor to utils * With tests * Update the source * Clean * bs=1 case * Add test * add some failing test * Almost working version * Much cleaner implementation * Use pad_input_tensor * All tests passing! * Do it at tracing too --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Marc Sun <marc@huggingface.co> * Rm literal * Allow users to pass in max_memory * Note about recursion * Document, document, document * Right import check * Fix bug, add tests to multigpu runners * Change default to None * Start of docs * Try again? * Try again x2 * Trailing comma * Move import * Clean * typehint * typo * From code review * Use num_chunks * Update tests/test_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Bad copy/paste * hf_split_points --------- Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
What does this PR do?
This PR allows for dynamic batch padding by finding and duplicating the last item in the batch before concating everything.
Current issue:
PiPPy
has issues with batch inference on GPT-2, will talk to the pippy folks about thisFixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@SunMarc