-
Notifications
You must be signed in to change notification settings - Fork 32
/
composites.py
292 lines (245 loc) · 10.1 KB
/
composites.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# This file is part of Zennit
# Copyright (C) 2019-2021 Christopher J. Anders
#
# zennit/composites.py
#
# Zennit is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# Zennit is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
# more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <https://www.gnu.org/licenses/>.
'''Composites, registered in a global composite dict.'''
import torch
from .core import Composite
from .layer import Sum
from .rules import Gamma, Epsilon, ZBox, ZPlus, AlphaBeta, Flat, Pass, Norm, ReLUDeconvNet, ReLUGuidedBackprop
from .types import Convolution, Linear, AvgPool, Activation
class LayerMapComposite(Composite):
'''A Composite for which hooks are specified by a mapping from module types to hooks.
Parameters
----------
layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]`
A mapping as a list of tuples, with a tuple of applicable module types and a Hook.
'''
def __init__(self, layer_map, canonizers=None):
self.layer_map = layer_map
super().__init__(self.mapping, canonizers)
# pylint: disable=unused-argument
def mapping(self, ctx, name, module):
'''Get the appropriate hook given a mapping from module types to hooks.
Parameters
----------
ctx: dict
A context dictionary to keep track of previously registered hooks.
name: str
Name of the module.
module: obj:`torch.nn.Module`
Instance of the module to find a hook for.
Returns
-------
obj:`Hook` or None
The hook found with the module type in the given layer map, or None if no applicable hook was found.
'''
return next((hook for types, hook in self.layer_map if isinstance(module, types)), None)
class SpecialFirstLayerMapComposite(LayerMapComposite):
'''A Composite for which hooks are specified by a mapping from module types to hooks.
Parameters
----------
layer_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]`
A mapping as a list of tuples, with a tuple of applicable module types and a Hook.
first_map: `list[tuple[tuple[torch.nn.Module, ...], Hook]]`
Applicable mapping for the first layer, same format as `layer_map`.
'''
def __init__(self, layer_map, first_map, canonizers=None):
self.first_map = first_map
super().__init__(layer_map, canonizers)
def mapping(self, ctx, name, module):
'''Get the appropriate hook given a mapping from module types to hooks and a special mapping for the first
occurrence in another mapping.
Parameters
----------
ctx: dict
A context dictionary to keep track of previously registered hooks.
name: str
Name of the module.
module: obj:`torch.nn.Module`
Instance of the module to find a hook for.
Returns
-------
obj:`Hook` or None
The hook found with the module type in the given layer map, in the first layer map if this was the first
layer, or None if no applicable hook was found.
'''
if not ctx.get('first_layer_visited', False):
for types, hook in self.first_map:
if isinstance(module, types):
ctx['first_layer_visited'] = True
return hook
return super().mapping(ctx, name, module)
class NameMapComposite(Composite):
'''A Composite for which hooks are specified by a mapping from module types to hooks.
Parameters
----------
name_map: `list[tuple[tuple[str, ...], Hook]]`
A mapping as a list of tuples, with a tuple of applicable module names and a Hook.
'''
def __init__(self, name_map, canonizers=None):
self.name_map = name_map
super().__init__(self.mapping, canonizers)
# pylint: disable=unused-argument
def mapping(self, ctx, name, module):
'''Get the appropriate hook given a mapping from module names to hooks.
Parameters
----------
ctx: dict
A context dictionary to keep track of previously registered hooks.
name: str
Name of the module.
module: obj:`torch.nn.Module`
Instance of the module to find a hook for.
Returns
-------
obj:`Hook` or None
The hook found with the module type in the given name map, or None if no applicable hook was found.
'''
return next((hook for names, hook in self.name_map if name in names), None)
COMPOSITES = {}
def register_composite(name):
'''Register a composite in the global COMPOSITES dict under `name`.'''
def wrapped(composite):
'''Wrapped function to be called on the composite to register it to the global COMPOSITES dict.'''
COMPOSITES[name] = composite
return composite
return wrapped
LAYER_MAP_BASE = [
(Activation, Pass()),
(Sum, Norm()),
(AvgPool, Norm())
]
@register_composite('epsilon_gamma_box')
class EpsilonGammaBox(SpecialFirstLayerMapComposite):
'''An explicit composite using the ZBox rule for the first convolutional layer, gamma rule for all following
convolutional layers, and the epsilon rule for all fully connected layers.
Parameters
----------
low: obj:`torch.Tensor`
A tensor with the same size as the input, describing the lowest possible pixel values.
high: obj:`torch.Tensor`
A tensor with the same size as the input, describing the highest possible pixel values.
epsilon: float
Epsilon parameter for the epsilon rule.
gamma: float
Gamma parameter for the gamma rule.
'''
def __init__(self, low, high, epsilon=1e-6, gamma=0.25, canonizers=None):
layer_map = LAYER_MAP_BASE + [
(Convolution, Gamma(gamma=gamma)),
(torch.nn.Linear, Epsilon(epsilon=epsilon)),
]
first_map = [
(Convolution, ZBox(low, high))
]
super().__init__(layer_map, first_map, canonizers=canonizers)
@register_composite('epsilon_plus')
class EpsilonPlus(LayerMapComposite):
'''An explicit composite using the zplus rule for all convolutional layers and the epsilon rule for all fully
connected layers.
Parameters
----------
epsilon: float
Epsilon parameter for the epsilon rule.
'''
def __init__(self, epsilon=1e-6, canonizers=None):
layer_map = LAYER_MAP_BASE + [
(Convolution, ZPlus()),
(torch.nn.Linear, Epsilon(epsilon=epsilon)),
]
super().__init__(layer_map, canonizers=canonizers)
@register_composite('epsilon_alpha2_beta1')
class EpsilonAlpha2Beta1(LayerMapComposite):
'''An explicit composite using the alpha2-beta1 rule for all convolutional layers and the epsilon rule for all
fully connected layers.
Parameters
----------
epsilon: float
Epsilon parameter for the epsilon rule.
'''
def __init__(self, epsilon=1e-6, canonizers=None):
layer_map = LAYER_MAP_BASE + [
(Convolution, AlphaBeta(alpha=2, beta=1)),
(torch.nn.Linear, Epsilon(epsilon=epsilon)),
]
super().__init__(layer_map, canonizers=canonizers)
@register_composite('epsilon_plus_flat')
class EpsilonPlusFlat(SpecialFirstLayerMapComposite):
'''An explicit composite using the flat rule for any linear first layer, the zplus rule for all other convolutional
layers and the epsilon rule for all other fully connected layers.
Parameters
----------
epsilon: float
Epsilon parameter for the epsilon rule.
'''
def __init__(self, epsilon=1e-6, canonizers=None):
layer_map = LAYER_MAP_BASE + [
(Convolution, ZPlus()),
(torch.nn.Linear, Epsilon(epsilon=epsilon)),
]
first_map = [
(Linear, Flat())
]
super().__init__(layer_map, first_map, canonizers=canonizers)
@register_composite('epsilon_alpha2_beta1_flat')
class EpsilonAlpha2Beta1Flat(SpecialFirstLayerMapComposite):
'''An explicit composite using the flat rule for any linear first layer, the alpha2-beta1 rule for all other
convolutional layers and the epsilon rule for all other fully connected layers.
Parameters
----------
epsilon: float
Epsilon parameter for the epsilon rule.
'''
def __init__(self, epsilon=1e-6, canonizers=None):
layer_map = LAYER_MAP_BASE + [
(Convolution, AlphaBeta(alpha=2, beta=1)),
(torch.nn.Linear, Epsilon(epsilon=epsilon)),
]
first_map = [
(Linear, Flat())
]
super().__init__(layer_map, first_map, canonizers=canonizers)
@register_composite('deconvnet')
class DeconvNet(LayerMapComposite):
'''An explicit composite modifying the gradients of all ReLUs according to DeconvNet
:cite:p:`zeiler2014visualizing`.
'''
def __init__(self, canonizers=None):
layer_map = [
(torch.nn.ReLU, ReLUDeconvNet()),
]
super().__init__(layer_map, canonizers=canonizers)
@register_composite('guided_backprop')
class GuidedBackprop(LayerMapComposite):
'''An explicit composite modifying the gradients of all ReLUs according to GuidedBackprop
:cite:p:`springenberg2015striving`.
'''
def __init__(self, canonizers=None):
layer_map = [
(torch.nn.ReLU, ReLUGuidedBackprop()),
]
super().__init__(layer_map, canonizers=canonizers)
@register_composite('excitation_backprop')
class ExcitationBackprop(LayerMapComposite):
'''An explicit composite implementing the ExcitationBackprop :cite:p:`zhang2016top`.'''
def __init__(self, canonizers=None):
layer_map = [
(Sum, Norm()),
(AvgPool, Norm()),
(Linear, ZPlus()),
]
super().__init__(layer_map, canonizers=canonizers)