Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,105 @@ def index_rule(mesh, op_schema):
return out_strat


@register_opschema_rule(torch.ops.aten.gather.default)
def gather_strategy(mesh, op_schema):
from torch.distributed.tensor._op_schema import PlacementList
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy

input_strategy = op_schema.args_schema[0]
dim = op_schema.args_schema[1]
index_strategy = op_schema.args_schema[2]

input_shape = input_strategy.shape
index_shape = index_strategy.shape

single_mesh_dim_strategies = []

# placement list stores placements of [output, input, index]
# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 3
single_mesh_dim_strategies.append(all_replicate)

# input sharding, input sharded, index accepts mask partial, output follows index
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if index_shape[dim] == 1:
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),
index_partial_placement,
]
single_mesh_dim_strategies.append(input_sharding)

# index sharding, input replicated, index sharded, output follows index
# this only works when the sharding dimension is the gather dimension
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)]
single_mesh_dim_strategies.append(index_sharding)

if len(input_shape) == len(index_shape):
for d in range(len(input_shape)):
if d != dim:
sharding = [Shard(d), Shard(d), Shard(d)]
single_mesh_dim_strategies.append(sharding)

return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)


@register_opschema_rule(torch.ops.aten.scatter_add.default)
def scatter_add_strategy(mesh, op_schema):
from torch.distributed.tensor._op_schema import PlacementList

# from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy

input_strategy = op_schema.args_schema[0]
dim = op_schema.args_schema[1]
index_strategy = op_schema.args_schema[2]
# src_strategy = op_schema.args_schema[3]

input_shape = input_strategy.shape
index_shape = index_strategy.shape

single_mesh_dim_strategies = []

# placement list stores placements of [output, input, index]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: [output, input, index, src]

# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 4
single_mesh_dim_strategies.append(all_replicate)

"""
# input sharding, input sharded, index accepts mask partial, output follows index
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if index_shape[dim] == 1:
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),
index_partial_placement,
]
single_mesh_dim_strategies.append(input_sharding)
"""
# index sharding, input replicated, index sharded, output follows index
# this only works when the sharding dimension is the gather dimension
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)]
Copy link
Contributor

@zpcore zpcore Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this may not be correct. Taking the example here https://docs.pytorch.org/docs/stable/generated/torch.Tensor.scatter_add_.html#torch.Tensor.scatter_add_:

>> src = torch.ones((2, 5))
>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[2., 0., 0., 1., 1.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 1., 1.]])

the output can become Partial as: [Partial(), Replicate(), Shard(dim), Shard(dim)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for the review! I had roughly copy-pasted the gather rule and didn't fix this part. Will adapt it shortly.

Also, do you think we could have this implemented natively in PyTorch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Current upstream scatter_add strategy is just a quick workaround. We can follow up.

single_mesh_dim_strategies.append(index_sharding)

if len(input_shape) == len(index_shape):
for d in range(len(input_shape)):
if d != dim:
sharding = [Shard(d), Shard(d), Shard(d), Shard(d)]
Copy link
Contributor

@zpcore zpcore Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried more tests and noticed that with [Shard(d), Shard(d), Shard(d), Shard(d)], we can't simply shard the output and input. E.g., if dim = 1, we want to shard on dim=0, but input can have much more rows than index, then we will most like only modify the first shard of input, because input row and index row is one to one mapping.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, but I thought the shapes needed to match except of the dim ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RIght, the shape need to match. I change it to

 if d != dim and input_shape[d] == index_shape[d]:

and the op coverage pass now. I created the PR with the update here pytorch/pytorch#160140.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's true, I definitely missed that case!

single_mesh_dim_strategies.append(sharding)

return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)


def sdpa_rule(op, mesh, op_schema):
out_strat = get_op_strategy(op, op_schema)
# remove wrong context-parallel strategy
Expand Down