Skip to content

Commit

Permalink
fix mmdet3d
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon committed Sep 12, 2023
1 parent 985a4f3 commit b291a1c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet3d/deploy/mono_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def create_input(

if data_preprocessor is not None:
collate_data = data_preprocessor(collate_data, False)
inputs = collate_data['inputs']
inputs = collate_data['inputs']['imgs']
else:
inputs = collate_data['inputs']
return collate_data, inputs
Expand Down
10 changes: 6 additions & 4 deletions mmdeploy/codebase/mmdet3d/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@
'mmdet3d.models.detectors.Base3DDetector.forward' # noqa: E501
)
def basedetector__forward(self,
inputs: list,
voxels: torch.Tensor,
num_points: torch.Tensor,
coors: torch.Tensor,
data_samples=None,
**kwargs) -> Tuple[List[torch.Tensor]]:
"""Extract features of images."""

batch_inputs_dict = {
'voxels': {
'voxels': inputs[0],
'num_points': inputs[1],
'coors': inputs[2]
'voxels': voxels,
'num_points': num_points,
'coors': coors
}
}
return self._forward(batch_inputs_dict, data_samples, **kwargs)
11 changes: 7 additions & 4 deletions mmdeploy/codebase/mmdet3d/models/single_stage_mono3d.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.single_stage_mono3d.'
'SingleStageMono3DDetector.forward')
def singlestagemono3ddetector__forward(self, inputs: list, **kwargs):
"""Rewrite this func to r.
def singlestagemono3ddetector__forward(self, inputs: Tensor, **kwargs):
"""Rewrite to support feed inputs of Tensor type.
Args:
inputs (dict): Input dict comprises `imgs`
inputs (Tensor): Input image
Returns:
list: two torch.Tensor
"""
x = self.extract_feat(inputs)

x = self.extract_feat({'imgs': inputs})
results = self.bbox_head.forward(x)
return results[0], results[1]
2 changes: 1 addition & 1 deletion tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_pointpillars(backend_type: Backend):
cfg=deploy_cfg,
backend=deploy_cfg.backend_config.type,
opset=deploy_cfg.onnx_config.opset_version):
outputs = model.forward(data)
outputs = model.forward(*data)
assert len(outputs) == 3


Expand Down

0 comments on commit b291a1c

Please sign in to comment.