Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]Support NeRF-Det #2732

Merged
merged 48 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
ea49d43
first try
Yanyirong Sep 5, 2023
2a87fea
support dataset and pipeline
Yanyirong Sep 14, 2023
94f891f
modify the config
Yanyirong Sep 14, 2023
729239a
update configs
Yanyirong Oct 16, 2023
eace403
update dataset
Yanyirong Oct 16, 2023
7f4ae48
update pipeline
Yanyirong Oct 17, 2023
d71b328
update structures
Yanyirong Oct 17, 2023
d0d5599
update dense_head
Yanyirong Oct 17, 2023
2596ced
update nerfdet
Yanyirong Oct 17, 2023
b7a797a
update nerf_utils
Yanyirong Oct 17, 2023
4459ba1
fix a bug in utils
Yanyirong Oct 20, 2023
2fc3b87
refactor dataset
Yanyirong Oct 23, 2023
759ceb3
rollback dataset converter
Yanyirong Oct 24, 2023
a354cd8
delete test dataset
Yanyirong Oct 24, 2023
687241e
move configs to project
Yanyirong Oct 24, 2023
6e3c392
move dataset to projects
Yanyirong Oct 24, 2023
f1bcd26
move pipeline to projects
Yanyirong Oct 24, 2023
c2bc0ae
move uitls
Yanyirong Oct 24, 2023
18fcf2f
move bbox_head
Yanyirong Oct 24, 2023
e9ef015
move structure
Yanyirong Oct 24, 2023
26ebd97
move formating to projects
Yanyirong Oct 24, 2023
b93621c
move preprocessor to projects
Yanyirong Oct 24, 2023
0ffc05d
move model to projects
Yanyirong Oct 24, 2023
38c01c0
update configs
Yanyirong Oct 24, 2023
59c2510
move the dataset converter
Yanyirong Oct 25, 2023
a0fef75
update README
Yanyirong Oct 25, 2023
8fc5ea6
Merge branch 'open-mmlab:main' into dev
Yanyirong Oct 26, 2023
9a924e6
add res101
Yanyirong Oct 30, 2023
5eed02b
Merge branch 'dev' of https://github.com/Yanyirong/mmdetection3d into…
Yanyirong Oct 30, 2023
aa8e657
update readme and fix an error
Yanyirong Nov 1, 2023
950ed10
update README
Yanyirong Nov 23, 2023
47cf20e
fix some commits
Yanyirong Nov 28, 2023
cb97e60
fix datasample
Yanyirong Nov 29, 2023
a486651
fix lidar2cam and cam2img
Yanyirong Dec 1, 2023
c7d2bc7
fix configs
Yanyirong Dec 1, 2023
07ccdbf
update dataset converter
Yanyirong Dec 21, 2023
9f5b428
update README
Yanyirong Dec 21, 2023
9617c17
Refresh README
Yanyirong Dec 21, 2023
88a1aaa
modify dataset name
Yanyirong Dec 29, 2023
59963f8
modify nerf_mlp.py
Yanyirong Dec 29, 2023
7c9f45a
modify projection.py
Yanyirong Dec 29, 2023
7e10eee
modify save_rendered_img.py
Yanyirong Dec 29, 2023
10ed86a
modify render_ray
Yanyirong Dec 29, 2023
2066cbf
modify render_ray.py
Yanyirong Dec 29, 2023
d33480a
revert the change of infos
Yanyirong Dec 29, 2023
2db1916
revert create_data
Yanyirong Dec 29, 2023
09e1c0d
update prepare_infos.py
Yanyirong Dec 29, 2023
f4920a9
fix some typos
Yanyirong Jan 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions projects/NeRF-Det/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection

> [NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection](https://arxiv.org/abs/2307.14620)

<!-- [ALGORITHM] -->

## Abstract

NeRF-Det is a novel method for indoor 3D detection with posed RGB images as input. Unlike existing indoor 3D detection methods that struggle to model scene geometry,NeRF-Det makes novel use of NeRF in an end-to-end manner to explicitly estimate 3D geometry, thereby improving 3D detection performance. Specifically, to avoid the significant extra latency associated with per-scene optimization of NeRF, NeRF-Det introduce sufficient geometry priors to enhance the generalizability of NeRF-MLP. Furthermore, it subtly connect the detection and NeRF branches through a shared MLP, enabling an efficient adaptation of NeRF to detection and yielding geometry-aware volumetric representations for 3D detection. NeRF-Det outperforms state-of-the-arts by 3.9 mAP and 3.1 mAP on the ScanNet and ARKITScenes benchmarks, respectively. The author provide extensive analysis to shed light on how NeRF-Det works. As a result of joint-training design, NeRF-Det is able to generalize well to unseen scenes for object detection, view synthesis, and depth estimation tasks without requiring per-scene optimization.Code will be available at https://github.com/facebookresearch/NeRF-Det

<div align=center>
<img src="https://chenfengxu714.github.io/nerfdet/static/images/method-cropped_1.png" width="800"/>
</div>

## Introduction

This directory contains the implementations of NeRF-Det (https://arxiv.org/abs/2307.14620). Our implementations are built on top of MMdetection3D.We have updated NeRF-Det to be compatible with latest mmdet3d version. The codebase and config files have all changed to adapt to the new mmdet3d version. All previous pretrained models are verified with the result listed below. However, newly trained models are yet to be uploaded.

<!-- Share any information you would like others to know. For example:
Author: @xxx.
This is an implementation of \[XXX\]. -->

## Dataset

The format of the scannet dataset in the latest version of mmdet3d only supports the lidar tasks.For NeRF-Det,we need to create the new format of ScanNet Dataset.
Yanyirong marked this conversation as resolved.
Show resolved Hide resolved

Please following the files in mmdet3d to prepare the raw data of ScanNet. After that, please use this command to generate the pkls used in nerfdet.

```bash
python tools/create_data.py scannet --root-path ./data/scannet \
--out-dir ./data/scannet --extra-tag scannet --version nerfdet
```

The new format of the pkl is organized as below:

- scannet_infos_train.pkl: The train data infos, the detailed info of each scan is as follows:
- info\['instances'\]:A list of dict contains all annotations, each dict contains all annotation information of single instance.For the i-th instance:
- info\['instances'\]\[i\]\['bbox_3d'\]: List of 6 numbers representing the axis_aligned in depth coordinate system, in (x,y,z,l,w,h) order.
- info\['instances'\]\[i\]\['bbox_label_3d'\]: The label of each 3d bounding boxes.
- info\['cam2img'\]: The intrinsic matrix.Every scene has one matrix.
- info\['lidar2cam'\]: The extrinsic matrixes.Every scene has 300 matrixes.
- info\['img_paths'\]: The paths of the 300 rgb pictures.
- info\['axis_align_matrix'\]: The align matrix.Every scene has one matrix.

After preparing your scannet dataset pkls,please change the paths in configs to fit your project.

## Train

In MMDet3D's root directory, run the following command to train the model:

```bash
python tools/train.py projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py ${WORK_DIR}
```

## Results and Models

### NeRF-Det

| Backbone | mAP@25 | mAP@50 | Log |
| :-------------------------------------------------------------: | :----: | :----: | :-------: |
| [NeRF-Det-R50](./configs/nerfdet_res50_2x_low_res.py) | 53.0 | 26.8 | [log](<>) |
| [NeRF-Det-R50\*](./configs/nerfdet_res50_2x_low_res_depth.py) | 52.2 | 28.5 | [log](<>) |
| [NeRF-Det-R101\*](./configs/nerfdet_res101_2x_low_res_depth.py) | 52.3 | 28.5 | [log](<>) |

(Here NeRF-Det-R50\* means this model uses depth information in the training step)

### Notes

- The values showed in the chart all represents the best mAP in the training.

- Since there is a lot of randomness in the behavior of the model, we conducted three experiments on each config and took the average. The mAP showed on the above chart are all average values.

- We also conducted the same experiments in the original code, the results are showed below.

| Backbone | mAP@25 | mAP@50 |
| :-------------: | :----: | :----: |
| NeRF-Det-R50 | 52.8 | 26.8 |
| NeRF-Det-R50\* | 52.4 | 27.5 |
| NeRF-Det-R101\* | 52.8 | 28.6 |

- Attention: Because of the randomness in the construction of the ScanNet dataset itself and the behavior of the model, the training results will fluctuate considerably. According to experimental results and experience, the experimental results will fluctuate by plus or minus 1.5 points.

## Evaluation using pretrained models

1. Download the pretrained checkpoints through the linkings in the above chart.

2. Testing

To test, use:

```bash
python tools/test.py projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py ${CHECKPOINT_PATH}
```

## Citation

<!-- You may remove this section if not applicable. -->

```latex
@inproceedings{
xu2023nerfdet,
title={NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection},
author={Xu, Chenfeng and Wu, Bichen and Hou, Ji and Tsai, Sam and Li, Ruilong and Wang, Jialiang and Zhan, Wei and He, Zijian and Vajda, Peter and Keutzer, Kurt and Tomizuka, Masayoshi},
booktitle={ICCV},
year={2023},
}

@inproceedings{
park2023time,
title={Time Will Tell: New Outlooks and A Baseline for Temporal Multi-View 3D Object Detection},
author={Jinhyung Park and Chenfeng Xu and Shijia Yang and Kurt Keutzer and Kris M. Kitani and Masayoshi Tomizuka and Wei Zhan},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=H3HcEJA2Um}
}
```
198 changes: 198 additions & 0 deletions projects/NeRF-Det/configs/nerfdet_res101_2x_low_res_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
_base_ = ['../../../configs/_base_/default_runtime.py']

custom_imports = dict(imports=['projects.NeRF-Det.nerfdet'])
prior_generator = dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-3.2, -3.2, -1.28, 3.2, 3.2, 1.28]],
rotations=[.0])

model = dict(
type='NerfDet',
data_preprocessor=dict(
type='NeRFDetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=10),
backbone=dict(
type='mmdet.ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet101'),
style='pytorch'),
neck=dict(
type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
neck_3d=dict(
type='IndoorImVoxelNeck',
in_channels=256,
out_channels=128,
n_blocks=[1, 1, 1]),
bbox_head=dict(
type='NerfDetHead',
bbox_loss=dict(type='AxisAlignedIoULoss', loss_weight=1.0),
n_classes=18,
n_levels=3,
n_channels=128,
n_reg_outs=6,
pts_assign_threshold=27,
pts_center_threshold=18,
prior_generator=prior_generator),
prior_generator=prior_generator,
voxel_size=[.16, .16, .2],
n_voxels=[40, 40, 16],
aabb=([-2.7, -2.7, -0.78], [3.7, 3.7, 1.78]),
near_far_range=[0.2, 8.0],
N_samples=64,
N_rand=2048,
nerf_mode='image',
depth_supervise=True,
use_nerf_mask=True,
nerf_sample_view=20,
squeeze_scale=4,
nerf_density=True,
train_cfg=dict(),
test_cfg=dict(nms_pre=1000, iou_thr=.25, score_thr=.01))

dataset_type = 'ScanNetMultiViewDataset'
Yanyirong marked this conversation as resolved.
Show resolved Hide resolved
data_root = 'data/scannet/'
class_names = [
'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf',
'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'showercurtrain',
'toilet', 'sink', 'bathtub', 'garbagebin'
]
metainfo = dict(CLASSES=class_names)
file_client_args = dict(backend='disk')

input_modality = dict(
use_camera=True,
use_depth=True,
use_lidar=False,
use_neuralrecon_depth=False,
use_ray=True)
backend_args = None

train_collect_keys = [
'img', 'gt_bboxes_3d', 'gt_labels_3d', 'depth', 'lightpos', 'nerf_sizes',
'raydirs', 'gt_images', 'gt_depths', 'denorm_images'
]

test_collect_keys = [
'img',
'depth',
'lightpos',
'nerf_sizes',
'raydirs',
'gt_images',
'gt_depths',
'denorm_images',
]

train_pipeline = [
dict(type='LoadAnnotations3D'),
dict(
type='MultiViewPipeline',
n_images=48,
Yanyirong marked this conversation as resolved.
Show resolved Hide resolved
transforms=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(320, 240), keep_ratio=True),
],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
margin=10,
depth_range=[0.5, 5.5],
loading='random',
nerf_target_views=10),
dict(type='RandomShiftOrigin', std=(.7, .7, .0)),
dict(type='PackNeRFDetInputs', keys=train_collect_keys)
]

test_pipeline = [
dict(type='LoadAnnotations3D'),
dict(
type='MultiViewPipeline',
n_images=101,
transforms=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(320, 240), keep_ratio=True),
],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
margin=10,
depth_range=[0.5, 5.5],
loading='random',
nerf_target_views=1),
dict(type='PackNeRFDetInputs', keys=test_collect_keys)
]

train_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=6,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='scannet_infos_train_new.pkl',
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo)))
val_dataloader = dict(
batch_size=1,
num_workers=5,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='scannet_infos_val_new.pkl',
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo))
test_dataloader = val_dataloader

val_evaluator = dict(type='IndoorMetric')
test_evaluator = val_evaluator

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
test_cfg = dict()
val_cfg = dict()

optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}),
clip_grad=dict(max_norm=35., norm_type=2))
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]

# hooks
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=12))

# runtime
find_unused_parameters = True # only 1 of 4 FPN outputs is used
Loading
Loading