Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module with Multiple Inputs #176

Open
rachtibat opened this issue Feb 15, 2023 · 2 comments
Open

Module with Multiple Inputs #176

rachtibat opened this issue Feb 15, 2023 · 2 comments
Labels
core Feature/bug concerning core functionality enhancement New feature or request

Comments

@rachtibat
Copy link
Contributor

Hey Chris,

hope you're well. I noticed an implementation detail where I am unsure if this was programmed on purpose and why.

At line you take in the backward pass only the first input, while saving previously all inputs in line.

I see that you defined the summation layer using a concat operation at line, so I assume restricting the inputs is on purpose.

So do you think, it is possible to attribute a summation layer defined in the following way in the future? And why did you restrict the input layers to have only one input?

class Sum(torch.nn.Module):

    def forward(self, input_1, input_2):
        return input_1 + input_2

Thanks a lot!

@chr5tphr
Copy link
Owner

Hey Reduan,

thanks for the issue as always.

Currently, I am indeed restricting Zennit to only attribute single inputs.
I started out with single inputs as most layers that need to be attributed usually only have a single input, and for most cases there exists an equivalent module structure with only a single input (e.g. concatenated inputs).
See for example here that the backward hook is also only attached to the first input.

I planned from the beginning to also support multiple inputs (and along the way, also parameters), and am working on getting this done in #168 , although I did not get to work on it recently.
You can see here that I define multiple gradient_sinks, which can be attributed.

The current work in the PR to be done focuses more on the parameters, as it turned out somewhat tricky to reliably hook to Parameters (hooking to the tensor will always trigger when its gradient is computed, i.e. also at the wrong time, while creating a function to hook to is a little tricky as the parameter is not passed to a function but obtained as an attribute [which is probably where I will intercept]).

For the future, your proposed Sum module is intended to work, even with the BasicHook. If you are curios, you can see in the PR that the attribution will be computed differently for each specified sink e.g. here, although the way of addressing the sinks may change.

@rachtibat
Copy link
Contributor Author

Hey,

awesome many thanks for the detailed explanation. I am very excited about the future development and will have a look at the PR to see if I can modify it for my purposes.
Otherwise I noticed - and this is also a great strength of Zennit - that you can define Pytorch functions with a custom backward method (https://pytorch.org/docs/stable/notes/extending.html), which can be overwritten to compute a complex attribution method that might not yet be supported by Zennit Hooks and still perfectly integrate in the Zennit workflow.

Best

@chr5tphr chr5tphr added the enhancement New feature or request label Aug 10, 2023
@chr5tphr chr5tphr added the core Feature/bug concerning core functionality label Aug 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core Feature/bug concerning core functionality enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants