Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ def get_placement_options(mesh, op, specs, user_args):

propagate_tensor_meta(op, user_args, out_strat)
fill_missing_redistribute_cost(op, specs, out_strat)

kept = []
for strategy in out_strat.strategies:
is_valid = True
for input_spec in strategy.input_specs:
shape = list(input_spec.tensor_meta.shape)
for mesh_shape, plc in zip(mesh.shape, input_spec.placements):
if plc.is_shard():
dim = plc.dim
if shape[dim] % mesh_shape == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically, we can support uneven sharding, can't we? But for now I am ok with this. Also @XilunWu / @zpcore are going to put this logic into DTensor so we can remove it from here once that's ready.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now in DTensor sharding prop, we only check if tensor_dim_size >= mesh_size in some rules as in https://github.com/pytorch/pytorch/blob/0decd966af9cdcb7ab4410cf475d2fc09f2dea0c/torch/distributed/tensor/_ops/_matrix_ops.py#L78-L80 .

Francisco's case is a bit different: what he's doing is to filter out uneven sharding. I would like you guys to clarify on this a bit.

I think there're 3 concepts:

  1. empty local tensor. This comes from some special cases of uneven sharding.
  2. uneven sharding. This doesn't necessarily lead to empty local tensor.
  3. the current is_tensor_shardable which means tensor_dim_size >= mesh_size.

And there're some TODOs:

  1. unify (2) and (3). The consequence is we won't be able to support uneven sharding. Wanchao has a TODO on unifying (2) and (3) ( see https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/utils.py#L147-L148 ) but I think this requires us to enable "always-on padding" (sorry I forgot the exact term we've been using, but the meaning is to always pad DTensor's local tensor rather than only pad/unpad around collectives) otherwise this unification rejects all uneven DTensors in sharding prop.

  2. enforce is_tensor_shardable check in every op's sharding prop.

My question to @Francescaaa is, why AutoParallel wants to avoid (1)? If this is a hard blocker, then we can consider upstreaming a property to OpStrategy that outputs all even sharding strategies, because filtering out all uneven sharding strategies directly in sharding prop may be problematic for current DTensor status (i.e. we need to support always-on padding).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can ultimately support uneven sharding. But it might be better to focus on the "simple case" first and have everything be evenly shardable.

In the end, we should filter out everything which is not possible, so something like tensor_dim_size >= mesh_size indeed. But to make things simpler to reason about I'm filtering a bit more things.

My question to @fmassa is, why AutoParallel wants to avoid (1)?

For now I believe keeping only even sharding will probably be best to ensure the rest of the stack is working as expected, as it will be easier to reason about (i.e., we don't need to check the outcome of many different traces to understand what changed between them).
But when supporting dynamic shapes (which will be important for Ads models), we will need to re-enable uneven sharding indeed

shape[dim] /= mesh_shape
else:
is_valid = False
break
if is_valid:
kept.append(strategy)

out_strat = OpStrategy(kept)

return out_strat


Expand Down
Loading