-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes Explain step for tied weights #3214
Conversation
Unit Test Results 6 files ± 0 6 suites ±0 6h 29m 55s ⏱️ + 35m 12s For more details on these failures, see this check. Results for commit a6f1ecc. ± Comparison against base commit e3a9416. ♻️ This comment has been updated with latest results. |
for more information, see https://pre-commit.ci
…ai/ludwig into fix-explain-tied-weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, thanks @geoffreyangus !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, side note, great docstrings and comments - made it very easy to review this PR!
This PR fixes the explain step for tied weights. Prior to this change, the explain step would fail with the following error:
We tracked this failure down to the fact that Captum doesn't like it when you pass duplicates of the same torch module into the integrated gradients step.
The fix for this was to create a deep copy of the tied torch module and update the computation graph used during the explanation step to use the deep copy so that every target layer passed into Captum is only forward-passed once per batch.