Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Obtanining graident of LRP otput w.r.t. network parameters #183

Closed
MikiFER opened this issue Apr 28, 2023 · 12 comments
Closed

Obtanining graident of LRP otput w.r.t. network parameters #183

MikiFER opened this issue Apr 28, 2023 · 12 comments

Comments

@MikiFER
Copy link

MikiFER commented Apr 28, 2023

Hi, first of all thank you for all the hard work that was put into developing this framework and then making it available to everyone.
I was wondering if there is a way to obtain gradient of the explanation obtained using LRP with respect to the network parameters in order to optimize it.
I stumbled across your overview paper and would like to use the framework in my own EGL research.

@chr5tphr
Copy link
Owner

chr5tphr commented May 2, 2023

Hey @MikiFER
sorry for the delayed response.
It is possible, although currently a little more involved (see below for a PoC)
Also see this discussion.
I have been working on supporting this in #168 at the end of last year, but unfortunately did not yet have the time to finalize the PR.
If you are using VGG or something similar, it may work, but for ResNet, the current implementation has a few issues.

Otherwise, you can try to use this proof of concept I quickly put together:

Code
from itertools import islice

import torch
from torchvision.models import AlexNet

from zennit.core import BasicHook, ParamMod
from zennit.rules import Epsilon, Gamma, ZBox
from zennit.composites import EpsilonGammaBox
from zennit.attribution import Gradient
from zennit.types import Convolution


class ParamBasicHook(BasicHook):
    '''Hook to also get the relevance wrt. Parameters'''

    def backward(self, module, grad_input, grad_output):
        '''Backward hook to compute LRP based on the class attributes.'''
        original_input = self.stored_tensors['input'][0].clone()
        inputs = []
        outputs = []
        params = {key: [] for key, _ in module.named_parameters()}
        for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
            input = in_mod(original_input).requires_grad_()
            with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
                # remember the gradient state
                grad_states = [param.requires_grad for param in modified.parameters()]
                # require the gradients to compute the relevance
                for param in modified.parameters():
                    param.requires_grad_()

                output = modified.forward(input)
                output = out_mod(output)

                # keep track of the params for later gradient computation
                for key, param in modified.named_parameters():
                    params[key].append(param)
                # reset the gradient state
                for param, grad_state in zip(modified.parameters(), grad_states):
                    param.requires_grad = grad_state
            inputs.append(input)
            outputs.append(output)
        grad_outputs = self.gradient_mapper(grad_output[0], outputs)
        if isinstance(grad_outputs, torch.Tensor):
            grad_outputs = [grad_outputs]
        gradients = torch.autograd.grad(
            outputs * (1 + len(params)),
            inputs + sum(params.values(), []),
            grad_outputs=grad_outputs * (1 + len(params)),
            create_graph=grad_output[0].requires_grad
        )
        grad_groups = [list(islice(elem, len(inputs))) for elem in [iter(gradients)] * (1 + len(params))]
        relevance = self.reducer(inputs, grad_groups[0])

        # set the .grad of the original parameter
        for (key, param), gradient in zip(params.items(), grad_groups[1:]):
            getattr(module, key).grad = self.reducer(param, gradient)

        return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)

    @classmethod
    def inject(cls, hook_type):
        '''Create a subclass of hook_type and this class, injecting this class
        before BasicHook in order to give this class' backward a higher
        priority. May also be done manually with e.g.
        ``class EpsilonParam(Epsilon, ParamBasicHook, BasicHook): pass``.'''
        return type(f'{hook_type.__name__}Param', (hook_type, cls, BasicHook), {})


ZBoxParam = ParamBasicHook.inject(ZBox)
GammaParam = ParamBasicHook.inject(Gamma)
EpsilonParam = ParamBasicHook.inject(Epsilon)


def main():
    torch.manual_seed(0xdeadbeef)
    net = AlexNet().eval()

    layer_map = [
        (torch.nn.Linear, EpsilonParam(epsilon=1e-6)),
        (Convolution, GammaParam(gamma=0.25)),
    ]
    first_map = [
        (Convolution, ZBoxParam(low=-3., high=3.)),
    ]
    composite = EpsilonGammaBox(low=-3., high=3., layer_map=layer_map, first_map=first_map)

    data = torch.randn((1, 3, 224, 224))

    # not needed for LRP, only using this to compute the gradients
    for param in net.parameters():
        param.requires_grad = True
    net(data).sum().backward()
    weight_grad = net.features[0].weight.grad[:]
    for param in net.parameters():
        del param.grad
        param.requires_grad = False

    # compute LRP
    with Gradient(net, composite=composite) as attributor:
        out2, relevance = attributor(data)
    weight_relevance = net.features[0].weight.grad[:]

    # demonstrate that the gradient was modified
    print((weight_grad - weight_relevance).abs().sum())


if __name__ == '__main__':
    main()

Here, the trick is to create a new BasicHook which also computes the relevances of the parameters, create new Subclasses of existing hooks, injecting our new class via multiple inheritance, and then use those hooks whenever you would like to compute the relevance wrt. parameters. In the example, I have used Cooperative Layer Maps, but MixedComposite or any custom composite also works. Let me know if you have more questions, or in case there is something wrong with the PoC.

@MaxH1996 may also be interested in this code.

@MikiFER
Copy link
Author

MikiFER commented May 3, 2023

Hi @chr5tphr thank you for the response.
I think that maybe we didn't understand each other. I do not require relevance of parameters I require gradient of the relevance map (in the input space) w.r.t. network parameters. For example lets say I want to compare relevance map obtained using LRP with some GT relevance map and I want to optimize network parameters in order to minimize some loss between them. On your documentation page I have found this page but I am not sure if that is what I need.

@chr5tphr
Copy link
Owner

chr5tphr commented May 3, 2023

Hey @MikiFER,

sorry for the confusion!
This should indeed work without problem as described in the documentation.
You should be able to use Tensor.backward() together with any Optimizer, as you would do when training with the unmodified gradients.
Just make sure you call .backward outside the Attributor/Composite context, or within the Attributor.inactive() context as shown in the documentation.

Let me know in case you have issues with this, so we can try to figure it out together.

@MikiFER
Copy link
Author

MikiFER commented May 3, 2023

Thank you so much. Will try it and will get back to you if there are issues :)

@MikiFER
Copy link
Author

MikiFER commented May 5, 2023

Hi @chr5tphr by diving little deeper into the code I arrived to a question I cannot answer.
I would like to use resnet architecture in my experiments. To obtain valid LRP explanations first a canonized version of the network must be obtained. In canonization batch norm is merged with linear layer that it is attached to by. If I were then to obtain an explanation and obtained gradient of the parameters w.r.t. to it would those gradients be accurate if I were then to de-canonize the network back to its original state. Here is the pseudo code of what I am trying to do.

composite = EpsilonPlusFlat(canonizers=canonizer)
for input, gt_value, gt_mask in dataset:
    model_out = model(input)
    classification_loss = Loss(model_out, gt_value)
    with composite.context(model) as canonized_model:
        explanation = model_out.backward(gradient=gt_value)
    explanation_loss = ExplanationLoss(explanation, gt_mask)
    combined_loss = classification_loss + explanation_loss
    combined_loss.backward()
    ...

I am afraid that calculating combined_loss.backward() will result in gradients of the canonized network but I want to optimize parameters of the "normal" network that is batch-norm and appropriate linear layer parameters will never be optimized.

Is there something that I am not understanding correctly?

@chr5tphr
Copy link
Owner

chr5tphr commented May 8, 2023

Hey @MikiFER,

theoretically, this should not be a problem, as the canonized parameters should be computed from the original parameters in such a way that the gradient is the same.
However, I tested this and found out it is not behaving as expected, since the current implementation results in a detaching of the gradient.
I looked into this and added #185 where the gradients seem to be computed correctly now.

You can check it out by directly installing with pip:

pip install git+https://github.com/chr5tphr/zennit.git@canonizer-merge-batchnorm-gradfix

Let me know whether it works for you.

Here's a proof of concept check
import torch

from zennit.core import Composite
from zennit.canonizers import SequentialMergeBatchNorm


def main():
    torch.manual_seed(0xdeadbeef)
    net = torch.nn.Sequential(
        torch.nn.Linear(32, 32),
        torch.nn.BatchNorm1d(32),
    )
    weight = net[0].weight
    net.eval()
    net[1].running_mean += 1.
    net[1].running_var *= 3.
    canonizers = [
        SequentialMergeBatchNorm()
    ]

    composite = Composite(canonizers=canonizers)

    data = torch.randn((1, 32))

    weight.requires_grad = True
    out_base = net(data).sum()
    grad_base, = torch.autograd.grad(out_base, weight)

    with composite.context(net) as modified:
        out_canon = modified(data).sum()
        grad_canon, = torch.autograd.grad(out_canon, weight)

    print((out_base - out_canon).abs().sum())
    print((grad_base - grad_canon).abs().sum())


if __name__ == '__main__':
    main()

@MikiFER
Copy link
Author

MikiFER commented May 10, 2023

Hi @chr5tphr thank you so much, I will try it out and get back to you if there are any more issues.

@chr5tphr
Copy link
Owner

Assuming there were no more issues, closing this for now after merging #185 . Feel free to reopen once something pops up.

@MikiFER
Copy link
Author

MikiFER commented Aug 2, 2023

Hi @chr5tphr I have a question regarding the obtained explanation using the ResNetCanonizer in combination with EpsilonPlusFlat composite. I noticed that sum of attributions for the input image is not 1 even though when using LRP with starting relevance for the output layer equal to 1 sum of relevance in all layers should be 1. Here is piece of code I used to replicate this behavior.

import torch
from torchvision.models import resnet18

from zennit.composites import  EpsilonPlusFlat
from zennit.torchvision import ResNetCanonizer

model = resnet18(weights=None)
canonizer = ResNetCanonizer()

composite = EpsilonPlusFlat(canonizers=[canonizer])

target = torch.eye(1000)[[437]]
input_data = torch.rand(1, 3, 224, 224)
input_data.requires_grad = True
output = model(input_data)
with composite.context(model) as modified_model:
    attribution, = torch.autograd.grad(output, input_data, target)

print(attribution.shape, attribution.sum())

Am I not understanding something correctly or is this an error?

@chr5tphr
Copy link
Owner

chr5tphr commented Aug 9, 2023

Hey MikiFER,

usually, the attributions will not sum to one, unless you are certain that no attribution is lost to the bias, which you can do by passing zero_params='bias', e.g., in your case

composite = EpsilonPlusFlat(canonizers=[canonizer], zero_params='bias')

While investigating your issue, I noticed that, although #185 increased the overall attribution stability within ResNet, it lead to a negative attribution sum in the input (which can happen if some attribution is lost to biases in combination with skip-connections), for which I have opened #194. While at least for EpsilonGammaBox there is a quickfix, there does not seem to be a solution for EpsilonPlusFlat until I fixed the problem.

@MikiFER
Copy link
Author

MikiFER commented Aug 11, 2023

Hi @chr5tphr thanks for the response.
I find it a little bit weird that almost 90% of attribution is lost to stability parameters (when inference is done before composite context) and I feel like there is something more to it.
Also what I have noticed is that different attribution is obtained when model inference is done inside of the composite context and outside of it (before it). Is that the desired behavior? I believe the attribution should be the same because canonized model and original model should be equivalent.

Also one unrelated question. Have you tried paring up your library with pytorch-lightning? I get some weird results when trying to use half precision (fp16) training where model inference results in NaN result when inside composite context.

@chr5tphr
Copy link
Owner

Hey @MikiFER

it's not the stability parameters, but the bias term, which silently receives attribution.
For example, in the Epsilon-Rule, we have

$R_i = \sum_j \frac{x_i w_{ji}}{\sum_{i'}w_{ji'}x_{i'} + b_j + \varepsilon} R_j$

where the denominator includes not only $\varepsilon$, but also the bias $b_j$.
Since the biases are constant inputs to the network (or one could imagine a constant 1 in the input with another column for the bias in the weights), they will also receive relevance, which will result in a reduced relevance for the inputs.

This lost relevance can be omitted by removing the bias term from the denominator, which zero_params='bias' is for.

There is, however, as you also pointed out, currently something wrong with the changes introduced by #185, and my investigation so far points to the ResNet canonizer.

To have a better overview, feel free to create new issue when the topics are not directly related.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants