Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add] CenterPoint compatibility with KITTI #924

Open
wants to merge 9 commits into
base: 1.0
Choose a base branch
from
80 changes: 80 additions & 0 deletions configs/_base_/models/centerpoint_005voxel_second_secfpn_kitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
voxel_size = [0.05, 0.05, 0.1]
model = dict(
type='CenterPoint',
pts_voxel_layer=dict(
max_num_points=5, voxel_size=voxel_size, max_voxels=(16000, 40000)),
pts_voxel_encoder=dict(type='HardSimpleVFE', num_features=4),
pts_middle_encoder=dict(
type='SparseEncoder',
in_channels=4,
sparse_shape=[41, 1600, 1408],
output_channels=128,
order=('conv', 'norm', 'act'),
encoder_channels=((16, 16, 32), (32, 32, 64), (64, 64, 128), (128,
128)),
encoder_paddings=((0, 0, 1), (0, 0, 1), (0, 0, [0, 1, 1]), (0, 0)),
block_type='basicblock'),
pts_backbone=dict(
type='SECOND',
in_channels=256,
out_channels=[128, 256],
layer_nums=[5, 5],
layer_strides=[1, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False)),
pts_neck=dict(
type='SECONDFPN',
in_channels=[128, 256],
out_channels=[256, 256],
upsample_strides=[1, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
use_conv_for_no_stride=True),
pts_bbox_head=dict(
type='CenterHead',
in_channels=sum([256, 256]),
tasks=[
dict(num_class=1, class_names=['Car']),
dict(num_class=1, class_names=['Pedestrian']),
dict(num_class=1, class_names=['Cyclist']),
],
common_heads=dict(reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2)),
share_conv_channel=64,
bbox_coder=dict(
type='CenterPointBBoxCoder',
post_center_range=[-10, -50, -10, 80.4, 50, 10],
max_num=100,
score_threshold=0.1,
out_size_factor=8,
voxel_size=voxel_size[:2],
code_size=7,
),
separate_head=dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3),
loss_cls=dict(type='GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(type='L1Loss', reduction='mean', loss_weight=0.25),
norm_bbox=True),
# model training and testing settings
train_cfg=dict(
pts=dict(
grid_size=[1408, 1600, 40],
voxel_size=voxel_size,
out_size_factor=8,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])),
test_cfg=dict(
pts=dict(
post_center_limit_range=[-10, -50, -10, 80.4, 50, 10],
max_per_img=500,
max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think min_radius should be modified according to classes

Copy link

@Xrenya Xrenya Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KickCellarDoor could please explain how these radii were obtained? I have calculated the statistic of NuScenes for each object and was not able to understand on what basis it was selected.
For example, for first radius is the car and it is equal 4m.

The average car has the following dimensions:
vehicle.car [widht=1.95447557, lenght=4.61892457, height=1.73131372] 

score_threshold=0.1,
out_size_factor=8,
voxel_size=voxel_size[:2],
nms_type='rotate',
pre_max_size=4096,
post_max_size=512,
nms_thr=0.2)))
31 changes: 31 additions & 0 deletions configs/_base_/schedules/cyclic_80e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# The schedule is usually used by models trained on KITTI dataset

# The learning rate set in the cyclic schedule is the initial learning rate
# rather than the max learning rate. Since the target_ratio is (10, 1e-4),
# the learning rate will change from 0.0018 to 0.018, than go to 0.0018*1e-4
lr = 0.0018
# The optimizer follows the setting in SECOND.Pytorch, but here we use
# the offcial AdamW optimizer implemented by PyTorch.
optimizer = dict(type='AdamW', lr=lr, betas=(0.95, 0.99), weight_decay=0.01)
optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2))
# We use cyclic learning rate and momentum schedule following SECOND.Pytorch
# https://github.com/traveller59/second.pytorch/blob/3aba19c9688274f75ebb5e576f65cfe54773c021/torchplus/train/learning_schedules_fastai.py#L69 # noqa
# We implement them in mmcv, for more details, please refer to
# https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/lr_updater.py#L327 # noqa
# https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/momentum_updater.py#L130 # noqa
lr_config = dict(
policy='cyclic',
target_ratio=(10, 1e-4),
cyclic_times=1,
step_ratio_up=0.4,
)
momentum_config = dict(
policy='cyclic',
target_ratio=(0.85 / 0.95, 1),
cyclic_times=1,
step_ratio_up=0.4,
)
# Although the max_epochs is 40, this schedule is usually used we
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_epochs is 80

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wonder if this schedule is necessary to reproduce CenterNet results? We can also use cyclic_40 with repeating dataset twice during training (although it can result in a little difference in terms of the used learning rate)?

# RepeatDataset with repeat ratio N, thus the actual max epoch
# number could be Nx40
runner = dict(type='EpochBasedRunner', max_epochs=80)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = [
'../_base_/datasets/kitti-3d-3class.py',
'../_base_/models/centerpoint_005voxel_second_secfpn_kitti.py',
'../_base_/schedules/cyclic_80e.py', '../_base_/default_runtime.py'
]

# If point cloud range is changed, the models should also change their point
# cloud range accordingly
point_cloud_range = [0, -40, -3, 70.4, 40, 1]

# Add 'point_cloud_range' into model config according to dataset
model = dict(
pts_voxel_layer=dict(point_cloud_range=point_cloud_range),
pts_bbox_head=dict(bbox_coder=dict(pc_range=point_cloud_range[:2])),
# model training and testing settings
train_cfg=dict(pts=dict(point_cloud_range=point_cloud_range)),
test_cfg=dict(pts=dict(pc_range=point_cloud_range[:2])))
44 changes: 30 additions & 14 deletions mmdet3d/models/dense_heads/centerpoint_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
grid_size = torch.tensor(self.train_cfg['grid_size'])
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
gt_annotation_num = len(self.train_cfg['code_weights'])

feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor']

Expand Down Expand Up @@ -489,7 +490,7 @@ def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0]))

anno_box = gt_bboxes_3d.new_zeros((max_objs, 10),
anno_box = gt_bboxes_3d.new_zeros((max_objs, gt_annotation_num),
dtype=torch.float32)

ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
Expand Down Expand Up @@ -546,20 +547,29 @@ def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):

ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
# TODO: support other outdoor dataset
vx, vy = task_boxes[idx][k][7:]

rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
anno_box[new_idx] = torch.cat([

anno_elems = [
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0),
vx.unsqueeze(0),
vy.unsqueeze(0)
])
torch.cos(rot).unsqueeze(0)
]
# Assumes datasets with bbox annotations with 9
# values have two additional velocity components (vx, vy)
# in addition to the standard KITTI-like 7 values
# (H, W, L, x, y, z, rot).
# NOTE: Rotation is split into two (sin, cos) components,
# hence incrementing the annotation number by one.
if gt_annotation_num == 10:
vx, vy = task_boxes[idx][k][7:10]
anno_elems += [vx.unsqueeze(0), vy.unsqueeze(0)]

anno_box[new_idx] = torch.cat(anno_elems)

heatmaps.append(heatmap)
anno_boxes.append(anno_box)
Expand Down Expand Up @@ -587,17 +597,23 @@ def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
# heatmap focal loss
preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap'])
num_pos = heatmaps[task_id].eq(1).float().sum().item()

loss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'],
heatmaps[task_id],
avg_factor=max(num_pos, 1))
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]['anno_box'] = torch.cat(
(preds_dict[0]['reg'], preds_dict[0]['height'],
preds_dict[0]['dim'], preds_dict[0]['rot'],
preds_dict[0]['vel']),
dim=1)
# Reconstruct the anno_box from multiple reg heads
# Default keys assumed to exist for annotations with standard
# KITTI-like 7 values
anno_box = [
preds_dict[0]['reg'], preds_dict[0]['height'],
preds_dict[0]['dim'], preds_dict[0]['rot']
]
# Key assumed to exist for bbox annotations with 9 values
if 'vel' in preds_dict[0]:
anno_box.append(preds_dict[0]['vel'])
preds_dict[0]['anno_box'] = torch.cat(anno_box, dim=1)

# Regression loss for dimension, offset, height, rotation
ind = inds[task_id]
Expand Down
4 changes: 4 additions & 0 deletions mmdet3d/models/detectors/centerpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.nn import functional as F

from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet.models import DETECTORS
Expand Down Expand Up @@ -45,6 +46,9 @@ def extract_pts_feat(self, pts, img_feats, img_metas):
x = self.pts_backbone(x)
if self.with_pts_neck:
x = self.pts_neck(x)
# Upsample output feature map spatial dimension to match target
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this condition generalizable? I think a clear comparison between the output feature map and target would be better?

if self.train_cfg['pts']['out_size_factor'] == 4:
Copy link
Contributor

@KickCellarDoor KickCellarDoor Oct 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.train_cfg['pts']['out_size_factor'] == 4 is incompatible with the default test script, since cfg.model.train_cfg is set to None. It's better to try self.test_cfg is self.train_cfg is not available.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to try self.test_cfg when self.train_cfg is not available.

x[0] = F.interpolate(x[0], scale_factor=2, mode='bilinear')
return x

def forward_pts_train(self,
Expand Down