Skip to content

Commit

Permalink
bugfixed: add **kwargs to each roi_head to support variable argument (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sshaoshuai committed Jun 9, 2021
1 parent 686cf44 commit 26a1612
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 6 deletions.
3 changes: 2 additions & 1 deletion pcdet/models/detectors/detector3d_template.py
Expand Up @@ -77,7 +77,8 @@ def build_backbone_3d(self, model_info_dict):
)
model_info_dict['module_list'].append(backbone_3d_module)
model_info_dict['num_point_features'] = backbone_3d_module.num_point_features
model_info_dict['backbone_channels'] = backbone_3d_module.backbone_channels
model_info_dict['backbone_channels'] = backbone_3d_module.backbone_channels \
if hasattr(backbone_3d_module, 'backbone_channels') else None
return backbone_3d_module, model_info_dict

def build_map_to_bev_module(self, model_info_dict):
Expand Down
2 changes: 1 addition & 1 deletion pcdet/models/roi_heads/partA2_head.py
Expand Up @@ -8,7 +8,7 @@


class PartA2FCHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1):
def __init__(self, input_channels, model_cfg, num_class=1, **kwargs):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg

Expand Down
2 changes: 1 addition & 1 deletion pcdet/models/roi_heads/pointrcnn_head.py
Expand Up @@ -8,7 +8,7 @@


class PointRCNNHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1):
def __init__(self, input_channels, model_cfg, num_class=1, **kwargs):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg
use_bn = self.model_cfg.USE_BN
Expand Down
2 changes: 1 addition & 1 deletion pcdet/models/roi_heads/pvrcnn_head.py
Expand Up @@ -6,7 +6,7 @@


class PVRCNNHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1):
def __init__(self, input_channels, model_cfg, num_class=1, **kwargs):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg

Expand Down
2 changes: 1 addition & 1 deletion pcdet/models/roi_heads/roi_head_template.py
Expand Up @@ -9,7 +9,7 @@


class RoIHeadTemplate(nn.Module):
def __init__(self, num_class, model_cfg):
def __init__(self, num_class, model_cfg, **kwargs):
super().__init__()
self.model_cfg = model_cfg
self.num_class = num_class
Expand Down
2 changes: 1 addition & 1 deletion pcdet/models/roi_heads/second_head.py
Expand Up @@ -5,7 +5,7 @@


class SECONDHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1):
def __init__(self, input_channels, model_cfg, num_class=1, **kwargs):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg

Expand Down

0 comments on commit 26a1612

Please sign in to comment.