-
Notifications
You must be signed in to change notification settings - Fork 3
Add gather and scatter_add strategies #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -660,6 +660,105 @@ def index_rule(mesh, op_schema): | |
return out_strat | ||
|
||
|
||
@register_opschema_rule(torch.ops.aten.gather.default) | ||
def gather_strategy(mesh, op_schema): | ||
from torch.distributed.tensor._op_schema import PlacementList | ||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial | ||
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy | ||
|
||
input_strategy = op_schema.args_schema[0] | ||
dim = op_schema.args_schema[1] | ||
index_strategy = op_schema.args_schema[2] | ||
|
||
input_shape = input_strategy.shape | ||
index_shape = index_strategy.shape | ||
|
||
single_mesh_dim_strategies = [] | ||
|
||
# placement list stores placements of [output, input, index] | ||
# first we always have replicate all for inputs and output | ||
all_replicate: PlacementList = [Replicate()] * 3 | ||
single_mesh_dim_strategies.append(all_replicate) | ||
|
||
# input sharding, input sharded, index accepts mask partial, output follows index | ||
# this only works when the input is sharded on the gather dimension, and | ||
# index has size 1 on the gather dimension | ||
if index_shape[dim] == 1: | ||
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) | ||
input_sharding: PlacementList = [ | ||
index_partial_placement, | ||
Shard(dim), | ||
index_partial_placement, | ||
] | ||
single_mesh_dim_strategies.append(input_sharding) | ||
|
||
# index sharding, input replicated, index sharded, output follows index | ||
# this only works when the sharding dimension is the gather dimension | ||
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] | ||
single_mesh_dim_strategies.append(index_sharding) | ||
|
||
if len(input_shape) == len(index_shape): | ||
for d in range(len(input_shape)): | ||
if d != dim: | ||
sharding = [Shard(d), Shard(d), Shard(d)] | ||
single_mesh_dim_strategies.append(sharding) | ||
|
||
return expand_to_full_mesh_op_strategy( | ||
mesh, op_schema, single_mesh_dim_strategies, input_index=1 | ||
) | ||
|
||
|
||
@register_opschema_rule(torch.ops.aten.scatter_add.default) | ||
def scatter_add_strategy(mesh, op_schema): | ||
from torch.distributed.tensor._op_schema import PlacementList | ||
|
||
# from torch.distributed.tensor._ops._embedding_ops import _MaskPartial | ||
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy | ||
|
||
input_strategy = op_schema.args_schema[0] | ||
dim = op_schema.args_schema[1] | ||
index_strategy = op_schema.args_schema[2] | ||
# src_strategy = op_schema.args_schema[3] | ||
|
||
input_shape = input_strategy.shape | ||
index_shape = index_strategy.shape | ||
|
||
single_mesh_dim_strategies = [] | ||
|
||
# placement list stores placements of [output, input, index] | ||
# first we always have replicate all for inputs and output | ||
all_replicate: PlacementList = [Replicate()] * 4 | ||
single_mesh_dim_strategies.append(all_replicate) | ||
|
||
""" | ||
# input sharding, input sharded, index accepts mask partial, output follows index | ||
# this only works when the input is sharded on the gather dimension, and | ||
# index has size 1 on the gather dimension | ||
if index_shape[dim] == 1: | ||
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) | ||
input_sharding: PlacementList = [ | ||
index_partial_placement, | ||
Shard(dim), | ||
index_partial_placement, | ||
] | ||
single_mesh_dim_strategies.append(input_sharding) | ||
""" | ||
# index sharding, input replicated, index sharded, output follows index | ||
# this only works when the sharding dimension is the gather dimension | ||
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel this may not be correct. Taking the example here https://docs.pytorch.org/docs/stable/generated/torch.Tensor.scatter_add_.html#torch.Tensor.scatter_add_:
the output can become Partial as: [Partial(), Replicate(), Shard(dim), Shard(dim)] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thanks for the review! I had roughly copy-pasted the Also, do you think we could have this implemented natively in PyTorch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! Current upstream |
||
single_mesh_dim_strategies.append(index_sharding) | ||
|
||
if len(input_shape) == len(index_shape): | ||
for d in range(len(input_shape)): | ||
if d != dim: | ||
sharding = [Shard(d), Shard(d), Shard(d), Shard(d)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried more tests and noticed that with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct me if I'm wrong, but I thought the shapes needed to match except of the dim ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RIght, the shape need to match. I change it to
and the op coverage pass now. I created the PR with the update here pytorch/pytorch#160140. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh that's true, I definitely missed that case! |
||
single_mesh_dim_strategies.append(sharding) | ||
|
||
return expand_to_full_mesh_op_strategy( | ||
mesh, op_schema, single_mesh_dim_strategies, input_index=1 | ||
) | ||
|
||
|
||
def sdpa_rule(op, mesh, op_schema): | ||
out_strat = get_op_strategy(op, op_schema) | ||
# remove wrong context-parallel strategy | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: [output, input, index, src]