-
Notifications
You must be signed in to change notification settings - Fork 32
/
torchvision.py
149 lines (119 loc) · 4.88 KB
/
torchvision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# This file is part of Zennit
# Copyright (C) 2019-2021 Christopher J. Anders
#
# zennit/torchvision.py
#
# Zennit is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# Zennit is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
# more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <https://www.gnu.org/licenses/>.
'''Specialized Canonizers for models from torchvision.'''
import torch
from torchvision.models.resnet import Bottleneck as ResNetBottleneck, BasicBlock as ResNetBasicBlock
from .canonizers import SequentialMergeBatchNorm, AttributeCanonizer, CompositeCanonizer
from .layer import Sum
class VGGCanonizer(SequentialMergeBatchNorm):
'''Canonizer for torchvision.models.vgg* type models. This is so far identical to a SequentialMergeBatchNorm'''
class ResNetBottleneckCanonizer(AttributeCanonizer):
'''Canonizer specifically for Bottlenecks of torchvision.models.resnet* type models.'''
def __init__(self):
super().__init__(self._attribute_map)
@classmethod
def _attribute_map(cls, name, module):
'''Create a forward function and a Sum module to overload as new attributes for module.
Parameters
----------
name : string
Name by which the module is identified.
module : obj:`torch.nn.Module`
Instance of a module. If this is a Bottleneck layer, the appropriate attributes to overload are returned.
Returns
-------
None or dict
None if `module` is not an instance of Bottleneck, otherwise the appropriate attributes to overload onto
the module instance.
'''
if isinstance(module, ResNetBottleneck):
attributes = {
'forward': cls.forward.__get__(module),
'canonizer_sum': Sum(),
}
return attributes
return None
@staticmethod
def forward(self, x):
'''Modified Bottleneck forward for ResNet.'''
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = torch.stack([identity, out], dim=-1)
out = self.canonizer_sum(out)
out = self.relu(out)
return out
class ResNetBasicBlockCanonizer(AttributeCanonizer):
'''Canonizer specifically for BasicBlocks of torchvision.models.resnet* type models.'''
def __init__(self):
super().__init__(self._attribute_map)
@classmethod
def _attribute_map(cls, name, module):
'''Create a forward function and a Sum module to overload as new attributes for module.
Parameters
----------
name : string
Name by which the module is identified.
module : obj:`torch.nn.Module`
Instance of a module. If this is a BasicBlock layer, the appropriate attributes to overload are returned.
Returns
-------
None or dict
None if `module` is not an instance of BasicBlock, otherwise the appropriate attributes to overload onto
the module instance.
'''
if isinstance(module, ResNetBasicBlock):
attributes = {
'forward': cls.forward.__get__(module),
'canonizer_sum': Sum(),
}
return attributes
return None
@staticmethod
def forward(self, x):
'''Modified BasicBlock forward for ResNet.'''
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = torch.stack([identity, out], dim=-1)
out = self.canonizer_sum(out)
out = self.relu(out)
return out
class ResNetCanonizer(CompositeCanonizer):
'''Canonizer for torchvision.models.resnet* type models. This applies SequentialMergeBatchNorm, as well as
add a Sum module to the Bottleneck modules and overload their forward method to use the Sum module instead of
simply adding two tensors, such that forward and backward hooks may be applied.'''
def __init__(self):
super().__init__((
SequentialMergeBatchNorm(),
ResNetBottleneckCanonizer(),
ResNetBasicBlockCanonizer(),
))