Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relevance of hidden layers #2

Closed
francescomalandrino opened this issue Feb 19, 2021 · 5 comments
Closed

Relevance of hidden layers #2

francescomalandrino opened this issue Feb 19, 2021 · 5 comments

Comments

@francescomalandrino
Copy link

Is there a simple (or simple-ish) way to obtain the relevance scores of hidden layers, i.e., what is visualized in the cell starting with "for i,l in enumerate([31,21,11,1]):" here?
https://git.tu-berlin.de/gmontavon/lrp-tutorial

The library sure computes them, but I could not find a way to recover those. Thanks!

@fhvilshoj
Copy link
Owner

I guess that this is actually a bit cumbersome. One way would be to break the network up into multiple lrp.Sequentials and then use the intermediate outputs to reference their output.grad variables.

It relates to this post about obtaining grads (explanations in our case) and this post about getting intermediate outputs of pretrained networks.

I would be happy to discuss if there is a better way to structure this code such that similar tasks become easier.

@francescomalandrino
Copy link
Author

hey, thanks for the answer!

yes, "cumbersome" is the right word here!

pull request #3 gets the job done, but it's kind of hack-ish... not sure if there's a better/more elegant way to do it!

@fhvilshoj
Copy link
Owner

I have merged your pull request although I never liked global variables. In this case, however, it does make sense to some extend. Thanks for the contribution!
I will integrate your code a bit more and push an update soon.
In particular, the trace should be done in all backward computations and not only for the Epsilon rule.

@francescomalandrino
Copy link
Author

The best way would be to pass a list to fill with the trace when calling backward -- that way, we would not need global variables:

tl=[]
y_hat.backward(trace_list=tl)

However, I think backward up there calls PyTrorch's backward first, which then calls your own backward. If this is the case, PyTorch's function will throw a TypeError exception upon seeing the extra parameter.

@fhvilshoj
Copy link
Owner

Yes, this is exactly the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants