Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(pytorch) onnx is slower than pytorch #342

Open
SoraJung opened this issue Jun 20, 2024 · 8 comments
Open

(pytorch) onnx is slower than pytorch #342

SoraJung opened this issue Jun 20, 2024 · 8 comments
Assignees

Comments

@SoraJung
Copy link

Describe the bug
I trained my custom dataset with rtdetr_r101vd_6x_coco_custom.yml. However, I found onnx is three times slower than pytorch.
I just run export_onnx.py in the github and saves model.onnx. Please review my inference code referenced from issue.

result

  1. onnx
python ./tools/predict_onnx.py -i ./images/D16030_196_Add00407.jpg

torch.Size([1, 3, 640, 640])
Inferece time = 0.421980619430542 s
FPS = 2.3697770796902677
  1. pytorch
`python ./tools/predict_pytorch.py -c ./configs/rtdetr/rtdetr_r101vd_6x_coco_custom.yml -w ../output/rtdetr_r101vd_6x_coco_custom/checkpoint0004.pth -i ./images/D16030_196_Add00407.jpg

Load PResNet101 state_dict
Inferece time = 0.15229344367980957 s
FPS = 6.566270850782369
  1. pytorch inference code
import argparse
from pathlib import Path
import time

class ImageReader:
    def __init__(self, resize=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.transform = transforms.Compose([
            # transforms.Resize((resize, resize)) if isinstance(resize, int) else transforms.Resize(
            #     (resize[0], resize[1])),
            transforms.ToTensor(),
            # transforms.Normalize(mean=mean, std=std),
        ])
        self.resize = resize
        self.pil_img = None   

    def __call__(self, image_path, *args, **kwargs):
        self.pil_img = Image.open(image_path).convert('RGB').resize((self.resize, self.resize))
        return self.transform(self.pil_img).unsqueeze(0)


class Model(nn.Module):
    def __init__(self, confg=None, ckpt="") -> None:
        super().__init__()
        self.cfg = YAMLConfig(confg, resume=ckpt)
        if ckpt:
            checkpoint = torch.load(ckpt, map_location='cpu') 
            if 'ema' in checkpoint:
                state = checkpoint['ema']['module']
            else:
                state = checkpoint['model']
        else:
            raise AttributeError('only support resume to load model.state_dict by now.')

        # NOTE load train mode state -> convert to deploy mode
        self.cfg.model.load_state_dict(state)

        self.model = self.cfg.model.deploy()
        self.postprocessor = self.cfg.postprocessor.deploy()
        # print(self.postprocessor.deploy_mode)
        
    def forward(self, images, orig_target_sizes):
        outputs = self.model(images)
        return self.postprocessor(outputs, orig_target_sizes)



def get_argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", '-c', type=str, )
    parser.add_argument("--ckpt", '-w', type=str, ) # pth
    parser.add_argument("--image", '-i', type=str, ) 
    parser.add_argument("--device", default="cuda:1")

    return parser


def main(args):
    img_path = Path(args.image)
    device = torch.device(args.device)
    reader = ImageReader(resize=640)
    model = Model(confg=args.config, ckpt=args.ckpt)
    model.to(device=device)

    img = reader(img_path).to(device)
    size = torch.tensor([[img.shape[2], img.shape[3]]]).to(device)
    
    start_time = time.time()
    output = model(img, size)
    inf_time = time.time() - start_time
    fps = float(1/inf_time)
    print("Inferece time = {} s".format(inf_time, '.4f'))
    print("FPS = {} ".format(fps, '.1f') )
    
    labels, boxes, scores = output
    
    im = reader.pil_img
    draw = ImageDraw.Draw(im)
    thrh = 0.6

    for i in range(img.shape[0]):

        scr = scores[i]
        lab = labels[i][scr > thrh]
        box = boxes[i][scr > thrh]

        for b in box:
            draw.rectangle(list(b), outline='red', )
            draw.text((b[0], b[1]), text=str(lab[i]), fill='blue', )

    # save_path = Path(args.output_dir) / img_path.name
    file_dir = os.path.dirname(args.image)
    new_file_name = os.path.basename(args.image).split('.')[0] + '_torch'+ os.path.splitext(args.image)[1]
    new_file_path = file_dir + '/' + new_file_name
    print('new_file_path: ', new_file_path)
    im.save(new_file_path)
 

if __name__ == "__main__":
    main(get_argparser().parse_args())
  1. onnx inference code
mport os 
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))

import torch
import onnxruntime as ort
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms import ToTensor
import argparse
import time

def main(args, ):
    im = Image.open(args.img).convert('RGB')
    im = im.resize((640, 640))
    im_data = ToTensor()(im)[None]
    # (width, height) = im.size
    print(im_data.shape)
    # print(width, height)
    # size = torch.tensor([[width, height]])
    size = torch.tensor([[640, 640]])
    sess = ort.InferenceSession(args.model)
    
    start_time = time.time()
    output = sess.run(
        # output_names=['labels', 'boxes', 'scores'],
        output_names=None,
        input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}        
    )
    end_time = time.time()
    # inf_time = time.time() - start_time
    inf_time = end_time - start_time
    fps = float(1/inf_time)
    print("Inferece time = {} s".format(inf_time, '.4f'))
    print("FPS = {} ".format(fps, '.1f') )
    #print(type(output))
    #print([out.shape for out in output])

    labels, boxes, scores = output
    
    draw = ImageDraw.Draw(im)  # Draw on the original image
    thrh = 0.6

    for i in range(im_data.shape[0]):

        scr = scores[i]
        lab = labels[i][scr > thrh]
        box = boxes[i][scr > thrh]

        #print(i, sum(scr > thrh))

        for b in box:
            draw.rectangle(list(b), outline='red',)
            # font = ImageFont.truetype("Arial.ttf", 15)
            draw.text((b[0], b[1]), text=str(lab[i]), fill='yellow', )

    # Save the original image with bounding boxes
    file_dir = os.path.dirname(args.img)
    new_file_name = os.path.basename(args.img).split('.')[0] + '_onnx'+ os.path.splitext(args.img)[1]
    new_file_path = file_dir + '/' + new_file_name
    print('new_file_path: ', new_file_path)
    im.save(new_file_path)
 

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--img', '-i', type=str, )
    parser.add_argument('--model', '-m', type=str, default='model.onnx')

    args = parser.parse_args()

    main(args)
@lyuwenyu
Copy link
Owner

Please check that your onnxruntime is using GPU.

# pip install onnxruntime-gpu

import onnxruntime as ort

print(ort.get_device())

@SoraJung
Copy link
Author

SoraJung commented Jun 20, 2024

Thank you for your prompt response. I tried your suggestion, but much slower than before. help me please...

  1. add my code
    providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
    sess_options = ort.SessionOptions()
    sess = ort.InferenceSession(args.model, sess_options=sess_options, providers=providers)
    
    start_time = time.time()
    output = sess.run(
        # output_names=['labels', 'boxes', 'scores'],
        output_names=None,
        input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}        
    )
    end_time = time.time()
    # inf_time = time.time() - start_time
    inf_time = end_time - start_time
    fps = float(1/inf_time)
    print("Inferece time = {:.4f} s".format(inf_time))
    print("FPS = {:.2f} ".format(fps))
  1. result
python ./tools/predict_onnx.py -i ./images/D16030_196_Add00407.jpg

ort.get_device() GPU
torch.Size([1, 3, 640, 640])
Inferece time = 19.2355 s
FPS = 0.05

@SoraJung
Copy link
Author

Now I found different providers have an effect on FPS. FPS increased by 2. ( 6.5 -> 8.5 ). But I'm not sure if this is right.
How about paddle? paddle onnx is faster than pytorch onnx?

  1. code
providers = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}), "CPUExecutionProvider"]
sess_options = ort.SessionOptions()
sess_options.enable_profiling = True
sess = ort.InferenceSession(args.model, sess_options=sess_options, providers=providers)
  1. result
python ./tools/predict_onnx.py -i ./images/D16030_196_Add00407.jpg

ort.get_device() GPU
torch.Size([1, 3, 640, 640])
Inferece time = 0.1174 s
FPS = 8.52

@lyuwenyu
Copy link
Owner

lyuwenyu commented Jun 20, 2024

    start_time = time.time()
    output = sess.run(
        # output_names=['labels', 'boxes', 'scores'],
        output_names=None,
        input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}        
    )
    end_time = time.time()

I think you can run this piece of code several times, then compute average time.

tic = time.time()
for _ in range(N)
    # code

average_time = (time.time() - tic) / N

@SoraJung
Copy link
Author

Thanks for your advice! I solved the problem. I modified the code, from one img to img directory.
pytorch average FPS 23.45, onnx average FPS 28.32 for 10 images!

  1. pytorch
python ./tools/predict_pytorch.py -c ./configs/rtdetr/rtdetr_r101vd_6x_coco_custom.yml -w ../output/rtdetr_r101vd_6x_coco_custom/checkpoint0004.pth -i ./images/input

img_path: images/input/D16030_196_Add00407.jpg, inf_time: 0.1581, FPS: 6.32
new_file_path:  images/output/D16030_196_Add00407_torch.jpg
================================================================================
Load PResNet101 state_dict
img_path: images/input/aihub3.jpg, inf_time: 0.0446, FPS: 22.40
new_file_path:  images/output/aihub3_torch.jpg
================================================================================
Load PResNet101 state_dict
img_path: images/input/D16030_196_Add00407_1.jpg, inf_time: 0.0430, FPS: 23.26
new_file_path:  images/output/D16030_196_Add00407_1_torch.jpg
================================================================================
.
.
All images count: 10
Average Inferece time = 0.0426 s
Average FPS = 23.45
  1. onnx
python ./tools/predict_onnx.py -i ./images/input/

img_path: ./images/input//D16030_196_Add00407.jpg, inf_time: 0.1257, FPS: 7.95
new_file_path:  images/output/D16030_196_Add00407_onnx.jpg
================================================================================
img_path: ./images/input//aihub3.jpg, inf_time: 0.0415, FPS: 24.09
new_file_path:  images/output/aihub3_onnx.jpg
================================================================================
img_path: ./images/input//D16030_196_Add00407_1.jpg, inf_time: 0.0414, FPS: 24.13
new_file_path:  images/output/D16030_196_Add00407_1_onnx.jpg
================================================================================
img_path: ./images/input//ytb_SterlingT_Suwon_0_000044_1.jpg, inf_time: 0.0415, FPS: 24.12
new_file_path:  images/output/ytb_SterlingT_Suwon_0_000044_1_onnx.jpg
================================================================================
img_path: ./images/input//aihub1.jpg, inf_time: 0.0354, FPS: 28.25
new_file_path:  images/output/aihub1_onnx.jpg
================================================================================
img_path: ./images/input//aihub.jpg, inf_time: 0.0352, FPS: 28.44
new_file_path:  images/output/aihub_onnx.jpg
.
.
All images count: 10
Average Inferece time = 0.0353 s
Average FPS = 28.32

@DaCheng1823
Copy link

@SoraJung Could you please send me the complete modified prediction codes again? I may need your codes. Thanks.

@sangkv
Copy link

sangkv commented Jun 21, 2024

@DaCheng1823 While waiting for @SoraJung , you can temporarily use this pytorch code, I edited it from his code to run on my CPU.

import argparse
from pathlib import Path
import sys
import time
import os
current_dir = os.path.dirname(os.path.realpath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(parent_dir)

from src.core import YAMLConfig 

import torch
from torch import nn
from PIL import Image, ImageDraw
from torchvision import transforms

class ImageReader:
    def __init__(self, resize=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.transform = transforms.Compose([
            # transforms.Resize((resize, resize)) if isinstance(resize, int) else transforms.Resize(
            #     (resize[0], resize[1])),
            transforms.ToTensor(),
            # transforms.Normalize(mean=mean, std=std),
        ])
        self.resize = resize
        self.pil_img = None   

    def __call__(self, image_path, *args, **kwargs):
        self.pil_img = Image.open(image_path).convert('RGB').resize((self.resize, self.resize))
        return self.transform(self.pil_img).unsqueeze(0)


class Model(nn.Module):
    def __init__(self, confg=None, ckpt="") -> None:
        super().__init__()
        self.cfg = YAMLConfig(confg, resume=ckpt)
        if ckpt:
            checkpoint = torch.load(ckpt, map_location='cpu') 
            if 'ema' in checkpoint:
                state = checkpoint['ema']['module']
            else:
                state = checkpoint['model']
        else:
            raise AttributeError('only support resume to load model.state_dict by now.')

        # NOTE load train mode state -> convert to deploy mode
        self.cfg.model.load_state_dict(state)

        self.model = self.cfg.model.deploy()
        self.postprocessor = self.cfg.postprocessor.deploy()
        # print(self.postprocessor.deploy_mode)
        
    def forward(self, images, orig_target_sizes):
        outputs = self.model(images)
        return self.postprocessor(outputs, orig_target_sizes)



def get_argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", '-c', type=str, )
    parser.add_argument("--ckpt", '-w', type=str, ) # pth
    parser.add_argument("--image", '-i', type=str, ) 
    parser.add_argument("--device", default="cpu")

    return parser


def main(args):
    img_path = Path(args.image)
    device = torch.device(args.device)
    reader = ImageReader(resize=640)
    model = Model(confg=args.config, ckpt=args.ckpt)
    model.to(device=device)

    img = reader(img_path).to(device)
    size = torch.tensor([[img.shape[2], img.shape[3]]]).to(device)
    
    start_time = time.time()
    output = model(img, size)
    inf_time = time.time() - start_time
    fps = float(1/inf_time)
    print("Inferece time = {} s".format(inf_time, '.4f'))
    print("FPS = {} ".format(fps, '.1f') )
    
    labels, boxes, scores = output
    
    im = reader.pil_img
    draw = ImageDraw.Draw(im)
    thrh = 0.6

    for i in range(img.shape[0]):

        scr = scores[i]
        lab = labels[i][scr > thrh]
        box = boxes[i][scr > thrh]

        for b in box:
            draw.rectangle(list(b), outline='red', )
            draw.text((b[0], b[1]), text=str(lab[i]), fill='blue', )

    # save_path = Path(args.output_dir) / img_path.name
    file_dir = os.path.dirname(args.image)
    new_file_name = os.path.basename(args.image).split('.')[0] + '_torch'+ os.path.splitext(args.image)[1]
    new_file_path = file_dir + '/' + new_file_name
    print('new_file_path: ', new_file_path)
    im.save(new_file_path)
 

if __name__ == "__main__":
    main(get_argparser().parse_args())

And this is the result:

python ./tools/predict_pytorch.py -c ./configs/rtdetr/rtdetr_r18vd_6x_coco.yml -w rtdetr_r18vd_5x_coco_objects365_from_paddle.pth -i ./data/input/i1.jpg
Load PResNet18 state_dict
Inferece time = 0.5181112289428711 s
FPS = 1.930087487276335 
new_file_path:  ./data/input/i1_torch.jpg

This was referenced Jun 21, 2024
@SoraJung
Copy link
Author

SoraJung commented Jun 22, 2024

@SoraJung Could you please send me the complete modified prediction codes again? I may need your codes. Thanks.

It's final code. Check please ^__^

import os 
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))

import torch
import onnxruntime as ort
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms import ToTensor
import argparse
import time
from pathlib import Path

def read_img(img_path):
    im = Image.open(img_path).convert('RGB')
    im = im.resize((640, 640))
    im_data = ToTensor()(im)[None]
    # (width, height) = im.size
    # print(im_data.shape)
    # print(width, height)
    # size = torch.tensor([[width, height]])
    size = torch.tensor([[640, 640]])
    return im, im_data, size

def createDirectory(directory):
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSError:
        print("Error: Failed to create the directory.")


def main(args, ):
    
    print("ort.get_device()", ort.get_device())
    providers = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}), "CPUExecutionProvider"]
    sess_options = ort.SessionOptions()
    sess_options.enable_profiling = True
    sess = ort.InferenceSession(args.model, sess_options=sess_options, providers=providers)
    
    img_path_list = []
    possible_img_extension = ['.jpg', '.jpeg', '.JPG', '.bmp', '.png'] # 이미지 확장자들
    for (root, dirs, files) in os.walk(args.img):
        if len(files) > 0:
            for file_name in files:
                if os.path.splitext(file_name)[1] in possible_img_extension:
                    img_path = root + '/' + file_name     
                    img_path_list.append(img_path)
    
    all_inf_time = []
    for img_path in img_path_list:
        im, im_data, size = read_img(img_path) 
        
        tic = time.time()
        output = sess.run(
            # output_names=['labels', 'boxes', 'scores'],
            output_names=None,
            input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}        
        )
        inf_time = time.time() - tic
        fps = float(1/inf_time)
        print('img_path: {}, inf_time: {:.4f}, FPS: {:.2f}'.format(img_path, inf_time, fps))
        all_inf_time.append(inf_time)
        
        #print(type(output))
        #print([out.shape for out in output])

        labels, boxes, scores = output
    
        draw = ImageDraw.Draw(im)  # Draw on the original image
        thrh = 0.6

        for i in range(im_data.shape[0]):

            scr = scores[i]
            lab = labels[i][scr > thrh]
            box = boxes[i][scr > thrh]

            #print(i, sum(scr > thrh))

            for b in box:
                draw.rectangle(list(b), outline='red',)
                # font = ImageFont.truetype("Arial.ttf", 15)
                draw.text((b[0], b[1]), text=str(lab[i]), fill='yellow', )

        # Save the original image with bounding boxes
        file_dir = Path(img_path).parent.parent / 'output'
        createDirectory(file_dir)
        new_file_name = os.path.basename(img_path).split('.')[0] + '_onnx'+ os.path.splitext(img_path)[1]
        new_file_path = file_dir / new_file_name
        print('new_file_path: ', new_file_path)
        print("================================================================================")
        im.save(new_file_path)
    
    avr_time = sum(all_inf_time) / len(img_path_list)
    avr_fps = float(1/avr_time)
    print('All images count: {}'.format(len(img_path_list)))
    print("Average Inferece time = {:.4f} s".format(inf_time))
    print("Average FPS = {:.2f} ".format(fps))
 
if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--img', '-i', type=str, )  # dir 
    parser.add_argument('--model', '-m', type=str, default='model.onnx')

    args = parser.parse_args()

    main(args)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants