Skip to content

Commit

Permalink
[Fix] Dbnet performance of trt8 (open-mmlab#278)
Browse files Browse the repository at this point in the history
* compatible trt version for dbnet

* judge inside rewrite
  • Loading branch information
AllentDan committed Dec 13, 2021
1 parent bd28671 commit 105acc9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision,version
8 changes: 5 additions & 3 deletions docs/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](doc
<td class="tg-0lax">model config file</td>
</tr>
<tr>
<td class="tg-nrix" rowspan="3">DBNet</td>
<td class="tg-nrix" rowspan="3">DBNet*</td>
<td class="tg-nrix" rowspan="3">TextDetection</td>
<td class="tg-baqh">recall</td>
<td class="tg-baqh">0.7310</td>
Expand Down Expand Up @@ -823,6 +823,8 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](doc


### Notes
As some datasets contains images with various resolutions in codebase like MMDet. The speed benchmark is gained through static configs in MMDeploy, while the performance benchmark is gained through dynamic ones.
- As some datasets contains images with various resolutions in codebase like MMDet. The speed benchmark is gained through static configs in MMDeploy, while the performance benchmark is gained through dynamic ones.

Some int8 performance benchmarks of tensorrt require nvidia cards with tensor core, or the performance would drop heavily.
- Some int8 performance benchmarks of TensorRT require nvidia cards with tensor core, or the performance would drop heavily.

- DBNet uses the interpolate mode `nearest` in the neck of the model, which TensorRT-7 applies quite different strategy from pytorch. To make the repository compatible with TensorRT-7, we rewrite the neck to use the interpolate mode `bilinear` which improves final detection performance. To get the matched performance with Pytorch, TensorRT-8+ is recommended, which the interpolate methods are all the same as Pytorch.
12 changes: 8 additions & 4 deletions mmdeploy/codebase/mmocr/models/text_detection/fpn_cat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
import version

from mmdeploy.core import FUNCTION_REWRITER

Expand All @@ -11,7 +12,7 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
"""Rewrite `forward` of FPNC for tensorrt backend.
Rewrite this function to replace nearest upsampling with bilinear
upsampling. Tensorrt backend applies different nearest sampling strategy
upsampling. TensorRT-7 backend applies different nearest sampling strategy
from pytorch, which heavily influenced the final performance.
Args:
Expand All @@ -24,6 +25,10 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
outs (Tensor): A feature map output from FPNC. The tensor shape
(N, C, H, W).
"""
# TensorRT version 8+ matches the upsampling with pytorch
import tensorrt as trt
apply_rewrite = version.parse(trt.__version__) < version.parse('8')
mode = 'bilinear' if apply_rewrite else 'nearest'

assert len(inputs) == len(self.in_channels)
# build laterals
Expand All @@ -36,16 +41,15 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate(
laterals[i], size=prev_shape, mode='bilinear')
laterals[i], size=prev_shape, mode=mode)
# build outputs
# part 1: from original levels
outs = [
self.smooth_convs[i](laterals[i]) for i in range(used_backbone_levels)
]

for i, out in enumerate(outs):
outs[i] = F.interpolate(
outs[i], size=outs[0].shape[2:], mode='bilinear')
outs[i] = F.interpolate(outs[i], size=outs[0].shape[2:], mode=mode)
out = torch.cat(outs, dim=1)

if self.conv_after_concat:
Expand Down

0 comments on commit 105acc9

Please sign in to comment.