New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torchscripted crf #100
Torchscripted crf #100
Conversation
Hi, thanks for the PR. The tests are failing because |
Sure, I can help out. What would we need to do to upgrade this library? |
Awesome, thank you! I think we can start simple:
Also, moving the conversation on the upgrade to a separate issue is perhaps better for tracking. Thanks again for agreeing to help out! |
I have created issue #104 where work for upgrading this library to Pytorch version 1.2 will take place. |
Hey @kmkurn , I tried running this branch with pytorch version 1.2.0 which is where torch.jit.export is introduced but I've been running into some issues. I also have a proposal for a fix at the end of this post. First issue is that torch.jit doesn't support some basic operations such as not in.
Even if we fix this with by replacing this line with
We run into another issue where Torchscripted code returns that it doesn't recognize the type Torch.LongTensor.
This above error occurs in all torch versions between 1.2.0 to 1.8.0. The way we can fix it is by changing the type from torch.LongTensor to torch.Tensor but this could mean that we specify FloatTensor as input which is undesirable. In 1.9.0, the method Tensor.new_ones is not supported by Torchscripting leading to this error:
We can change this to instead use Tensor.ones of the same shape and explicitly specifying that dtype=emissions.dtype and device=emissions.device but it won't be as elegant as using Tensor.new_ones. Torch 1.10.0 is where the Torchscripted code compiles successfully. My proposal is that we have a utility function that takes in a CRF module and returns a scripted CRF module . The function will first assert that the Torch version is greater than equal to version 1.10.0, and then return scripted CRF module. |
Thanks for the detailed write up! Really appreciate the time you spent investigating this. It's surprising how half-baked the JIT support feels with all these unsupported operations. With your suggestion, does that mean there'd be essentially 2 versions of the code, one is the normal version and the other is the JIT-friendly version? |
I was suggesting adding a function inside init after the CRF module class that takes in a CRF module as input and we would first check the torch version and then return a scripted CRF. Something like:
The issue with having multiple versions is that if there is a fix in one version, then it might not be fixed in the other version. Instead of having a utility function that torchscripts the CRF module, what if we merge these changes in once the Pytorch 1.2 update #104 gets merged into main? I can update the torchscripting tests to only run if the torch version is 1.10.0 or higher and we can also update the documentation to specify that CRF model is torchscriptable with torch version 1.10.0 and higher. |
Sorry for responding so late. I like your solution. I've merged #104 so you can implement the solution now. |
Hi, what is missing here, I am interested in using this. Any help needed? |
Actually, I think you could add to the README that torchscripting does not work for PyTorch < 1.10.0 and just merge this. When people try to torchscript something they will not look for a function from the library but will just call Furthermore, I think the We can also specify a minimum PyTorch version in the |
Description
Supports Torchscripting the CRF model. Fixes issue #93 without changing the current interface. Added additional tests to verify that Torchscripted CRF model outputs for forward and decode are equivalent to that in non-scripted CRF model.