Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchscript-compatible TabNet #2126

Merged
merged 6 commits into from
Jun 10, 2022
Merged

Torchscript-compatible TabNet #2126

merged 6 commits into from
Jun 10, 2022

Conversation

geoffreyangus
Copy link
Collaborator

@geoffreyangus geoffreyangus commented Jun 9, 2022

This PR implements a Torchscript-compatible TabNet (addressing #2124). The issues in the original implementation were twofold:

  1. Some weirdness in the interaction between the Torchscript compiler and the weights shared between torch.nn.Module objects. This is a somewhat known issue (HuggingFace explicitly accounts for it through a torchscript flag: link). This PR works around it by first registering a module with the exact same properties as the shared module, then doing an overwriting assignment to the shared module immediately following.
  2. Inability to script custom autograd.Function subclasses. This is also a known issue (autodiff for user script functions aka torch.jit.script for autograd.Function pytorch/pytorch#22329). This PR works around it by decomposing the forward functions of the custom classes into standalone functions, which are scriptable.

The following validation was run in order to test the new changes:

  1. A new test was added in tests/integration_tests/test_torchscript.py which tests that a trained LudwigModel (with a TabNet combiner) has the same outputs as its torchscript equivalent.
  2. Existing tests were modified in order to ensure that the custom autograd classes remained unchanged.
  3. Two Ludwig models were trained on the Titanic dataset with a TabNet combiner. One was on the master branch (commit: dd026ca2fb9e7e9dc0ef7fb9bcc73ccaae01b8a7) and one was on this branch (commit: 34fe6da455d44edb8b6cb4947f4595303cf3511d). The models were more or less the same:
    tabnet-performance-comparison
    Full config here:
input_features:
  - name: Pclass
    type: category
  - name: Sex
    type: category
  - name: Age
    type: number
    preprocessing:
      missing_value_strategy: fill_with_mean
  - name: SibSp
    type: number
  - name: Parch
    type: number
  - name: Fare
    type: number
    preprocessing:
      missing_value_strategy: fill_with_mean
  - name: Embarked
    type: category

output_features:
  - name: Survived
    type: binary

combiner:
  type: tabnet

Copy link
Contributor

@brightsparc brightsparc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing! Thanks for doing the regression analysis.

ludwig/modules/tabnet_modules.py Show resolved Hide resolved
# Avoids call to custom autograd.Function during eval to ensure torchscript compatibility
# custom autograd.Function is not scriptable: https://github.com/pytorch/pytorch/issues/22329#issuecomment-506608053
if not training:
output, _ = _sparsemax_forward(X, dim, k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine this means we cannot use integrated gradients with torchscript, since we lose the grad info at predict time with this approach, is that right? Not the end of the world, but something we'll need to think about.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup that's correct.

ludwig/utils/entmax/activations.py Show resolved Hide resolved
@tgaddair tgaddair merged commit 3967cc5 into master Jun 10, 2022
@tgaddair tgaddair deleted the fix-ts-tabnet branch June 10, 2022 02:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants