Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to train CenterPoint on kitti #486

Closed
guojingming opened this issue Apr 26, 2021 · 12 comments
Closed

Try to train CenterPoint on kitti #486

guojingming opened this issue Apr 26, 2021 · 12 comments

Comments

@guojingming
Copy link

I tried to train CenterPoint on kitti, when finished training, the val result is far from expected.
Use PillarFeatureNet as feature extracter, second as backbone and second_fpn as neck, train on 3-class kitti, cyc lr and 80 epochs, eval by strict 3D iou.
Got mAP of Car is 69.17, Ped is 30.68, Cyc is 53.53).

Are there problems with my config file?

@guojingming
Copy link
Author

guojingming commented Apr 26, 2021

base = [
'../base/datasets/kitti-3class.py',
'../base/models/centerpoint_02pillar_second_secfpn_kitti.py',
'../base/schedules/cyclic_40e.py',
'../base/default_runtime.py'
]

point_cloud_range = [0, -39.68, -3, 69.12, 39.68, 1]
data_root = 'data/kitti/'
class_names = ['Pedestrian', 'Cyclist', 'Car']
db_sampler = dict(
data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)),
classes=class_names,
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10))

model = dict(
pts_voxel_layer=dict(point_cloud_range=point_cloud_range),
pts_voxel_encoder=dict(point_cloud_range=point_cloud_range),
pts_bbox_head=dict(bbox_coder=dict(pc_range=point_cloud_range[:2])),
# model training and testing settings
train_cfg=dict(pts=dict(point_cloud_range=point_cloud_range)),
test_cfg=dict(pts=dict(pc_range=point_cloud_range[:2])))

train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='ObjectNoise',
num_try=100,
translation_std=[0.25, 0.25, 0.25],
global_rot_range=[0.0, 0.0],
rot_range=[-0.15707963267, 0.15707963267]),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]

data = dict(
train=dict(dataset=dict(pipeline=train_pipeline, classes=class_names)),
val=dict(pipeline=test_pipeline, classes=class_names),
test=dict(pipeline=test_pipeline, classes=class_names))

lr = 0.001
optimizer = dict(lr=lr)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
evaluation = dict(interval=5)
total_epochs = 80

@guojingming
Copy link
Author

'centerpoint_02pillar_second_secfpn_kitti.py' in my config file is modified from centerpoint_02pillar_second_secfpn_nus.py,only changed some params (such as class names) to adapt KITTI, and I didn't change the architecture.

@ZhangYu1ing
Copy link

Hi, we are also working on training CenterPoint on KITTI.
May I ask what the voxel_size that you set? Since we met the error when training the centerpoint. The error is shown as below. Did you meet the same problem before?
Thanks in advance.

_2021-04-27 13:54:02,527 - mmdet - INFO - Start running, host: root@b181bde9e289, work_dir: /content/gdrive/My Drive/Colab_Notebooks/mmdet3d-colab/mmdetection3d/work_dirs/centerpoint_nusc2kitti_02pillar
2021-04-27 13:54:02,527 - mmdet - INFO - workflow: [('train', 1)], max: 20 epochs
2021-04-27 13:54:02.709169: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
max_voxels: 30000
max_voxels: 30000
voxel_size: [0.2, 0.2, 8]
coors_range: [0, -39.68, -3, 69.12, 39.68, 1]
voxels_size: torch.Size([30000, 20, 4])
coors: torch.Size([30000, 3])
num_points_per_voxel: torch.Size([30000])
max_points: 20
Traceback (most recent call last):
File "tools/train.py", line 212, in
main()
File "tools/train.py", line 208, in main
meta=meta)
File "/content/gdrive/My Drive/Colab_Notebooks/mmdet3d-colab/mmdetection/mmdet/apis/train.py", line 170, in train_detector
runner.run(data_loaders, cfg.workflow)
File "/usr/local/lib/python3.7/dist-packages/mmcv/runner/epoch_based_runner.py", line 125, in run
epoch_runner(data_loaders[i], **kwargs)
File "/usr/local/lib/python3.7/dist-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
self.run_iter(data_batch, train_mode=True, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/mmcv/runner/epoch_based_runner.py", line 30, in run_iter
**kwargs)
File "/usr/local/lib/python3.7/dist-packages/mmcv/parallel/data_parallel.py", line 67, in train_step
return self.module.train_step(*inputs[0], **kwargs[0])
File "/content/gdrive/My Drive/Colab_Notebooks/mmdet3d-colab/mmdetection/mmdet/models/detectors/base.py", line 247, in train_step
losses = self(**data)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/mmcv/runner/fp16_utils.py", line 84, in new_func
return old_func(*args, **kwargs)
File "/content/gdrive/My Drive/Colab_Notebooks/mmdet3d-colab/mmdetection3d/mmdet3d/models/detectors/base.py", line 58, in forward
return self.forward_train(**kwargs)
File "/content/gdrive/My Drive/Colab_Notebooks/mmdet3d-colab/mmdetection3d/mmdet3d/models/detectors/mvx_two_stage.py", line 273, in forward_train
points, img=img, img_metas=img_metas)
File "/content/gdrive/My Drive/Colab_Notebooks/mmdet3d-colab/mmdetection3d/mmdet3d/models/detectors/mvx_two_stage.py", line 207, in extract_feat
pts_feats = self.extract_pts_feat(points, img_feats, img_metas)
File "/content/gdrive/My Drive/Colab_Notebooks/mmdet3d-colab/mmdetection3d/mmdet3d/models/detectors/centerpoint.py", line 38, in extract_pts_feat
voxels, num_points, coors = self.voxelize(pts)
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/mmcv/runner/fp16_utils.py", line 164, in new_func
return old_func(*args, **kwargs)
File "/content/gdrive/My Drive/Colab_Notebooks/mmdet3d-colab/mmdetection3d/mmdet3d/models/detectors/mvx_two_stage.py", line 224, in voxelize
res_voxels, res_coors, res_num_points = self.pts_voxel_layer(res)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in call_impl
result = self.forward(*input, **kwargs)
File "/content/gdrive/My Drive/Colab_Notebooks/mmdet3d-colab/mmdetection3d/mmdet3d/ops/voxel/voxelize.py", line 123, in forward
self.max_num_points, max_voxels)
File "/content/gdrive/My Drive/Colab_Notebooks/mmdet3d-colab/mmdetection3d/mmdet3d/ops/voxel/voxelize.py", line 63, in forward
voxel_num = hard_voxelize(points, voxels, coors, num_points_per_voxel, voxel_size,coors_range, max_points, max_voxels, 3)
RuntimeError: CUDA error: no kernel image is available for execution on the device

@guojingming
Copy link
Author

The voxel size in my config file is [0.16, 0.16, 4], point cloud range is [0, -40.96, -1, 81.92, 40.96, 3], Pillar feature net output shape and grid size is [512, 512]. I'm sorry that I didn't meet this error, may be caused by pytorch version according to https://blog.csdn.net/yyhaohaoxuexi/article/details/107460836 @ZhangYu1ing

@tianweiy
Copy link
Contributor

One implementation here https://github.com/tianweiy/CenterPoint-KITTI

I played with mmdet3d in the past but I don't get a clean codebase to release at the moment. Additionally, PointPillars just doesn't work too well on KITTI so you may consider playing with the VoxelNet backbone.

@tianweiy
Copy link
Contributor

RuntimeError: CUDA error: no kernel image is available for execution on the device

That means that the hard_voxelize code is not compiled correctly for your GPU.

@guojingming
Copy link
Author

Thanks very much for reply and sharing the code ! @tianweiy Another question is what is the mAP difference roughly between pillars and voxelnet on KITTI according to your experiment result?

@tianweiy
Copy link
Contributor

I only played with PointPillars in the Det3D repo. I am not able to get upper than 75 mAP. I will add model zoo to the repo soon.

@ZhangYu1ing
Copy link

@guojingming @Tai-Wang Thanks for giving the advice. We noticed that the problem only be raised up in Google Colab. But it could work on PC.
Thanks a lot for your sharing.

@ZhangYu1ing
Copy link

Hi, I am trying to implement your provided code. I created the KITTI data following the instruction in OpenPCDet. However, there is a bug when I run the tools/train.py, which is shown below. Do you have any ideas? Thanks a lot.

2021-04-28_ 15:44:55,801 INFO Start training tools/cfgs/kitti_models/centerpoint(default)
epochs: 0%| | 0/80 [00:00<?, ?it/s]
Traceback (most recent call last): | 0/464 [00:00<?, ?it/s]
File "tools/train.py", line 198, in
main()
File "tools/train.py", line 170, in main
merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch
File "/home/xymbiotec/CenterPoint-KITTI-main/tools/train_utils/train_utils.py", line 93, in train_model
dataloader_iter=dataloader_iter
File "/home/xymbiotec/CenterPoint-KITTI-main/tools/train_utils/train_utils.py", line 19, in train_one_epoch
batch = next(dataloader_iter)
File "/home/xymbiotec/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 435, in next
data = self._next_data()
File "/home/xymbiotec/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data
return self._process_data(data)
File "/home/xymbiotec/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data
data.reraise()
File "/home/xymbiotec/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/xymbiotec/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/home/xymbiotec/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/xymbiotec/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/xymbiotec/CenterPoint-KITTI-main/pcdet/datasets/kitti/kitti_dataset.py", line 348, in getitem
sample_idx = info['point_cloud']['lidar_idx']
KeyError: 'lidar_idx'

@tianweiy
Copy link
Contributor

tianweiy commented Apr 28, 2021

Maybe, there are some version compatibility issue. I will check today

@ZhangYu1ing
Copy link

Maybe, there are some version compatibility issue. I will check today

No worries, the problem has been solved. The problem is related to the path of Kitti Dataset. I modified the DATA_PATH in tools/cfgs/dataset_configs/kitti_dataset.yaml
I changed the DATA_PATH='/data/kitti', instead of DATA_PATH='../data/kitti'
Then, it could successfully work.

This is my attempt to solve this problem. I am not sure if there are another issues also lead to this problem?
Thanks for your help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants