Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ort-Genai benchmark run failed at GroupQueryAttention/RunRotaryEmbedding: Exception thrown at 0x00007FFE7576352C (onnxruntime.dll) in model_benchmark.exe: 0xC0000005: Access violation reading location 0x0000018FE4BB8080. #22252

Closed
liqunfu opened this issue Sep 27, 2024 · 3 comments

Comments

@liqunfu
Copy link
Contributor

liqunfu commented Sep 27, 2024

Describe the issue

When run onnxruntime-genai model_benchmark -l 128 -g 4 -i C:\example-models\phi2-int4-int8-blklen32-cpu, the program crashes with:
Exception thrown at 0x00007FFE7576352C (onnxruntime.dll) in model_benchmark.exe: 0xC0000005: Access violation reading location 0x0000018FE4BB8080.

To reproduce

build genai https://onnxruntime.ai/docs/genai/howto/build-from-source.html, use local onnxruntime built. Build in Debug for both onnxruntime and onnxruntime-genai for easier investigation.

Run model_benchmark -l 128 -g 4 -i C:\example-models\phi2-int4-int8-blklen32-cp:
got:

onnxruntime.dll!onnxruntime::contrib::RunRotaryEmbedding::__l2::(__int64 begin, __int64 end) Line 94 C++
[External Code]
onnxruntime.dll!onnxruntime::concurrency::ThreadPool::ParallelFor(__int64 n, const onnxruntime::TensorOpCost & c, const std::function<void __cdecl(__int64,__int64)> & f) Line 621 C++
onnxruntime.dll!onnxruntime::concurrency::ThreadPool::TryParallelFor(onnxruntime::concurrency::ThreadPool * tp, __int64 total, const onnxruntime::TensorOpCost & cost_per_unit, const std::function<void __cdecl(__int64,__int64)> & fn) Line 703 C++
onnxruntime.dll!onnxruntime::concurrency::ThreadPool::TryParallelFor(onnxruntime::concurrency::ThreadPool * tp, __int64 total, double cost_per_unit, const std::function<void __cdecl(__int64,__int64)> & fn) Line 251 C++
onnxruntime.dll!onnxruntime::contrib::RunRotaryEmbedding(onnxruntime::concurrency::ThreadPool * tp, onnxruntime::contrib::rotary_embedding_helper::RotaryParameters parameters, const float * input, const __int64 * position_ids, const float * cos_cache, const float * sin_cache, float * output, bool interleaved) Line 62 C++
onnxruntime.dll!onnxruntime::contrib::GroupQueryAttention::Compute(onnxruntime::OpKernelContext * context) Line 170 C++
onnxruntime.dll!onnxruntime::ExecuteKernel(onnxruntime::StreamExecutionContext & ctx, unsigned __int64 idx, unsigned __int64 stream_idx, const bool & terminate_flag, onnxruntime::SessionScope & session_scope) Line 495 C++
onnxruntime.dll!onnxruntime::LaunchKernelStep::Execute(onnxruntime::StreamExecutionContext & ctx, unsigned __int64 stream_idx, onnxruntime::SessionScope & session_scope, const bool & terminate_flag, bool & continue_flag) Line 73 C++
onnxruntime.dll!onnxruntime::RunSince(unsigned __int64 stream_idx, onnxruntime::StreamExecutionContext & ctx, onnxruntime::SessionScope & session_scope, const bool & terminate_flag, unsigned __int64 since) Line 222 C++
onnxruntime.dll!onnxruntime::ExecuteThePlan::__l23::() Line 589 C++
[External Code]
onnxruntime.dll!onnxruntime::concurrency::ThreadPool::Schedule(onnxruntime::concurrency::ThreadPool * tp, std::function<void __cdecl(void)> fn) Line 233 C++
onnxruntime.dll!onnxruntime::ExecuteThePlan(const onnxruntime::SessionState & session_state, gsl::span<int const ,-1> feed_mlvalue_idxs, gsl::span<OrtValue const ,-1> feeds, gsl::span<int const ,-1> fetch_mlvalue_idxs, std::vector<OrtValue,std::allocator> & fetches, const std::unordered_map<unsigned __int64,std::function<onnxruntime::common::Status __cdecl(onnxruntime::TensorShape const &,OrtDevice const &,OrtValue &,bool &)>,std::hash,std::equal_to,std::allocator<std::pair<unsigned __int64 const ,std::function<onnxruntime::common::Status __cdecl(onnxruntime::TensorShape const &,OrtDevice const &,OrtValue &,bool &)>>>> & fetch_allocators, const onnxruntime::logging::Logger & logger, const onnxruntime::DeviceStreamCollection * device_streams, const bool & terminate_flag, const bool only_execute_path_to_fetches, bool single_thread_mode) Line 588 C++
onnxruntime.dll!onnxruntime::utils::ExecuteGraphImpl(const onnxruntime::SessionState & session_state, const onnxruntime::FeedsFetchesManager & feeds_fetches_manager, gsl::span<OrtValue const ,-1> feeds, std::vector<OrtValue,std::allocator> & fetches, const std::unordered_map<unsigned __int64,std::function<onnxruntime::common::Status __cdecl(onnxruntime::TensorShape const &,OrtDevice const &,OrtValue &,bool &)>,std::hash,std::equal_to,std::allocator<std::pair<unsigned __int64 const ,std::function<onnxruntime::common::Status __cdecl(onnxruntime::TensorShape const &,OrtDevice const &,OrtValue &,bool &)>>>> & fetch_allocators, ExecutionMode execution_mode, const bool & terminate_flag, const onnxruntime::logging::Logger & logger, onnxruntime::DeviceStreamCollection * device_stream_collection, const bool only_execute_path_to_fetches, onnxruntime::Stream * parent_stream) Line 649 C++
onnxruntime.dll!onnxruntime::utils::ExecuteGraph(const onnxruntime::SessionState & session_state, onnxruntime::FeedsFetchesManager & feeds_fetches_manager, gsl::span<OrtValue const ,-1> feeds, std::vector<OrtValue,std::allocator> & fetches, ExecutionMode execution_mode, const bool & terminate_flag, const onnxruntime::logging::Logger & logger, onnxruntime::DeviceStreamCollectionHolder & device_stream_collection_holder, bool only_execute_path_to_fetches, onnxruntime::Stream * parent_stream) Line 752 C++
onnxruntime.dll!onnxruntime::utils::ExecuteGraph(const onnxruntime::SessionState & session_state, onnxruntime::FeedsFetchesManager & feeds_fetches_manager, gsl::span<OrtValue const ,-1> feeds, std::vector<OrtValue,std::allocator> & fetches, ExecutionMode execution_mode, const OrtRunOptions & run_options, onnxruntime::DeviceStreamCollectionHolder & device_stream_collection_holder, const onnxruntime::logging::Logger & logger) Line 774 C++
onnxruntime.dll!onnxruntime::InferenceSession::Run(const OrtRunOptions & run_options, gsl::span<std::string const ,-1> feed_names, gsl::span<OrtValue const ,-1> feeds, gsl::span<std::string const ,-1> output_names, std::vector<OrtValue,std::allocator> * p_fetches, const std::vector<OrtDevice,std::allocator> * p_fetches_device_info) Line 2608 C++
onnxruntime.dll!onnxruntime::InferenceSession::Run(const OrtRunOptions & run_options, gsl::span<char const * const,-1> feed_names, gsl::span<OrtValue const * const,-1> feeds, gsl::span<char const * const,-1> fetch_names, gsl::span<OrtValue *,-1> fetches) Line 2736 C++
onnxruntime.dll!OrtApis::Run(OrtSession * sess, const OrtRunOptions * run_options, const char * const * input_names, const OrtValue * const * input, unsigned __int64 input_len, const char * const * output_names, unsigned __int64 output_names_len, OrtValue * * output) Line 831 C++
model_benchmark.exe!OrtSession::Run(const OrtRunOptions * run_options, const char * const * input_names, const OrtValue * const * input_values, unsigned __int64 input_count, const char * const * output_names, OrtValue * * output_values, unsigned __int64 output_count) Line 819 C++
model_benchmark.exe!Generators::State::Run(OrtSession & session, OrtRunOptions & run_options, int new_batch_size) Line 70 C++
model_benchmark.exe!Generators::DecoderOnly_State::Run(int current_length, Generators::RoamingArray next_tokens, Generators::RoamingArray next_indices) Line 36 C++
model_benchmark.exe!Generators::Generator::ComputeLogits() Line 160 C++
model_benchmark.exe!Generators::Generate(const Generators::Model & model, const Generators::GeneratorParams & params) Line 236 C++
model_benchmark.exe!OgaGenerate(const OgaModel * model, const OgaGeneratorParams * generator_params, OgaSequences * * out) Line 237 C++
model_benchmark.exe!OgaModel::Generate(const OgaGeneratorParams & params) Line 71 C++
model_benchmark.exe!`anonymous namespace'::RunBenchmark(const benchmark::Options & opts) Line 163 C++
model_benchmark.exe!main(int argc, char * * argv) Line 240 C++

Urgency

yes

Platform

Windows

OS Version

2022

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

commit 7880342

ONNX Runtime API

C

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@edgchen1
Copy link
Contributor

edgchen1 commented Oct 1, 2024

here's the problematic code:

if (packed_qkv) {
OrtValue RotaryQKV;
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV);
q_input = Q.Get<Tensor>().Data<T>();
k_input = q_input + num_heads_ * sequence_length * head_size;
q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>();
k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
Q = RotaryQKV;
} else {

annotated:

    if (packed_qkv) {
      // Q is an OrtValue declared in the enclosing scope.
      OrtValue RotaryQKV;
      Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV);
      // Save pointer to Q's data in q_input.
      q_input = Q.Get<Tensor>().Data<T>();
      k_input = q_input + num_heads_ * sequence_length * head_size;
      q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>();
      k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
      // Overwrite Q with RotaryQKV (OrtValues contain shared_ptr to contained value).
      // Now, q_input is pointing to freed memory.
      Q = RotaryQKV;
    }

later on, when we use q_input, there is a read access violation.

ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
pos_ids.data(), cos_cache->Data<T>(),
sin_cache->Data<T>(), q_rotary, rotary_interleaved_));

this problem showed up when CPU allocator sharing between sessions was enabled. in that case, the CPU allocator's arena was disabled. I suspect that the default usage of the arena hid this issue.

though I debugged into the first branch, this appears to be a problem in both branches:

if (packed_qkv) {
OrtValue RotaryQKV;
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV);
q_input = Q.Get<Tensor>().Data<T>();
k_input = q_input + num_heads_ * sequence_length * head_size;
q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>();
k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
Q = RotaryQKV;
} else {
OrtValue RotaryQ;
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_, sequence_length, head_size}), allocator, RotaryQ);
OrtValue RotaryK;
Tensor::InitOrtValue(element_type, TensorShape({batch_size, kv_num_heads_, sequence_length, head_size}), allocator, RotaryK);
q_input = Q.Get<Tensor>().Data<T>();
k_input = K.Get<Tensor>().Data<T>();
q_rotary = RotaryQ.GetMutable<Tensor>()->MutableData<T>();
k_rotary = RotaryK.GetMutable<Tensor>()->MutableData<T>();
Q = RotaryQ;
K = RotaryK;
}

@edgchen1
Copy link
Contributor

edgchen1 commented Oct 1, 2024

microsoft/onnxruntime-genai#945 should be a workaround to unblock testing. it disables the allocator sharing by default.

aciddelgado added a commit that referenced this issue Oct 8, 2024
### Description
In GQA there was a memory issue which was best described by @edgchen1
[here](#22252 (comment))

> here's the problematic code:
> 
>
https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L149-L157
> 
> annotated:
> 
> ```c++
>     if (packed_qkv) {
>       // Q is an OrtValue declared in the enclosing scope.
>       OrtValue RotaryQKV;
> Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_
+ 2 * kv_num_heads_, sequence_length, head_size}), allocator,
RotaryQKV);
>       // Save pointer to Q's data in q_input.
>       q_input = Q.Get<Tensor>().Data<T>();
>       k_input = q_input + num_heads_ * sequence_length * head_size;
>       q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>();
>       k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
> // Overwrite Q with RotaryQKV (OrtValues contain shared_ptr to
contained value).
>       // Now, q_input is pointing to freed memory.
>       Q = RotaryQKV;
>     }
> ```
> 
> later on, when we use `q_input`, there is a read access violation.
> 
>
https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L170-L172
> 
> this problem showed up when CPU allocator sharing between sessions was
enabled. in that case, the CPU allocator's arena was disabled. I suspect
that the default usage of the arena hid this issue.
> 
> though I debugged into the first branch, this appears to be a problem
in both branches:
> 
>
https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L149-L168





### Motivation and Context
Fixes a crucial bug. The issue was found here
#22252
@skyline75489
Copy link
Contributor

Should we close this since #22290 is now merged?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants