Skip to content

Commit

Permalink
[Feature] Add SMOKE Head and Coder(v1.0.0.dev0) (#959)
Browse files Browse the repository at this point in the history
* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import (#839)

* [Enhance]  refactor  iou_neg_piecewise_sampler.py (#842)

* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import

* refactor iou_neg_piecewise_sampler.py

* add docstring

* modify docstring

Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com>
Co-authored-by: THU17cyz <congyezhen71@hotmail.com>

* [Feature] Add roipooling cuda ops (#843)

* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import

* add roipooling cuda ops

* add roi extractor

* add test_roi_extractor unittest

* Modify setup.py to install roipooling ops

* modify docstring

* remove enlarge bbox in roipoint pooling

* add_roipooling_ops

* modify docstring

Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com>
Co-authored-by: THU17cyz <congyezhen71@hotmail.com>

* [Refactor] Refactor code structure and docstrings (#803)

* refactor points_in_boxes

* Merge same functions of three boxes

* More docstring fixes and unify x/y/z size

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Remove None in function param type

* Fix unittest

* Add comments for NMS functions

* Merge methods of Points

* Add unittest

* Add optional and default value

* Fix box conversion and add unittest

* Fix comments

* Add unit test

* Indent

* Fix CI

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Add unit test for box bev

* More unit tests and refine docstrings in box_np_ops

* Fix comment

* Add deprecation warning

* [Feature] PointXYZWHLRBBoxCoder (#856)

* support PointBasedBoxCoder

* fix unittest bug

* support unittest in gpu

* support unittest in gpu

* modified docstring

* add args

* add args

* [Enhance] Change Groupfree3D config (#855)

* All mods

* PointSample

* PointSample

* [Doc] Add tutorials/data_pipeline Chinese version (#827)

* [Doc] Add tutorials/data_pipeline Chinese version

* refine doc

* Use the absolute link

* Use the absolute link

Co-authored-by: Tai-Wang <tab_wang@outlook.com>

* [Doc] Add Chinese doc for `scannet_det.md` (#836)

* Part

* Complete

* Fix comments

* Fix comments

* [Doc] Add Chinese doc for `waymo_det.md` (#859)

* Add complete translation

* Refinements

* Fix comments

* Fix a minor typo

Co-authored-by: Tai-Wang <tab_wang@outlook.com>

* Remove 2D annotations on Lyft (#867)

* Add header for files (#869)

* Add header for files

* Add header for files

* Add header for files

* Add header for files

* [fix] fix typos (#872)

* Fix 3 unworking configs (#882)

* [Fix] Fix `index.rst` for Chinese docs (#873)

* Fix index.rst for zh docs

* Change switch language

* [Fix] Centerpoint head nested list transpose  (#879)

* FIX Transpose nested lists without Numpy

* Removed unused Numpy import

* [Enhance] Update PointFusion (#791)

* update point fusion

* remove LIDAR hardcode

* move get_proj_mat_by_coord_type to utils

* fix lint

* remove todo

* fix lint

* [Doc] Add nuscenes_det.md Chinese version (#854)

* add nus chinese doc

* add nuScenes Chinese doc

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* [Fix] Fix RegNet pretrained weight loading (#889)

* Fix regnet pretrained weight loading

* Remove unused file

* Fix centerpoint tta (#892)

* [Enhance] Add benchmark regression script (#808)

* Initial commit

* [Feature] Support DGCNN (v1.0.0.dev0) (#896)

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* fix typo

* fix typo

* fix typo

* del gf&fa registry (wo reuse pointnet module)

* fix typo

* add benchmark and add copyright header (for DGCNN only)

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* support dgcnn

* Change cam rot_3d_in_axis (#906)

* [Doc] Add coord sys tutorial pic and change links to dev branch (#912)

* Modify link branch and add pic

* Fix pic

* [Feature] add kitti AP40 evaluation metric (v1.0.0.dev0) (#927)

* Add citation (#901)

* [Feature] Add python3.9 in CI (#900)

* Add python3.0 in CI

* Add python3.0 in CI

* Bump to v0.17.0 (#898)

* Update README.md

* Update README_zh-CN.md

* Update version.py

* Update getting_started.md

* Update getting_started.md

* Update changelog.md

* Remove "recent" in the news

* Remove "recent" in the news

* Fix comments

* [Docs] Fix the version of sphinx (#902)

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* add AP40

* add unitest

* add unitest

* seperate AP11 and AP40

* fix some typos

Co-authored-by: dingchang <hudingchang.vendor@sensetime.com>
Co-authored-by: Tai-Wang <tab_wang@outlook.com>

* [Feature] add smoke backbone neck (#939)

* add smoke detecotor and it's backbone and neck

* typo fix

* fix typo

* add docstring

* fix typo

* fix comments

* fix comments

* fix comments

* fix typo

* fix typo

* fix

* fix typo

* fix docstring

* refine feature

* fix typo

* use Basemodule in Neck

* [Refactor] Refactor the transformation from image to camera coordinates (#938)

* Refactor points_img2cam

* Refine docstring

* Support array converter and add unit tests

* [Feature] FCOS3D BBox Coder (#940)

* FCOS3D BBox Coder

* Add unit tests

* Change the value from long to float/double

* Rename bbox_out as bbox

* Add comments to forward returns

* [Feature] PGD BBox Coder (#948)

* Support PGD BBox Coder

* Refine docstring

* [Feature] Support Uncertain L1 Loss (#950)

* Add uncertain l1 loss and its unit tests

* Remove mmcv.jit and refine docstrings

* [Fix] Fix visualization in KITTI dataset (#956)

* fix bug to support kitti vis

* fix

* add smoke head

* fix typo

* fix code

* fix

* fix typo

* add detector

* delete dectector

* fix test_heads

* add coder test

* fix head

* fix bugs in smoke head

* refine kitti_mono3d data config

* remove cuda is available

* fix docstring

* fix unitest

* fix typo

* refine

* add affine aug

* fix rebase typo

* fix docs

* change cam_intrinsics to cam2imgs

* fix typos

* fix lint

* fix bugs

* fix las typos

* fix typos

* add dosctrings for trans_mats

Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com>
Co-authored-by: Xi Liu <75658786+xiliu8006@users.noreply.github.com>
Co-authored-by: THU17cyz <congyezhen71@hotmail.com>
Co-authored-by: Wenhao Wu <79644370+wHao-Wu@users.noreply.github.com>
Co-authored-by: Tai-Wang <tab_wang@outlook.com>
Co-authored-by: dingchang <hudingchang.vendor@sensetime.com>
Co-authored-by: 谢恩泽 <Johnny_ez@163.com>
Co-authored-by: Robin Karlsson <34254153+robin-karlsson0@users.noreply.github.com>
Co-authored-by: Danila Rukhovich <danrukh@gmail.com>
  • Loading branch information
10 people committed Sep 29, 2021
1 parent e93a77f commit f268ba4
Show file tree
Hide file tree
Showing 6 changed files with 855 additions and 3 deletions.
3 changes: 2 additions & 1 deletion mmdet3d/core/bbox/coders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder
from .pgd_bbox_coder import PGDBBoxCoder
from .point_xyzwhlr_bbox_coder import PointXYZWHLRBBoxCoder
from .smoke_bbox_coder import SMOKECoder

__all__ = [
'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder',
'CenterPointBBoxCoder', 'AnchorFreeBBoxCoder', 'GroupFree3DBBoxCoder',
'PointXYZWHLRBBoxCoder', 'FCOS3DBBoxCoder', 'PGDBBoxCoder'
'PointXYZWHLRBBoxCoder', 'FCOS3DBBoxCoder', 'PGDBBoxCoder', 'SMOKECoder'
]
206 changes: 206 additions & 0 deletions mmdet3d/core/bbox/coders/smoke_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import numpy as np
import torch

from mmdet.core.bbox import BaseBBoxCoder
from mmdet.core.bbox.builder import BBOX_CODERS


@BBOX_CODERS.register_module()
class SMOKECoder(BaseBBoxCoder):
"""Bbox Coder for SMOKE.
Args:
base_depth (tuple[float]): Depth references for decode box depth.
base_dims (tuple[tuple[float]]): Dimension references [l, h, w]
for decode box dimension for each category.
code_size (int): The dimension of boxes to be encoded.
"""

def __init__(self, base_depth, base_dims, code_size):
super(SMOKECoder, self).__init__()
self.base_depth = base_depth
self.base_dims = base_dims
self.bbox_code_size = code_size

def encode(self, locations, dimensions, orientations, input_metas):
"""Encode CameraInstance3DBoxes by locations, dimemsions, orientations.
Args:
locations (Tensor): Center location for 3D boxes.
(N, 3)
dimensions (Tensor): Dimensions for 3D boxes.
shape (N, 3)
orientations (Tensor): Orientations for 3D boxes.
shape (N, 1)
input_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Return:
:obj:`CameraInstance3DBoxes`: 3D bboxes of batch images,
shape (N, bbox_code_size).
"""

bboxes = torch.cat((locations, dimensions, orientations), dim=1)
assert bboxes.shape[1] == self.bbox_code_size, 'bboxes shape dose not'\
'match the bbox_code_size.'
batch_bboxes = input_metas[0]['box_type_3d'](
bboxes, box_dim=self.bbox_code_size)

return batch_bboxes

def decode(self,
reg,
points,
labels,
cam2imgs,
trans_mats,
locations=None):
"""Decode regression into locations, dimemsions, orientations.
Args:
reg (Tensor): Batch regression for each predict center2d point.
shape: (batch * K (max_objs), C)
points(Tensor): Batch projected bbox centers on image plane.
shape: (batch * K (max_objs) , 2)
labels (Tensor): Batch predict class label for each predict
center2d point.
shape: (batch, K (max_objs))
cam2imgs (Tensor): Batch images' camera intrinsic matrix.
shape: (batch, 4, 4)
trans_mats (Tensor): transformation matrix from original image
to feature map.
shape: (batch, 3, 3)
locations (None | Tensor): if locations is None, this function
is used to decode while inference, otherwise, it's used while
training using the ground truth 3d bbox locations.
shape: (batch * K (max_objs), 3)
Return:
tuple(Tensor): The tuple has components below:
- locations (Tensor): Centers of 3D boxes.
shape: (batch * K (max_objs), 3)
- dimensions (Tensor): Dimensions of 3D boxes.
shpae: (batch * K (max_objs), 3)
- orientations (Tensor): Orientations of 3D
boxes.
shape: (batch * K (max_objs), 1)
"""
depth_offsets = reg[:, 0]
centers2d_offsets = reg[:, 1:3]
dimensions_offsets = reg[:, 3:6]
orientations = reg[:, 6:8]
depths = self._decode_depth(depth_offsets)
# get the 3D Bounding box's center location.
pred_locations = self._decode_location(points, centers2d_offsets,
depths, cam2imgs, trans_mats)
pred_dimensions = self._decode_dimension(labels, dimensions_offsets)
if locations is None:
pred_orientations = self._decode_orientation(
orientations, pred_locations)
else:
pred_orientations = self._decode_orientation(
orientations, locations)

return pred_locations, pred_dimensions, pred_orientations

def _decode_depth(self, depth_offsets):
"""Transform depth offset to depth."""
base_depth = depth_offsets.new_tensor(self.base_depth)
depths = depth_offsets * base_depth[1] + base_depth[0]

return depths

def _decode_location(self, points, centers2d_offsets, depths, cam2imgs,
trans_mats):
"""Retrieve objects location in camera coordinate based on projected
points.
Args:
points (Tensor): Projected points on feature map in (x, y)
shape: (batch * K, 2)
centers2d_offset (Tensor): Project points offset in
(delta_x, delta_y). shape: (batch * K, 2)
depths (Tensor): Object depth z.
shape: (batch * K)
cam2imgs (Tensor): Batch camera intrinsics matrix.
shape: (batch, 4, 4)
trans_mats (Tensor): transformation matrix from original image
to feature map.
shape: (batch, 3, 3)
"""
# number of points
N = centers2d_offsets.shape[0]
# batch_size
N_batch = cam2imgs.shape[0]
batch_id = torch.arange(N_batch).unsqueeze(1)
obj_id = batch_id.repeat(1, N // N_batch).flatten()
trans_mats_inv = trans_mats.inverse()[obj_id]
cam2imgs_inv = cam2imgs.inverse()[obj_id]
centers2d = points + centers2d_offsets
centers2d_extend = torch.cat((centers2d, centers2d.new_ones(N, 1)),
dim=1)
# expand project points as [N, 3, 1]
centers2d_extend = centers2d_extend.unsqueeze(-1)
# transform project points back on original image
centers2d_img = torch.matmul(trans_mats_inv, centers2d_extend)
centers2d_img = centers2d_img * depths.view(N, -1, 1)
centers2d_img_extend = torch.cat(
(centers2d_img, centers2d.new_ones(N, 1, 1)), dim=1)
locations = torch.matmul(cam2imgs_inv, centers2d_img_extend).squeeze(2)

return locations[:, :3]

def _decode_dimension(self, labels, dims_offset):
"""Transform dimension offsets to dimension according to its category.
Args:
labels (Tensor): Each points' category id.
shape: (N, K)
dims_offset (Tensor): Dimension offsets.
shape: (N, 3)
"""
labels = labels.flatten().long()
base_dims = dims_offset.new_tensor(self.base_dims)
dims_select = base_dims[labels, :]
dimensions = dims_offset.exp() * dims_select

return dimensions

def _decode_orientation(self, ori_vector, locations):
"""Retrieve object orientation.
Args:
ori_vector (Tensor): Local orientation in [sin, cos] format.
shape: (N, 2)
locations (Tensor): Object location.
shape: (N, 3)
Return:
Tensor: yaw(Orientation). Notice that the yaw's
range is [-np.pi, np.pi].
shape:(N, 1)
"""
assert len(ori_vector) == len(locations)
locations = locations.view(-1, 3)
rays = torch.atan(locations[:, 0] / (locations[:, 2] + 1e-7))
alphas = torch.atan(ori_vector[:, 0] / (ori_vector[:, 1] + 1e-7))

# get cosine value positive and negtive index.
cos_pos_inds = (ori_vector[:, 1] >= 0).nonzero()
cos_neg_inds = (ori_vector[:, 1] < 0).nonzero()

alphas[cos_pos_inds] -= np.pi / 2
alphas[cos_neg_inds] += np.pi / 2
# retrieve object rotation y angle.
yaws = alphas + rays

larger_inds = (yaws > np.pi).nonzero()
small_inds = (yaws < -np.pi).nonzero()

if len(larger_inds) != 0:
yaws[larger_inds] -= 2 * np.pi
if len(small_inds) != 0:
yaws[small_inds] += 2 * np.pi

yaws = yaws.unsqueeze(-1)
return yaws
3 changes: 2 additions & 1 deletion mmdet3d/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from .groupfree3d_head import GroupFree3DHead
from .parta2_rpn_head import PartA2RPNHead
from .shape_aware_head import ShapeAwareHead
from .smoke_mono3d_head import SMOKEMono3DHead
from .ssd_3d_head import SSD3DHead
from .vote_head import VoteHead

__all__ = [
'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead',
'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead',
'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead',
'GroupFree3DHead'
'GroupFree3DHead', 'SMOKEMono3DHead'
]
Loading

0 comments on commit f268ba4

Please sign in to comment.