Skip to content

Commit

Permalink
Draft: Core: Second Order Gradients
Browse files Browse the repository at this point in the history
- support computing the gradient of the modified gradient via second
  order gradients
- for this, the second gradient pass must be done without hooks
- this is supported by either removing the hooks before computing the
  second order gradients, or setting for each hook `hook.active = False`
  by e.g. using the contextmanager `composite.inactive()` before
  computing the second order gradients

- make SmoothGrad and IntegratedGradients inherit from Gradient
- add `create_graph` and `retain_graph` arguments for
  Gradient-Attributors
- add `.grad` function to Gradient, which is used by its subclasses to
  compute the gradient
- fix attributor docstrings
- recognize in BasicHook.backward whether to use `create_graph=True` for
  the backward pass in order to compute the relevance by checking
  whether `grad_output` requires a gradient
- add the ReLUBetaSmooth rule, which transforms the gradient of ReLU to
  the gradient of softplus (i.e. sigmoid); this is used as a surrogate
  to compute meaningful gradients of ReLU

- add test to check effect of hook.active
- add test to check whether the second order gradient of Hook is
  computed as expected
- add second order gradient tests for gradient attributors
- add test for attributor.inactive
- add test for Composite.inactive
- add test for ReLUBetaSmooth

TODO:
- add How-To's
  • Loading branch information
chr5tphr committed Sep 21, 2022
1 parent 2e6c498 commit 9b581c3
Show file tree
Hide file tree
Showing 6 changed files with 381 additions and 67 deletions.
197 changes: 142 additions & 55 deletions src/zennit/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ class Attributor(metaclass=ABCMeta):
Parameters
----------
model: obj:`torch.nn.Module`
model: :py:obj:`torch.nn.Module`
The model for which the attribution will be computed. If `composite` is provided, this will also be the model
to which the composite will be registered within `with` statements, or when calling the `Attributor` instance.
composite: obj:`zennit.core.Composite`, optional
composite: :py:obj:`zennit.core.Composite`, optional
The optional composite to, if provided, be registered to the model within `with` statements or when calling the
`Attributor` instance.
attr_output: obj:`torch.Tensor` or callable, optional
attr_output: :py:obj:`torch.Tensor` or callable, optional
The default output attribution to be used when calling the `Attributor` instance, which is either a Tensor
compatible with any input used, or a function of the model's output. If None (default), the value will be the
identity function.
Expand All @@ -126,7 +126,7 @@ def __enter__(self):
Returns
-------
self: obj:`Attributor`
self: :py:obj:`Attributor`
The `Attributor` instance.
'''
Expand Down Expand Up @@ -155,18 +155,18 @@ def __call__(self, input, attr_output=None):
Parameters
----------
input: obj:`torch.Tensor`
input: :py:obj:`torch.Tensor`
Input for the model, and wrt. compute the attribution
attr_output: obj:`torch.Tensor` or callable, optional
attr_output: :py:obj:`torch.Tensor` or callable, optional
The output attribution, which is either a Tensor compatible `input` (i.e. has the same shape as the output
of the model), or a function of the model's output. If None (default), the default attribution will be
used, which if neither supplied, will result in the model output used as the output attribution.
Returns
-------
output: obj:`torch.Tensor`
output: :py:obj:`torch.Tensor`
Output of the model with argument `input`.
attribution: obj:`torch.Tensor`
attribution: :py:obj:`torch.Tensor`
Attribution of the model wrt. to `input`, with the same shape as `input`.
'''
if attr_output is None:
Expand All @@ -182,6 +182,11 @@ def __call__(self, input, attr_output=None):
with self:
return self.forward(input, attr_output_fn)

@property
def inactive(self):
'''Return the attributor's composite's ``.inactive`` context.'''
return self.composite.inactive

@abstractmethod
def forward(self, input, attr_output_fn):
'''Abstract method. Compute the attribution of the model wrt. input, by using `attr_output_fn` as the function
Expand All @@ -190,9 +195,9 @@ def forward(self, input, attr_output_fn):
Parameters
----------
input: obj:`torch.Tensor`
input: :py:obj:`torch.Tensor`
Input for the model, and wrt. compute the attribution
attr_output: obj:`torch.Tensor` or callable, optional
attr_output: :py:obj:`torch.Tensor` or callable, optional
The output attribution function of the model's output.
'''

Expand All @@ -201,49 +206,109 @@ class Gradient(Attributor):
'''The Gradient Attributor. The result is the product of the attribution output and the (possibly modified)
jacobian. With a composite, i.e. `EpsilonGammaBox`, this will compute the Layerwise Relevance Propagation
attribution values.
Parameters
----------
model: :py:obj:`torch.nn.Module`
The model for which the attribution will be computed. If `composite` is provided, this will also be the model
to which the composite will be registered within `with` statements, or when calling the `Attributor` instance.
composite: :py:obj:`zennit.core.Composite`, optional
The optional composite to, if provided, be registered to the model within `with` statements or when calling the
`Attributor` instance.
attr_output: :py:obj:`torch.Tensor` or callable, optional
The default output attribution to be used when calling the `Attributor` instance, which is either a Tensor
compatible with any input used, or a function of the model's output. If None (default), the value will be the
identity function.
create_graph: bool, optional
Specify whether to use ``create_graph=True`` (default is False) to compute the gradient with
:py:obj:`torch.autograd.grad`. This needs to be `True` to compute higher order gradients.
retain_graph: bool, optional
Specify whether to use ``retain_graph=True`` (default is the value of create_graph) to compute the gradient
with :py:obj:`torch.autograd.grad`.
'''
def __init__(self, model, composite=None, attr_output=None, create_graph=False, retain_graph=None):
super().__init__(model=model, composite=composite, attr_output=attr_output)
self.create_graph = create_graph
self.retain_graph = retain_graph

def grad(self, input, attr_output_fn):
'''Compute the gradient of the model wrt. input, by using ``attr_output_fn`` as the function of the model
output to provide the vector for the vector jacobian product.
This function is used by subclasses to compute the gradient with the same parameters.
Parameters
----------
input: :py:obj:`torch.Tensor`
Input for the model.
attr_output: :py:obj:`torch.Tensor` or callable, optional
The output attribution function of the model's output.
Returns
-------
output: :py:obj:`torch.Tensor`
Output of the model given ``input``.
gradient: :py:obj:`torch.Tensor`
Gradient of the model wrt. to ``input``, with the same shape as ``input``.
'''
if not input.requires_grad:
input.requires_grad = True
output = self.model(input)
gradient, = torch.autograd.grad(
(output,),
(input,),
grad_outputs=(attr_output_fn(output),),
create_graph=self.create_graph,
retain_graph=self.retain_graph,
)
return output, gradient

def forward(self, input, attr_output_fn):
'''Compute the gradient of the model wrt. input, by using `attr_output_fn` as the function of the model output
to provide the vector for the vector jacobian product.
This function will not register the composite, and is wrapped in the `__call__` of `Attributor`.
Parameters
----------
input: obj:`torch.Tensor`
input: :py:obj:`torch.Tensor`
Input for the model, and wrt. compute the attribution.
attr_output: obj:`torch.Tensor` or callable, optional
attr_output: :py:obj:`torch.Tensor` or callable, optional
The output attribution function of the model's output.
Returns
-------
output: obj:`torch.Tensor`
output: :py:obj:`torch.Tensor`
Output of the model given `input`.
attribution: obj:`torch.Tensor`
attribution: :py:obj:`torch.Tensor`
Attribution of the model wrt. to `input`, with the same shape as `input`.
'''
input = input.detach().requires_grad_(True)
output = self.model(input)
gradient, = torch.autograd.grad((output,), (input,), grad_outputs=(attr_output_fn(output.detach()),))
return output, gradient
# create a view of input in case it does not already requires grad
input = input.view_as(input)
return self.grad(input, attr_output_fn)


class SmoothGrad(Attributor):
class SmoothGrad(Gradient):
'''This implements SmoothGrad [1]_. The result is the average over the gradient of multiple iterations where some
normal distributed noise was added to the input. Supplying a composite will result instead in averaging over the
modified gradient.
Parameters
----------
model: obj:`torch.nn.Module`
model: :py:obj:`torch.nn.Module`
The model for which the attribution will be computed. If `composite` is provided, this will also be the model
to which the composite will be registered within `with` statements, or when calling the `Attributor` instance.
composite: obj:`zennit.core.Composite`, optional
composite: :py:obj:`zennit.core.Composite`, optional
The optional composite to, if provided, be registered to the model within `with` statements or when calling the
`Attributor` instance.
attr_output: obj:`torch.Tensor` or callable, optional
attr_output: :py:obj:`torch.Tensor` or callable, optional
The default output attribution to be used when calling the `Attributor` instance, which is either a Tensor
compatible with any input used, or a function of the model's output. If None (default), the value will be the
identity function.
create_graph: bool, optional
Specify whether to use ``create_graph=True`` (default is False) to compute the gradient with
:py:obj:`torch.autograd.grad`. This needs to be `True` to compute higher order gradients.
retain_graph: bool, optional
Specify whether to use ``retain_graph=True`` (default is the value of create_graph) to compute the gradient
with :py:obj:`torch.autograd.grad`.
noise_level: float, optional
The noise level, which is :math:`\\frac{\\sigma}{x_{max} - x_{min}}` and defaults to 0.1.
n_iter: int, optional
Expand All @@ -255,8 +320,16 @@ class SmoothGrad(Attributor):
noise," CoRR, vol. abs/1706.03825, 2017.
'''
def __init__(self, model, composite=None, attr_output=None, noise_level=0.1, n_iter=20):
super().__init__(model=model, composite=composite, attr_output=attr_output)
def __init__(
self, model, composite=None, attr_output=None, create_graph=False, retain_graph=None, noise_level=0.1, n_iter=20
):
super().__init__(
model=model,
composite=composite,
attr_output=attr_output,
create_graph=create_graph,
retain_graph=retain_graph
)
self.noise_level = noise_level
self.n_iter = n_iter

Expand All @@ -267,20 +340,18 @@ def forward(self, input, attr_output_fn):
Parameters
----------
input: obj:`torch.Tensor`
input: :py:obj:`torch.Tensor`
Input for the model, and wrt. compute the attribution.
attr_output: obj:`torch.Tensor` or callable, optional
attr_output: :py:obj:`torch.Tensor` or callable, optional
The output attribution function of the model's output.
Returns
-------
output: obj:`torch.Tensor`
output: :py:obj:`torch.Tensor`
Output of the model given `input`.
attribution: obj:`torch.Tensor`
attribution: :py:obj:`torch.Tensor`
Attribution of the model wrt. to `input`, with the same shape as `input`.
'''
input = input.detach()

dims = tuple(range(1, input.ndim))
std = self.noise_level * (input.amax(dims, keepdim=True) - input.amin(dims, keepdim=True))

Expand All @@ -292,32 +363,36 @@ def forward(self, input, attr_output_fn):
epsilon = torch.zeros_like(input)
else:
epsilon = torch.randn_like(input) * std
noisy_input = (input + epsilon).requires_grad_()
output = self.model(noisy_input)
gradient, = torch.autograd.grad((output,), (noisy_input,), grad_outputs=(attr_output_fn(output.detach()),))
output, gradient = self.grad(input + epsilon, attr_output_fn)
result += gradient / self.n_iter

# output is leaking from the loop for the last epsilon (which is zero)
return output, result


class IntegratedGradients(Attributor):
class IntegratedGradients(Gradient):
'''This implements Integrated Gradients [2]_. The result is the path integral of the gradients, estimated over
multiple discrete iterations. Supplying a composite will result instead in the path integral over the modified
gradient.
Parameters
----------
model: obj:`torch.nn.Module`
model: :py:obj:`torch.nn.Module`
The model for which the attribution will be computed. If `composite` is provided, this will also be the model
to which the composite will be registered within `with` statements, or when calling the `Attributor` instance.
composite: obj:`zennit.core.Composite`, optional
composite: :py:obj:`zennit.core.Composite`, optional
The optional composite to, if provided, be registered to the model within `with` statements or when calling the
`Attributor` instance.
attr_output: obj:`torch.Tensor` or callable, optional
attr_output: :py:obj:`torch.Tensor` or callable, optional
The default output attribution to be used when calling the `Attributor` instance, which is either a Tensor
compatible with any input used, or a function of the model's output. If None (default), the value will be the
identity function.
create_graph: bool, optional
Specify whether to use ``create_graph=True`` (default is False) to compute the gradient with
:py:obj:`torch.autograd.grad`. This needs to be `True` to compute higher order gradients.
retain_graph: bool, optional
Specify whether to use ``retain_graph=True`` (default is the value of create_graph) to compute the gradient
with :py:obj:`torch.autograd.grad`.
baseline_fn: callable, optional
The baseline for which the model output is zero, supplied as a function of the input. Defaults to
`torch.zeros_like`.
Expand All @@ -331,8 +406,23 @@ class IntegratedGradients(Attributor):
Proceedings of Machine Learning Research, D. Precup and Y. W. Teh, Eds., vol. 70. PMLR, 2017, pp. 3319–3328.
'''
def __init__(self, model, composite=None, attr_output=None, baseline_fn=None, n_iter=20):
super().__init__(model=model, composite=composite, attr_output=attr_output)
def __init__(
self,
model,
composite=None,
attr_output=None,
create_graph=False,
retain_graph=None,
baseline_fn=None,
n_iter=20
):
super().__init__(
model=model,
composite=composite,
attr_output=attr_output,
create_graph=create_graph,
retain_graph=retain_graph
)
if baseline_fn is None:
baseline_fn = torch.zeros_like
self.baseline_fn = baseline_fn
Expand All @@ -345,27 +435,24 @@ def forward(self, input, attr_output_fn):
Parameters
----------
input: obj:`torch.Tensor`
input: :py:obj:`torch.Tensor`
Input for the model, and wrt. compute the attribution.
attr_output: obj:`torch.Tensor` or callable, optional
attr_output: :py:obj:`torch.Tensor` or callable, optional
The output attribution function of the model's output.
Returns
-------
output: obj:`torch.Tensor`
output: :py:obj:`torch.Tensor`
Output of the model given `input`.
attribution: obj:`torch.Tensor`
attribution: :py:obj:`torch.Tensor`
Attribution of the model wrt. to `input`, with the same shape as `input`.
'''
input = input.detach()

baseline = self.baseline_fn(input)

result = torch.zeros_like(input)
for alpha in torch.linspace(1. / self.n_iter, 1., self.n_iter):
path_step = (baseline + alpha * (input - baseline)).requires_grad_()
output = self.model(path_step)
gradient, = torch.autograd.grad((output,), (path_step,), grad_outputs=(attr_output_fn(output.detach()),))
path_step = baseline + alpha * (input - baseline)
output, gradient = self.grad(path_step, attr_output_fn)
result += gradient / self.n_iter

result *= (input - baseline)
Expand All @@ -379,13 +466,13 @@ class Occlusion(Attributor):
Parameters
----------
model: obj:`torch.nn.Module`
model: :py:obj:`torch.nn.Module`
The model for which the attribution will be computed. If `composite` is provided, this will also be the model
to which the composite will be registered within `with` statements, or when calling the `Attributor` instance.
composite: obj:`zennit.core.Composite`, optional
composite: :py:obj:`zennit.core.Composite`, optional
The optional composite to, if provided, be registered to the model within `with` statements or when calling the
`Attributor` instance. Note that for Occlusion, this has no effect on the result.
attr_output: obj:`torch.Tensor` or callable, optional
attr_output: :py:obj:`torch.Tensor` or callable, optional
The default output attribution to be used when calling the `Attributor` instance, which is either a Tensor
compatible with any input used, or a function of the model's output. If None (default), the value will be the
identity function.
Expand Down Expand Up @@ -446,16 +533,16 @@ def forward(self, input, attr_output_fn):
Parameters
----------
input: obj:`torch.Tensor`
input: :py:obj:`torch.Tensor`
Input for the model, and wrt. compute the attribution.
attr_output: obj:`torch.Tensor` or callable, optional
attr_output: :py:obj:`torch.Tensor` or callable, optional
The output attribution function of the model's output.
Returns
-------
output: obj:`torch.Tensor`
output: :py:obj:`torch.Tensor`
Output of the model given `input`.
attribution: obj:`torch.Tensor`
attribution: :py:obj:`torch.Tensor`
Attribution of the model wrt. to `input`, with the same shape as `input`.
'''
window, stride = self._resolve_window_stride(input)
Expand Down
Loading

0 comments on commit 9b581c3

Please sign in to comment.