Skip to content

Conversation

fmassa
Copy link
Contributor

@fmassa fmassa commented Aug 6, 2025

They were taken from #29

Would be good to have those in PyTorch, but I've seen the gather was useful for CrossEntropyLoss as well, so probably better to unblock first

@fmassa fmassa requested review from wconstab and zpcore August 6, 2025 15:36
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 6, 2025

single_mesh_dim_strategies = []

# placement list stores placements of [output, input, index]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: [output, input, index, src]

"""
# index sharding, input replicated, index sharded, output follows index
# this only works when the sharding dimension is the gather dimension
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)]
Copy link
Contributor

@zpcore zpcore Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this may not be correct. Taking the example here https://docs.pytorch.org/docs/stable/generated/torch.Tensor.scatter_add_.html#torch.Tensor.scatter_add_:

>> src = torch.ones((2, 5))
>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[2., 0., 0., 1., 1.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 1., 1.]])

the output can become Partial as: [Partial(), Replicate(), Shard(dim), Shard(dim)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for the review! I had roughly copy-pasted the gather rule and didn't fix this part. Will adapt it shortly.

Also, do you think we could have this implemented natively in PyTorch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Current upstream scatter_add strategy is just a quick workaround. We can follow up.

Copy link
Contributor

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think scatter_add may produce incorrect output.

if len(input_shape) == len(index_shape):
for d in range(len(input_shape)):
if d != dim:
sharding = [Shard(d), Shard(d), Shard(d), Shard(d)]
Copy link
Contributor

@zpcore zpcore Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried more tests and noticed that with [Shard(d), Shard(d), Shard(d), Shard(d)], we can't simply shard the output and input. E.g., if dim = 1, we want to shard on dim=0, but input can have much more rows than index, then we will most like only modify the first shard of input, because input row and index row is one to one mapping.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, but I thought the shapes needed to match except of the dim ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RIght, the shape need to match. I change it to

 if d != dim and input_shape[d] == index_shape[d]:

and the op coverage pass now. I created the PR with the update here pytorch/pytorch#160140.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's true, I definitely missed that case!

@fmassa
Copy link
Contributor Author

fmassa commented Aug 8, 2025

Subsumed by pytorch/pytorch#160140

@fmassa fmassa closed this Aug 8, 2025
@fmassa fmassa deleted the fmassa/gather_scatter_add branch August 8, 2025 08:54
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Aug 8, 2025
As title.

This PR made a small fix on top of meta-pytorch/autoparallel#81.

Pull Request resolved: #160140
Approved by: https://github.com/fmassa
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants