-
Notifications
You must be signed in to change notification settings - Fork 9
Remove more invalid / uneven shardings #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Before we were only removing shardings which were invalid for the inputs of the ops. Now we are also removing those which are invalid for the output. With that, we can now remove the solver constraint to remove invalid views, as those don't appear anymore
|
Could you clarify the different types of invalid sharping that we wer facing, and which ones should be fixed at the dtensor level? I think because dtensor supports uneven sharding, it is less clear what counts as invalid. I think for auto parallel we want things to be symmetric across ranks so that's an additional constraint on our end? |
|
Hi Will, I've commented in #22 just now about some of my thoughts. I think that ultimately we will want to support uneven sharding, but I think for development it is preferable to keep the setup simpler (as we can inspect only a single GPU to verify what has been output) The types of things we want fixed is maybe to enforce |
| output_specs = strategy.output_specs | ||
| if isinstance(output_specs, DTensorSpec): | ||
| output_specs = [output_specs] | ||
| specs = list(strategy.input_specs) + list(output_specs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: list(output_specs) seems redundant to the line above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strategy.output_specs can also be a tuple of DTensorSpec, so I'm just trying to make sure we are not concatenating lists and tuples together
| if len(orig_shape) > len(shape): | ||
| # TODO: FIXME as I think we should also handle this case | ||
| continue | ||
| # print("in heeeererereer", orig_shape, shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol
Before we were only removing shardings which were invalid for the inputs of the ops. Now we are also removing those which are invalid for the output. With that, we can now remove the solver constraint to remove invalid views, as those don't appear anymore.
There was also a slight issue with the way we were banning invalid views in the solver, and this should fix it
Additionally, also remove uneven sharding for the parameters / buffers of the model