Skip to content

Gather to Slice Fusion#13599

Merged
Lafi7e merged 4 commits into
mainfrom
weicwang/gather_to_slice
Nov 10, 2022
Merged

Gather to Slice Fusion#13599
Lafi7e merged 4 commits into
mainfrom
weicwang/gather_to_slice

Conversation

@Lafi7e
Copy link
Copy Markdown
Contributor

@Lafi7e Lafi7e commented Nov 9, 2022

This PR is to optimize the running for below code from Huggingface's XLNet model.

x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))

The code will be exported to Range->Gather, which can be fused to a Slice Op. Slice kernel is much faster than Gather, especially for backward run. The main reason is for Gather, the data in indices can be duplicated so that it needs sum during backward, but Slice node cannot have such case.

Use Huggingface's XLNet model for profiling.

  • Before the fuse
    forward, ~753us
    image
    backward, ~46101us
    image

  • After the fuse
    forward, ~627us
    image
    backward, ~677us
    image

@Lafi7e Lafi7e added the training issues related to ONNX Runtime training; typically submitted using template label Nov 9, 2022

/*
Fuse Range->Gather to Slice.
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a note here explaining that this fusion is primarily helpful for gradient computation for context.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forward is also faster, but not that big. The model code also has below comment to show this for PyTorch. That's why I also put this fusion to the transformer utils file for inference.

        # Note: the tensor-slice form was faster in my testing than torch.index_select
        #       However, tracing doesn't like the nature of the slice, and if klen changes
        #       during the run then it'll fail, whereas index_select will be fine.
        x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))

baijumeswani
baijumeswani previously approved these changes Nov 9, 2022
Copy link
Copy Markdown
Contributor

@baijumeswani baijumeswani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small comment to add some context. You can get it in another PR if you like.

Looks good.

baijumeswani
baijumeswani previously approved these changes Nov 10, 2022
@Lafi7e Lafi7e merged commit 2bda3fd into main Nov 10, 2022
@Lafi7e Lafi7e deleted the weicwang/gather_to_slice branch November 10, 2022 05:03
simon-moo pushed a commit to simon-moo/onnxruntime that referenced this pull request Dec 21, 2022
This PR is to optimize the running for below code from Huggingface's
XLNet model.
```
x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
```

The code will be exported to Range->Gather, which can be fused to a
Slice Op. Slice kernel is much faster than Gather, especially for
backward run. The main reason is for Gather, the data in indices can be
duplicated so that it needs sum during backward, but Slice node cannot
have such case.

Use Huggingface's XLNet model for profiling.
- Before the fuse
forward, ~753us

![image](https://user-images.githubusercontent.com/11661208/200758439-63f2f9b5-9610-4df8-98c8-a1ad4dc62f4e.png)
backward, ~46101us

![image](https://user-images.githubusercontent.com/11661208/200758530-fe16a8ec-ea8f-4b79-b3ac-386b72ba1670.png)

- After the fuse
forward, ~627us

![image](https://user-images.githubusercontent.com/11661208/200758654-ab9a6068-c45d-40f4-9c71-3862a56732f8.png)
backward, ~677us

![image](https://user-images.githubusercontent.com/11661208/200758833-aab1b8e1-1b5d-4e55-88cf-03c2a1d9d42b.png)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants