Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Global (average) Pooling unusable. #15482

Closed
cloudhan opened this issue Apr 12, 2023 · 8 comments · Fixed by #15481
Closed

[ROCm] Global (average) Pooling unusable. #15482

cloudhan opened this issue Apr 12, 2023 · 8 comments · Fixed by #15481
Labels
ep:ROCm questions/issues related to ROCm execution provider

Comments

@cloudhan
Copy link
Member

Describe the issue

  1. Crash on some shapes
  2. Incorrect result on some shape

To reproduce

To reproduce a crash

Run the following single node model

import numpy as np
import onnx
import onnxruntime as ort

batch=1
channel=64
dim1 = 410
dim2 = 400

ort.set_default_logger_severity(0)
ort.set_default_logger_verbosity(1000)

x = onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT16, [batch, channel, dim1, dim2])
y = onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT16, [batch, channel, 1, 1])

node = onnx.helper.make_node("GlobalAveragePool", inputs=["x"], outputs=["y"])
graph = onnx.helper.make_graph([node], "GP", [x], [y])
model = onnx.helper.make_model(graph)

sess = ort.InferenceSession(
    model.SerializeToString(), providers=[("ROCMExecutionProvider", {"miopen_conv_use_max_workspace": False})]
)

x = np.random.randn(batch, channel, dim1, dim2).astype(np.float16)
sess.run(input_feed = {"x": x}, output_names = ["y"])

will create the following error:

MIOpen(HIP): Error [Do] 'amd_comgr_do_action(kind, handle, in.GetHandle(), out.GetHandle())' AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE: ERROR (1)
MIOpen(HIP): Error [BuildOcl] comgr status = ERROR (1)
MIOpen(HIP): Warning [BuildOcl] error: stack frame size (328004) exceeds limit (131056) in function 'mloPoolingG'
1 error generated.

MIOpen Error: /long_pathname_so_that_rpms_can_package_the_debug_info/data/driver/MLOpen/src/hipoc/hipoc_program.cpp:304: Code object build failed. Source: MIOpenPooling.cl
2023-04-12 08:27:32.942631957 [E:onnxruntime:Default, rocm_call.cc:119 RocmCall] MIOPEN failure 7: miopenStatusUnknownError ; GPU=0 ; hostname=linmif39a00000F ; file=/home/guangyunhan/onnxruntime/onnxruntime/core/providers/rocm/nn/pool.cc ; line=226 ; expr=PoolingForwardHelper(GetMiopenHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data); 
2023-04-12 08:27:32.942678735 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running GlobalAveragePool node. Name:'' Status Message: MIOPEN failure 7: miopenStatusUnknownError ; GPU=0 ; hostname=linmif39a00000F ; file=/home/guangyunhan/onnxruntime/onnxruntime/core/providers/rocm/nn/pool.cc ; line=226 ; expr=PoolingForwardHelper(GetMiopenHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data);

The maxmium shape it can run is

batch=1
channel=64
dim1 = 255
dim2 = 255

This fix the problem by switching the global pool to use reduction instead.

This problem impact the usability of ROCm EP of our internal users.

To reproduce an incorrect result

import numpy as np
import onnx
import onnxruntime as ort

batch=1
channel=3
dim1 = 255
dim2 = 255

ort.set_default_logger_severity(0)
ort.set_default_logger_verbosity(1000)

x = onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT16, [batch, channel, dim1, dim2])
y = onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT16, [batch, channel, 1, 1])

node = onnx.helper.make_node("GlobalAveragePool", inputs=["x"], outputs=["y"])
graph = onnx.helper.make_graph([node], "GP", [x], [y])
model = onnx.helper.make_model(graph)

x = np.random.uniform(low=0.0, high=1.10, size=(batch, channel, dim1, dim2)).astype(np.float16)

sess = ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
ref = sess.run(input_feed = {"x": x}, output_names = ["y"])[0]

sess = ort.InferenceSession(
    model.SerializeToString(), providers=[("ROCMExecutionProvider", {"miopen_conv_use_max_workspace": False})]
)
y = sess.run(input_feed = {"x": x}, output_names = ["y"])[0]

print(ref.shape)
print(y.shape)
print(ref)
print(y)
print("max relative error:", np.abs((ref-y)/ref).max())

Urgency

Must be fixed

Platform

Linux

OS Version

no apply

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

d49a8de

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Other / Unknown

Execution Provider Library Version

ROCm 5.4.2

@github-actions github-actions bot added the ep:ROCm questions/issues related to ROCm execution provider label Apr 12, 2023
@JehandadKhan
Copy link

Minimal MIOpenDriver command to reproduce the issue

Crash Issue:

MIOpen(HIP): Command [Pooling_logging_cmd] ./bin/MIOpenDriver poolfp16 -M 0 --input 1x64x410x400,10496000x164000x400x1 -y 410 -x 400 -p 0 -q 0 -v 1 -u 1 -m avg -F 1 -t 1

Incorrect Issue:

MIOpen(HIP): Command [Pooling_logging_cmd] ./bin/MIOpenDriver poolfp16 -M 0 --input 1x3x255x255,195075x65025x255x1 -y 255 -x 255 -p 0 -q 0 -v 1 -u 1 -m avg -F 1 -t 1

@atamazov
Copy link

atamazov commented Apr 24, 2023

@cloudhan are you interested to get a proper fix in MIOpen?

@cloudhan
Copy link
Member Author

Why not?

@atamazov
Copy link

@cloudhan Ok, I'll keep in informed.

@atamazov
Copy link

The correctness issue in MIOpen is fixed in ROCm/MIOpen#2118.

@atamazov
Copy link

@atamazov
Copy link

atamazov commented Sep 4, 2023

@cloudhan

FYI this is fixed for Backward pooling (except workspace index mask mode for Max pooling) in ROCm/MIOpen#2372.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:ROCm questions/issues related to ROCm execution provider
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants