Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorRT inference slower than PyTorch for DeepLabV3+ with ResNest101 #239

Closed
Xingxu1996 opened this issue Mar 15, 2022 · 18 comments
Closed
Assignees

Comments

@Xingxu1996
Copy link

backbone为ResNest101的DeepLabV3+转tensorrt后,推理速度慢了两倍多
请问为什么呢

@Xingxu1996
Copy link
Author

是因为不支持ResNest吗

@lvhan028
Copy link
Collaborator

Could you please change the title in English so that users overseas can tell what it's about?

@RunningLeon
Copy link
Collaborator

@Xingxu1996 Hi, could you provide us how you test the speed of PyTorch model and TensorRT model?

@Xingxu1996
Copy link
Author

@Xingxu1996 Hi, could you provide us how you test the speed of PyTorch model and TensorRT model?

In apis/test.py, I modify code in def single_gou_test:
for i, data in enumerate(data_loader):
with torch.no_grad():
torch.cuda.synchronize()
start = time.time()
result = model(return_loss=False, **data)
torch.cuda.synchronize()
end = time.time()
sum_time = sum_time + (end - start)
Results:
Pytorch model on CUDA:
WechatIMG123
ONNX model on CUDA:
image
TRT engine:
image

So,ONNX model is four times slower than Pytorch model, and TRT is two times slower than Pytorch model!

@RunningLeon
Copy link
Collaborator

It's no surprise to see ONNX model is slower. But normaly, TensorRT should be faster than PyTorch. Please provide detailed script you use to convert the model in mmdeploy, including model config, deploy config and checkpoint.

@RunningLeon RunningLeon changed the title backbone为ResNest101的DeepLabV3+转tensorrt后,推理速度慢了两倍多 TensorRT inference slower than PyTorch for DeepLabV3+ with ResNest101 Mar 18, 2022
@Xingxu1996
Copy link
Author

It's no surprise to see ONNX model is slower. But normaly, TensorRT should be faster than PyTorch. Please provide detailed script you use to convert the model in mmdeploy, including model config, deploy config and checkpoint.

checkpoint file? How to provide it ?

@RunningLeon
Copy link
Collaborator

RunningLeon commented Mar 24, 2022

If you use standard ckpt provided by mmseg, then give us the link is Ok.

@Xingxu1996
Copy link
Author

If you use standard ckpt provided by mmseg, then give us the link is Ok.

I have solved the problem. In tensorrt engine, ReduceSum operations (torch.sum(dim=n) in resnest block) cost a lot of time. So, I replace it with slice and element-wise-add. Based on torch2trt framework, my segmention model impove inference speed 40%+。But I have not find the solution for the routine of torch->onnx->trt as in MMdeploy!

@RunningLeon
Copy link
Collaborator

RunningLeon commented Apr 7, 2022

@Xingxu1996 Hi, good to know it' solved. Could you tell us what lines of code you have changed in mmseg? Let's see if this could be integrated in mmdeploy.

@Xingxu1996
Copy link
Author

@Xingxu1996 Hi, good to know it' solved. Could you tell us what lines of code you have changed in mmseg? Let's see if this could be integrated in mmdeploy.

In mmseg/models/backbones/resnest.py class SplitAttentionConv2d(nn.Module): def forward(),there are two reducesum operations: 1、gap = splits.sum(dim=1) ; 2、torch.sum(attens * splits, dim=1) . Because my self.radix =2, so it's easy to replace it with slice and element-wise add. Look forward DaLao's more robust solution !!!! Based on trt profiler, the two operations cost of a lot of inference time. The version of TensorRT is 7.2.1.6.

@RunningLeon
Copy link
Collaborator

Thanks for your detailed info. Let's see if there's any good solution here.

@RunningLeon
Copy link
Collaborator

@Xingxu1996 Hi, could you run tools/check_env.py in mmdeploy root dirrectory to get env info for us? Besides, please provide version of GPU card and model's input_shape and model config if possible.

@Xingxu1996
Copy link
Author

@Xingxu1996 Hi, could you run tools/check_env.py in mmdeploy root dirrectory to get env info for us? Besides, please provide version of GPU card and model's input_shape and model config if possible.

I am very sorry for the late reply!
2022-05-14 11:05:36,452 - mmdeploy - INFO - TorchVision: 0.9.0
2022-05-14 11:05:36,452 - mmdeploy - INFO - OpenCV: 4.5.4-dev
2022-05-14 11:05:36,452 - mmdeploy - INFO - MMCV: 1.4.0
2022-05-14 11:05:36,452 - mmdeploy - INFO - MMCV Compiler: GCC 7.5
2022-05-14 11:05:36,452 - mmdeploy - INFO - MMCV CUDA Compiler: 11.0
2022-05-14 11:05:36,452 - mmdeploy - INFO - MMDeployment: 0.3.0+776659a
2022-05-14 11:05:36,452 - mmdeploy - INFO -

2022-05-14 11:05:36,453 - mmdeploy - INFO - Backend information
2022-05-14 11:05:43,507 - mmdeploy - INFO - onnxruntime: 1.8.0 ops_is_avaliable : True
2022-05-14 11:05:43,799 - mmdeploy - INFO - tensorrt: 7.2.1.6 ops_is_avaliable : True
2022-05-14 11:05:43,987 - mmdeploy - INFO - ncnn: None ops_is_avaliable : False
2022-05-14 11:05:44,126 - mmdeploy - INFO - pplnn_is_avaliable: False
2022-05-14 11:05:44,314 - mmdeploy - INFO - openvino_is_avaliable: False
2022-05-14 11:05:44,314 - mmdeploy - INFO -

2022-05-14 11:05:44,314 - mmdeploy - INFO - Codebase information
2022-05-14 11:05:44,316 - mmdeploy - INFO - mmcls: 0.19.0
2022-05-14 11:05:44,650 - mmdeploy - INFO - mmdet: 2.22.0
2022-05-14 11:05:44,669 - mmdeploy - INFO - mmedit: None
2022-05-14 11:05:44,670 - mmdeploy - INFO - mmocr: None
2022-05-14 11:05:44,983 - mmdeploy - INFO - mmseg: 0.14.1

@Xingxu1996
Copy link
Author

@Xingxu1996 Hi, could you run tools/check_env.py in mmdeploy root dirrectory to get env info for us? Besides, please provide version of GPU card and model's input_shape and model config if possible.

GPU is P40, input_shape [1,3,1080,1920],
model:
backbone=dict(
type='ResNeSt',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
style='pytorch',
contract_dilation=True,
stem_channels=128,
radix=2,
reduction_factor=4,
avg_down_stride=True),
decode_head=dict(
type='DepthwiseSeparableASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
c1_in_channels=256,
c1_channels=48,
dropout_ratio=0.1,
num_classes=6,
norm_cfg=dict(type='BN', requires_grad=True),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),

@Xingxu1996
Copy link
Author

Whether will you describe how to solve the question? or we would better contact with each other by WeChat!

@RunningLeon
Copy link
Collaborator

Whether will you describe how to solve the question? or we would better contact with each other by WeChat!

@Xingxu1996 Hi, sorry for late replying. We have not integrate what you mentioned before into mmdeploy yet. You could join our wechat group for more convenient discussion. Please contact the assistant through WeChat ID OpenMMLabwx and join mmdeploy wechat group.

@RunningLeon
Copy link
Collaborator

RunningLeon commented May 16, 2022

@Xingxu1996 Hi,thanks for your patience. I've tested on my machine and following are the results

env

2022-05-16 16:17:07,102 - mmdeploy - INFO - **********Environmental information**********
2022-05-16 16:17:08,104 - mmdeploy - INFO - sys.platform: linux
2022-05-16 16:17:08,104 - mmdeploy - INFO - Python: 3.7.5 (default, Oct 25 2019, 15:51:11) [GCC 7.3.0]
2022-05-16 16:17:08,104 - mmdeploy - INFO - CUDA available: True
2022-05-16 16:17:08,104 - mmdeploy - INFO - GPU 0: NVIDIA GeForce RTX 2080
2022-05-16 16:17:08,104 - mmdeploy - INFO - CUDA_HOME: /usr/local/cuda
2022-05-16 16:17:08,104 - mmdeploy - INFO - NVCC: Build cuda_11.1.TC455_06.29069683_0
2022-05-16 16:17:08,104 - mmdeploy - INFO - GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
2022-05-16 16:17:08,104 - mmdeploy - INFO - PyTorch: 1.8.0

2022-05-16 16:17:08,104 - mmdeploy - INFO - TorchVision: 0.9.0
2022-05-16 16:17:08,104 - mmdeploy - INFO - OpenCV: 4.5.2
2022-05-16 16:17:08,104 - mmdeploy - INFO - MMCV: 1.4.8
2022-05-16 16:17:08,104 - mmdeploy - INFO - MMCV Compiler: GCC 7.3
2022-05-16 16:17:08,104 - mmdeploy - INFO - MMCV CUDA Compiler: 11.1
2022-05-16 16:17:08,104 - mmdeploy - INFO - MMDeploy: 0.4.0+0cd44a6
2022-05-16 16:17:08,104 - mmdeploy - INFO - 

2022-05-16 16:17:08,104 - mmdeploy - INFO - **********Backend information**********
[2022-05-16 16:17:08.311] [mmdeploy] [info] [model.cpp:95] Register 'DirectoryModel'
2022-05-16 16:17:08,361 - mmdeploy - INFO - onnxruntime: 1.11.1	ops_is_avaliable : True
2022-05-16 16:17:08,362 - mmdeploy - INFO - tensorrt: 8.2.1.8 [and 7.2.1.6]	ops_is_avaliable : True
2022-05-16 16:17:08,363 - mmdeploy - INFO - ncnn: 1.0.20220428	ops_is_avaliable : False
2022-05-16 16:17:08,363 - mmdeploy - INFO - pplnn_is_avaliable: True
2022-05-16 16:17:08,364 - mmdeploy - INFO - openvino_is_avaliable: True
2022-05-16 16:17:08,364 - mmdeploy - INFO - 

2022-05-16 16:17:08,364 - mmdeploy - INFO - **********Codebase information**********
2022-05-16 16:17:08,376 - mmdeploy - INFO - mmdet:	2.22.0
2022-05-16 16:17:08,377 - mmdeploy - INFO - mmseg:	0.24.0
2022-05-16 16:17:08,377 - mmdeploy - INFO - mmcls:	0.21.0
2022-05-16 16:17:08,377 - mmdeploy - INFO - mmocr:	None
2022-05-16 16:17:08,377 - mmdeploy - INFO - mmedit:	0.12.0
2022-05-16 16:17:08,377 - mmdeploy - INFO - mmdet3d:	1.0.0rc0
2022-05-16 16:17:08,377 - mmdeploy - INFO - mmpose:	0.24.0

Configs:

model cfg: https://github.com/open-mmlab/mmsegmentation/blob/master/configs/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes.py
pytorch ckpt: https://download.openmmlab.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes/deeplabv3plus_s101-d8_512x1024_80k_cityscapes_20200807_144429-1239eb43.pth
deploy_cfg: https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmseg/segmentation_tensorrt_static-1024x2048.py
dataset: cityscapes
image_shape: 1024x1048

Original model: No change of mmseg.
New model: changed following code

    def forward(self, x):
        x = self.conv(x)
        x = self.norm0(x)
        x = self.relu(x)

        batch, rchannel = x.shape[:2]
        batch = x.size(0)
        if self.radix > 1:
            ##### original lines are commented #######
            # splits = x.view(batch, self.radix, -1, *x.shape[2:])
            # gap = splits.sum(dim=1)
            assert self.radix == 2
            n = x.shape[1] // 2
            x_h1 = x[:, :n, :, :]
            x_h2 = x[:, n:, :, :]
            gap = x_h1 + x_h2
        else:
            gap = x
        gap = F.adaptive_avg_pool2d(gap, 1)
        gap = self.fc1(gap)

        gap = self.norm1(gap)
        gap = self.relu(gap)

        atten = self.fc2(gap)
        atten = self.rsoftmax(atten).view(batch, -1, 1, 1)

        if self.radix > 1:
            ##### original lines are commented #######
            # attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
            # out = torch.sum(attens * splits, dim=1)
            assert self.radix == 2
            n = atten.shape[1] // 2
            atten_h1 = atten[:, :n, :, :]
            atten_h2 = atten[:, n:, :, :]
            out = (x_h1 * atten_h1) + (x_h2 * atten_h2)
        else:
            out = atten * x
        return out.contiguous()

Result

model backend speed(ms)
Original PyTorch 687.48
Original TensorRT==7.2.1.6 1754.05
Original TensorRT==8.2.1.8 538.78
New PyTorch 689.18
New TensorRT==7.2.1.6 542.04
New TensorRT==8.2.1.8 527.49

Conclusions

  1. Inference of TensorRT==7.2 is much slower than PyTorch, while TensorRT==8.2 is faster than PyTorch.
  2. Changes could improve the inference latency for TensorRT==7.2 by a large margin.
  3. Suggest to use trt8.x if possible.
  4. Since it's a special case and a rewriting for this case should not be included in mmdeploy.

@RunningLeon
Copy link
Collaborator

Closed since we have above solution for this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants