Skip to content

Commit

Permalink
Add layer_map argument to NameMapComposite.
Browse files Browse the repository at this point in the history
This way it is possible to specify a fallback layer map if we want to apply more general rules to the remaining layers.
The parameter is optional. If None is passed, there's no change in functionality.
If a layer_map list is passed with the same format as for the LayerMapComposite, then this will be used as a fallback if there's no hook for this specific layer name.
  • Loading branch information
dkrako committed Sep 6, 2022
1 parent 2e6c498 commit 0383f9b
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/zennit/composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,14 @@ class NameMapComposite(Composite):
----------
name_map: `list[tuple[tuple[str, ...], Hook]]`
A mapping as a list of tuples, with a tuple of applicable module names and a Hook.
layer_map: list[tuple[tuple[torch.nn.Module, ...], Hook]], optional
A mapping as a list of tuples, with a tuple of applicable module types and a Hook.
canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional
List of canonizer instances to be applied before applying hooks.
'''
def __init__(self, name_map, canonizers=None):
def __init__(self, name_map, layer_map=None, canonizers=None):
self.name_map = name_map
self.layer_map = layer_map if layer_map is not None else ()
super().__init__(module_map=self.mapping, canonizers=canonizers)

# pylint: disable=unused-argument
Expand All @@ -133,9 +136,19 @@ def mapping(self, ctx, name, module):
Returns
-------
obj:`Hook` or None
The hook found with the module type in the given name map, or None if no applicable hook was found.
The hook found with the module name in the given name map, or with
the module type in the given layer map, or None if no applicable
hook was found.
'''
return next((hook for names, hook in self.name_map if name in names), None)
return next(
(hook for names, hook in self.name_map if name in names),
next(
(hook for types, hook in self.layer_map
if isinstance(module, types)
),
None,
),
)


COMPOSITES = {}
Expand Down

0 comments on commit 0383f9b

Please sign in to comment.