Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchScript #113

Closed
wants to merge 1 commit into from
Closed

TorchScript #113

wants to merge 1 commit into from

Conversation

erksch
Copy link

@erksch erksch commented May 9, 2023

Applied the required changes to #100.

  • Added to the README that torchscript works starting from PyTorch 1.10.0
  • Set testing PyTorch version to >=1.10.0

@erksch erksch mentioned this pull request May 9, 2023
Copy link
Owner

@kmkurn kmkurn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! I have left comments and questions. I'd appreciate if you could have a look at them.

Comment on lines +539 to +541
crf_bf.start_transitions.requires_grad_(False).copy_(crf.start_transitions)
crf_bf.end_transitions.requires_grad_(False).copy_(crf.end_transitions)
crf_bf.transitions.requires_grad_(False).copy_(crf.transitions)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to copy the parameters? Same question for similar changes in test_scripted_decode.

@@ -151,16 +153,18 @@ def _validate(
f'got {emissions.size(2)}')

if tags is not None:
if emissions.shape[:2] != tags.shape:
if emissions.shape[0] != tags.shape[0] or emissions.shape[1] != tags.shape[1]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this changed? Same question for other similar changes below.


# We trace back where the best last tag comes from, append that to our best tag
# sequence, and trace it back again, and so on
for hist in reversed(history[:seq_ends[idx]]):
for hist in history[:seq_ends[idx]][::-1]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this changed?

@@ -1,6 +1,6 @@
# This only installs PyTorch with a specific CUDA version which may not be
# compatible with yours. If so, install PyTorch with the correct CUDA version
# as instructed on https://pytorch.org/get-started/locally/
torch
torch>=1.10.0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PyTorch version that gets installed for testing is specified in .github/workflows/run_tests.yml. You should change the version there rather than here. That said, I'm not comfortable with testing only for v1.10.0 just to accommodate torch scripting. I think the ideal option is to run all tests except TestTorchScript for v1.2, and run only TestTorchScript on v1.10.0. That way, all the other tests are still run for <v1.10.0, ensuring that the library still supports those earlier versions. Would you be interested in implementing this?

@coveralls
Copy link

Coverage Status

Coverage: 0.0%. Remained the same when pulling 5cde02d on erksch:torchscriptable into c885fba on kmkurn:master.

@github-actions github-actions bot added the Stale label Aug 13, 2023
@github-actions github-actions bot closed this Sep 13, 2023
@leejuyuu
Copy link

Hello @erksch, I would like to see this feature merged as it seems to enable ONNX export. Are you still interested? If not, would you mind if I try to resolve the reviews of @kmkurn on top of this branch (possibly in another PR because of push rights)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants