-
Notifications
You must be signed in to change notification settings - Fork 2
Add gather and scatter_add strategies #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
They were taken from #29
|
||
single_mesh_dim_strategies = [] | ||
|
||
# placement list stores placements of [output, input, index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: [output, input, index, src]
""" | ||
# index sharding, input replicated, index sharded, output follows index | ||
# this only works when the sharding dimension is the gather dimension | ||
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this may not be correct. Taking the example here https://docs.pytorch.org/docs/stable/generated/torch.Tensor.scatter_add_.html#torch.Tensor.scatter_add_:
>> src = torch.ones((2, 5))
>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[2., 0., 0., 1., 1.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 1., 1.]])
the output can become Partial as: [Partial(), Replicate(), Shard(dim), Shard(dim)]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks for the review! I had roughly copy-pasted the gather
rule and didn't fix this part. Will adapt it shortly.
Also, do you think we could have this implemented natively in PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! Current upstream scatter_add
strategy is just a quick workaround. We can follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think scatter_add may produce incorrect output.
if len(input_shape) == len(index_shape): | ||
for d in range(len(input_shape)): | ||
if d != dim: | ||
sharding = [Shard(d), Shard(d), Shard(d), Shard(d)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried more tests and noticed that with [Shard(d), Shard(d), Shard(d), Shard(d)]
, we can't simply shard the output and input. E.g., if dim = 1, we want to shard on dim=0, but input
can have much more rows than index
, then we will most like only modify the first shard of input, because input
row and index
row is one to one mapping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but I thought the shapes needed to match except of the dim ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RIght, the shape need to match. I change it to
if d != dim and input_shape[d] == index_shape[d]:
and the op coverage pass now. I created the PR with the update here pytorch/pytorch#160140.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's true, I definitely missed that case!
Subsumed by pytorch/pytorch#160140 |
As title. This PR made a small fix on top of meta-pytorch/autoparallel#81. Pull Request resolved: #160140 Approved by: https://github.com/fmassa
They were taken from #29
Would be good to have those in PyTorch, but I've seen the
gather
was useful forCrossEntropyLoss
as well, so probably better to unblock first