New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TorchScript #113
TorchScript #113
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR! I have left comments and questions. I'd appreciate if you could have a look at them.
crf_bf.start_transitions.requires_grad_(False).copy_(crf.start_transitions) | ||
crf_bf.end_transitions.requires_grad_(False).copy_(crf.end_transitions) | ||
crf_bf.transitions.requires_grad_(False).copy_(crf.transitions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to copy the parameters? Same question for similar changes in test_scripted_decode
.
@@ -151,16 +153,18 @@ def _validate( | |||
f'got {emissions.size(2)}') | |||
|
|||
if tags is not None: | |||
if emissions.shape[:2] != tags.shape: | |||
if emissions.shape[0] != tags.shape[0] or emissions.shape[1] != tags.shape[1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this changed? Same question for other similar changes below.
|
||
# We trace back where the best last tag comes from, append that to our best tag | ||
# sequence, and trace it back again, and so on | ||
for hist in reversed(history[:seq_ends[idx]]): | ||
for hist in history[:seq_ends[idx]][::-1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this changed?
@@ -1,6 +1,6 @@ | |||
# This only installs PyTorch with a specific CUDA version which may not be | |||
# compatible with yours. If so, install PyTorch with the correct CUDA version | |||
# as instructed on https://pytorch.org/get-started/locally/ | |||
torch | |||
torch>=1.10.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PyTorch version that gets installed for testing is specified in .github/workflows/run_tests.yml
. You should change the version there rather than here. That said, I'm not comfortable with testing only for v1.10.0 just to accommodate torch scripting. I think the ideal option is to run all tests except TestTorchScript
for v1.2, and run only TestTorchScript
on v1.10.0. That way, all the other tests are still run for <v1.10.0, ensuring that the library still supports those earlier versions. Would you be interested in implementing this?
Applied the required changes to #100.