Skip to content

Conversation

@fmassa
Copy link
Contributor

@fmassa fmassa commented Jul 1, 2025

Before we were only removing shardings which were invalid for the inputs of the ops. Now we are also removing those which are invalid for the output. With that, we can now remove the solver constraint to remove invalid views, as those don't appear anymore.

There was also a slight issue with the way we were banning invalid views in the solver, and this should fix it

Additionally, also remove uneven sharding for the parameters / buffers of the model

Before we were only removing shardings which were invalid for the inputs of the ops. Now we are also removing those which are invalid for the output. With that, we can now remove the solver constraint to remove invalid views, as those don't appear anymore
@fmassa fmassa requested review from bdhirsh and wconstab July 1, 2025 13:05
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 1, 2025
@wconstab
Copy link
Contributor

wconstab commented Jul 1, 2025

Could you clarify the different types of invalid sharping that we wer facing, and which ones should be fixed at the dtensor level?

I think because dtensor supports uneven sharding, it is less clear what counts as invalid. I think for auto parallel we want things to be symmetric across ranks so that's an additional constraint on our end?

@fmassa
Copy link
Contributor Author

fmassa commented Jul 1, 2025

Hi Will,

I've commented in #22 just now about some of my thoughts.

I think that ultimately we will want to support uneven sharding, but I think for development it is preferable to keep the setup simpler (as we can inspect only a single GPU to verify what has been output)

The types of things we want fixed is maybe to enforce is_tensor_shardable for all ops, so that we can be sure we can rely on the outputs of the sharding propagator.

output_specs = strategy.output_specs
if isinstance(output_specs, DTensorSpec):
output_specs = [output_specs]
specs = list(strategy.input_specs) + list(output_specs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: list(output_specs) seems redundant to the line above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strategy.output_specs can also be a tuple of DTensorSpec, so I'm just trying to make sure we are not concatenating lists and tuples together

if len(orig_shape) > len(shape):
# TODO: FIXME as I think we should also handle this case
continue
# print("in heeeererereer", orig_shape, shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol

@wconstab wconstab merged commit 5726d7c into main Jul 2, 2025
4 checks passed
@wconstab wconstab deleted the fmassa/remove_further_constraints branch July 2, 2025 21:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants