/
resblocks.py
301 lines (257 loc) · 10.3 KB
/
resblocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
"""
Implementation of residual blocks for discriminator and generator.
We follow the official SNGAN Chainer implementation as closely as possible:
https://github.com/pfnet-research/sngan_projection
"""
import math
import torch.nn as nn
import torch.nn.functional as F
from torch_mimicry.modules import SNConv2d, ConditionalBatchNorm2d
class GBlock(nn.Module):
r"""
Residual block for generator.
Uses bilinear (rather than nearest) interpolation, and align_corners
set to False. This is as per how torchvision does upsampling, as seen in:
https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/_utils.py
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
hidden_channels (int): The channel size of intermediate feature maps.
upsample (bool): If True, upsamples the input feature map.
num_classes (int): If more than 0, uses conditional batch norm instead.
spectral_norm (bool): If True, uses spectral norm for convolutional layers.
"""
def __init__(self,
in_channels,
out_channels,
hidden_channels=None,
upsample=False,
num_classes=0,
spectral_norm=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels if hidden_channels is not None else out_channels
self.learnable_sc = in_channels != out_channels or upsample
self.upsample = upsample
self.num_classes = num_classes
self.spectral_norm = spectral_norm
# Build the layers
# Note: Can't use something like self.conv = SNConv2d to save code length
# this results in somehow spectral norm working worse consistently.
if self.spectral_norm:
self.c1 = SNConv2d(self.in_channels,
self.hidden_channels,
3,
1,
padding=1)
self.c2 = SNConv2d(self.hidden_channels,
self.out_channels,
3,
1,
padding=1)
else:
self.c1 = nn.Conv2d(self.in_channels,
self.hidden_channels,
3,
1,
padding=1)
self.c2 = nn.Conv2d(self.hidden_channels,
self.out_channels,
3,
1,
padding=1)
if self.num_classes == 0:
self.b1 = nn.BatchNorm2d(self.in_channels)
self.b2 = nn.BatchNorm2d(self.hidden_channels)
else:
self.b1 = ConditionalBatchNorm2d(self.in_channels,
self.num_classes)
self.b2 = ConditionalBatchNorm2d(self.hidden_channels,
self.num_classes)
self.activation = nn.ReLU(True)
nn.init.xavier_uniform_(self.c1.weight.data, math.sqrt(2.0))
nn.init.xavier_uniform_(self.c2.weight.data, math.sqrt(2.0))
# Shortcut layer
if self.learnable_sc:
if self.spectral_norm:
self.c_sc = SNConv2d(in_channels,
out_channels,
1,
1,
padding=0)
else:
self.c_sc = nn.Conv2d(in_channels,
out_channels,
1,
1,
padding=0)
nn.init.xavier_uniform_(self.c_sc.weight.data, 1.0)
def _upsample_conv(self, x, conv):
r"""
Helper function for performing convolution after upsampling.
"""
return conv(
F.interpolate(x,
scale_factor=2,
mode='bilinear',
align_corners=False))
def _residual(self, x):
r"""
Helper function for feedforwarding through main layers.
"""
h = x
h = self.b1(h)
h = self.activation(h)
h = self._upsample_conv(h, self.c1) if self.upsample else self.c1(h)
h = self.b2(h)
h = self.activation(h)
h = self.c2(h)
return h
def _residual_conditional(self, x, y):
r"""
Helper function for feedforwarding through main layers, including conditional BN.
"""
h = x
h = self.b1(h, y)
h = self.activation(h)
h = self._upsample_conv(h, self.c1) if self.upsample else self.c1(h)
h = self.b2(h, y)
h = self.activation(h)
h = self.c2(h)
return h
def _shortcut(self, x):
r"""
Helper function for feedforwarding through shortcut layers.
"""
if self.learnable_sc:
x = self._upsample_conv(
x, self.c_sc) if self.upsample else self.c_sc(x)
return x
else:
return x
def forward(self, x, y=None):
r"""
Residual block feedforward function.
"""
if y is None:
return self._residual(x) + self._shortcut(x)
else:
return self._residual_conditional(x, y) + self._shortcut(x)
class DBlock(nn.Module):
"""
Residual block for discriminator.
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
hidden_channels (int): The channel size of intermediate feature maps.
downsample (bool): If True, downsamples the input feature map.
spectral_norm (bool): If True, uses spectral norm for convolutional layers.
"""
def __init__(self,
in_channels,
out_channels,
hidden_channels=None,
downsample=False,
spectral_norm=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels if hidden_channels is not None else in_channels
self.downsample = downsample
self.learnable_sc = (in_channels != out_channels) or downsample
self.spectral_norm = spectral_norm
# Build the layers
if self.spectral_norm:
self.c1 = SNConv2d(self.in_channels, self.hidden_channels, 3, 1, 1)
self.c2 = SNConv2d(self.hidden_channels, self.out_channels, 3, 1,
1)
else:
self.c1 = nn.Conv2d(self.in_channels, self.hidden_channels, 3, 1,
1)
self.c2 = nn.Conv2d(self.hidden_channels, self.out_channels, 3, 1,
1)
self.activation = nn.ReLU(True)
nn.init.xavier_uniform_(self.c1.weight.data, math.sqrt(2.0))
nn.init.xavier_uniform_(self.c2.weight.data, math.sqrt(2.0))
# Shortcut layer
if self.learnable_sc:
if self.spectral_norm:
self.c_sc = SNConv2d(in_channels, out_channels, 1, 1, 0)
else:
self.c_sc = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
nn.init.xavier_uniform_(self.c_sc.weight.data, 1.0)
def _residual(self, x):
"""
Helper function for feedforwarding through main layers.
"""
h = x
h = self.activation(h)
h = self.c1(h)
h = self.activation(h)
h = self.c2(h)
if self.downsample:
h = F.avg_pool2d(h, 2)
return h
def _shortcut(self, x):
"""
Helper function for feedforwarding through shortcut layers.
"""
if self.learnable_sc:
x = self.c_sc(x)
return F.avg_pool2d(x, 2) if self.downsample else x
else:
return x
def forward(self, x):
"""
Residual block feedforward function.
"""
return self._residual(x) + self._shortcut(x)
class DBlockOptimized(nn.Module):
"""
Optimized residual block for discriminator. This is used as the first residual block,
where there is a definite downsampling involved. Follows the official SNGAN reference implementation
in chainer.
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
spectral_norm (bool): If True, uses spectral norm for convolutional layers.
"""
def __init__(self, in_channels, out_channels, spectral_norm=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.spectral_norm = spectral_norm
# Build the layers
if self.spectral_norm:
self.c1 = SNConv2d(self.in_channels, self.out_channels, 3, 1, 1)
self.c2 = SNConv2d(self.out_channels, self.out_channels, 3, 1, 1)
self.c_sc = SNConv2d(self.in_channels, self.out_channels, 1, 1, 0)
else:
self.c1 = nn.Conv2d(self.in_channels, self.out_channels, 3, 1, 1)
self.c2 = nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1)
self.c_sc = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
self.activation = nn.ReLU(True)
nn.init.xavier_uniform_(self.c1.weight.data, math.sqrt(2.0))
nn.init.xavier_uniform_(self.c2.weight.data, math.sqrt(2.0))
nn.init.xavier_uniform_(self.c_sc.weight.data, 1.0)
def _residual(self, x):
"""
Helper function for feedforwarding through main layers.
"""
h = x
h = self.c1(h)
h = self.activation(h)
h = self.c2(h)
h = F.avg_pool2d(h, 2)
return h
def _shortcut(self, x):
"""
Helper function for feedforwarding through shortcut layers.
"""
return self.c_sc(F.avg_pool2d(x, 2))
def forward(self, x):
"""
Residual block feedforward function.
"""
return self._residual(x) + self._shortcut(x)