/
multi_backbone.py
125 lines (108 loc) · 4.55 KB
/
multi_backbone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import copy
import torch
import warnings
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn
from mmdet.models import BACKBONES, build_backbone
@BACKBONES.register_module()
class MultiBackbone(BaseModule):
"""MultiBackbone with different configs.
Args:
num_streams (int): The number of backbones.
backbones (list or dict): A list of backbone configs.
aggregation_mlp_channels (list[int]): Specify the mlp layers
for feature aggregation.
conv_cfg (dict): Config dict of convolutional layers.
norm_cfg (dict): Config dict of normalization layers.
act_cfg (dict): Config dict of activation layers.
suffixes (list): A list of suffixes to rename the return dict
for each backbone.
"""
def __init__(self,
num_streams,
backbones,
aggregation_mlp_channels=None,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
act_cfg=dict(type='ReLU'),
suffixes=('net0', 'net1'),
init_cfg=None,
pretrained=None,
**kwargs):
super().__init__(init_cfg=init_cfg)
assert isinstance(backbones, dict) or isinstance(backbones, list)
if isinstance(backbones, dict):
backbones_list = []
for ind in range(num_streams):
backbones_list.append(copy.deepcopy(backbones))
backbones = backbones_list
assert len(backbones) == num_streams
assert len(suffixes) == num_streams
self.backbone_list = nn.ModuleList()
# Rename the ret_dict with different suffixs.
self.suffixes = suffixes
out_channels = 0
for backbone_cfg in backbones:
out_channels += backbone_cfg['fp_channels'][-1][-1]
self.backbone_list.append(build_backbone(backbone_cfg))
# Feature aggregation layers
if aggregation_mlp_channels is None:
aggregation_mlp_channels = [
out_channels, out_channels // 2,
out_channels // len(self.backbone_list)
]
else:
aggregation_mlp_channels.insert(0, out_channels)
self.aggregation_layers = nn.Sequential()
for i in range(len(aggregation_mlp_channels) - 1):
self.aggregation_layers.add_module(
f'layer{i}',
ConvModule(
aggregation_mlp_channels[i],
aggregation_mlp_channels[i + 1],
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=True,
inplace=True))
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@auto_fp16()
def forward(self, points):
"""Forward pass.
Args:
points (torch.Tensor): point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
dict[str, list[torch.Tensor]]: Outputs from multiple backbones.
- fp_xyz[suffix] (list[torch.Tensor]): The coordinates of
each fp features.
- fp_features[suffix] (list[torch.Tensor]): The features
from each Feature Propagate Layers.
- fp_indices[suffix] (list[torch.Tensor]): Indices of the
input points.
- hd_feature (torch.Tensor): The aggregation feature
from multiple backbones.
"""
ret = {}
fp_features = []
for ind in range(len(self.backbone_list)):
cur_ret = self.backbone_list[ind](points)
cur_suffix = self.suffixes[ind]
fp_features.append(cur_ret['fp_features'][-1])
if cur_suffix != '':
for k in cur_ret.keys():
cur_ret[k + '_' + cur_suffix] = cur_ret.pop(k)
ret.update(cur_ret)
# Combine the features here
hd_feature = torch.cat(fp_features, dim=1)
hd_feature = self.aggregation_layers(hd_feature)
ret['hd_feature'] = hd_feature
return ret