-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
resnet_helper.py
725 lines (681 loc) · 24.2 KB
/
resnet_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Video models."""
import torch
import torch.nn as nn
from slowfast.models.common import drop_path
from slowfast.models.nonlocal_helper import Nonlocal
from slowfast.models.operators import SE, Swish
def get_trans_func(name):
"""
Retrieves the transformation module by name.
"""
trans_funcs = {
"bottleneck_transform": BottleneckTransform,
"basic_transform": BasicTransform,
"x3d_transform": X3DTransform,
}
assert (
name in trans_funcs.keys()
), "Transformation function '{}' not supported".format(name)
return trans_funcs[name]
class BasicTransform(nn.Module):
"""
Basic transformation: Tx3x3, 1x3x3, where T is the size of temporal kernel.
"""
def __init__(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner=None,
num_groups=1,
stride_1x1=None,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
norm_module=nn.BatchNorm3d,
block_idx=0,
):
"""
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temp_kernel_size (int): the temporal kernel sizes of the first
convolution in the basic block.
stride (int): the stride of the bottleneck.
dim_inner (None): the inner dimension would not be used in
BasicTransform.
num_groups (int): number of groups for the convolution. Number of
group is always 1 for BasicTransform.
stride_1x1 (None): stride_1x1 will not be used in BasicTransform.
inplace_relu (bool): if True, calculate the relu on the original
input without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super(BasicTransform, self).__init__()
self.temp_kernel_size = temp_kernel_size
self._inplace_relu = inplace_relu
self._eps = eps
self._bn_mmt = bn_mmt
self._construct(dim_in, dim_out, stride, norm_module)
def _construct(self, dim_in, dim_out, stride, norm_module):
# Tx3x3, BN, ReLU.
self.a = nn.Conv3d(
dim_in,
dim_out,
kernel_size=[self.temp_kernel_size, 3, 3],
stride=[1, stride, stride],
padding=[int(self.temp_kernel_size // 2), 1, 1],
bias=False,
)
self.a_bn = norm_module(
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt
)
self.a_relu = nn.ReLU(inplace=self._inplace_relu)
# 1x3x3, BN.
self.b = nn.Conv3d(
dim_out,
dim_out,
kernel_size=[1, 3, 3],
stride=[1, 1, 1],
padding=[0, 1, 1],
bias=False,
)
self.b_bn = norm_module(
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt
)
self.b_bn.transform_final_bn = True
def forward(self, x):
x = self.a(x)
x = self.a_bn(x)
x = self.a_relu(x)
x = self.b(x)
x = self.b_bn(x)
return x
class X3DTransform(nn.Module):
"""
X3D transformation: 1x1x1, Tx3x3 (channelwise, num_groups=dim_in), 1x1x1,
augmented with (optional) SE (squeeze-excitation) on the 3x3x3 output.
T is the temporal kernel size (defaulting to 3)
"""
def __init__(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups,
stride_1x1=False,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
dilation=1,
norm_module=nn.BatchNorm3d,
se_ratio=0.0625,
swish_inner=True,
block_idx=0,
):
"""
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temp_kernel_size (int): the temporal kernel sizes of the middle
convolution in the bottleneck.
stride (int): the stride of the bottleneck.
dim_inner (int): the inner dimension of the block.
num_groups (int): number of groups for the convolution. num_groups=1
is for standard ResNet like networks, and num_groups>1 is for
ResNeXt like networks.
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise
apply stride to the 3x3 conv.
inplace_relu (bool): if True, calculate the relu on the original
input without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
dilation (int): size of dilation.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
se_ratio (float): if > 0, apply SE to the Tx3x3 conv, with the SE
channel dimensionality being se_ratio times the Tx3x3 conv dim.
swish_inner (bool): if True, apply swish to the Tx3x3 conv, otherwise
apply ReLU to the Tx3x3 conv.
"""
super(X3DTransform, self).__init__()
self.temp_kernel_size = temp_kernel_size
self._inplace_relu = inplace_relu
self._eps = eps
self._bn_mmt = bn_mmt
self._se_ratio = se_ratio
self._swish_inner = swish_inner
self._stride_1x1 = stride_1x1
self._block_idx = block_idx
self._construct(
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
dilation,
norm_module,
)
def _construct(
self,
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
dilation,
norm_module,
):
(str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride)
# 1x1x1, BN, ReLU.
self.a = nn.Conv3d(
dim_in,
dim_inner,
kernel_size=[1, 1, 1],
stride=[1, str1x1, str1x1],
padding=[0, 0, 0],
bias=False,
)
self.a_bn = norm_module(
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt
)
self.a_relu = nn.ReLU(inplace=self._inplace_relu)
# Tx3x3, BN, ReLU.
self.b = nn.Conv3d(
dim_inner,
dim_inner,
[self.temp_kernel_size, 3, 3],
stride=[1, str3x3, str3x3],
padding=[int(self.temp_kernel_size // 2), dilation, dilation],
groups=num_groups,
bias=False,
dilation=[1, dilation, dilation],
)
self.b_bn = norm_module(
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt
)
# Apply SE attention or not
use_se = True if (self._block_idx + 1) % 2 else False
if self._se_ratio > 0.0 and use_se:
self.se = SE(dim_inner, self._se_ratio)
if self._swish_inner:
self.b_relu = Swish()
else:
self.b_relu = nn.ReLU(inplace=self._inplace_relu)
# 1x1x1, BN.
self.c = nn.Conv3d(
dim_inner,
dim_out,
kernel_size=[1, 1, 1],
stride=[1, 1, 1],
padding=[0, 0, 0],
bias=False,
)
self.c_bn = norm_module(
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt
)
self.c_bn.transform_final_bn = True
def forward(self, x):
for block in self.children():
x = block(x)
return x
class BottleneckTransform(nn.Module):
"""
Bottleneck transformation: Tx1x1, 1x3x3, 1x1x1, where T is the size of
temporal kernel.
"""
def __init__(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups,
stride_1x1=False,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
dilation=1,
norm_module=nn.BatchNorm3d,
block_idx=0,
):
"""
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temp_kernel_size (int): the temporal kernel sizes of the first
convolution in the bottleneck.
stride (int): the stride of the bottleneck.
dim_inner (int): the inner dimension of the block.
num_groups (int): number of groups for the convolution. num_groups=1
is for standard ResNet like networks, and num_groups>1 is for
ResNeXt like networks.
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise
apply stride to the 3x3 conv.
inplace_relu (bool): if True, calculate the relu on the original
input without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
dilation (int): size of dilation.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super(BottleneckTransform, self).__init__()
self.temp_kernel_size = temp_kernel_size
self._inplace_relu = inplace_relu
self._eps = eps
self._bn_mmt = bn_mmt
self._stride_1x1 = stride_1x1
self._construct(
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
dilation,
norm_module,
)
def _construct(
self,
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
dilation,
norm_module,
):
(str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride)
# Tx1x1, BN, ReLU.
self.a = nn.Conv3d(
dim_in,
dim_inner,
kernel_size=[self.temp_kernel_size, 1, 1],
stride=[1, str1x1, str1x1],
padding=[int(self.temp_kernel_size // 2), 0, 0],
bias=False,
)
self.a_bn = norm_module(
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt
)
self.a_relu = nn.ReLU(inplace=self._inplace_relu)
# 1x3x3, BN, ReLU.
self.b = nn.Conv3d(
dim_inner,
dim_inner,
[1, 3, 3],
stride=[1, str3x3, str3x3],
padding=[0, dilation, dilation],
groups=num_groups,
bias=False,
dilation=[1, dilation, dilation],
)
self.b_bn = norm_module(
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt
)
self.b_relu = nn.ReLU(inplace=self._inplace_relu)
# 1x1x1, BN.
self.c = nn.Conv3d(
dim_inner,
dim_out,
kernel_size=[1, 1, 1],
stride=[1, 1, 1],
padding=[0, 0, 0],
bias=False,
)
self.c_bn = norm_module(
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt
)
self.c_bn.transform_final_bn = True
def forward(self, x):
# Explicitly forward every layer.
# Branch2a.
x = self.a(x)
x = self.a_bn(x)
x = self.a_relu(x)
# Branch2b.
x = self.b(x)
x = self.b_bn(x)
x = self.b_relu(x)
# Branch2c
x = self.c(x)
x = self.c_bn(x)
return x
class ResBlock(nn.Module):
"""
Residual block.
"""
def __init__(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
trans_func,
dim_inner,
num_groups=1,
stride_1x1=False,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
dilation=1,
norm_module=nn.BatchNorm3d,
block_idx=0,
drop_connect_rate=0.0,
):
"""
ResBlock class constructs redisual blocks. More details can be found in:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.
"Deep residual learning for image recognition."
https://arxiv.org/abs/1512.03385
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temp_kernel_size (int): the temporal kernel sizes of the middle
convolution in the bottleneck.
stride (int): the stride of the bottleneck.
trans_func (string): transform function to be used to construct the
bottleneck.
dim_inner (int): the inner dimension of the block.
num_groups (int): number of groups for the convolution. num_groups=1
is for standard ResNet like networks, and num_groups>1 is for
ResNeXt like networks.
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise
apply stride to the 3x3 conv.
inplace_relu (bool): calculate the relu on the original input
without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
dilation (int): size of dilation.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
drop_connect_rate (float): basic rate at which blocks are dropped,
linearly increases from input to output blocks.
"""
super(ResBlock, self).__init__()
self._inplace_relu = inplace_relu
self._eps = eps
self._bn_mmt = bn_mmt
self._drop_connect_rate = drop_connect_rate
self._construct(
dim_in,
dim_out,
temp_kernel_size,
stride,
trans_func,
dim_inner,
num_groups,
stride_1x1,
inplace_relu,
dilation,
norm_module,
block_idx,
)
def _construct(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
trans_func,
dim_inner,
num_groups,
stride_1x1,
inplace_relu,
dilation,
norm_module,
block_idx,
):
# Use skip connection with projection if dim or res change.
if (dim_in != dim_out) or (stride != 1):
self.branch1 = nn.Conv3d(
dim_in,
dim_out,
kernel_size=1,
stride=[1, stride, stride],
padding=0,
bias=False,
dilation=1,
)
self.branch1_bn = norm_module(
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt
)
self.branch2 = trans_func(
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups,
stride_1x1=stride_1x1,
inplace_relu=inplace_relu,
dilation=dilation,
norm_module=norm_module,
block_idx=block_idx,
)
self.relu = nn.ReLU(self._inplace_relu)
def forward(self, x):
f_x = self.branch2(x)
if self.training and self._drop_connect_rate > 0.0:
f_x = drop_path(f_x, self._drop_connect_rate)
if hasattr(self, "branch1"):
x = self.branch1_bn(self.branch1(x)) + f_x
else:
x = x + f_x
x = self.relu(x)
return x
class ResStage(nn.Module):
"""
Stage of 3D ResNet. It expects to have one or more tensors as input for
single pathway (C2D, I3D, Slow), and multi-pathway (SlowFast) cases.
More details can be found here:
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
"""
def __init__(
self,
dim_in,
dim_out,
stride,
temp_kernel_sizes,
num_blocks,
dim_inner,
num_groups,
num_block_temp_kernel,
nonlocal_inds,
nonlocal_group,
nonlocal_pool,
dilation,
instantiation="softmax",
trans_func_name="bottleneck_transform",
stride_1x1=False,
inplace_relu=True,
norm_module=nn.BatchNorm3d,
drop_connect_rate=0.0,
):
"""
The `__init__` method of any subclass should also contain these arguments.
ResStage builds p streams, where p can be greater or equal to one.
Args:
dim_in (list): list of p the channel dimensions of the input.
Different channel dimensions control the input dimension of
different pathways.
dim_out (list): list of p the channel dimensions of the output.
Different channel dimensions control the input dimension of
different pathways.
temp_kernel_sizes (list): list of the p temporal kernel sizes of the
convolution in the bottleneck. Different temp_kernel_sizes
control different pathway.
stride (list): list of the p strides of the bottleneck. Different
stride control different pathway.
num_blocks (list): list of p numbers of blocks for each of the
pathway.
dim_inner (list): list of the p inner channel dimensions of the
input. Different channel dimensions control the input dimension
of different pathways.
num_groups (list): list of number of p groups for the convolution.
num_groups=1 is for standard ResNet like networks, and
num_groups>1 is for ResNeXt like networks.
num_block_temp_kernel (list): extent the temp_kernel_sizes to
num_block_temp_kernel blocks, then fill temporal kernel size
of 1 for the rest of the layers.
nonlocal_inds (list): If the tuple is empty, no nonlocal layer will
be added. If the tuple is not empty, add nonlocal layers after
the index-th block.
dilation (list): size of dilation for each pathway.
nonlocal_group (list): list of number of p nonlocal groups. Each
number controls how to fold temporal dimension to batch
dimension before applying nonlocal transformation.
https://github.com/facebookresearch/video-nonlocal-net.
instantiation (string): different instantiation for nonlocal layer.
Supports two different instantiation method:
"dot_product": normalizing correlation matrix with L2.
"softmax": normalizing correlation matrix with Softmax.
trans_func_name (string): name of the the transformation function apply
on the network.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
drop_connect_rate (float): basic rate at which blocks are dropped,
linearly increases from input to output blocks.
"""
super(ResStage, self).__init__()
assert all(
(
num_block_temp_kernel[i] <= num_blocks[i]
for i in range(len(temp_kernel_sizes))
)
)
self.num_blocks = num_blocks
self.nonlocal_group = nonlocal_group
self._drop_connect_rate = drop_connect_rate
self.temp_kernel_sizes = [
(temp_kernel_sizes[i] * num_blocks[i])[: num_block_temp_kernel[i]]
+ [1] * (num_blocks[i] - num_block_temp_kernel[i])
for i in range(len(temp_kernel_sizes))
]
assert (
len(
{
len(dim_in),
len(dim_out),
len(temp_kernel_sizes),
len(stride),
len(num_blocks),
len(dim_inner),
len(num_groups),
len(num_block_temp_kernel),
len(nonlocal_inds),
len(nonlocal_group),
}
)
== 1
)
self.num_pathways = len(self.num_blocks)
self._construct(
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
trans_func_name,
stride_1x1,
inplace_relu,
nonlocal_inds,
nonlocal_pool,
instantiation,
dilation,
norm_module,
)
def _construct(
self,
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
trans_func_name,
stride_1x1,
inplace_relu,
nonlocal_inds,
nonlocal_pool,
instantiation,
dilation,
norm_module,
):
for pathway in range(self.num_pathways):
for i in range(self.num_blocks[pathway]):
# Retrieve the transformation function.
trans_func = get_trans_func(trans_func_name)
# Construct the block.
res_block = ResBlock(
dim_in[pathway] if i == 0 else dim_out[pathway],
dim_out[pathway],
self.temp_kernel_sizes[pathway][i],
stride[pathway] if i == 0 else 1,
trans_func,
dim_inner[pathway],
num_groups[pathway],
stride_1x1=stride_1x1,
inplace_relu=inplace_relu,
dilation=dilation[pathway],
norm_module=norm_module,
block_idx=i,
drop_connect_rate=self._drop_connect_rate,
)
self.add_module("pathway{}_res{}".format(pathway, i), res_block)
if i in nonlocal_inds[pathway]:
nln = Nonlocal(
dim_out[pathway],
dim_out[pathway] // 2,
nonlocal_pool[pathway],
instantiation=instantiation,
norm_module=norm_module,
)
self.add_module(
"pathway{}_nonlocal{}".format(pathway, i), nln
)
def forward(self, inputs):
output = []
for pathway in range(self.num_pathways):
x = inputs[pathway]
for i in range(self.num_blocks[pathway]):
m = getattr(self, "pathway{}_res{}".format(pathway, i))
x = m(x)
if hasattr(self, "pathway{}_nonlocal{}".format(pathway, i)):
nln = getattr(
self, "pathway{}_nonlocal{}".format(pathway, i)
)
b, c, t, h, w = x.shape
if self.nonlocal_group[pathway] > 1:
# Fold temporal dimension into batch dimension.
x = x.permute(0, 2, 1, 3, 4)
x = x.reshape(
b * self.nonlocal_group[pathway],
t // self.nonlocal_group[pathway],
c,
h,
w,
)
x = x.permute(0, 2, 1, 3, 4)
x = nln(x)
if self.nonlocal_group[pathway] > 1:
# Fold back to temporal dimension.
x = x.permute(0, 2, 1, 3, 4)
x = x.reshape(b, t, c, h, w)
x = x.permute(0, 2, 1, 3, 4)
output.append(x)
return output