Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] onnxruntime-gpu 1.14.x is not thread safe #15154

Closed
mzchtx opened this issue Mar 22, 2023 · 6 comments
Closed

[Bug] onnxruntime-gpu 1.14.x is not thread safe #15154

mzchtx opened this issue Mar 22, 2023 · 6 comments
Assignees
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@mzchtx
Copy link

mzchtx commented Mar 22, 2023

Describe the issue

There is a thread-safe issue in onnxruntime-gpu 1.14.x. You can use the following configuration to reproduce this bug:

  • onnxruntime-gpu 1.14.x + CUDAExecutionProvider

image

The following configurations do not have this issue:

  • onnxruntime-gpu 1.14.x + CUDAExecutionProvider + CUDA_DEVICE_MAX_CONNECTIONS=1
  • onnxruntime-gpu 1.14.x + CPUExecutionProvider
  • onnxruntime-gpu 1.13.1 + CUDAExecutionProvider

image

Does this bug was introduced by the following feature?

image

To reproduce

The resnet18-v2-7.onnx model is from onnx/models/resnet18-v2-7.onnx

import threading

import cv2
import onnxruntime
import numpy as np


class Predictor:
    def __init__(self, model):
        options = onnxruntime.SessionOptions()
        cuda_options = {
            'device_id': 0,
            'gpu_mem_limit': 2*1024*1024*1024,
            'arena_extend_strategy': "kSameAsRequested",
        }
        self.predictor = onnxruntime.InferenceSession(
            model,
            sess_options=options,
            providers=[('CUDAExecutionProvider', cuda_options)]
        )

    @staticmethod
    def preprocess(img):
        img = cv2.resize(img, (224, 224))
        img = img.transpose((2, 0, 1))[::-1]
        img = np.ascontiguousarray(img)
        img = img / np.float32(255.0)
        img = img[None, :]
        return img

    def __call__(self, img):
        image = self.preprocess(img)
        input_name = self.predictor.get_inputs()[0].name
        pred = self.predictor.run(None, {input_name: image})[0]
        pred = pred.flatten().tolist()
        return pred


def worker(predictor, img, worker_idx=0, iter_num=1000):
    for i in range(iter_num):
        out = predictor(img)
        print(worker_idx, i, max(out))


def test():
    predictor = Predictor("./resnet18-v2-7.onnx")
    fname = "ILSVRC2012_val_00008640.JPEG"
    img = cv2.imread(fname)

    worker_num = 8
    iter_num = 1000
    workers = [
        threading.Thread(
            target=worker,
            args=(predictor, img, idx, iter_num)) for idx in range(worker_num)
    ]
    for work in workers:
        work.start()

    for work in workers:
        work.join()


if __name__ == "__main__":
    test()

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04.1 LTS

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14.x

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.6

@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Mar 22, 2023
@mzchtx
Copy link
Author

mzchtx commented Mar 23, 2023

I tried to compile from source code and reproduce the issue, and found that the bug was introduced by this commit 13495, which means that the program runs well when built from commit 13941, but when compiled from commit 13495, the issue can be reproduced.

image

@jslhcl
Copy link
Contributor

jslhcl commented Mar 23, 2023

@mzchtx Thanks for reporting and the script! We can repro this issue locally and it should be related with non-default stream. We are investing it now and will update this thread once there is any progress

@jslhcl
Copy link
Contributor

jslhcl commented Mar 24, 2023

@mzchtx Could you try to disable memory pattern by adding the following line after creating sessionOptions:

options.enable_mem_pattern = False

please let us know how it goes

@souptc
Copy link
Member

souptc commented Mar 24, 2023

there is a bug in the memory pattern feature for multi-stream. we are working on the fix, at meantime, @mzchtx, the "options.enable_mem_pattern=False" trick should be able to help you workaround it. Let's know whether it helps.

@mzchtx
Copy link
Author

mzchtx commented Mar 24, 2023

@mzchtx Could you try to disable memory pattern by adding the following line after creating sessionOptions:

options.enable_mem_pattern = False

please let us know how it goes

It works, options.enable_mem_pattern = False can solve this issue.

jslhcl added a commit that referenced this issue Apr 17, 2023
### Description
Create a stream in DeviceStreamCollection for memory pattern case to fix
the thread safe issue 15154



### Motivation and Context
This is to fix the bug 15154
#15154
@jslhcl
Copy link
Contributor

jslhcl commented Apr 17, 2023

checked in fix #15426

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

3 participants