-
Notifications
You must be signed in to change notification settings - Fork 36
Restore TorchScript functionality (necessary for quantization) #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
e3c8fe0
Add TorchScript and deepcopy tests
danieldk c2de837
TorchScript fix: ModuleList can only be indexed with literals
danieldk c0724a2
TorchScript fix: **kwargs is not allowed
danieldk cf0fc5d
TorchScript fix: we cannot condition on global state
danieldk 1350bdc
TorchScript fix: TorchScript does not allow Module type annotation
danieldk 2d70f21
TorchScript fixes: many fixes for the Attention class
danieldk aeea0c8
Revert "Add support for Torch `scaled_dot_product_attention` (#128)"
danieldk 3eb1842
Attempt to fix CI pip issues
danieldk e090aeb
Describe some TorchScript rules of thumb in DEVELOP.md
danieldk 09dd921
Simplify TorchScript type inference
danieldk ce5038f
Remove unused imports
danieldk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| # Development | ||
|
|
||
| ## TorchScript | ||
|
|
||
| Every `torch.nn.Module` in this project must have a TorchScript conversion test. | ||
| TorchScript only supports a subset of Python and we want to make sure that all | ||
| models are convertable to TorchScript. | ||
|
|
||
| In this section we will give some rules of thumb to avoid conversion errors. | ||
|
|
||
| ### Do not use global state | ||
|
|
||
| TorchScript cannot use global state. One form of global state that we have | ||
| in this project is the `Errors` class. Consequently, we cannot use `Errors` | ||
| in `Module`s. The following is therefore invalid: | ||
|
|
||
| ```python | ||
| class Foo(nn.Module): | ||
| def forward(X: Tensor) -> Tensor: | ||
| # Problem: Errors fields are global state. | ||
| raise ValueError(Errors.E042) | ||
| ``` | ||
|
|
||
| In these cases we have to use an inline string instead: | ||
|
|
||
| ```python | ||
| class Foo(nn.Module): | ||
| def forward(X: Tensor) -> Tensor: | ||
| raise ValueError("This module does not do anything yet.") | ||
| ``` | ||
|
|
||
| For the same reason we can also not rely on `has_*` bools in a module: | ||
|
|
||
| ```python | ||
| class Foo(nn.Module): | ||
| def forward(X: Tensor) -> Tensor: | ||
| # Problem: conditional on global state. | ||
| if has_torch_feature: | ||
| ... | ||
| ``` | ||
|
|
||
| ## Typing limitations | ||
|
|
||
| TorchScript only supports a small [subset of Python types](https://pytorch.org/docs/stable/jit_language_reference.html#supported-type). | ||
| This also applies to type annotations. For instance, the following will not work, because | ||
| TorchScript only supports fully-specified tuple types: | ||
|
|
||
| ```python | ||
| class Foo(nn.Module): | ||
| # Problem: underspecified tuple | ||
| def shape(self) -> Tuple: | ||
| ... | ||
|
|
||
| # Problem: underspecified tuple | ||
| def shape(self) -> Tuple[int, ...]: | ||
| ... | ||
| ``` | ||
|
|
||
| The following is ok, because it is a valid TorchScript type: | ||
|
|
||
| ```python | ||
| class Foo(nn.Module): | ||
| def shape(self) -> Tuple[int, int]: | ||
| ... | ||
| ``` | ||
|
|
||
| ## Do not use `**kwargs` arguments | ||
|
|
||
| TorchScript does not support `**kwargs` wildcards. So the following is | ||
| invalid: | ||
|
|
||
| ```python | ||
| class Foo(nn.Module): | ||
| ... | ||
|
|
||
| def forward(X: Tensor, **kwargs) -> Tensor: | ||
| hidden = self.inner1(X) | ||
| return self.inner2(hidden, **kwargs) | ||
|
|
||
| ``` | ||
|
|
||
| Instead we have to spell out all arguments, eg.: | ||
|
|
||
| ```python | ||
| class Foo(nn.Module): | ||
| ... | ||
|
|
||
| def forward(X: Tensor, attn_mask: AttentionMask) -> Tensor: | ||
| hidden = self.inner1(X) | ||
| return self.inner2(hidden, attn_mask=attn_mask) | ||
|
|
||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.