Gather to Slice Fusion#13599
Merged
Merged
Conversation
baijumeswani
reviewed
Nov 9, 2022
|
|
||
| /* | ||
| Fuse Range->Gather to Slice. | ||
| */ |
Contributor
There was a problem hiding this comment.
Can you add a note here explaining that this fusion is primarily helpful for gradient computation for context.
Contributor
Author
There was a problem hiding this comment.
Forward is also faster, but not that big. The model code also has below comment to show this for PyTorch. That's why I also put this fusion to the transformer utils file for inference.
# Note: the tensor-slice form was faster in my testing than torch.index_select
# However, tracing doesn't like the nature of the slice, and if klen changes
# during the run then it'll fail, whereas index_select will be fine.
x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
baijumeswani
previously approved these changes
Nov 9, 2022
Contributor
baijumeswani
left a comment
There was a problem hiding this comment.
Just a small comment to add some context. You can get it in another PR if you like.
Looks good.
baijumeswani
previously approved these changes
Nov 10, 2022
baijumeswani
approved these changes
Nov 10, 2022
simon-moo
pushed a commit
to simon-moo/onnxruntime
that referenced
this pull request
Dec 21, 2022
This PR is to optimize the running for below code from Huggingface's XLNet model. ``` x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long)) ``` The code will be exported to Range->Gather, which can be fused to a Slice Op. Slice kernel is much faster than Gather, especially for backward run. The main reason is for Gather, the data in indices can be duplicated so that it needs sum during backward, but Slice node cannot have such case. Use Huggingface's XLNet model for profiling. - Before the fuse forward, ~753us  backward, ~46101us  - After the fuse forward, ~627us  backward, ~677us 
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR is to optimize the running for below code from Huggingface's XLNet model.
The code will be exported to Range->Gather, which can be fused to a Slice Op. Slice kernel is much faster than Gather, especially for backward run. The main reason is for Gather, the data in indices can be duplicated so that it needs sum during backward, but Slice node cannot have such case.
Use Huggingface's XLNet model for profiling.
Before the fuse


forward, ~753us
backward, ~46101us
After the fuse


forward, ~627us
backward, ~677us