Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic hook #17

Merged
merged 2 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Currently, only feed-forward type models are supported.

At its heart, Zennit registers hooks at PyTorch's Module level, to modify the backward pass to produce LRP
attributions (instead of the usual gradient).
All rules are implemented as hooks (`zennit/rules.py`) and most use the basic `LinearHook` (`zennit/core.py`).
All rules are implemented as hooks (`zennit/rules.py`) and most use the LRP-specific `BasicHook` (`zennit/core.py`).
**Composites** are a way of choosing the right hook for the right layer.
In addition to the abstract **NameMapComposite**, which assigns hooks to layers by name, and **LayerMapComposite**,
which assigns hooks to layers based on their Type, there exist explicit Composites, which currently are
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#!/usr/bin/env python3
from setuptools import setup
from setuptools import setup, find_packages

setup(
name="zennit",
use_scm_version=True,
packages=['zennit'],
packages=find_packages(include=['zennit*']),
install_requires=[
'numpy',
'Pillow',
'torch==1.7.0',
'torch>=1.7.0',
],
setup_requires=[
'setuptools_scm',
Expand Down
39 changes: 22 additions & 17 deletions zennit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,25 +163,30 @@ def register(self, module):
])


class LinearHook(Hook):
'''A hook to compute the layerwise attribution of the layer it is attached to.
A `LinearHook` instance may only be registered with a single module.
class BasicHook(Hook):
'''A hook to compute the layerwise attribution of the module it is attached to.
A `BasicHook` instance may only be registered with a single module.

Parameters
----------
input_modifiers: list of function
A list of functions to produce multiple inputs.
param_modifiers: list of function
A list of functions to temporarily modify the parameters of the attached linear layer for each input produced
with `input_modifiers`.
gradient_mapper: function
input_modifiers: list[callable], optional
A list of functions to produce multiple inputs. Default is a single input which is the identity.
param_modifiers: list[callable], optional
A list of functions to temporarily modify the parameters of the attached module for each input produced
with `input_modifiers`. Default is unmodified parameters for each input.
output_modifiers: list[callable], optional
A list of functions to modify the module's output computed using the modified parameters before gradient
computation for each input produced with `input_modifier`. Default is the identity for each output.
gradient_mapper: callable, optional
Function to modify upper relevance. Call signature is of form `(grad_output, outputs)` and a tuple of
the same size as outputs is expected to be returned. `outputs` has the same size as `input_modifiers` and
`param_modifiers`.
reducer: function
`param_modifiers`. Default is a stabilized normalization by each of the outputs, multiplied with the output
gradient.
reducer: callable
Function to reduce all the inputs and gradients produced through `input_modifiers` and `param_modifiers`.
Call signature is of form `(inputs, gradients)`, where `inputs` and `gradients` have the same as
`input_modifiers` and `param_modifiers`
`input_modifiers` and `param_modifiers`. Default is the sum of the multiplications of each input and its
corresponding gradient.
param_keys: list[str], optional
A list of parameters that shall be modified. If `None` (default), all parameters are modified (which may be
none). If `[]`, no parameters are modified and `modifier` is ignored.
Expand All @@ -206,7 +211,7 @@ def __init__(
'out': output_modifiers,
}
supplied = {key for key, val in modifiers.items() if val is not None}
num_mods = len(next(iter(supplied), (None,)))
num_mods = len(modifiers[next(iter(supplied))]) if supplied else 1
modifiers.update({key: (self._default_modifier,) * num_mods for key in set(modifiers) - supplied})

if gradient_mapper is None:
Expand Down Expand Up @@ -249,7 +254,7 @@ def copy(self):
'''Return a copy of this hook.
This is used to describe hooks of different modules by a single hook instance.
'''
return LinearHook(
return BasicHook(
self.input_modifiers,
self.param_modifiers,
self.output_modifiers,
Expand Down Expand Up @@ -321,9 +326,9 @@ class Composite:

Parameters
----------
module_map: list[function, Hook]]
A mapping from functions that check applicability of hooks to hook instances that shall be applied to instances
of applicable modules.
module_map: callable
A function `(ctx: dict, name: str, module: torch.nn.Module) -> Hook or None` which

canonizers: list[Canonizer]
List of canonizer instances to be applied before applying hooks.
'''
Expand Down
18 changes: 9 additions & 9 deletions zennit/rules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
'''Rules based on Hooks'''
import torch

from .core import Hook, LinearHook, stabilize
from .core import Hook, BasicHook, stabilize


class Epsilon(LinearHook):
class Epsilon(BasicHook):
'''Epsilon LRP rule.

Parameters
Expand All @@ -22,7 +22,7 @@ def __init__(self, epsilon=1e-6):
)


class Gamma(LinearHook):
class Gamma(BasicHook):
'''Gamma LRP rule.

Parameters
Expand All @@ -40,7 +40,7 @@ def __init__(self, gamma=0.25):
)


class ZPlus(LinearHook):
class ZPlus(BasicHook):
'''ZPlus (or alpha=1, beta=0) LRP rule.

Notes
Expand All @@ -66,7 +66,7 @@ def __init__(self):
)


class AlphaBeta(LinearHook):
class AlphaBeta(BasicHook):
'''AlphaBeta LRP rule.

Parameters
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, alpha=2., beta=1.):
)


class ZBox(LinearHook):
class ZBox(BasicHook):
'''ZBox LRP rule for input pixel space.

Parameters
Expand Down Expand Up @@ -146,7 +146,7 @@ def backward(self, module, grad_input, grad_output):
return grad_output


class Norm(LinearHook):
class Norm(BasicHook):
'''Normalize and weigh relevance by input contribution.
This is essentially the same as the LRP Epsilon Rule with a fixed epsilon only used as a stabilizer, and without
the need of the attached layer to have parameters `weight` and `bias`.
Expand All @@ -162,7 +162,7 @@ def __init__(self):
)


class WSquare(LinearHook):
class WSquare(BasicHook):
'''This is the WSquare LRP rule.'''
def __init__(self):
super().__init__(
Expand All @@ -174,7 +174,7 @@ def __init__(self):
)


class Flat(LinearHook):
class Flat(BasicHook):
'''This is the Flat LRP rule. It is essentially the same as the WSquare Rule, but with all parameters set to ones.
'''
def __init__(self):
Expand Down