forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
inductor: enable weight prepack for LSTM (pytorch#103071)
- Enabled LSTM weight prepack in inductor. - Added a mkldnn decomposition for lstm which won't change for different `seq_lens`. With the previous decomposition, for dynamic shapes use case where `seq_lens` changes, the graph will be different. - Extended several inductor utility functions to support `List(Tensor`) as input. Previously those functions only supported `Tensor` input. **Update 2023-07-26:** - pytorch#103851 has moved CPU weight packing to be after AOTAutograd. Fixed the support in this PR to follow the same way (mainly in pytorch@3b207f7#diff-6dffed1ade0ba3e887f9a4eafa3bfcec267ab2365b8adcb91bd391f49b3fd2e3). LSTM is decomposed in `aten.mkldnn_rnn_layer` by layer and by direction. The weight prepack is done at the `mkldnn_rnn_layer` level. - Add a fix in rnn `__get_state__` function in case we need to recompile an `LSTM` module. When compiling the module, the weights tensors which are the `named_parameters` of the module are converted to `functional_tensor` here: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/utils/stateless.py#L125-L128 The forward function of LSTM will be called: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_functorch/aot_autograd.py#L3379-L3381 In the forward function, the `_flat_weights` are updated to be the same as the weights, thus becoming `functional_tensor`: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/modules/rnn.py#L775-L778 The weights tensors are converted back to the original tensors (which are not `functional_tensor` anymore) before exiting the `_reparametrize_module` context here: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/utils/stateless.py#L130-L142 But since `_flat_weights` is not in the `named_parameters` of the module, it's still `functional_tensor` ([link of the parameters that will be converted to functional and reverted back](https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_functorch/aot_autograd.py#L3695-L3698)). At this moment, if we need to recompile the model, `deepcopy` will be called: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_dynamo/utils.py#L915-L917 And it will report `UnImplemented` since we have `functional_tensor` (`_flat_weights`) and will trigger graph break which is not what we expect: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_subclasses/meta_utils.py#L514 Added a fix in the `__get_state__` to update the `_flat_weights` if ever weights have changed to fix this issue. The fix is covered in the `test_lstm_packed` UT. Pull Request resolved: pytorch#103071 Approved by: https://github.com/jgong5, https://github.com/jansel
- Loading branch information
1 parent
6781852
commit ae74eee
Showing
13 changed files
with
801 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.