Skip to content

Commit

Permalink
Support single stage rotated detector in MMRotate (open-mmlab#428)
Browse files Browse the repository at this point in the history
* fix lint

* fix lint

* add mmrotate part

* update

* update

* fix

* remove init_detector

* success run with bs=1

* nms_rotated support batch

* support [batch_id, class_id, box_id]

* fix

* fix

* Create test_mmrotate_core.py

* add ut

* add ut

* Update nms_rotated.py

* fix

* Revert "fix"

This reverts commit f792387fb449ba091c1d932f29d28214805fb6e3.

* add mmrotate into requirements

* add ut

* update doc

* update

* skip test because mmcv version < 1.4.6

* update

* Update rotated-detection_static.py

* Update rotated-detection_static.py

* Update rotated-detection_static.py

* fix bug of memory leak.

* Update rotated_detection_model.py
  • Loading branch information
zytx121 committed May 7, 2022
1 parent 76f6e25 commit 42dc5bc
Show file tree
Hide file tree
Showing 32 changed files with 2,290 additions and 71 deletions.
17 changes: 17 additions & 0 deletions configs/mmrotate/rotated-detection_onnxruntime_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = ['./rotated-detection_onnxruntime_static.py']
onnx_config = dict(
dynamic_axes={
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'dets': {
0: 'batch',
1: 'num_dets',
},
'labels': {
0: 'batch',
1: 'num_dets',
},
}, )
3 changes: 3 additions & 0 deletions configs/mmrotate/rotated-detection_onnxruntime_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = ['./rotated-detection_static.py', '../_base_/backends/onnxruntime.py']

onnx_config = dict(output_names=['dets', 'labels'], input_shape=None)
9 changes: 9 additions & 0 deletions configs/mmrotate/rotated-detection_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = ['../_base_/onnx_config.py']
codebase_config = dict(
type='mmrotate',
task='RotatedDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.1,
pre_top_k=3000,
keep_top_k=2000))
122 changes: 67 additions & 55 deletions csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,15 @@ float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2)
NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info)
: api_(api), ort_(api_), info_(info) {
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
score_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "score_threshold");

// create allocator
allocator_ = Ort::AllocatorWithDefaultOptions();
}

void NMSRotatedKernel::Compute(OrtKernelContext* context) {
const float iou_threshold = iou_threshold_;
const float score_threshold = score_threshold_;

const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0);
const float* boxes_data = reinterpret_cast<const float*>(ort_.GetTensorData<float>(boxes));
Expand All @@ -280,67 +282,77 @@ void NMSRotatedKernel::Compute(OrtKernelContext* context) {
OrtTensorDimensions boxes_dim(ort_, boxes);
OrtTensorDimensions scores_dim(ort_, scores);

int64_t nboxes = boxes_dim[0];
assert(boxes_dim[1] == 5); //(cx,cy,w,h,theta)
// loop over batch
int64_t nbatch = boxes_dim[0];
int64_t nboxes = boxes_dim[1];
int64_t nclass = scores_dim[1];
assert(boxes_dim[2] == 5); //(cx,cy,w,h,theta)

// allocate tmp memory
float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nboxes * 5);
float* sc = (float*)allocator_.Alloc(sizeof(float) * nboxes);
bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nboxes);
for (int64_t i = 0; i < nboxes; i++) {
select[i] = true;
}
float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nbatch * nboxes * 5);
float* sc = (float*)allocator_.Alloc(sizeof(float) * nbatch * nclass * nboxes);
bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes);

memcpy(tmp_boxes, boxes_data, sizeof(float) * nboxes * 5);
memcpy(sc, scores_data, sizeof(float) * nboxes);
memcpy(tmp_boxes, boxes_data, sizeof(float) * nbatch * nboxes * 5);
memcpy(sc, scores_data, sizeof(float) * nbatch * nclass * nboxes);

// sort scores
std::vector<float> tmp_sc;
for (int i = 0; i < nboxes; i++) {
tmp_sc.push_back(sc[i]);
}
std::vector<int64_t> order(tmp_sc.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(),
[&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; });

for (int64_t _i = 0; _i < nboxes; _i++) {
if (select[_i] == false) continue;
auto i = order[_i];

for (int64_t _j = _i + 1; _j < nboxes; _j++) {
if (select[_j] == false) continue;
auto j = order[_j];
RotatedBox box1, box2;
auto center_shift_x = (tmp_boxes[i * 5] + tmp_boxes[j * 5]) / 2.0;
auto center_shift_y = (tmp_boxes[i * 5 + 1] + tmp_boxes[j * 5 + 1]) / 2.0;
box1.x_ctr = tmp_boxes[i * 5] - center_shift_x;
box1.y_ctr = tmp_boxes[i * 5 + 1] - center_shift_y;
box1.w = tmp_boxes[i * 5 + 2];
box1.h = tmp_boxes[i * 5 + 3];
box1.a = tmp_boxes[i * 5 + 4];
box2.x_ctr = tmp_boxes[j * 5] - center_shift_x;
box2.y_ctr = tmp_boxes[j * 5 + 1] - center_shift_y;
box2.w = tmp_boxes[j * 5 + 2];
box2.h = tmp_boxes[j * 5 + 3];
box2.a = tmp_boxes[j * 5 + 4];
auto area1 = box1.w * box1.h;
auto area2 = box2.w * box2.h;
auto intersection = rotated_boxes_intersection(box1, box2);
float baseS = 1.0;
baseS = (area1 + area2 - intersection);
auto ovr = intersection / baseS;
if (ovr > iou_threshold) select[_j] = false;
}
}
// std::vector<std::vector<int64_t>> res_order;
std::vector<int64_t> res_order;
for (int i = 0; i < nboxes; i++) {
if (select[i]) {
res_order.push_back(order[i]);
}
}
for (int64_t k = 0; k < nbatch; k++) {
for (int64_t g = 0; g < nclass; g++) {
for (int64_t i = 0; i < nboxes; i++) {
select[i] = true;
}
// sort scores
std::vector<float> tmp_sc;
for (int i = 0; i < nboxes; i++) {
tmp_sc.push_back(sc[k * nboxes * nclass + g * nboxes + i]);
}
std::vector<int64_t> order(tmp_sc.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(),
[&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; });
for (int64_t _i = 0; _i < nboxes; _i++) {
if (select[_i] == false) continue;
auto i = order[_i];
for (int64_t _j = _i + 1; _j < nboxes; _j++) {
if (select[_j] == false) continue;
auto j = order[_j];
RotatedBox box1, box2;
auto center_shift_x =
(tmp_boxes[k * nboxes * 5 + i * 5] + tmp_boxes[k * nboxes * 5 + j * 5]) / 2.0;
auto center_shift_y =
(tmp_boxes[k * nboxes * 5 + i * 5 + 1] + tmp_boxes[k * nboxes * 5 + j * 5 + 1]) / 2.0;
box1.x_ctr = tmp_boxes[k * nboxes * 5 + i * 5] - center_shift_x;
box1.y_ctr = tmp_boxes[k * nboxes * 5 + i * 5 + 1] - center_shift_y;
box1.w = tmp_boxes[k * nboxes * 5 + i * 5 + 2];
box1.h = tmp_boxes[k * nboxes * 5 + i * 5 + 3];
box1.a = tmp_boxes[k * nboxes * 5 + i * 5 + 4];
box2.x_ctr = tmp_boxes[k * nboxes * 5 + j * 5] - center_shift_x;
box2.y_ctr = tmp_boxes[k * nboxes * 5 + j * 5 + 1] - center_shift_y;
box2.w = tmp_boxes[k * nboxes * 5 + j * 5 + 2];
box2.h = tmp_boxes[k * nboxes * 5 + j * 5 + 3];
box2.a = tmp_boxes[k * nboxes * 5 + j * 5 + 4];
auto area1 = box1.w * box1.h;
auto area2 = box2.w * box2.h;
auto intersection = rotated_boxes_intersection(box1, box2);
float baseS = 1.0;
baseS = (area1 + area2 - intersection);
auto ovr = intersection / baseS;
if (ovr > iou_threshold) select[_j] = false;
}
}
for (int i = 0; i < nboxes; i++) {
if (select[i] & (tmp_sc[order[i]] > score_threshold)) {
res_order.push_back(k);
res_order.push_back(g);
res_order.push_back(order[i]);
}
}
} // class loop
} // batch loop

std::vector<int64_t> inds_dims({(int64_t)res_order.size()});
std::vector<int64_t> inds_dims({(int64_t)res_order.size() / 3, 3});

OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size());
int64_t* res_data = ort_.GetTensorMutableData<int64_t>(res);
Expand Down
1 change: 1 addition & 0 deletions csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct NMSRotatedKernel {
const OrtKernelInfo* info_;
Ort::AllocatorWithDefaultOptions allocator_;
float iou_threshold_;
float score_threshold_;
};

struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> {
Expand Down
47 changes: 47 additions & 0 deletions docs/en/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,53 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
</div>
</details>

<details>
<summary style="margin-left: 25px;">MMRotate</summary>
<div style="margin-left: 25px;">
<table class="docutils">
<thead>
<tr>
<th align="center" colspan="4">MMRotate</th>
<th align="center">Pytorch</th>
<th align="center">ONNXRuntime</th>
<th align="center" colspan="2">TensorRT</th>
<th align="center">PPLNN</th>
<th align="center">OpenVINO</th>
<th align="left">Model Config</th>
</tr>
</thead>
<tbody>
<tr>
<td align="center">Model</td>
<td align="center">Task</td>
<td align="center">Dataset</td>
<td align="center">Metrics</td>
<td align="center">fp32</td>
<td align="center">fp32</td>
<td align="center">fp32</td>
<td align="center">fp16</td>
<td align="center">fp16</td>
<td align="center">fp32</td>
<td>model config file</td>
</tr>
<tr>
<td align="center" rowspan="2">RotatedRetinaNet</td>
<td align="center" rowspan="2">Rotated Detection</td>
<td align="center" rowspan="2">DOTA-v1.0</td>
<td align="center">mAP</td>
<td align="center">0.698</td>
<td align="center">0.698</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td rowspan="2">$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py</td>
</tr>
</tbody>
</table>
</div>
</details>


### Notes
- As some datasets contain images with various resolutions in codebase like MMDet. The speed benchmark is gained through static configs in MMDeploy, while the performance benchmark is gained through dynamic ones.
Expand Down
53 changes: 53 additions & 0 deletions docs/en/codebases/mmrotate.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# MMRotate Support

[MMRotate](https://github.com/open-mmlab/mmrotate) is an open-source toolbox for rotated object detection based on PyTorch. It is a part of the [OpenMMLab](https://openmmlab.com/) project.

## MMRotate installation tutorial

Please refer to [official installation guide](https://mmrotate.readthedocs.io/en/latest/install.html) to install the codebase.

## MMRotate models support

| Model | Task | ONNX Runtime | TensorRT | NCNN | PPLNN | OpenVINO | Model config |
|:----------------------|:--------------|:------------:|:--------:|:----:|:-----:|:--------:|:-------------------------------------------------------------------------------------------:|
| RotatedRetinaNet | RotatedDetection | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |

### Example

```bash
# convert ort
python tools/deploy.py \
configs/mmrotate/rotated-detection_onnxruntime_dynamic.py \
$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py \
$MMROTATE_DIR/checkpoints/rotated_retinanet_obb_r50_fpn_1x_dota_le135-e4131166.pth \
$MMROTATE_DIR/demo/demo.jpg \
--work-dir work-dirs/mmrotate/rotated_retinanet/ort \
--device cpu

# compute metric
python tools/test.py \
configs/mmrotate/rotated-detection_onnxruntime_dynamic.py \
$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py \
--model work-dirs/mmrotate/rotated_retinanet/ort/end2end.onnx \
--metrics mAP

# generate submit file
python tools/test.py \
configs/mmrotate/rotated-detection_onnxruntime_dynamic.py \
$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py \
--model work-dirs/mmrotate/rotated_retinanet/ort/end2end.onnx \
--format-only \
--metric-options submission_dir=work-dirs/mmrotate/rotated_retinanet/ort/Task1_results
```

Note

- Usually, mmrotate models need some extra information for the input image, but we can't get it directly. So, when exporting the model, you can use `$MMROTATE_DIR/demo/demo.jpg` as input.

## Reminder

None

## FAQs

None
48 changes: 48 additions & 0 deletions docs/zh_cn/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -1808,6 +1808,54 @@ GPU: ncnn, TensorRT, PPLNN
</details>


<details>
<summary style="margin-left: 25px;">MMRotate</summary>
<div style="margin-left: 25px;">
<table class="docutils">
<thead>
<tr>
<th align="center" colspan="4">MMRotate</th>
<th align="center">Pytorch</th>
<th align="center">ONNXRuntime</th>
<th align="center" colspan="2">TensorRT</th>
<th align="center">PPLNN</th>
<th align="center">OpenVINO</th>
<th align="left">Model Config</th>
</tr>
</thead>
<tbody>
<tr>
<td align="center">Model</td>
<td align="center">Task</td>
<td align="center">Dataset</td>
<td align="center">Metrics</td>
<td align="center">fp32</td>
<td align="center">fp32</td>
<td align="center">fp32</td>
<td align="center">fp16</td>
<td align="center">fp16</td>
<td align="center">fp32</td>
<td>model config file</td>
</tr>
<tr>
<td align="center" rowspan="2">RotatedRetinaNet</td>
<td align="center" rowspan="2">Rotated Detection</td>
<td align="center" rowspan="2">DOTA-v1.0</td>
<td align="center">mAP</td>
<td align="center">0.698</td>
<td align="center">0.698</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td rowspan="2">$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py</td>
</tr>
</tbody>
</table>
</div>
</details>


### 注意
- 由于某些数据集在代码库中包含各种分辨率的图像,例如 MMDet,速度基准是通过 MMDeploy 中的静态配置获得的,而性能基准是通过动态配置获得的。

Expand Down
5 changes: 4 additions & 1 deletion mmdeploy/codebase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from mmdeploy.utils import Codebase
from .base import BaseTask, MMCodebase, get_codebase_class

extra_dependent_library = {Codebase.MMOCR: ['mmdet']}
extra_dependent_library = {
Codebase.MMOCR: ['mmdet'],
Codebase.MMROTATE: ['mmdet']
}


def import_codebase(codebase: Codebase):
Expand Down
4 changes: 4 additions & 0 deletions mmdeploy/codebase/mmrotate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .core import * # noqa: F401,F403
from .deploy import * # noqa: F401,F403
from .models import * # noqa: F401,F403
3 changes: 3 additions & 0 deletions mmdeploy/codebase/mmrotate/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bbox import * # noqa: F401,F403
from .post_processing import * # noqa: F401,F403
2 changes: 2 additions & 0 deletions mmdeploy/codebase/mmrotate/core/bbox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .delta_xywha_rbbox_coder import * # noqa: F401,F403
Loading

0 comments on commit 42dc5bc

Please sign in to comment.