/
blocks.py
252 lines (190 loc) · 9.5 KB
/
blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Common model blocks."""
import numpy as np
import torch
import torch.nn as nn
from pycls.core.config import cfg
from torch.nn import Module
# ----------------------- Shortcuts for common torch.nn layers ----------------------- #
def conv2d(w_in, w_out, k, *, stride=1, groups=1, bias=False):
"""Helper for building a conv2d layer."""
assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues."
s, p, g, b = stride, (k - 1) // 2, groups, bias
return nn.Conv2d(w_in, w_out, k, stride=s, padding=p, groups=g, bias=b)
def patchify2d(w_in, w_out, k, *, bias=True):
"""Helper for building a patchify layer as used by ViT models."""
return nn.Conv2d(w_in, w_out, k, stride=k, padding=0, bias=bias)
def norm2d(w_in):
"""Helper for building a norm2d layer."""
return nn.BatchNorm2d(num_features=w_in, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
def pool2d(_w_in, k, *, stride=1):
"""Helper for building a pool2d layer."""
assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues."
return nn.MaxPool2d(k, stride=stride, padding=(k - 1) // 2)
def gap2d(_w_in):
"""Helper for building a gap2d layer."""
return nn.AdaptiveAvgPool2d((1, 1))
def layernorm(w_in):
"""Helper for building a layernorm layer."""
return nn.LayerNorm(w_in, eps=cfg.LN.EPS)
def linear(w_in, w_out, *, bias=False):
"""Helper for building a linear layer."""
return nn.Linear(w_in, w_out, bias=bias)
def activation(activation_fun=None):
"""Helper for building an activation layer."""
activation_fun = (activation_fun or cfg.MODEL.ACTIVATION_FUN).lower()
if activation_fun == "relu":
return nn.ReLU(inplace=cfg.MODEL.ACTIVATION_INPLACE)
elif activation_fun == "silu" or activation_fun == "swish":
try:
return torch.nn.SiLU()
except AttributeError:
return SiLU()
elif activation_fun == "gelu":
return torch.nn.GELU()
else:
raise AssertionError("Unknown MODEL.ACTIVATION_FUN: " + activation_fun)
# --------------------------- Complexity (cx) calculations --------------------------- #
def conv2d_cx(cx, w_in, w_out, k, *, stride=1, groups=1, bias=False):
"""Accumulates complexity of conv2d into cx = (h, w, flops, params, acts)."""
assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues."
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
h, w = (h - 1) // stride + 1, (w - 1) // stride + 1
flops += k * k * w_in * w_out * h * w // groups + (w_out * h * w if bias else 0)
params += k * k * w_in * w_out // groups + (w_out if bias else 0)
acts += w_out * h * w
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def patchify2d_cx(cx, w_in, w_out, k, *, bias=True):
"""Accumulates complexity of patchify2d into cx = (h, w, flops, params, acts)."""
err_str = "Only kernel sizes divisible by the input size are supported."
assert cx["h"] % k == 0 and cx["w"] % k == 0, err_str
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
h, w = h // k, w // k
flops += k * k * w_in * w_out * h * w + (w_out * h * w if bias else 0)
params += k * k * w_in * w_out + (w_out if bias else 0)
acts += w_out * h * w
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def norm2d_cx(cx, w_in):
"""Accumulates complexity of norm2d into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
params += 2 * w_in
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def pool2d_cx(cx, w_in, k, *, stride=1):
"""Accumulates complexity of pool2d into cx = (h, w, flops, params, acts)."""
assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues."
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
h, w = (h - 1) // stride + 1, (w - 1) // stride + 1
acts += w_in * h * w
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def gap2d_cx(cx, _w_in):
"""Accumulates complexity of gap2d into cx = (h, w, flops, params, acts)."""
flops, params, acts = cx["flops"], cx["params"], cx["acts"]
return {"h": 1, "w": 1, "flops": flops, "params": params, "acts": acts}
def layernorm_cx(cx, w_in):
"""Accumulates complexity of layernorm into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
params += 2 * w_in
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def linear_cx(cx, w_in, w_out, *, bias=False, num_locations=1):
"""Accumulates complexity of linear into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
flops += w_in * w_out * num_locations + (w_out * num_locations if bias else 0)
params += w_in * w_out + (w_out if bias else 0)
acts += w_out * num_locations
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
# ---------------------------------- Shared blocks ----------------------------------- #
class SiLU(Module):
"""SiLU activation function (also known as Swish): x * sigmoid(x)."""
# Note: will be part of Pytorch 1.7, at which point can remove this.
def __init__(self):
super(SiLU, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
class SE(Module):
"""Squeeze-and-Excitation (SE) block: AvgPool, FC, Act, FC, Sigmoid."""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self.avg_pool = gap2d(w_in)
self.f_ex = nn.Sequential(
conv2d(w_in, w_se, 1, bias=True),
activation(),
conv2d(w_se, w_in, 1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
@staticmethod
def complexity(cx, w_in, w_se):
h, w = cx["h"], cx["w"]
cx = gap2d_cx(cx, w_in)
cx = conv2d_cx(cx, w_in, w_se, 1, bias=True)
cx = conv2d_cx(cx, w_se, w_in, 1, bias=True)
cx["h"], cx["w"] = h, w
return cx
class MultiheadAttention(Module):
"""Multi-head Attention block from Transformer models."""
def __init__(self, hidden_d, n_heads):
super(MultiheadAttention, self).__init__()
self.block = nn.MultiheadAttention(hidden_d, n_heads, batch_first=False)
def forward(self, query, key, value, need_weights=False):
return self.block(query=query, key=key, value=value, need_weights=need_weights)
@staticmethod
def complexity(cx, hidden_d, n_heads, seq_len):
# See https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
h, w = cx["h"], cx["w"]
flops, params, acts = cx["flops"], cx["params"], cx["acts"]
# q, k, v = linear(input).chunk(3)
flops += seq_len * (hidden_d * hidden_d * 3 + hidden_d * 3)
params += hidden_d * hidden_d * 3 + hidden_d * 3
acts += hidden_d * 3 * seq_len
# attn_output_weights = torch.bmm(q, k.transpose)
head_d = hidden_d // n_heads
flops += n_heads * (seq_len * head_d * seq_len)
acts += n_heads * seq_len * seq_len
# attn_output = torch.bmm(attn_output_weights, v)
flops += n_heads * (seq_len * seq_len * head_d)
acts += n_heads * seq_len * head_d
# attn_output = linear(attn_output)
flops += seq_len * (hidden_d * hidden_d + hidden_d)
params += hidden_d * hidden_d + hidden_d
acts += hidden_d * seq_len
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
# ---------------------------------- Miscellaneous ----------------------------------- #
def adjust_block_compatibility(ws, bs, gs):
"""Adjusts the compatibility of widths, bottlenecks, and groups."""
assert len(ws) == len(bs) == len(gs)
assert all(w > 0 and b > 0 and g > 0 for w, b, g in zip(ws, bs, gs))
assert all(b < 1 or b % 1 == 0 for b in bs)
vs = [int(max(1, w * b)) for w, b in zip(ws, bs)]
gs = [int(min(g, v)) for g, v in zip(gs, vs)]
ms = [np.lcm(g, int(b)) if b > 1 else g for g, b in zip(gs, bs)]
vs = [max(m, int(round(v / m) * m)) for v, m in zip(vs, ms)]
ws = [int(v / b) for v, b in zip(vs, bs)]
assert all(w * b % g == 0 for w, b, g in zip(ws, bs, gs))
return ws, bs, gs
def init_weights(m):
"""Performs ResNet-style weight initialization."""
if isinstance(m, nn.Conv2d):
# Note that there is no bias due to BN
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out))
elif isinstance(m, nn.BatchNorm2d):
zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA
zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.01)
m.bias.data.zero_()
def drop_connect(x, drop_ratio):
"""Drop connect (adapted from DARTS)."""
keep_ratio = 1.0 - drop_ratio
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
mask.bernoulli_(keep_ratio)
x.div_(keep_ratio)
x.mul_(mask)
return x