Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchscripted crf #100

Closed

Conversation

aravindMahadevan
Copy link
Contributor

@aravindMahadevan aravindMahadevan commented Jun 27, 2022

Description

Supports Torchscripting the CRF model. Fixes issue #93 without changing the current interface. Added additional tests to verify that Torchscripted CRF model outputs for forward and decode are equivalent to that in non-scripted CRF model.

>>> import torch
>>> from torchcrf import CRF
>>> num_tags = 5  # number of tags is 5
>>> model = CRF(num_tags)
>>> script = torch.jit.script(model)
RecursiveScriptModule(original_name=CRF)

@kmkurn
Copy link
Owner

kmkurn commented Jul 9, 2022

Hi, thanks for the PR. The tests are failing because torch.jit doesn't seem to have attribute export. I suspect this is because the PyTorch version this library uses is too old. Unfortunately, I have limited availability to upgrade to newer PyTorch. Are you perhaps interested in helping out?

@aravindMahadevan
Copy link
Contributor Author

Sure, I can help out. What would we need to do to upgrade this library?

@kmkurn
Copy link
Owner

kmkurn commented Jul 9, 2022

Awesome, thank you!

I think we can start simple:

  1. Figure out what the lowest version of PyTorch that has torch.jit.export is
  2. Update the PyTorch version in .github/workflows/run_tests.yml so the tests will be against the new PyTorch version
  3. Make sure all the tests are succeeding

Also, moving the conversation on the upgrade to a separate issue is perhaps better for tracking. Thanks again for agreeing to help out!

@aravindMahadevan
Copy link
Contributor Author

aravindMahadevan commented Jul 9, 2022

I have created issue #104 where work for upgrading this library to Pytorch version 1.2 will take place.

@aravindMahadevan
Copy link
Contributor Author

Hey @kmkurn , I tried running this branch with pytorch version 1.2.0 which is where torch.jit.export is introduced but I've been running into some issues. I also have a proposal for a fix at the end of this post.

First issue is that torch.jit doesn't support some basic operations such as not in.

E               torch.jit.frontend.NotSupportedError: unsupported comparison operator: NotIn:
E                               none|sum|mean|token_mean. none: no reduction will be applied.
E                               sum: the output will be summed over batches. mean: the output will be
E                               averaged over batches. token_mean: the output will be averaged over tokens.
E               
E                       Returns:
E                           ~torch.Tensor: The log likelihood. This will have size (batch_size,) if
E                           reduction is none, () otherwise.
E                       """
E                       self._validate(emissions, tags=tags, mask=mask)
E                       if reduction not in ['none', 'sum', 'mean', 'token_mean']:
E                                   ~~~~~~~~ <--- HERE
E                           raise ValueError(f'invalid reduction: {reduction}')
E                       if mask is None:
E                           mask = torch.ones_like(tags, dtype=torch.uint8)
E               
E                       if self.batch_first:
E                           emissions = emissions.transpose(0, 1)
E                           tags = tags.transpose(0, 1)
E                           mask = mask.transpose(0, 1)

Even if we fix this with by replacing this line with

if reduction != 'none' and reduction != 'sum' and reduction !='mean' and reduction != 'token_mean:'

We run into another issue where Torchscripted code returns that it doesn't recognize the type Torch.LongTensor.

E       RuntimeError: 
E       Unknown type name 'torch.LongTensor':
E           def forward(
E                   self,
E                   emissions: torch.Tensor,
E                   tags: torch.LongTensor,
E                         ~~~~~~~~~~~~~~~~ <--- HERE
E                   mask: Optional[torch.ByteTensor] = None,
E                   reduction: str = 'sum',
E           ) -> torch.Tensor:
E               """Compute the conditional log likelihood of a sequence of tags given emission scores.
E       
E               Args:
E                   emissions (`~torch.Tensor`): Emission score tensor of size
E                       ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
E                       ``(batch_size, seq_length, num_tags)`` otherwise.

This above error occurs in all torch versions between 1.2.0 to 1.8.0. The way we can fix it is by changing the type from torch.LongTensor to torch.Tensor but this could mean that we specify FloatTensor as input which is undesirable.

In 1.9.0, the method Tensor.new_ones is not supported by Torchscripting leading to this error:

E       RuntimeError: 
E       'Tensor' object has no attribute or method 'new_ones'.:
E               self._validate(emissions, mask=mask)
E               if mask is None:
E                   mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
E                          ~~~~~~~~~~~~~~~~~~ <--- HERE
E           
E               if self.batch_first:

We can change this to instead use Tensor.ones of the same shape and explicitly specifying that dtype=emissions.dtype and device=emissions.device but it won't be as elegant as using Tensor.new_ones.

Torch 1.10.0 is where the Torchscripted code compiles successfully.

My proposal is that we have a utility function that takes in a CRF module and returns a scripted CRF module . The function will first assert that the Torch version is greater than equal to version 1.10.0, and then return scripted CRF module.

@kmkurn
Copy link
Owner

kmkurn commented Jul 17, 2022

Thanks for the detailed write up! Really appreciate the time you spent investigating this. It's surprising how half-baked the JIT support feels with all these unsupported operations. With your suggestion, does that mean there'd be essentially 2 versions of the code, one is the normal version and the other is the JIT-friendly version?

@aravindMahadevan
Copy link
Contributor Author

aravindMahadevan commented Jul 17, 2022

I was suggesting adding a function inside init after the CRF module class that takes in a CRF module as input and we would first check the torch version and then return a scripted CRF. Something like:

def script_crf(crf):
     assert torch.__version__ >= 1.10.0, "Torchscripting the CRF model requires Pytorch version 1.10.0 and higher" 
     return torch.jit.script(crf)

The issue with having multiple versions is that if there is a fix in one version, then it might not be fixed in the other version. Instead of having a utility function that torchscripts the CRF module, what if we merge these changes in once the Pytorch 1.2 update #104 gets merged into main? I can update the torchscripting tests to only run if the torch version is 1.10.0 or higher and we can also update the documentation to specify that CRF model is torchscriptable with torch version 1.10.0 and higher.

@github-actions github-actions bot added the Stale label Oct 16, 2022
@github-actions github-actions bot closed this Nov 16, 2022
@kmkurn kmkurn reopened this Dec 9, 2022
@kmkurn
Copy link
Owner

kmkurn commented Dec 9, 2022

Sorry for responding so late. I like your solution. I've merged #104 so you can implement the solution now.

@kmkurn kmkurn removed the Stale label Dec 9, 2022
@marcelbischoff
Copy link

Hi, what is missing here, I am interested in using this. Any help needed?

@erksch
Copy link

erksch commented May 9, 2023

Actually, I think you could add to the README that torchscripting does not work for PyTorch < 1.10.0 and just merge this.

When people try to torchscript something they will not look for a function from the library but will just call torch.jit.script on their final module and get errors anyway and look it up.

Furthermore, I think the script_crf function is not actually usable because you always script your complete module that uses CRF not script every submodule individually.

We can also specify a minimum PyTorch version in the requirements-test.txt to solve the problem with failing tests on older versions.

@erksch erksch mentioned this pull request May 9, 2023
@erksch
Copy link

erksch commented May 9, 2023

@kmkurn I recreated a pull request with all the mentioned changes combined :) #113

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants