Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] gpu inference is much slower than cpu #17489

Open
MinGiSa opened this issue Sep 11, 2023 · 2 comments
Open

[Performance] gpu inference is much slower than cpu #17489

MinGiSa opened this issue Sep 11, 2023 · 2 comments
Labels
ep:CUDA issues related to the CUDA execution provider platform:windows issues related to the Windows platform quantization issues related to quantization

Comments

@MinGiSa
Copy link

MinGiSa commented Sep 11, 2023

Describe the issue

Hello, I'm trying to export a Craft model to an ONNX file using torch export.

When I export and run inference, I receive some messages blow, and inference takes much longer than on the CPU, about 2,000 ms, while the CPU takes about 50 ms.

2023-09-11 15:07:36.5009034 [W:onnxruntime:, session_state.cc:1169 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns
2023-09-11 15:07:36.5058247 [W:onnxruntime:, session_state.cc:1171 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.

I suspect there are some issues with exporting the ONNX model.

Is there anything wrong with the exported model, or did I miss something to do further?

The following code is part of the inference code

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

prov = ['CUDAExecutionProvider'] if device == 'cuda' else ['CPUExecutionProvider']

ort_session = onnxruntime.InferenceSession(onnxModelFile, providers=prov)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

To reproduce

import os
import torch
from collections import OrderedDict
import craft
import onnx
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.init as init
from torchvision import models
from collections import namedtuple

def init_weights(modules):
for m in modules:
if isinstance(m, nn.Conv2d):
init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()

class vgg16_bn(torch.nn.Module):
def init(self, pretrained=True, freeze=True):
super(vgg16_bn, self).init()
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(12): # conv2_2
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 19): # conv3_3
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(19, 29): # conv4_3
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(29, 39): # conv5_3
self.slice4.add_module(str(x), vgg_pretrained_features[x])

    # fc6, fc7 without atrous conv
    self.slice5 = torch.nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
            nn.Conv2d(1024, 1024, kernel_size=1)
    )

    if not pretrained:
        init_weights(self.slice1.modules())
        init_weights(self.slice2.modules())
        init_weights(self.slice3.modules())
        init_weights(self.slice4.modules())

    init_weights(self.slice5.modules())        # no pretrained model for fc6 and fc7

    if freeze:
        for param in self.slice1.parameters():      # only first conv
            param.requires_grad= False

def forward(self, X):
    h = self.slice1(X)
    h_relu2_2 = h
    h = self.slice2(h)
    h_relu3_2 = h
    h = self.slice3(h)
    h_relu4_3 = h
    h = self.slice4(h)
    h_relu5_3 = h
    h = self.slice5(h)
    h_fc7 = h
    vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
    out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
    return out

class double_conv(nn.Module):
def init(self, in_ch, mid_ch, out_ch):
super(double_conv, self).init()
self.conv = nn.Sequential(
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
nn.BatchNorm2d(mid_ch),
nn.ReLU(inplace=True),
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)

def forward(self, x):
    x = self.conv(x)
    return x

class CRAFT(nn.Module):
def init(self, pretrained=False, freeze=False):
super(CRAFT, self).init()

    """ Base network """
    self.basenet = vgg16_bn(pretrained, freeze)

    """ U network """
    self.upconv1 = double_conv(1024, 512, 256)
    self.upconv2 = double_conv(512, 256, 128)
    self.upconv3 = double_conv(256, 128, 64)
    self.upconv4 = double_conv(128, 64, 32)

    num_class = 2
    self.conv_cls = nn.Sequential(
        nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
        nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
        nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
        nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
        nn.Conv2d(16, num_class, kernel_size=1),
    )

    init_weights(self.upconv1.modules())
    init_weights(self.upconv2.modules())
    init_weights(self.upconv3.modules())
    init_weights(self.upconv4.modules())
    init_weights(self.conv_cls.modules())

def forward(self, x):
    """ Base network """
    sources = self.basenet(x)

    """ U network """
    y = torch.cat([sources[0], sources[1]], dim=1)
    y = self.upconv1(y)

    y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
    y = torch.cat([y, sources[2]], dim=1)
    y = self.upconv2(y)

    y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
    y = torch.cat([y, sources[3]], dim=1)
    y = self.upconv3(y)

    y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
    y = torch.cat([y, sources[4]], dim=1)
    feature = self.upconv4(y)

    y = self.conv_cls(feature)

    return y.permute(0,2,3,1), feature

def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
newStateDict = OrderedDict()
for k, v in state_dict.items():
name = ".".join(k.split(".")[start_idx:])
newStateDict[name] = v
return newStateDict

def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False):
net = CRAFT()
if device == 'cpu':
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
# if quantize:
# try:
# torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True)
# print('Quantized model')
# except:
# pass
else:
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
print('Loaded weights from {}'.format(trained_model))
net = net.to(device)
cudnn.benchmark = cudnn_benchmark

net.eval()
return net

model = get_detector(trained_model=r'customDetect.pth', device='cuda:0', quantize=False)

input_shape = (1, 3, 480, 640)
inputs = torch.ones(*input_shape)
inputs = inputs.to('cuda:0')
input_names=['input']
output_names=['output']

dynamic_axes= {'input':{0:'batch_size', 2:'height', 3:'width'}, 'output':{0:'batch_size', 2:'height', 3:'width'}}
torch.onnx.export(model, inputs, r"D:\Sa\EasyOCR_ONNX\craft.onnx", dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)

Urgency

No response

Platform

Windows

OS Version

windows 11

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU, CUDA

Execution Provider Library Version

torch 2.0.1, cuDNN 11.7

Model File

No response

Is this a quantized model?

No

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider platform:windows issues related to the Windows platform quantization issues related to quantization labels Sep 11, 2023
@skottmckay
Copy link
Contributor

There are certain operations where it is more efficient to use CPU than CUDA (e.g. operations involving tensor shapes) so that's an expected message. You could run with log severity set to verbose to see exactly which nodes are executing on CPU.

How are you measuring performance? Ignore the first inference as a lot of things get initialized and cached based on that.

You probably need to use IOBinding so the input/output is on GPU otherwise the cost of copying data between CPU to GPU to run the model is included in the inferencing time. See https://onnxruntime.ai/docs/performance/tune-performance/iobinding.html

@MinGiSa
Copy link
Author

MinGiSa commented Sep 12, 2023

There are certain operations where it is more efficient to use CPU than CUDA (e.g. operations involving tensor shapes) so that's an expected message. You could run with log severity set to verbose to see exactly which nodes are executing on CPU.

How are you measuring performance? Ignore the first inference as a lot of things get initialized and cached based on that.

You probably need to use IOBinding so the input/output is on GPU otherwise the cost of copying data between CPU to GPU to run the model is included in the inferencing time. See https://onnxruntime.ai/docs/performance/tune-performance/iobinding.html

I measured performance with

print('detection ONNX') ort_session = onnxruntime.InferenceSession(onnxModelFile, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} start = time.time() ort_outs = ort_session.run(None, ort_inputs) finish = time.time() total_time_ms = (finish - start) * 1000 print(f'detection total processing time: {total_time_ms:.2f} ms')

anyway thank you for the answer. I' try to look up it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider platform:windows issues related to the Windows platform quantization issues related to quantization
Projects
None yet
Development

No branches or pull requests

2 participants