-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torchscript-compatible TabNet #2126
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing! Thanks for doing the regression analysis.
# Avoids call to custom autograd.Function during eval to ensure torchscript compatibility | ||
# custom autograd.Function is not scriptable: https://github.com/pytorch/pytorch/issues/22329#issuecomment-506608053 | ||
if not training: | ||
output, _ = _sparsemax_forward(X, dim, k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I imagine this means we cannot use integrated gradients with torchscript, since we lose the grad info at predict time with this approach, is that right? Not the end of the world, but something we'll need to think about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup that's correct.
This PR implements a Torchscript-compatible TabNet (addressing #2124). The issues in the original implementation were twofold:
torch.nn.Module
objects. This is a somewhat known issue (HuggingFace explicitly accounts for it through atorchscript
flag: link). This PR works around it by first registering a module with the exact same properties as the shared module, then doing an overwriting assignment to the shared module immediately following.autograd.Function
subclasses. This is also a known issue (autodiff for user script functions aka torch.jit.script for autograd.Function pytorch/pytorch#22329). This PR works around it by decomposing theforward
functions of the custom classes into standalone functions, which are scriptable.The following validation was run in order to test the new changes:
tests/integration_tests/test_torchscript.py
which tests that a trained LudwigModel (with a TabNet combiner) has the same outputs as its torchscript equivalent.dd026ca2fb9e7e9dc0ef7fb9bcc73ccaae01b8a7
) and one was on this branch (commit:34fe6da455d44edb8b6cb4947f4595303cf3511d
). The models were more or less the same:Full config here: