Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch_size sanity check logic for split_batches #2344

Merged
merged 5 commits into from
Jan 27, 2024

Conversation

izhx
Copy link
Contributor

@izhx izhx commented Jan 16, 2024

What does this PR do?

Fixes #2285 and PR/2310 comment: Unfortunately, this broke Megatron DL

Some libraries may not follow PyTorch's design of BatchSampler.
If the batch_size cannot be found, a warning is left to notify the user.

Do anyone have suggestions for such situations?

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@muellerzr @BenjaminBossan

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this solution makes sense, but I know this originates from Megatron so cc @pacman100 for guidance.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @izhx for the fixes in cases when the Batch Sampler doesn't conform to PyTorch. Makes sense.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Sourab is happy, I'm happy. But would definitely be happier if @stas00 could also give this an all-clear? :)

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

  1. OK, so the key fix of this PR is it adding if split_batches: as in my trainer I have never set it to True. So this solves my problem and doesn't require Megatron Sampler to comply with certain requirements. Thank you!

  2. Now let's zoom out to the general case where split_batches==True has been set by the user. While this PR makes the dealing with an issue somewhat user-friendlier - this approach swipes the problem under the carpet. If anything it should be an assert and not a warning. i.e. if the user wants split_batches this code path must not proceed w/o having a way to get to sampler.batch_size. In other words split_batches functionality must require a sampler class that provides self.batch_size.

  3. While I was investigating this I think you have this problem here as well:

def total_batch_size(self):
batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
return (
batch_sampler.batch_size
if getattr(batch_sampler, "split_batches", False)
else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1))

it always calls sampler.batch_size. so that code also probably should check hasattr and assert if it is False. It's probably not in the scope of this PR though, so I'm just flagging to maintainers.

  1. In general avoid using warnings as much as possible. Nobody sees those in a sea of thousands of logs lines. It's either a problem and then you assert, or it's not really a problem and then you do nothing. With some rare exceptions there is no in-between state.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's land this one, @muellerzr?

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating! (Good to go once tests pass)

@muellerzr muellerzr merged commit 7aafa25 into huggingface:main Jan 27, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

accelerator.prepare(dataloader) sanity check fails when batch_sampler is given and split_batches is True
5 participants