From b832ca3d7a895e39a96d6938abd7a18f6800605d Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 25 Jul 2025 01:10:16 +1000 Subject: [PATCH 01/33] Support read-only allocator for use with initializers (#25348) ### Description Add new allocator type of OrtReadOnlyAllocator to enable providing a separate allocator that is only used for initializers. Update the SessionState logic to support this allocator type being provided, and use it when doing device allocations for initializers. ### Motivation and Context Performance. --- .../core/session/onnxruntime_c_api.h | 3 +- .../core/session/onnxruntime_ep_c_api.h | 10 ++++- onnxruntime/core/framework/session_state.cc | 39 ++++++++++++++++--- onnxruntime/core/framework/session_state.h | 18 ++++++--- .../core/framework/session_state_utils.cc | 24 ++++++++---- .../core/framework/simple_tensor_allocator.cc | 2 +- .../core/framework/tensor_allocator.cc | 5 ++- onnxruntime/core/framework/tensor_allocator.h | 1 + .../tensor_allocator_with_mem_pattern.h | 12 ++++-- onnxruntime/core/session/abi_devices.h | 3 ++ onnxruntime/core/session/ep_api.cc | 6 ++- .../session/ep_plugin_provider_interfaces.cc | 4 ++ onnxruntime/test/autoep/library/ep_factory.cc | 19 ++++++++- onnxruntime/test/autoep/library/ep_factory.h | 1 + 14 files changed, 117 insertions(+), 30 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9fd9e376cbf0d..8adc17b44826c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -421,7 +421,8 @@ typedef struct OrtCustomOp OrtCustomOp; typedef enum OrtAllocatorType { OrtInvalidAllocator = -1, OrtDeviceAllocator = 0, - OrtArenaAllocator = 1 + OrtArenaAllocator = 1, + OrtReadOnlyAllocator = 2, } OrtAllocatorType; /** \brief Memory types for allocated memory, execution provider specific types should be extended in each provider. diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index f7e304c98d7b5..1d9f9d00387ba 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -338,8 +338,14 @@ struct OrtEpApi { * The registered values will be used in calls to OrtEpFactory::CreateAllocator to ensure the required allocator/s * are available for EP usage. * - * At most one DEFAULT and one HOST_ACCESSIBLE entry can be added. - * Multiple calls for the same memory type will replace a previous entry. + * Multiple calls for the same entry type will replace a previous entry. + * + * Available entries: + * - OrtDeviceAllocator with type of OrtDeviceMemoryType_DEFAULT + * - OrtDeviceAllocator with type of OrtDeviceMemoryType_HOST_ACCESSIBLE + * - OrtReadOnlyAllocator with type of OrtDeviceMemoryType_DEFAULT + * - if provided this allocator will only be used to copy initializers to the device the EP uses. + * ORT will use the OrtDeviceAllocator if not provided. * * \param[in] ep_device The OrtEpDevice instance to register the OrtMemoryInfo with. * \param[in] allocator_memory_info The OrtMemoryInfo information for the allocator. diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 2cd5103b823d1..98cc2158eb0d0 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -89,7 +89,8 @@ SessionState::SessionState(Graph& graph, profiling::Profiler& profiler, const SessionOptions& sess_options, PrepackedWeightsContainer* prepacked_weights_container, - AllocatorMap* parent_allocators) + AllocatorMap* parent_allocators, + AllocatorMap* parent_initializer_allocators) : graph_(graph), execution_providers_(execution_providers), logger_(logger), @@ -109,16 +110,26 @@ SessionState::SessionState(Graph& graph, sess_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL; if (parent_allocators) { allocators_ = parent_allocators; + initializer_allocators_ = parent_initializer_allocators; } else { allocators_unique_ptr_ = std::make_unique(); allocators_ = allocators_unique_ptr_.get(); + + initializer_allocators_unique_ptr_ = std::make_unique(); + initializer_allocators_ = initializer_allocators_unique_ptr_.get(); + // The allocator registration rule: // Each location (OrtDevice) will only have 1 allocator used for whole session. - // The EP which is registered first will have higher priority + // The EP which is registered first will have higher priority. + // Allocators with a OrtAllocatorType of OrtReadOnlyAllocator go in the initializer allocators for (auto& ep : execution_providers_) { auto allocators = ep->CreatePreferredAllocators(); for (auto& alloc : allocators) { - allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key + if (alloc->Info().alloc_type == OrtReadOnlyAllocator) { + initializer_allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key + } else { + allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key + } } } } @@ -130,13 +141,29 @@ AllocatorPtr SessionState::GetAllocator(const OrtMemoryInfo& location) const noe AllocatorPtr SessionState::GetAllocator(const OrtDevice& device) const noexcept { auto it = allocators_->find(device); - if (it != allocators_->end()) return it->second; + if (it != allocators_->end()) { + return it->second; + } + return nullptr; } +AllocatorPtr SessionState::GetInitializerAllocator(const OrtDevice& device) const noexcept { + auto it = initializer_allocators_->find(device); + if (it != initializer_allocators_->end()) { + return it->second; + } + + return GetAllocator(device); +} + void SessionState::UpdateAllocatorsWithEnvAllocators(const std::vector& env_allocators) { for (const auto& env_alloc : env_allocators) { - (*allocators_)[env_alloc->Info().device] = env_alloc; + if (env_alloc->Info().alloc_type == OrtReadOnlyAllocator) { + (*initializer_allocators_)[env_alloc->Info().device] = env_alloc; + } else { + (*allocators_)[env_alloc->Info().device] = env_alloc; + } } } @@ -1158,7 +1185,7 @@ Status SessionState::CreateSubgraphSessionState() { std::make_unique(*subgraph, execution_providers_, thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, external_data_loader_mgr_, logger_, profiler_, sess_options_, - prepacked_weights_container_, allocators_); + prepacked_weights_container_, allocators_, initializer_allocators_); // Pass fused function manager to subgraph subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_); diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 71b88cb692f6f..e2102d95e1f17 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -98,7 +98,8 @@ class SessionState { profiling::Profiler& profiler, const SessionOptions& sess_options, PrepackedWeightsContainer* prepacked_weights_container = nullptr, - AllocatorMap* parent_allocators = nullptr); + AllocatorMap* parent_allocators = nullptr, + AllocatorMap* parent_initializer_allocators = nullptr); ~SessionState() { } @@ -127,6 +128,12 @@ class SessionState { /** Get the allocator for a given OrtDevice. The first allocator that matches will be returned. */ AllocatorPtr GetAllocator(const OrtDevice& device) const noexcept; + /** + Get an allocator for the given OrtDevice that is only used for read-only initializers. + Falls back to calling GetAllocator as needed. + */ + AllocatorPtr GetInitializerAllocator(const OrtDevice& device) const noexcept; + /* * Get allocators. */ @@ -464,17 +471,18 @@ class SessionState { } }; - // using std::map as OrtDevice would need a custom hash function to be used with std::unordered_map, - // and as this isn't considered performance critical currently it's not worth the maintenance overhead of adding one. - // We do get an allocator from ExecutionFrame so this is looked up frequently, however there most likely aren't many - // entries in the map // SessionState will contain other SessionState objects for subgraph. The unique ptr will be initialized only the // SessionState object is in the parent graph, the raw pointer will be initialized when session state is in parent // graph (from the unique ptr) or in the subgraph (from the raw pointer from parent session state). The raw pointer // will be used all the way to access std::map, unique pointer is only releasing the resource // when the parent session state is releasing. std::unique_ptr allocators_unique_ptr_; + // allocators with type of OrtAllocatorType::OrtReadOnlyAllocator that are used for initializers if found. + // if not we fallback to lookup in allocators_; + std::unique_ptr initializer_allocators_unique_ptr_; + AllocatorMap* allocators_; + AllocatorMap* initializer_allocators_; OrtValueNameIdxMap ort_value_name_idx_map_; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 8f0713fcd7cb1..17e337838b091 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -37,8 +37,10 @@ namespace session_state_utils { // The following method will allocate memory directly using the device allocator. // It can handle arena-based allocators and non-arena based allocators. -static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const TensorShape& tensor_shape, const DataTypeImpl* type, - const AllocatorPtr& alloc, /*out*/ void*& p_data) { +static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const TensorShape& tensor_shape, + const DataTypeImpl* type, + const AllocatorPtr& alloc, + /*out*/ void*& p_data) { size_t mem_size = 0; ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(type, tensor_shape, /*alignment*/ 0, mem_size)); @@ -76,13 +78,14 @@ static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const T * data loading, allocation, or copying operation fails. */ static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* memory_buffer, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const MemBuffer* memory_buffer, const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, const ExternalDataLoaderManager& external_data_loader_mgr, PrepackedWeightsForGraph& prepacked_for_graph, bool use_device_allocator_for_initializers = false) { - if (bool(alloc) == (memory_buffer != nullptr)) { + if (alloc != nullptr && memory_buffer != nullptr) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DeserializeTensorProto() takes either pre-allocated buffer or an allocator!"); } @@ -138,7 +141,8 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } } else { // for internal initializer, always allocate memory on device - tensor - ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); + ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape, + use_device_allocator_for_initializers, alloc)); if (device == default_cpu_device) { // deserialize directly to CPU tensor @@ -370,6 +374,9 @@ common::Status SaveInitializedTensors( AllocatorPtr alloc; // TODO: if the tensor need be copied, does it have enough room? ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, memory_buffer, alloc)); + + // ??? Should we ignore this session option if the EP is explictly providing the read only allocator? + // bool have_readonly_initializer_allocator = alloc->Info().alloc_type == OrtReadOnlyAllocator; const bool use_device_allocator_for_initializers = session_options.config_options.GetConfigOrDefault( kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; @@ -398,9 +405,10 @@ common::Status SaveInitializedTensors( // We need to deserialize the tensor proto into an OrtValue // using the preallocated buffer or allocator. - Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (memory_buffer.has_value()) ? &*memory_buffer : nullptr, alloc, - default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, - prepacked_for_graph, + Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, + (memory_buffer.has_value()) ? &*memory_buffer : nullptr, + alloc, default_cpu_alloc, ort_value, data_transfer_mgr, + external_data_loader_mgr, prepacked_for_graph, use_device_allocator_for_initializers); if (!st.IsOK()) { std::ostringstream oss; diff --git a/onnxruntime/core/framework/simple_tensor_allocator.cc b/onnxruntime/core/framework/simple_tensor_allocator.cc index ad9e0393baa01..d919e0c3c4a13 100644 --- a/onnxruntime/core/framework/simple_tensor_allocator.cc +++ b/onnxruntime/core/framework/simple_tensor_allocator.cc @@ -14,7 +14,7 @@ common::Status SimpleTensorAllocator::GetPreallocatedBuffer(int ort_value_index, AllocatorPtr& alloc_out) { const struct OrtDevice& location = seq_plan_.GetLocation(ort_value_index); // just return allocator and let others handle it. - alloc_out = GetAllocator(location); + alloc_out = GetInitializerAllocator(location); return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensor_allocator.cc b/onnxruntime/core/framework/tensor_allocator.cc index 9e81e8cd4783d..84d40d60a5d47 100644 --- a/onnxruntime/core/framework/tensor_allocator.cc +++ b/onnxruntime/core/framework/tensor_allocator.cc @@ -5,11 +5,14 @@ #include "simple_tensor_allocator.h" namespace onnxruntime { - AllocatorPtr ITensorAllocator::GetAllocator(const OrtDevice& device) { return session_state_.GetAllocator(device); } +AllocatorPtr ITensorAllocator::GetInitializerAllocator(const OrtDevice& device) { + return session_state_.GetInitializerAllocator(device); +} + std::unique_ptr ITensorAllocator::Create(bool enable_mem_pattern, const ExecutionPlanBase& execution_plan, const SessionState& session_state, diff --git a/onnxruntime/core/framework/tensor_allocator.h b/onnxruntime/core/framework/tensor_allocator.h index 923320681e683..daddfc7fd3cc0 100644 --- a/onnxruntime/core/framework/tensor_allocator.h +++ b/onnxruntime/core/framework/tensor_allocator.h @@ -26,6 +26,7 @@ class ITensorAllocator { InlinedVector& weights_buffers); AllocatorPtr GetAllocator(const OrtDevice& device); + AllocatorPtr GetInitializerAllocator(const OrtDevice& device); /** * diff --git a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h index ad88149c89b81..98179b96891b3 100644 --- a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h +++ b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h @@ -28,7 +28,7 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { planned_memory_sizes_in_byte.reserve(location_len); for (size_t i = 0; i < location_len; ++i) { auto& location = mem_patterns_.locations[i]; - auto alloc = GetAllocator(location); + auto alloc = GetInitializerAllocator(location); if (!alloc) return Status(common::ONNXRUNTIME, common::FAIL, "Failed to get allocator for location: " + location.ToString()); @@ -80,25 +80,28 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { if (!is_sealed_) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error."); } + const struct OrtDevice& location = seq_plan_.GetLocation(ort_value_index); auto pattern = mem_patterns_.GetPatterns(location); if (pattern == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mem pattern for initializer ", name, " is not found"); } + // if block is not found, means this ort_value is not traced // fall back to allocate separate buffer. // if it->second.get() is null, then fall back to the block not found case auto block = pattern->GetBlock(ort_value_index); if (nullptr == block) { // not traced, only return allocator - alloc_out = GetAllocator(location); + alloc_out = GetInitializerAllocator(location); return Status::OK(); } + auto it = buffers_.find(location); if (it == buffers_.end()) { if (block != nullptr && block->size_ == 0) { // Because the size is 0, this miss find is expected. we won't allocate a buffer with size of zero. - buf_out.emplace(nullptr, 0, GetAllocator(location)->Info()); + buf_out.emplace(nullptr, 0, GetInitializerAllocator(location)->Info()); return Status::OK(); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Weight buffer for initializer '", name, "' is not found"); @@ -108,7 +111,8 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Get preallocated buffer for initializer '", name, "' failed"); } - buf_out.emplace(reinterpret_cast(it->second) + block->offset_, block->size_, GetAllocator(location)->Info()); + buf_out.emplace(static_cast(it->second) + block->offset_, block->size_, + GetInitializerAllocator(location)->Info()); return Status::OK(); } common::Status Trace(int id, const ONNX_NAMESPACE::TensorProto* value) override { diff --git a/onnxruntime/core/session/abi_devices.h b/onnxruntime/core/session/abi_devices.h index 50469126996b2..571a9eb2a54e2 100644 --- a/onnxruntime/core/session/abi_devices.h +++ b/onnxruntime/core/session/abi_devices.h @@ -68,6 +68,9 @@ struct OrtEpDevice { const OrtMemoryInfo* device_memory_info{nullptr}; const OrtMemoryInfo* host_accessible_memory_info{nullptr}; + // used internally by ORT for initializers only. optional. + const OrtMemoryInfo* read_only_device_memory_info{nullptr}; + // the user provides const OrtEpDevice instances, but the OrtEpFactory API takes non-const instances for all // get/create methods to be as flexible as possible. this helper converts to a non-const factory instance. OrtEpFactory* GetMutableFactory() const { return ep_factory; } diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc index c49985d74c988..6c57f95719f41 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/ep_api.cc @@ -114,7 +114,11 @@ ORT_API_STATUS_IMPL(EpDevice_AddAllocatorInfo, _In_ OrtEpDevice* ep_device, const OrtDevice& info = allocator_memory_info->device; switch (info.MemType()) { case OrtDevice::MemType::DEFAULT: - ep_device->device_memory_info = allocator_memory_info; + if (allocator_memory_info->alloc_type == OrtReadOnlyAllocator) { + ep_device->read_only_device_memory_info = allocator_memory_info; + } else { + ep_device->device_memory_info = allocator_memory_info; + } break; case OrtDevice::MemType::HOST_ACCESSIBLE: ep_device->host_accessible_memory_info = allocator_memory_info; diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc index 52cf6c62c9702..c776020b037f0 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc @@ -135,6 +135,10 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio if (ep_device->host_accessible_memory_info != nullptr) { allocator_mem_infos_.push_back(ep_device->host_accessible_memory_info); } + + if (ep_device->read_only_device_memory_info != nullptr) { + allocator_mem_infos_.push_back(ep_device->read_only_device_memory_info); + } } } diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 019cf77a66b88..1cffb72c84879 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -47,6 +47,17 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL const OrtMemoryDevice* device = ep_api.MemoryInfo_GetMemoryDevice(default_memory_info_.get()); data_transfer_impl_ = std::make_unique(apis, device); + // create read-only allocator for use with initializers. same info as DEFAULT memory apart from the allocator type. + status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU readonly", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ 0xBE57, /* device_id */ 0, + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, + OrtAllocatorType::OrtReadOnlyAllocator, + &mem_info); + assert(status == nullptr); // should never fail. + + readonly_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + // HOST_ACCESSIBLE memory example. use the non-CPU device type so it's clear which device the memory is also // accessible from. we infer from the type of HOST_ACCESSIBLE that it's CPU accessible. mem_info = nullptr; @@ -121,7 +132,9 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* // register the allocator info required by the EP. // registering OrtMemoryInfo for host accessible memory would be done in an additional call. + // OrtReadOnlyAllocator + OrtDeviceMemoryType_DEFAULT allocator for use with initializers is optional. RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->default_memory_info_.get())); + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->readonly_memory_info_.get())); ep_devices[num_ep_devices++] = ep_device; } @@ -200,7 +213,10 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateAllocatorImpl(OrtEpFactory* this auto& factory = *static_cast(this_ptr); *allocator = nullptr; - if (memory_info != factory.default_memory_info_.get()) { + bool is_default_allocator = memory_info == factory.default_memory_info_.get(); + bool is_readonly_allocator = memory_info == factory.readonly_memory_info_.get(); + + if (!is_default_allocator && !is_readonly_allocator) { return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "INTERNAL ERROR! Unknown memory info provided to CreateAllocator. " "Value did not come directly from an OrtEpDevice returned by this factory."); @@ -209,6 +225,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateAllocatorImpl(OrtEpFactory* this // NOTE: The factory implementation is free to return a shared OrtAllocator* instance instead of creating a new // allocator on each call. To do this have an allocator instance as an OrtEpFactory class member and make // ReleaseAllocatorImpl a no-op. + // auto cpu_allocator = std::make_unique(memory_info, factory); *allocator = cpu_allocator.release(); return nullptr; diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 4b286928a79eb..60c9f63b78b8c 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -68,6 +68,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. using MemoryInfoUniquePtr = std::unique_ptr>; MemoryInfoUniquePtr default_memory_info_; + MemoryInfoUniquePtr readonly_memory_info_; // used for initializers std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory }; From bfa2c912348f1ba8bf92dd467454dc7a975ee2d1 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 24 Jul 2025 10:33:55 -0700 Subject: [PATCH 02/33] Move Linux CUDA pipelines to H100 (#25523) --- .github/workflows/linux_cuda_ci.yml | 6 +++--- .github/workflows/linux_tensorrt_ci.yml | 6 +++--- onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index 38526e7a5c00f..f4ee8a7c27cd0 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -29,7 +29,7 @@ jobs: dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1' docker_image_repo: onnxruntimecuda12manylinuxbuild - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --enable_cuda_profiling --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --enable_cuda_profiling --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' run_tests: false # <<< Do not run tests in this job upload_build_output: true # <<< Upload the build/Release directory @@ -42,7 +42,7 @@ jobs: needs: build-linux-cuda-x64-release runs-on: - self-hosted - - "1ES.Pool=Onnxruntime-github-Linux-GPU-A100-WUS3" + - "1ES.Pool=Onnxruntime-github-Linux-GPU-H100" permissions: contents: read packages: read @@ -99,5 +99,5 @@ jobs: build_config: Release mode: 'test' # Set mode to test execution_providers: 'cuda' - extra_build_flags: '--use_binskim_compliant_compile_flags --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 1df467043329a..a7d3f5ec0f5fd 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -29,7 +29,7 @@ jobs: dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 --build-arg TRT_VERSION=10.9.0.34-1.cuda12.8 --network=host' docker_image_repo: onnxruntimetensorrt86gpubuild - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' run_tests: false # <<< Do not run tests in this job upload_build_output: true # <<< Upload the build/Release directory @@ -42,7 +42,7 @@ jobs: needs: build-linux-TensorRT-x64-release runs-on: - self-hosted - - "1ES.Pool=Onnxruntime-github-Linux-GPU-A100-WUS3" + - "1ES.Pool=Onnxruntime-github-Linux-GPU-H100" permissions: contents: read packages: read @@ -101,5 +101,5 @@ jobs: build_config: Release mode: 'test' # Set mode to test execution_providers: 'cuda tensorrt' - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' diff --git a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h index 1aea58c8d7a10..a49f662ca1adb 100644 --- a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h +++ b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h @@ -46,7 +46,7 @@ } else if (std::is_same::value) { \ MAKE_PROVIDERS_EPS_EXT(2e-4, pad_to_nc1d) \ } else { \ - MAKE_PROVIDERS_EPS_EXT(2e-3, pad_to_nc1d) \ + MAKE_PROVIDERS_EPS_EXT(4e-3, pad_to_nc1d) \ } #define MAKE_PROVIDERS_EPS_TYPE(T) \ From 9001123f6813409bce2d8ec24888ac73e348c26e Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Thu, 24 Jul 2025 23:58:38 +0530 Subject: [PATCH 03/33] [OVEP] OpenVINO EP Features for ORT 1.23 Release Patch (#25525) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This PR patches the features provided for this PR https://github.com/microsoft/onnxruntime/pull/25476, this provides a stable fix for the GPU plugin with upcoming OV toolkit v2025.2.1 --------- Signed-off-by: Jianhui Dai Signed-off-by: dependabot[bot] Signed-off-by: bfilipek Co-authored-by: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Co-authored-by: n1harika Co-authored-by: sfatimar Co-authored-by: Jaskaran Singh Nagi Co-authored-by: Eric Crawford Co-authored-by: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Co-authored-by: Scott McKay Co-authored-by: Seungtaek Kim Co-authored-by: co63oc Co-authored-by: Jambay Kinley Co-authored-by: Hector Li Co-authored-by: Jian Chen Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Jiajia Qin Co-authored-by: Alessio Soldano Co-authored-by: Changming Sun Co-authored-by: Ashish Garg Co-authored-by: Ashish Garg Co-authored-by: Jie Chen Co-authored-by: wp Co-authored-by: Satya Kumar Jandhyala Co-authored-by: Prathik Rao Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tianlei Wu Co-authored-by: Jianhui Dai Co-authored-by: xhcao Co-authored-by: Wanming Lin Co-authored-by: Mark Schofield Co-authored-by: jiangzhaoming Co-authored-by: Yi-Hong Lyu Co-authored-by: vraspar Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: saurabh Co-authored-by: Ranjit Ranjan <165394499+ranjitshs@users.noreply.github.com> Co-authored-by: Baiju Meswani Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: jatinwadhwa921 Co-authored-by: Pallavi Gupta Co-authored-by: Nikolay Proshunin Co-authored-by: Preetha Veeramalai Co-authored-by: Javier Martinez Co-authored-by: Bartlomiej Filipek Co-authored-by: bopeng1234 Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: TejalKhade28 Co-authored-by: Vishnudas Thaniel S Co-authored-by: Yaru Du Co-authored-by: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Co-authored-by: Dvoretckii, Mikhail Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Adrian Lizarraga Co-authored-by: Fei Chen Co-authored-by: qti-yuduo Co-authored-by: Akupadhye Co-authored-by: Wang Ning Co-authored-by: Maximilian Müller <44298237+gedoensmax@users.noreply.github.com> Co-authored-by: George Wu Co-authored-by: quic-calvnguy Co-authored-by: Wei-Sheng Chin Co-authored-by: quic-hungjuiw Co-authored-by: Ian Hunter Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Co-authored-by: Jeff Kilpatrick Co-authored-by: Jeff Kilpatrick Co-authored-by: Nenad Banfic <46795300+nenad1002@users.noreply.github.com> Co-authored-by: derdeljan-msft Co-authored-by: Ryan Metcalfe --- onnxruntime/core/providers/openvino/ov_factory.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_factory.cc b/onnxruntime/core/providers/openvino/ov_factory.cc index 9f0871b14e92f..8860405338409 100644 --- a/onnxruntime/core/providers/openvino/ov_factory.cc +++ b/onnxruntime/core/providers/openvino/ov_factory.cc @@ -103,12 +103,12 @@ OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice* const auto& ov_device_type = device_it->second; std::string ov_device_name; - auto get_pci_device_id = [&](const std::string& ov_device) { + auto get_gpu_device_id = [&](const std::string& ov_device) { try { - ov::device::PCIInfo pci_info = ov_core_->get_property(ov_device, ov::device::pci_info); - return pci_info.device; + auto device_id_str = ov_core_->get_property(ov_device, "GPU_DEVICE_ID").as(); + return static_cast(std::stoul(device_id_str, nullptr, 0)); } catch (ov::Exception&) { - return 0u; // If we can't get the PCI info, we won't have a device ID. + return 0u; // If we can't get the GPU_DEVICE_ID info, we won't have a device ID. } }; @@ -118,7 +118,7 @@ OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice* // If there are multiple devices of the same type, we need to match by device ID. matched_device = std::find_if(filtered_devices.begin(), filtered_devices.end(), [&](const std::string& ov_device) { uint32_t ort_device_id = ort_api.HardwareDevice_DeviceId(&device); - return ort_device_id == get_pci_device_id(ov_device); + return ort_device_id == get_gpu_device_id(ov_device); }); } From 52fd75f63c7af5917273035c75954df62012caea Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 24 Jul 2025 13:05:17 -0700 Subject: [PATCH 04/33] Fix QNN SDK download problem (#25520) Previously the machine pool had a User-assigned managed identity (UMI) which was used for accessing the blob storage. Now the UMI was removed. to improve security. Therefore we baked the data into the VM image instead. --- ...arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- .../templates/android-java-api-aar.yml | 2 +- .../templates/jobs/init_linux_qnn_sdk_x64.yml | 42 +++++++++++++++++++ .../templates/py-linux-qnn.yml | 2 +- 4 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index e5e2a4749ef85..91f35d2b54033 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -52,7 +52,7 @@ jobs: - script: sudo chmod go+rw /dev/kvm displayName: Update permissions to KVM - - template: templates/jobs/download_linux_qnn_sdk.yml + - template: templates/jobs/init_linux_qnn_sdk_x64.yml parameters: QnnSDKVersion: ${{ parameters.QnnSdk }} diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 92e862bd79008..bbb84642320fb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -103,7 +103,7 @@ jobs: - template: use-android-ndk.yml - ${{ if contains(parameters.packageName, 'qnn') }}: - - template: jobs/download_linux_qnn_sdk.yml + - template: jobs/init_linux_qnn_sdk_x64.yml parameters: QnnSDKVersion: '${{parameters.QnnSDKVersion}}' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml new file mode 100644 index 0000000000000..b7fb8a51f28be --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml @@ -0,0 +1,42 @@ +parameters: + - name: QnnSDKVersion + type: string + default: '2.36.1.250708' + +steps: + - bash: | + echo "##vso[task.setvariable variable=QnnSDKRootDir]/data/qnnsdk/qnn-v${{ parameters.QnnSDKVersion }}" + displayName: Set QnnSDKRootDir + + - script: | + echo $(QnnSDKRootDir) + displayName: 'Print QnnSDKRootDir after downloading QNN SDK' + + - script: | + set -x + sdk_file="$(QnnSDKRootDir)/sdk.yaml" + # Parse the sdk.yaml file to get the QNN SDK version downloaded + downloaded_qnn_sdk_version=$(grep '^version:' "$sdk_file" | head -n 1 | cut -d':' -f2 | xargs | cut -d'.' -f1-3 | tr -d '\r') + + # Extract major.minor.patch part from QnnSDKVersion passed as parameter + expected_qnn_sdk_version=$(echo ${{ parameters.QnnSDKVersion }} | cut -d'.' -f1-3) + + if [[ -z "$downloaded_qnn_sdk_version" ]]; then + echo "QNN version not found in sdk.yaml." + exit 1 + fi + + # Compare provided version with version from sdk.yaml + if [[ "$downloaded_qnn_sdk_version" == "$expected_qnn_sdk_version" ]]; then + echo "Success: QnnSDKVersion matches sdk.yaml version ($downloaded_qnn_sdk_version)." + else + echo "Error: QnnSDKVersion ($expected_qnn_sdk_version) does not match sdk.yaml version ($downloaded_qnn_sdk_version) in the QNN SDK directory" + exit 1 + fi + displayName: "Sanity Check: QnnSDKVersion vs sdk.yaml version" + + + + - script: | + ls -al $(QnnSDKRootDir) + displayName: 'Print contents of QNN SDK' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index 788ceff8fd4f2..2168214527c91 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -58,7 +58,7 @@ jobs: clean: true submodules: none - - template: jobs/download_linux_qnn_sdk.yml + - template: jobs/init_linux_qnn_sdk_x64.yml parameters: QnnSDKVersion: ${{ parameters.QnnSdk }} From d7029780256625d650cd2a237a3d1cc6ef1ed211 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 24 Jul 2025 14:47:17 -0700 Subject: [PATCH 05/33] [EP ABI] API to get external initializer info + lazy load external OrtValues (#25482) ### Description - Adds APIs to get information (file path, file offset, byte size) for initializers with data in external files. This allows EPs to do their own custom memory-mapping of initializer data. By default, EPs that don't have specific requirements can still use `ValueInfo_GetInitializerValue` to get an `OrtValue` with memory-mapped initializer data. - Updates `OrtGraph` to only load `OrtValue` for external initializers on demand. This prevents having to memory map all external initializers before the first call to `OrtEp::GetCapability`. Follow up to https://github.com/microsoft/onnxruntime/pull/25320 New API functions: | Function | Summary| |-----------|--------------| | `ValueInfo_GetExternalInitializerInfo` | Get `OrtExternalInitializerInfo` from `OrtValueInfo` (or `NULL`). Must be released with `ReleaseExternalInitializerInfo`| | `ReleaseExternalInitializerInfo` | Releases the `OrtExternalInitializerInfo` instance | | `ExternalInitializerInfo_GetFilePath` | Returns the relative path to the file that stores the initializer's data | | `ExternalInitializerInfo_GetFileOffset` | Returns the byte offset within the file where the initializer's data is stored | | `ExternalInitializerInfo_GetByteSize` | Returns the size in bytes of the initializer's data within the file | ### Motivation and Context --------- Co-authored-by: Dmitri Smirnov Co-authored-by: Scott McKay --- include/onnxruntime/core/graph/graph.h | 22 +++ .../core/providers/utils/ort_graph_to_proto.h | 6 + .../core/session/onnxruntime_c_api.h | 72 ++++++++- .../core/framework/tensorprotoutils.cc | 4 + onnxruntime/core/framework/tensorprotoutils.h | 7 + onnxruntime/core/graph/abi_graph_types.h | 15 ++ onnxruntime/core/graph/ep_api_types.cc | 76 ++++++++-- onnxruntime/core/graph/ep_api_types.h | 6 +- onnxruntime/core/graph/graph.cc | 44 +++++- .../core/graph/model_editor_api_types.h | 6 + onnxruntime/core/session/onnxruntime_c_api.cc | 34 +++++ onnxruntime/core/session/ort_apis.h | 8 + onnxruntime/test/ep_graph/test_ep_graph.cc | 141 ++++++++++++++++-- .../test/ep_graph/test_ep_graph_utils.h | 24 +++ 14 files changed, 436 insertions(+), 29 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index ea9cbbfc6ca73..e164f23b8fc35 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -44,6 +44,7 @@ struct OrtGraph; namespace onnxruntime { +class ExternalDataInfo; class Graph; struct IndexedSubGraph; class Model; @@ -788,6 +789,27 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi */ bool GetOrtValueInitializer(const std::string& name, OrtValue& value, bool check_outer_scope = false) const; + /// + /// Loads an initializer with data in an external file into an OrtValue. Does NOT cache the OrtValue + /// in this Graph. + /// + /// The name of the initializer. + /// Output parameter set to the loaded OrtValue. Set to an existing OrtValue if + /// it is already loaded. + /// A status indicating an error or success. An error occurs if `name` is not an initializer + /// with external data. + Status LoadExternalInitializerAsOrtValue(const std::string& name, OrtValue& value) const; + + /// + /// Gets information (external filepath, file offset, num bytes) for an initializer with data in an external file. + /// + /// The initializer's name. + /// Output parameter set to the location information of the external data. + /// Set to true if parent graphs should be checked. + /// True if `name` refers to an initializer with data in an external file. Otherwise, returns false + bool GetExternalInitializerInfo(const std::string& name, std::unique_ptr& ext_info, + bool check_outer_scope = false) const; + /** Gets all the initializer tensors in this Graph. */ const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; } diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index ce0f134002d8e..b7311f70cd179 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// DO NOT include ORT header files as this is meant to be a header-only utility that can be copied +// to other projects. + /* SUMMARY: Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider @@ -494,11 +497,14 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, auto* ext_data_entries = tensor_proto->mutable_external_data(); onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); location_entry->set_key("location"); location_entry->set_value(ext_location); offset_entry->set_key("offset"); offset_entry->set_value(std::to_string(ext_offset)); + length_entry->set_key("length"); + length_entry->set_value(std::to_string(data_bytes)); } else { // User wants to store data inline the TensorProto's raw_data tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 8adc17b44826c..d70806e1a5a87 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -324,6 +324,7 @@ ORT_RUNTIME_CLASS(HardwareDevice); ORT_RUNTIME_CLASS(EpDevice); ORT_RUNTIME_CLASS(KeyValuePairs); ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream. +ORT_RUNTIME_CLASS(ExternalInitializerInfo); #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -5484,10 +5485,13 @@ struct OrtApi { * * Supports initializers defined in an outer scope (i.e., a parent graph). * + * Supports initializers stored in an external file. For external initializers, ORT memory maps + * the initializer data on the first call to this function. If caller needs custom memory mapping, + * use ValueInfo_GetExternalInitializerInfo to get the location of the initializer data. + * * \param[in] value_info The OrtValueInfo instance. - * \param[out] initializer_value Output parameter set to the initializer value or NULL. The OrtValue data pointer - * (obtained via GetTensorData) is stable during the lifetime of the OrtSession - * that owns the OrtGraph. + * \param[out] initializer_value Output parameter set to the initializer value or NULL. Do not cache the OrtValue + * as it is released when the owning OrtGraph is released. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -5496,6 +5500,24 @@ struct OrtApi { ORT_API2_STATUS(ValueInfo_GetInitializerValue, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtValue** initializer_value); + /** \brief Get information about an external initializer (e.g., filepath, file offset, byte size). + * + * Sets the output parameter `info` to NULL if the given OrtValueInfo does not represent an initializer + * with external data. In this case, a NULL status (non-error) is returned. + * + * \param[in] value_info The OrtValueInfo instance. + * \param[out] info Output parameter set to an OrtExternalInitializerInfo instance that can be used to query + * file path, file offset, etc. ORT sets this to NULL if the OrtValueInfo does not represent + * an external initializer. + * Must release with ReleaseExternalInitializerInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_GetExternalInitializerInfo, _In_ const OrtValueInfo* value_info, + _Outptr_result_maybenull_ OrtExternalInitializerInfo** info); + /** \brief Returns a boolean indicating if the given value is a required graph input. * * For ONNX IR version < 4, all graph inputs without a matching initializer are required. @@ -6085,6 +6107,50 @@ struct OrtApi { /// @} + /// \name OrtExternalInitializerInfo + /// @{ + + /** \brief Release an OrtExternalInitializerInfo instance. + * + * \param[in] input OrtExternalInitializerInfo instance to be released. + * + * \since Version 1.23. + */ + ORT_CLASS_RELEASE(ExternalInitializerInfo); + + /** \brief Get the relative path to the file that stores the initializer's data. + * + * \note The path is relative to the filesystem directory where the ONNX model was stored. + * Caller can use Graph_GetModelPath to get the model's full path and construct the absolute path to the + * external initializer file if necessary. + * + * \param[in] info The OrtExternalInitializerInfo instance. + * \return The relative path to the file that stores the initializer's data. Do NOT free this pointer. + * + * \since Version 1.23. + */ + ORT_API_T(const ORTCHAR_T*, ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info); + + /** \brief Get the byte offset within the file where the initializer's data is stored. + * + * \param[in] info The OrtExternalInitializerInfo instance. + * \return The byte offset where the initializer's data is stored within the file. + * + * \since Version 1.23. + */ + ORT_API_T(int64_t, ExternalInitializerInfo_GetFileOffset, _In_ const OrtExternalInitializerInfo* info); + + /** \brief Get the size in bytes of the initializer's data within the file. + * + * \param[in] info The OrtExternalInitializerInfo instance. + * \return The size in bytes of the initializer's data within the file. + * + * \since Version 1.23. + */ + ORT_API_T(size_t, ExternalInitializerInfo_GetByteSize, _In_ const OrtExternalInitializerInfo* info); + + /// @} + /// \name OrtRunOptions /// @{ diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 5ba0b908edaf5..ff440b595e499 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -252,6 +252,10 @@ bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) { return false; // No external data in memory } +bool HasExternalDataInFile(const ONNX_NAMESPACE::TensorProto& tensor_proto) { + return HasExternalData(tensor_proto) && !HasExternalDataInMemory(tensor_proto); +} + Status TensorProtoWithExternalDataToTensorProto( const ONNX_NAMESPACE::TensorProto& ten_proto, const std::filesystem::path& model_path, diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index e9148243f98b1..01086f38c8823 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -502,6 +502,13 @@ inline bool HasName(const ONNX_NAMESPACE::TypeProto_Opaque& op_proto) { /// true if ten_proto has external data and it is in memory [[nodiscard]] bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto); +/// +/// Check if the given tensor proto has external data store in a file (not in memory). +/// +/// +/// +[[nodiscard]] bool HasExternalDataInFile(const ONNX_NAMESPACE::TensorProto& tensor_proto); + /// /// This function converts TensorProto with external data to TensorProto with inline data. /// diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 0e939f7986aac..6383d29d7a2bc 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -7,6 +7,7 @@ #include #include #include "core/common/inlined_containers_fwd.h" +#include "core/framework/tensor_external_data_info.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/onnx_protobuf.h" @@ -29,6 +30,9 @@ enum class OrtGraphIrApi { kEpApi, }; +// Alias OrtExternalInitializerInfo to the internal type. +struct OrtExternalInitializerInfo : onnxruntime::ExternalDataInfo {}; + /// /// Public type that represents an ONNX value info. /// @@ -94,6 +98,17 @@ struct OrtValueInfo { /// A status indicating success or an error. virtual onnxruntime::Status GetInitializerValue(const OrtValue*& value) const = 0; + /// + /// Get information (file path, file offset, byte size) if this OrtValueInfo represents an initializer with + /// data in an external file. + /// + /// Output parameter set to the external initializer info or NULL if this is not an external + /// initializer. + /// A status indicating an error or success. Calling this function on an OrtValueInfo that does not represent + /// an external initializer is NOT an error. + virtual onnxruntime::Status GetExternalInitializerInfo( + std::unique_ptr& ext_info) const = 0; + /// /// Determine if the value is a required graph input. /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index b7e5351556c61..4ceadb6191a9b 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -15,6 +15,7 @@ #include #include "core/framework/allocator.h" +#include "core/framework/tensor_external_data_info.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/graph_viewer.h" @@ -452,11 +453,29 @@ Status EpValueInfo::GetInitializerValue(const OrtValue*& result) const { // This gets an initializer value defined in this graph or in a parent graph (as long as the value // is used in this graph). - result = graph_->GetInitializerValue(name_); + ORT_RETURN_IF_ERROR(graph_->GetInitializerValue(name_, result)); ORT_RETURN_IF(result == nullptr, "Unable to find initializer value named '", name_, "'."); return Status::OK(); } +Status EpValueInfo::GetExternalInitializerInfo(std::unique_ptr& result) const { + if (!IsFlagSet(kIsConstantInitializer) && !IsFlagSet(kIsOptionalGraphInput)) { + result = nullptr; + return Status::OK(); + } + + ORT_RETURN_IF(graph_ == nullptr, "Unable to get external initializer information for value named '", + name_, "': parent graph is NULL"); + + const onnxruntime::Graph& graph = graph_->GetGraphViewer().GetGraph(); + + if (!graph.GetExternalInitializerInfo(name_, result, /*check_outer_scope*/ true)) { + result = nullptr; + } + + return Status::OK(); +} + Status EpValueInfo::IsRequiredGraphInput(bool& is_required_graph_input) const { is_required_graph_input = IsFlagSet(Flags::kIsRequiredGraphInput); return Status::OK(); @@ -593,15 +612,18 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& initializer_value_infos.push_back(value_info); // Initialize OrtValue for the initializer. + // Note: using std::unique_ptr because we return a OrtValue* to the user and we want it to be stable. auto initializer_value = std::make_unique(); bool graph_has_ortvalue = graph_viewer.GetGraph().GetOrtValueInitializer(initializer_name, *initializer_value, /*check_outer_scope*/ false); if (!graph_has_ortvalue) { - // onnxruntime::Graph does not have an OrtValue for this initializer, so create one from the TensorProto. - // This should only happen for small initializers that are needed for ONNX shape inferencing. - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), *tensor_proto, - initializer_allocator, *initializer_value)); + // Copy to OrtValue if not external. This should only happen for small initializers. + // Do nothing for external initializers, as we will load/mmap into an OrtValue on demand. + if (!utils::HasExternalDataInFile(*tensor_proto)) { + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), *tensor_proto, + initializer_allocator, *initializer_value)); + } } initializer_values.emplace(value_info->GetName(), std::move(initializer_value)); @@ -650,8 +672,10 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& } EpValueInfo* outer_value_info = value_info_iter->second.get(); - bool is_constant = false; + + // Note: using std::unique_ptr because we return a OrtValue* to the user and we want it to be stable. auto outer_initializer_value = std::make_unique(); + bool is_constant = false; const ONNX_NAMESPACE::TensorProto* outer_initializer = parent_graph->GetInitializer(implicit_name, *outer_initializer_value, is_constant, @@ -665,11 +689,13 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& // Add the OrtValue if this is an initializer. if (outer_initializer != nullptr) { if (!outer_initializer_value->IsAllocated()) { - // onnxruntime::Graph does not have an OrtValue for this initializer, so create one from the TensorProto. - // This should only happen for small initializers that are needed for ONNX shape inferencing. - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), parent_graph->ModelPath(), - *outer_initializer, initializer_allocator, - *outer_initializer_value)); + // Copy to OrtValue if not external. This should only happen for small initializers. + // Do nothing for external initializers. Will load/mmap into an OrtValue on demand. + if (!utils::HasExternalDataInFile(*outer_initializer)) { + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), parent_graph->ModelPath(), + *outer_initializer, initializer_allocator, + *outer_initializer_value)); + } } outer_scope_initializer_values.emplace(outer_value_info->GetName(), std::move(outer_initializer_value)); } @@ -808,20 +834,40 @@ const EpNode* EpGraph::GetNode(NodeIndex node_index) const { return index_to_ep_node_.GetEpNode(node_index); } -const OrtValue* EpGraph::GetInitializerValue(std::string_view name) const { +Status EpGraph::GetInitializerValue(std::string_view name, const OrtValue*& result) const { + auto ensure_ort_value_loaded = [&](const std::unique_ptr& ort_value) -> Status { + if (!ort_value->IsAllocated()) { + // Lazy load the OrtValue. This happens for external initializers. + const Graph& graph = graph_viewer_.GetGraph(); + ORT_RETURN_IF_ERROR(graph.LoadExternalInitializerAsOrtValue(std::string(name), + const_cast(*ort_value))); + } + + return Status::OK(); + }; + // Check for initializer value in the graph's scope. if (auto iter = initializer_values_.find(name); iter != initializer_values_.end()) { - return iter->second.get(); + const std::unique_ptr& ort_value = iter->second; + ORT_RETURN_IF_ERROR(ensure_ort_value_loaded(ort_value)); + + result = ort_value.get(); + return Status::OK(); } // Check for the initializer value in an outer scope. // Only finds a value if the outer initializer value is used within this graph. if (auto iter = outer_scope_initializer_values_.find(name); iter != outer_scope_initializer_values_.end()) { - return iter->second.get(); + const std::unique_ptr& ort_value = iter->second; + ORT_RETURN_IF_ERROR(ensure_ort_value_loaded(ort_value)); + + result = ort_value.get(); + return Status::OK(); } - return nullptr; + result = nullptr; + return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index b9a494364a12e..243bdc2944ffb 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -65,6 +65,10 @@ struct EpValueInfo : public OrtValueInfo { // represent an initializer (either constant or non-constant). Status GetInitializerValue(const OrtValue*& value) const override; + // Gets external initializer information (file path, file offset, byte size) if this is an external initializer. + // Otherwise, sets the output parameter `ext_info` to nullptr (not an error). + Status GetExternalInitializerInfo(std::unique_ptr& ext_info) const override; + // Check if this value is a required graph input. Status IsRequiredGraphInput(bool& is_required_graph_input) const override; @@ -351,7 +355,7 @@ struct EpGraph : public OrtGraph { // Considers both constant and non-constant initializers. // Supports initializers defined in an outer scope as long as that initializer is used // within this graph. - const OrtValue* GetInitializerValue(std::string_view name) const; + Status GetInitializerValue(std::string_view name, const OrtValue*& value) const; private: /// diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index b929c27b21ec3..de6776b0e0df1 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3815,6 +3815,48 @@ bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value, boo return false; } +Status Graph::LoadExternalInitializerAsOrtValue(const std::string& name, OrtValue& value) const { + auto tensor_proto_iter = name_to_initial_tensor_.find(name); + ORT_RETURN_IF(tensor_proto_iter == name_to_initial_tensor_.end(), "Cannot load unknown initializer named '", + name, "'."); + const ONNX_NAMESPACE::TensorProto& tensor_proto = *tensor_proto_iter->second; + + // This only supports TensorProtos that currently have their external data in an actual file. + // This doesn't cache the new OrtValue because it would overwrite TensorProto.external_data and plugin EPs require + // every call to Graph::GetExternalInitializerInfo to return the actual external file info (path, offset, length). + ORT_RETURN_IF(!utils::HasExternalDataInFile(tensor_proto), "Initializer '", name, + "' does not have external data in a file."); + + // Create OrtValue that memory maps external initializer from file. + ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(Env::Default(), ModelPath(), tensor_proto, value)); + assert(value.IsAllocated()); + + return Status::OK(); +} + +bool Graph::GetExternalInitializerInfo(const std::string& name, std::unique_ptr& ext_info, + bool check_outer_scope) const { + // We want to make sure that the external data info is found on the same level as its tensor_proto + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (GetInitializedTensor(name, initializer)) { + if (utils::HasExternalDataInFile(*initializer)) { + std::unique_ptr external_data_info; + ORT_THROW_IF_ERROR(ExternalDataInfo::Create(initializer->external_data(), external_data_info)); + + ext_info = std::move(external_data_info); + return true; + } + } + + if (check_outer_scope && IsSubgraph()) { + if (IsOuterScopeValue(name)) { + // make sure there's not a local value with the same name. if there is it shadows any initializer in outer scope. + return parent_graph_->GetExternalInitializerInfo(name, ext_info, check_outer_scope); + } + } + return false; +} + void Graph::CleanAllInitializedTensors() noexcept { name_to_initial_tensor_.clear(); #if !defined(DISABLE_SPARSE_TENSORS) @@ -5202,7 +5244,7 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod // In the constant node, we won't have symbolic dims. const auto tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - const size_t size_in_bytes = SafeInt(ml_data->Size()) * tensor_shape.Size(); + const size_t size_in_bytes = Tensor::CalculateTensorStorageSize(ml_data, tensor_shape); if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) { OrtValue ort_value; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 07c7080d74c7c..5d84e48182bfe 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -49,6 +49,12 @@ struct ModelEditorValueInfo : public OrtValueInfo { "OrtModelEditorApi does not support getting the initializer value for a OrtValueInfo"); } + Status GetExternalInitializerInfo(std::unique_ptr& /*ext_info*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the external initializer information ", + "for a OrtValueInfo"); + } + Status IsRequiredGraphInput(bool& /*is_required_graph_input*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support querying if a graph input is required for OrtValueInfo"); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 6ada5df5976df..f6b6335dd29c0 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2514,6 +2514,35 @@ ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetInitializerValue, _In_ const OrtValueI API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetExternalInitializerInfo, _In_ const OrtValueInfo* value_info, + _Outptr_result_maybenull_ OrtExternalInitializerInfo** info) { + API_IMPL_BEGIN + std::unique_ptr ext_data_info = nullptr; + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->GetExternalInitializerInfo(ext_data_info)); + + // Note: ext_data_info can be nullptr if this OrtValueInfo does not represent an external initializer. + // std::unique_ptr::release() handles both cases. + *info = static_cast(ext_data_info.release()); + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExternalInitializerInfo* info) { + delete static_cast(info); +} + +ORT_API(const ORTCHAR_T*, OrtApis::ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info) { + return info->GetRelPath().c_str(); +} + +ORT_API(int64_t, OrtApis::ExternalInitializerInfo_GetFileOffset, _In_ const OrtExternalInitializerInfo* info) { + return static_cast(info->GetOffset()); +} + +ORT_API(size_t, OrtApis::ExternalInitializerInfo_GetByteSize, _In_ const OrtExternalInitializerInfo* info) { + return info->GetLength(); +} + ORT_API_STATUS_IMPL(OrtApis::ValueInfo_IsRequiredGraphInput, _In_ const OrtValueInfo* value_info, _Out_ bool* is_required_graph_input) { API_IMPL_BEGIN @@ -3966,6 +3995,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ValueInfo_GetValueNumConsumers, &OrtApis::ValueInfo_GetValueConsumers, &OrtApis::ValueInfo_GetInitializerValue, + &OrtApis::ValueInfo_GetExternalInitializerInfo, &OrtApis::ValueInfo_IsRequiredGraphInput, &OrtApis::ValueInfo_IsOptionalGraphInput, &OrtApis::ValueInfo_IsGraphOutput, @@ -4006,6 +4036,10 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetSubgraphs, &OrtApis::Node_GetGraph, &OrtApis::Node_GetEpName, + &OrtApis::ReleaseExternalInitializerInfo, + &OrtApis::ExternalInitializerInfo_GetFilePath, + &OrtApis::ExternalInitializerInfo_GetFileOffset, + &OrtApis::ExternalInitializerInfo_GetByteSize, &OrtApis::GetRunConfigEntry, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 772de5e312ffb..d2f22397bf82c 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -621,6 +621,8 @@ ORT_API_STATUS_IMPL(ValueInfo_GetValueConsumers, _In_ const OrtValueInfo* value_ _In_ size_t num_consumers); ORT_API_STATUS_IMPL(ValueInfo_GetInitializerValue, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtValue** initializer_value); +ORT_API_STATUS_IMPL(ValueInfo_GetExternalInitializerInfo, _In_ const OrtValueInfo* value_info, + _Outptr_result_maybenull_ OrtExternalInitializerInfo** info); ORT_API_STATUS_IMPL(ValueInfo_IsRequiredGraphInput, _In_ const OrtValueInfo* value_info, _Out_ bool* is_required_graph_input); ORT_API_STATUS_IMPL(ValueInfo_IsOptionalGraphInput, _In_ const OrtValueInfo* value_info, @@ -686,6 +688,12 @@ ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); +// OrtExternalInitializerInfo +ORT_API(void, ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExternalInitializerInfo* info); +ORT_API(const ORTCHAR_T*, ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info); +ORT_API(int64_t, ExternalInitializerInfo_GetFileOffset, _In_ const OrtExternalInitializerInfo* info); +ORT_API(size_t, ExternalInitializerInfo_GetByteSize, _In_ const OrtExternalInitializerInfo* info); + ORT_API(const char*, GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key); ORT_API(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device, diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 2a2d69a1c2e47..d0f682491e4f9 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -87,6 +87,105 @@ TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +// Check correctness of an OrtGraph that has external initializers. +TEST(EpGraphTest, CheckModelExternalInitializers) { + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_qdq_external_ini.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + +static void RunConvQDQExtIni(const ORTCHAR_T* model_path, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {1, 3, 24, 24}; + std::vector input_data(3 * 24 * 24, 0.5f); + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'input' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); + ort_input_names.push_back("input"); + + // Run session and get outputs + std::array output_names{"output"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 32 * 26 * 26); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + +// Test serializing an OrtGraph with external initializers to GraphProto. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_InputModelHasExternalIni) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/conv_qdq_external_ini.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("conv_qdq_ext_ini_serialized.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to GraphProto. Save initializers to external file. + std::string ext_ini_file_path = "conv_qdq_ext_ini_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::ModelProto model_proto; + OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, handle_initializer_data); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + RunConvQDQExtIni(original_model_path, output_original); + RunConvQDQExtIni(serialized_model_path, output_serialized); + + EXPECT_EQ(output_serialized, output_original); +} + static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -442,17 +541,40 @@ static void CheckValueInfoConsumers(const GraphViewer& graph_viewer, const OrtVa } static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, - const ONNX_NAMESPACE::TensorProto* tensor_proto) { + const ONNX_NAMESPACE::TensorProto* tensor_proto, + const GraphViewer& graph_viewer) { const OrtApi& ort_api = Ort::GetApi(); - const OrtValue* api_initializer_value = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_value_info, &api_initializer_value)); - ASSERT_NE(api_initializer_value, nullptr); - const char* api_initializer_name = nullptr; ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_value_info, &api_initializer_name)); ASSERT_NE(api_initializer_name, nullptr); + // Check external initializer info (if any). + OrtExternalInitializerInfo* api_ext_info = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetExternalInitializerInfo(api_value_info, &api_ext_info)); + DeferOrtRelease defer_release_info(&api_ext_info, ort_api.ReleaseExternalInitializerInfo); + + std::unique_ptr ext_info = nullptr; + bool has_ext_info = graph_viewer.GetGraph().GetExternalInitializerInfo(api_initializer_name, ext_info, true); + + if (has_ext_info) { + ASSERT_NE(api_ext_info, nullptr); + const ORTCHAR_T* api_ext_file_path = ort_api.ExternalInitializerInfo_GetFilePath(api_ext_info); + int64_t api_ext_file_offset = ort_api.ExternalInitializerInfo_GetFileOffset(api_ext_info); + size_t api_ext_byte_size = ort_api.ExternalInitializerInfo_GetByteSize(api_ext_info); + + ASSERT_EQ(PathString(api_ext_file_path), ext_info->GetRelPath()); + ASSERT_EQ(api_ext_file_offset, static_cast(ext_info->GetOffset())); + ASSERT_EQ(api_ext_byte_size, ext_info->GetLength()); + } else { + ASSERT_EQ(api_ext_info, nullptr); + ASSERT_FALSE(utils::HasExternalDataInFile(*tensor_proto)); + } + + const OrtValue* api_initializer_value = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_value_info, &api_initializer_value)); + ASSERT_NE(api_initializer_value, nullptr); + // Check initializer type. const ONNX_NAMESPACE::TypeProto type_proto = utils::TypeProtoFromTensorProto(*tensor_proto); auto type_info = OrtTypeInfo::FromTypeProto(type_proto); @@ -463,7 +585,8 @@ static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, } static void CheckInitializerValueInfosCApi(gsl::span initializer_value_infos, - const InitializedTensorSet& initializer_tensor_protos) { + const InitializedTensorSet& initializer_tensor_protos, + const GraphViewer& graph_viewer) { const OrtApi& ort_api = Ort::GetApi(); for (size_t i = 0; i < initializer_value_infos.size(); i++) { @@ -479,7 +602,7 @@ static void CheckInitializerValueInfosCApi(gsl::span const ONNX_NAMESPACE::TensorProto* tensor_proto = tensor_proto_iter->second; ASSERT_NE(tensor_proto, nullptr); - CheckInitializerValueInfo(api_value_info, tensor_proto); + CheckInitializerValueInfo(api_value_info, tensor_proto, graph_viewer); } } @@ -543,7 +666,7 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::spanTypeAsProto()); const OrtTypeInfo* api_type_info = nullptr; @@ -643,7 +766,7 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ std::vector api_initializers(api_num_initializers); ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInitializers(&api_graph, api_initializers.data(), api_initializers.size())); - CheckInitializerValueInfosCApi(api_initializers, graph_initializers); + CheckInitializerValueInfosCApi(api_initializers, graph_initializers, graph_viewer); // Check if it has a parent node. const Node* parent_node = graph_viewer.ParentNode(); diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.h b/onnxruntime/test/ep_graph/test_ep_graph_utils.h index 2ce107cf734c6..2aebd75e0aaac 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.h +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.h @@ -42,6 +42,30 @@ struct NodeArgConsumer { int64_t input_index = -1; }; +// Helper to release Ort one or more objects obtained from the public C API at the end of their scope. +template +struct DeferOrtRelease { + DeferOrtRelease(T** object_ptr, std::function release_func) + : objects_(object_ptr), count_(1), release_func_(release_func) {} + + DeferOrtRelease(T** objects, size_t count, std::function release_func) + : objects_(objects), count_(count), release_func_(release_func) {} + + ~DeferOrtRelease() { + if (objects_ != nullptr && count_ > 0) { + for (size_t i = 0; i < count_; ++i) { + if (objects_[i] != nullptr) { + release_func_(objects_[i]); + objects_[i] = nullptr; + } + } + } + } + T** objects_ = nullptr; + size_t count_ = 0; + std::function release_func_ = nullptr; +}; + // Returns consumers (i.e., consumer node + input index) of a NodeArg from the original graph. Status GetNodeArgConsumers(const GraphViewer& graph_viewer, const NodeArg& node_arg, /*out*/ std::vector& consumers); From e0ad8050d04c55318d1336874cc6d640509ddbc0 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 24 Jul 2025 16:52:46 -0700 Subject: [PATCH 06/33] Qnn license file update (#25158) ### Description Use the license file from QNN SDK to make sure it's up to date. --------- Co-authored-by: adrianlizarraga --- cmake/onnxruntime_providers_qnn.cmake | 8 ++++---- cmake/onnxruntime_python.cmake | 6 ++---- setup.py | 2 +- .../templates/jobs/download_linux_qnn_sdk.yml | 4 ---- .../templates/jobs/download_win_qnn_sdk.yml | 4 ---- tools/nuget/generate_nuspec_for_native_nuget.py | 4 ++-- 6 files changed, 9 insertions(+), 19 deletions(-) diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake index 748e3de843bab..f499c83d5f6c0 100644 --- a/cmake/onnxruntime_providers_qnn.cmake +++ b/cmake/onnxruntime_providers_qnn.cmake @@ -66,10 +66,10 @@ COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} $ ) endif() - if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") + if (EXISTS "${onnxruntime_QNN_HOME}/LICENSE.pdf") add_custom_command( TARGET ${onnxruntime_providers_qnn_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf" $ + COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/LICENSE.pdf" $/Qualcomm_LICENSE.pdf ) endif() else() @@ -154,10 +154,10 @@ COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} $ ) endif() - if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") + if (EXISTS "${onnxruntime_QNN_HOME}/LICENSE.pdf") add_custom_command( TARGET ${onnxruntime_providers_qnn_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf" $ + COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/LICENSE.pdf" $/Qualcomm_LICENSE.pdf ) endif() endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index b177074a1bc02..c5c85dff96411 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -1068,12 +1068,10 @@ if (onnxruntime_USE_QNN) ${QNN_LIB_FILES} $/onnxruntime/capi/ ) - if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") + if (EXISTS "${onnxruntime_QNN_HOME}/LICENSE.pdf") add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy - "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf" - $/onnxruntime/ + COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/LICENSE.pdf" $/onnxruntime/Qualcomm_LICENSE.pdf ) endif() endif() diff --git a/setup.py b/setup.py index 1893e18b8aab6..5ab1ac5b840d4 100644 --- a/setup.py +++ b/setup.py @@ -478,7 +478,7 @@ def finalize_options(self): examples = [path.join("datasets", x) for x in examples_names] # Extra files such as EULA and ThirdPartyNotices (and Qualcomm License, only for QNN release packages) -extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md", "Qualcomm AI Hub Proprietary License.pdf"] +extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md", "Qualcomm_LICENSE.pdf"] # Description readme_file = "docs/python/ReadMeOV.rst" if is_openvino else "docs/python/README.rst" diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 930dc83b73460..57703239fc594 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -39,10 +39,6 @@ steps: fi displayName: "Sanity Check: QnnSDKVersion vs sdk.yaml version" - - script: | - azcopy cp --recursive 'https://lotusscus.blob.core.windows.net/models/qnnsdk/Qualcomm AI Hub Proprietary License.pdf' $(QnnSDKRootDir) - displayName: 'Download Qualcomm AI Hub license' - - script: | ls -al $(QnnSDKRootDir) displayName: 'Print contents of QNN SDK' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 96eea6cd6d2fb..d2e401f3f6ab4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -18,10 +18,6 @@ steps: echo $(QnnSDKRootDir) displayName: 'Print QnnSDKRootDir after downloading QNN SDK' - - powershell: | - azcopy.exe cp --recursive 'https://lotusscus.blob.core.windows.net/models/qnnsdk/Qualcomm AI Hub Proprietary License.pdf' $(QnnSDKRootDir) - displayName: 'Download Qualcomm AI Hub license' - - task: CmdLine@2 displayName: 'Print contents of QNN SDK' inputs: diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index c5a204b6cb958..211cb7a2a8a75 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -1081,8 +1081,8 @@ def generate_files(line_list, args): files_list.append( "' + + os.path.join(args.native_build_path, "Qualcomm_LICENSE.pdf") + + '" target="Qualcomm_LICENSE.pdf" />' ) files_list.append("") From 8152168b8aa251c07a25e4ebb0797bed3f9d5864 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 24 Jul 2025 20:45:21 -0700 Subject: [PATCH 07/33] Update .config/1espt/PipelineAutobaseliningConfig.yml (#25450) --- .../1espt/PipelineAutobaseliningConfig.yml | 21 ++++++- .config/guardian/.gdnbaselines | 60 +++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/.config/1espt/PipelineAutobaseliningConfig.yml b/.config/1espt/PipelineAutobaseliningConfig.yml index 18315e55e854d..68bf98b3a2781 100644 --- a/.config/1espt/PipelineAutobaseliningConfig.yml +++ b/.config/1espt/PipelineAutobaseliningConfig.yml @@ -151,13 +151,16 @@ pipelines: lastModifiedDate: 2025-04-25 armory: lastModifiedDate: 2025-04-25 + policheck: + lastModifiedDate: 2025-07-23 binary: credscan: lastModifiedDate: 2025-04-25 binskim: - lastModifiedDate: 2025-04-25 + lastModifiedDate: 2025-07-24 spotbugs: lastModifiedDate: 2025-04-25 + usedBinskimScanAllExtensions: true 1234: retail: source: @@ -169,10 +172,24 @@ pipelines: lastModifiedDate: 2025-04-25 armory: lastModifiedDate: 2025-04-25 + policheck: + lastModifiedDate: 2025-07-23 binary: credscan: lastModifiedDate: 2025-04-25 binskim: - lastModifiedDate: 2025-04-25 + lastModifiedDate: 2025-07-24 spotbugs: lastModifiedDate: 2025-04-25 + usedBinskimScanAllExtensions: true + 1311: + retail: + source: + credscan: + lastModifiedDate: 2025-07-18 + eslint: + lastModifiedDate: 2025-07-18 + psscriptanalyzer: + lastModifiedDate: 2025-07-18 + armory: + lastModifiedDate: 2025-07-18 diff --git a/.config/guardian/.gdnbaselines b/.config/guardian/.gdnbaselines index 7246ad6ba36df..18a9250059134 100644 --- a/.config/guardian/.gdnbaselines +++ b/.config/guardian/.gdnbaselines @@ -409,6 +409,66 @@ "createdDate": "2025-04-25 22:25:55Z", "expirationDate": "2025-10-12 23:01:19Z", "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 23:01:19Z" + }, + "67acaef0adebeee9ddcb2ff2630fa3051c0c8e7083f36f64ac0040d9a22b73b5": { + "signature": "67acaef0adebeee9ddcb2ff2630fa3051c0c8e7083f36f64ac0040d9a22b73b5", + "alternativeSignatures": [ + "2f5e8344c6d8ffa32a8a54c363d0c480380320a6c0a3fd3e4ca1ff2aafe6dbcf" + ], + "target": "file:///E:/_work/_temp/RelWithDebInfo/RelWithDebInfo/dxcompiler.dll", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-07-24 10:13:44Z", + "expirationDate": "2026-01-10 11:03:50Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-07-24 11:03:50Z" + }, + "7a02b29f8870bfd4cc770281dd860421523a3aac51ea96332e25696ca1f5570e": { + "signature": "7a02b29f8870bfd4cc770281dd860421523a3aac51ea96332e25696ca1f5570e", + "alternativeSignatures": [ + "1229412e0db78558feac3bc51ea9eed6ae2311e60298dc1f2d3366bd12544c88" + ], + "target": "file:///E:/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnxruntime_perf_test.exe", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-07-24 10:13:44Z", + "expirationDate": "2026-01-10 11:03:50Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-07-24 11:03:50Z" + }, + "68b657f6d9dd9386bd43f0716cb424c196f7da4559eb1c3f3f26a1297b211239": { + "signature": "68b657f6d9dd9386bd43f0716cb424c196f7da4559eb1c3f3f26a1297b211239", + "alternativeSignatures": [ + "e026012915cda24b9e85a1d1fa38607d09effa532b40a1c0f0740eb3855f9599" + ], + "target": "file:///E:/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnxruntime_test_all.exe", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-07-24 10:13:44Z", + "expirationDate": "2026-01-10 11:03:50Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-07-24 11:03:50Z" + }, + "8222fbb019a791fa0541084854a0bf7bf723ce7ffaa4a0e1e5ca5cb76acb48bb": { + "signature": "8222fbb019a791fa0541084854a0bf7bf723ce7ffaa4a0e1e5ca5cb76acb48bb", + "alternativeSignatures": [ + "5e0bdc06af73864bdb480aceaf154a35e0774ab7f8490e7d9f8b5a36b7c19619" + ], + "target": "file:///E:/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnx_test_runner.exe", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-07-24 10:13:44Z", + "expirationDate": "2026-01-10 11:03:50Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-07-24 11:03:50Z" } } } \ No newline at end of file From 2a0c36abc98b654fb391df1235d9cb615884d6b3 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 25 Jul 2025 17:41:55 +1000 Subject: [PATCH 08/33] Add arena that uses EP API so that an EP library can be self-sufficient. (#25465) ### Description Add arena that uses EP API so that an EP library can be self-sufficient. Remove cross stream sharing from BFCArena. Nothing is using it and it creates a dependency on synchronizing streams inside the arena implementation. Tried to simplify the Stream/Notification usage. Current setup adds an AllocOnStream to OrtAllocator. There's no stream aware Free at this point as ORT does not attach the Stream to the memory usage so can't pass it in to the Free call. ### Motivation and Context If ORT adds BFCArena to an OrtAllocator from the EP we have OrtAllocator -> IAllocator wrapper -> BFCArena IAllocator [-> OrtAllocator wrapper for external usage]. The EP managing its own arena is much simpler. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../onnxruntime/core/framework/allocator.h | 46 +- .../core/framework/stream_handles.h | 134 ++- .../core/session/onnxruntime_c_api.h | 43 +- .../core/session/onnxruntime_ep_c_api.h | 35 + .../bert/cudnn_fmha/cudnn_flash_attention.cc | 3 +- .../transformers/generation_device_helper.cc | 2 +- onnxruntime/core/framework/allocator.cc | 19 +- onnxruntime/core/framework/allocator_utils.cc | 1 - onnxruntime/core/framework/allocator_utils.h | 7 +- onnxruntime/core/framework/bfc_arena.cc | 100 +-- onnxruntime/core/framework/bfc_arena.h | 33 +- onnxruntime/core/framework/execution_frame.cc | 41 +- onnxruntime/core/framework/execution_steps.cc | 16 +- onnxruntime/core/framework/execution_steps.h | 2 +- onnxruntime/core/framework/plugin_ep_stream.h | 4 +- onnxruntime/core/framework/stream_handles.cc | 20 + .../tensor_allocator_with_mem_pattern.h | 10 +- onnxruntime/core/framework/utils.cc | 35 +- .../providers/cann/cann_execution_provider.cc | 3 +- .../core/providers/cann/cann_stream_handle.cc | 2 +- .../core/providers/cann/cann_stream_handle.h | 2 - .../providers/cuda/cuda_execution_provider.cc | 4 +- onnxruntime/core/providers/cuda/cuda_kernel.h | 8 +- .../providers/cuda/cuda_provider_factory.cc | 6 +- .../core/providers/cuda/cuda_stream_handle.cc | 2 +- .../core/providers/cuda/cuda_stream_handle.h | 2 - .../providers/cuda/reduction/reduction_ops.cc | 43 +- .../cuda/tunable/cuda_tuning_context.cc | 2 +- .../migraphx/migraphx_execution_provider.cc | 4 +- .../migraphx/migraphx_stream_handle.cc | 5 +- .../migraphx/migraphx_stream_handle.h | 2 - .../core/session/allocator_adapters.cc | 52 +- onnxruntime/core/session/allocator_adapters.h | 4 + onnxruntime/core/session/custom_ops.cc | 2 +- onnxruntime/core/session/environment.cc | 52 +- onnxruntime/core/session/ep_api.cc | 29 + onnxruntime/core/session/ep_api.h | 5 + .../session/ep_plugin_provider_interfaces.cc | 27 +- onnxruntime/core/session/inference_session.cc | 2 +- .../test/autoep/library/ep_allocator.h | 40 +- onnxruntime/test/autoep/library/ep_arena.cc | 778 ++++++++++++++++++ onnxruntime/test/autoep/library/ep_arena.h | 629 ++++++++++++++ onnxruntime/test/autoep/library/ep_factory.cc | 58 +- onnxruntime/test/autoep/library/ep_factory.h | 13 + .../test/autoep/library/ep_stream_support.cc | 10 +- .../test/autoep/library/ep_stream_support.h | 8 +- .../autoep/library/example_plugin_ep_utils.h | 41 + onnxruntime/test/autoep/test_allocators.cc | 13 +- onnxruntime/test/framework/bfc_arena_test.cc | 101 +-- .../test/shared_lib/test_model_builder_api.cc | 3 + 50 files changed, 2092 insertions(+), 411 deletions(-) create mode 100644 onnxruntime/core/framework/stream_handles.cc create mode 100644 onnxruntime/test/autoep/library/ep_arena.cc create mode 100644 onnxruntime/test/autoep/library/ep_arena.h diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 609386fd1f081..24cc460a17fa9 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -86,8 +86,11 @@ class Stream; namespace synchronize { class Notification; } + using WaitNotificationFn = std::function; -void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn); +void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream, + // wait fn is for backwards compat with provider bridge + WaitNotificationFn ignored = nullptr); template using IAllocatorUniquePtr = std::unique_ptr>; @@ -105,6 +108,43 @@ class IAllocator { */ virtual void* Alloc(size_t size) = 0; + /** Return true if the allocator implements Stream handling in AllocOnStream. + */ + virtual bool IsStreamAware() const { return false; } + + /** Allocate memory, handling usage across different Streams + * + * A device Stream may be available when executing a model on non-CPU devices. In this case operations are queued + * asynchronously and the allocation/free call is made when the operation is queued rather than executed. + * Due to this it is not safe to use the memory on another stream or with no stream unless synchronization has + * occurred. + * + * ORT currently handles the synchronization when executing the model using streams. + * + * When two streams are synchronized the event used is identified by the producer stream's latest sync id. + * This pair is copied into the sync information of the consumer stream. + * Each new event creates a new sync id. + * + * It is safe to re-use an allocated piece of memory if: + * - the stream that currently owns the memory and the stream that wants to use the memory have been synchronized, + * - and the sync id from when the memory was assigned to the stream that currently owns it is less than the + * sync id from the last synchronization between the two streams. + * - e.g. stream0 is assigned the memory when its sync id is 1. + * stream0 (producer) and stream1 (consumer) are synchronized. + * stream0 sync id will be incremented to 2 when creating the event used in the synchronization. + * stream1 will copy this information into its sync info and now contains an entry for stream0 + * with a sync id of 2. + * stream0 frees the memory + * the memory is marked as not in use, but still assigned to stream0 + * stream1 is now able to use the memory as it is not in use, and the sync id from the allocation (1) + * is less than the sync id (2) that is has for stream0. + * or + * - the inference session that owned the Stream has completed inferencing + * - Stream::CleanUpOnRunEnd is called when this occurs + * - any memory assigned to the Stream is now able to be used by other streams when it is not longer in use. + */ + virtual void* AllocOnStream(size_t size, Stream* /*stream*/) { return Alloc(size); } + /** * Free memory at p. * If p is nullptr, do nothing. @@ -192,7 +232,7 @@ class IAllocator { template static IAllocatorUniquePtr MakeUniquePtr(std::shared_ptr allocator, size_t count_or_bytes, bool use_reserve = false, - Stream* stream = nullptr, WaitNotificationFn wait_fn = nullptr) { + Stream* stream = nullptr) { ValidateAllocator(allocator); // for now limit to fundamental types. we could support others, but to do so either we or the caller @@ -210,7 +250,7 @@ class IAllocator { } // allocate - T* p = static_cast(AllocateBufferWithOptions(*allocator, alloc_size, use_reserve, stream, std::move(wait_fn))); + T* p = static_cast(AllocateBufferWithOptions(*allocator, alloc_size, use_reserve, stream, nullptr)); ValidateAllocation(p, alloc_size); return IAllocatorUniquePtr{p, diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index 441e3ebda1502..7d27c7471d71f 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -2,9 +2,11 @@ // Licensed under the MIT License. #pragma once +#include #include #include #include + #include "core/framework/allocator.h" #include "core/framework/ortdevice.h" #include "core/common/status.h" @@ -21,9 +23,9 @@ namespace synchronize { class Notification; } -// a stream abstraction which hold an opaque handle, and a reference to which OrtDevice instance this stream belong to. -// it need to be OrtDevice instance as we might have different stream on different OrtDevice with same type. -// i.e. different cuda stream on different GPU. +/// +/// Class to represent a stream on the OrtDevice. +/// class Stream { public: Stream(StreamHandle h, const OrtDevice& d) @@ -31,123 +33,113 @@ class Stream { } virtual ~Stream() = default; + virtual std::unique_ptr CreateNotification(size_t /*num_consumers*/) { return {}; }; + // block the host thread until all the tasks in the stream finished. virtual void Flush() {}; + // The framework may reuse the stream instance for multiple iterations. // This is the API that provide a chance to let the device stream cleanup // resource at the end of a iteration. virtual Status CleanUpOnRunEnd() { return Status::OK(); }; + // Get the native stream handle. nullptr if the OrtDevice doesn't support streams. StreamHandle GetHandle() const { return handle_; } const OrtDevice& GetDevice() const { return device_; } - // We use the timestamp based vector clocks to optimize the resource sharing - // between different streams. - // Each stream maintain following data structure: - // 1. Current timestamp - // 2. A lookup table that for a given stream, what is its timestamp when the - // last synchronization happened with current stream. - // 3. When a notification is activated, it take a snapshot of current stream's - // lookup table. - // 4. When synchronization happened (current stream wait on a notification), - // update its lookup table with the table snapshot in notification. - // The memory reusing strategy is: - // A kernel in current stream is safe to reuse another stream's memory chunk - // as long as the reused chunk's timestamp is less than the last synchronized - // timestamp recorded in the lookup table. - - // Get the current timestamp - uint64_t GetCurrentTimestamp() const { return timestamp_; } - - // return the timestamp when the last synchronization happened between target stream and current stream. - // return 0 if no synchronization happened. - // if target_stream is nullptr, it means it is a sequence running on device doesn't support Stream (i.e. CPU) - // we can safely return 0 in that case to save a lookup. - uint64_t GetLastSyncTimestampWithTargetStream(Stream* target_stream) const { - if (!target_stream) - return 0; - auto it = other_stream_clock_.find(target_stream); - return it == other_stream_clock_.end() ? 0 : it->second; + // Get the current synchronization ID. + // The value is 0 until this stream records an event. + // The sync id is incremented before each new event that is recorded in our stream via Notification::Activate. + uint64_t GetSyncId() const { return sync_id_; } + + // Return the sync id from when the last synchronization happened between producer_stream and this stream. + // i.e. the id for the event that the producer stream recorded and we waited on + // + // Returns 0 if the streams have not previously been synchronized. + uint64_t GetSyncIdForLastWaitOnStream(const Stream& producer_stream) const { + auto it = producer_stream_sync_info_.find(&producer_stream); + return it == producer_stream_sync_info_.end() ? 0 : it->second; } - // make a copy of the current stream lookup table. - // this is used to create a snapshot of the stream lookup table in notification. - void CloneCurrentStreamSyncTable(std::unordered_map& output) const { - output.reserve(other_stream_clock_.size()); - output.insert(other_stream_clock_.begin(), other_stream_clock_.end()); - } + // Get the sync information that is added to a notification when it is activated. + // This contains sync ids for all streams we have waited on, and the new sync id for our stream. + std::unordered_map OnNotificationActivation() { + // copy our sync info so the notification can pass it on to any waiting streams + auto sync_info = producer_stream_sync_info_; + // and add our info to the copy, incrementing the sync_id + sync_info[this] = ++sync_id_; - // bump the current timestamp - // When a notification get activated, bump the snapshot in its owner. - // Stream is not shared across threads, BumpTimeStampAndReturn will only be invoked on the current thread - // where the stream is executed on, so there is no race condition. - uint64_t BumpTimeStampAndReturn() { - return ++timestamp_; + return sync_info; } - // update the stream lookup table with the snapshot saved in notification. - void UpdateStreamClock(const std::unordered_map& clock) { - for (const auto& kv : clock) { - auto ret = other_stream_clock_.insert(kv); - if (!ret.second) { - ret.first->second = std::max(ret.first->second, kv.second); - } - } - } + // Record information from a Notification we waited on. + // - copies the producer stream info from the notification. + void UpdateWithAwaitedNotification(const synchronize::Notification& notification); + // used in custom ops. doesn't really belong here. virtual void* GetResource(int /*version*/, int /*id*/) const { return nullptr; } - virtual WaitNotificationFn GetWaitNotificationFn() const { return nullptr; } - private: StreamHandle handle_; const OrtDevice& device_; - uint64_t timestamp_{0}; + + // current sync id. equivalent to a counter for the number of events we have recorded in the underlying stream. + // 0 == no events recorded. sync_id_ is updated prior to recording a new event. + std::atomic sync_id_{0}; + + // This is a map to track synchronization points between streams. When we wait on another stream (the producer) + // we add an entry to the map for that stream. + // The entry has the sync id from the producer stream for the event we waited on. + // // TODO: use inline container. // currently this class is header only, but abseil doesn't compile with nvcc // we need to add new symbol to provider_bridge and hide abseil from the header. - std::unordered_map other_stream_clock_{}; + std::unordered_map producer_stream_sync_info_{}; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Stream); }; namespace synchronize { -// An abstraction used for synchronization between streams. See its concrete subclass (CudaNotification, etc.) how the activate -// and wait works for a specific stream +// An abstraction used for synchronization between streams. +// See derived classes (CudaNotification, etc.) for implementation examples. class Notification { public: explicit Notification(Stream& s) : stream_(s) {} virtual ~Notification() = default; - // this api will perform three operations: - // 1. activate the notification on device, for example, record an event on GPU. - // 2. take a snapshot of the timestamp lookup table in current stream. - // 3. bump the timestamp for current stream. + // Activate the notification. This records an event in the Stream that created the Notification that other streams can wait on. void ActivateAndUpdate() { Activate(); - stream_.CloneCurrentStreamSyncTable(stream_clock_); - stream_clock_[&stream_] = stream_.BumpTimeStampAndReturn(); + + // copy the sync info. this includes an entry for stream_ with the next sync id. + stream_sync_info_ = stream_.OnNotificationActivation(); } - // return the timestamp lookup table saved in the notification. - const std::unordered_map& GetStreamSyncTable() { - return stream_clock_; + // Get the sync history for the producer stream that created this Notification. + // The notification must have be activated previously. + const std::unordered_map& GetStreamSyncInfo() const { + return stream_sync_info_; } protected: virtual void Activate() = 0; - // which stream create this notification. + + Stream& GetStream() { + return stream_; + } + + private: + // Stream that created the notification (producer stream). Stream& stream_; - // TODO: use inline container. - // currently this class is header only, but abseil doesn't compile with nvcc - // we need to add new symbol to provider_bridge and hide abseil from the header. - std::unordered_map stream_clock_{}; + + // This is a snapshot of the sync history for the stream that created the Notification. + std::unordered_map stream_sync_info_{}; }; } // namespace synchronize diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d70806e1a5a87..a4cf17845a494 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -352,16 +352,22 @@ typedef struct OrtAllocator { /** * @brief Optional allocation function to use for memory allocations made during session initialization. * Use this function if you want to separate allocations made by ORT during Run() calls from - * those made during session initialization. This allows for separate memory management strategies for these allocations. + * those made during session initialization. This allows for separate memory management strategies for these + * allocations. + * + * \return pointer to an allocated block of `size` bytes. nullptr if size was 0 or allocation failed. + * + * \since 1.18 */ - void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes + void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); /** * @brief Function used to get the statistics of the allocator. * - * Return a pointer to the OrtKeyValuePairs structure that contains the statistics of the allocator - * and the user should call OrtApi::ReleaseKeyValuePairs. - * Supported keys are: + * Return a pointer to the OrtKeyValuePairs structure that contains the statistics of the allocator. + * The user should call OrtApi::ReleaseKeyValuePairs when done. + * + * Current known keys are: * - Limit: Bytes limit of the allocator. -1 if no limit is set. * - InUse: Number of bytes in use. * - TotalAllocated: The total number of allocated bytes by the allocator. @@ -372,9 +378,32 @@ typedef struct OrtAllocator { * - NumArenaShrinkages: Number of arena shrinkages (Relevant only for arena based allocators) * - MaxAllocSize: The max single allocation seen. * - * NOTE: If the allocator does not implement this function, the OrtKeyValuePairs instance will be empty. + * The allocator is free to add other entries as appropriate. + * + * \note Implementation of this function is optional and GetStats may be set to a nullptr. + * If the OrtAllocator is wrapping an internal ORT allocator that does not implement GetStats + * the returned OrtKeyValuePairs instance will be empty. + * + * \since 1.23 */ ORT_API2_STATUS(GetStats, _In_ const struct OrtAllocator* this_, _Outptr_ OrtKeyValuePairs** out); + + /** \brief Allocate using a stream. + * + * If the allocator is stream aware this performs allocation using a stream. + * + * Alloc will be used if this is nullptr. + * + * \param[in] this_ OrtAllocator instance + * \param[in] size Size of the allocation in bytes. nullptr if size was 0 or allocation failed. + * \param[in] stream The stream to allocate on. + * + * \return pointer to an allocated block of `size` bytes + * + * \note Implementation of this function is optional and AllocOnStream may be set to a nullptr. + * \since 1.23 + */ + void*(ORT_API_CALL* AllocOnStream)(struct OrtAllocator* this_, size_t size, OrtSyncStream* stream); } OrtAllocator; typedef void(ORT_API_CALL* OrtLoggingFunction)( @@ -6198,7 +6227,7 @@ struct OrtApi { * \param[in] env The OrtEnv instance to create the shared allocator in. * \param[in] ep_device The OrtEpDevice instance to create the shared allocator for. * \param[in] mem_type The memory type to use for the shared allocator. - * \param[in] allocator_type The type of allocator to create (e.g. OrtAllocatorType::OrtArenaAllocator). + * \param[in] allocator_type The type of allocator to create. Only OrtDeviceAllocator is valid currently. * \param[in] allocator_options Optional key-value pairs to configure the allocator. If arena based, see * include/onnxruntime/core/framework/allocator.h for the keys and values that can be * used. diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 1d9f9d00387ba..620cb5fcf13cc 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -430,6 +430,41 @@ struct OrtEpApi { * \since Version 1.23. */ ORT_API_T(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device); + + /** \brief Get the OrtSyncStreamImpl associated with an OrtSyncStream instance. + * + * This allows an the plugin library to connect its OrtSyncStreamImpl instance with an OrtSyncStream if needed. + * + * \param[in] stream The OrtSyncStream instance to find an OrtSyncStreamImpl for. + * \return The associated OrtSyncStreamImpl if found. nullptr otherwise. + * + * \since Version 1.23. + * + * \remarks There should always be an OrtSyncStreamImpl associated with an OrtSyncStream instance that the EP gets. + */ + ORT_API_T(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* stream); + + /** \brief Get the current sync ID for a stream. + * + * \param[in] stream The OrtSyncStream to get the sync ID for. + * \return Current sync ID. + * + * \since Version 1.23. + */ + ORT_API_T(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream); + + /** \brief Get the sync ID for the last time the consumer_stream waited on the producer_stream. + * + * When two streams are synchronized, the sync id represents the event used in that synchronization. + * + * \param[in] producer_stream The OrtSyncStream that produced the data. + * \param[in] consumer_stream The OrtSyncStream that waited on the producer_stream. + * \return ID for last sync. 0 if no sync has occurred between the two streams. + * + * \since Version 1.23. + */ + ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream, + _In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream); }; /** diff --git a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc index ec5deccf655ff..ba786931bb39a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc @@ -391,8 +391,7 @@ void run( // Allocate workspace. auto bytes = mha_graph->get_workspace_size(); - IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr( - allocator, bytes, false, stream, WaitCudaNotificationOnDevice); + IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); CUDNN_FE_CALL_THROW(mha_graph->execute(handle, variant_pack, buffer.get())); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index f7ed758aedbb2..d20d0b4218bd3 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -232,7 +232,7 @@ Status AddToFeeds(Stream* ort_stream, } } if (!buffer) { - buffer = IAllocator::MakeUniquePtr(device_allocator, total_bytes, false, ort_stream, WaitCudaNotificationOnDevice); + buffer = IAllocator::MakeUniquePtr(device_allocator, total_bytes, false, ort_stream); } char* gpu_data = buffer.get(); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, total_bytes, cudaMemcpyHostToDevice, stream)); diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index f089761e0643b..e1b9d1294fb9e 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -164,19 +164,16 @@ void CPUAllocator::Free(void* p) { AllocatorDefaultFreeAligned(p, alignment); } -void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn) { - if (use_reserve) +void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve, Stream* stream, + WaitNotificationFn /*ignored*/) { + if (use_reserve) { return alloc.Reserve(size); - if (stream && alloc.Info().alloc_type == OrtArenaAllocator) { -#ifdef ORT_ENABLE_STREAM - auto* stream_aware_alloc = StreamAwareArena::FromBFCArena(static_cast(alloc)); - if (stream_aware_alloc) { - return stream_aware_alloc->AllocOnStream(size, stream, wait_fn); - } -#else - ORT_UNUSED_PARAMETER(wait_fn); -#endif // ORT_ENABLE_STREAM } + + if (stream && alloc.IsStreamAware()) { + return alloc.AllocOnStream(size, stream); + } + return alloc.Alloc(size); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocator_utils.cc b/onnxruntime/core/framework/allocator_utils.cc index edf965d3835b5..8c4e74c4b1cc7 100644 --- a/onnxruntime/core/framework/allocator_utils.cc +++ b/onnxruntime/core/framework/allocator_utils.cc @@ -54,7 +54,6 @@ AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) { return AllocatorPtr( std::make_unique(std::move(device_allocator), max_mem, - info.enable_cross_stream_reusing, arena_extend_str, initial_chunk_size_bytes, max_dead_bytes_per_chunk, diff --git a/onnxruntime/core/framework/allocator_utils.h b/onnxruntime/core/framework/allocator_utils.h index bef0b7057a7f8..076d4dbcc16c5 100644 --- a/onnxruntime/core/framework/allocator_utils.h +++ b/onnxruntime/core/framework/allocator_utils.h @@ -19,14 +19,12 @@ struct AllocatorCreationInfo { OrtDevice::DeviceId device_id = 0, bool use_arena = true, OrtArenaCfg arena_cfg = {0, -1, -1, -1, -1, -1L}, - bool stream_aware_arena = false, - bool cross_stream_reusing = false) + bool stream_aware_arena = false) : device_alloc_factory(device_alloc_factory), device_id(device_id), use_arena(use_arena), arena_cfg(arena_cfg), - use_stream_aware_arena(stream_aware_arena), - enable_cross_stream_reusing(cross_stream_reusing) { + use_stream_aware_arena(stream_aware_arena) { } AllocatorFactory device_alloc_factory; @@ -34,7 +32,6 @@ struct AllocatorCreationInfo { bool use_arena; OrtArenaCfg arena_cfg; bool use_stream_aware_arena; - bool enable_cross_stream_reusing; }; // Returns an allocator (an instance of IAllocator) based on the creation info provided. diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index ed64769d13fcc..e0b50cd04173e 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -224,6 +224,7 @@ Status BFCArena::Extend(size_t rounded_bytes) { c->next = kInvalidChunkHandle; // assign the new created chunk to default stream, so it can be pick up by any stream c->stream = nullptr; + c->stream_sync_id = 0; region_manager_.set_handle(c->ptr, h); @@ -253,7 +254,7 @@ void BFCArena::DeallocateChunk(ChunkHandle h) { Chunk* c = ChunkFromHandle(h); // clean the stream / timestamp when deallocate chunk c->stream = nullptr; - c->stream_timestamp = 0; + c->stream_sync_id = 0; c->next = free_chunks_list_; free_chunks_list_ = h; } @@ -268,7 +269,7 @@ size_t BFCArena::RoundedBytes(size_t bytes) { } void* BFCArena::Alloc(size_t size) { - return AllocateRawInternal(size, false, nullptr, false, nullptr); + return AllocateRawInternal(size, false, nullptr); } void* BFCArena::Reserve(size_t size) { @@ -309,13 +310,11 @@ size_t BFCArena::AllocatedSize(const void* ptr) { void* BFCArena::AllocateRawInternal(size_t num_bytes, bool dump_log_on_failure, - Stream* stream, - bool enable_cross_stream_reusing, - WaitNotificationFn wait_fn) { + Stream* stream) { if (num_bytes == 0) { - LOGS_DEFAULT(VERBOSE) << "tried to allocate 0 bytes"; return nullptr; } + // First, always allocate memory of at least kMinAllocationSize // bytes, and always allocate multiples of kMinAllocationSize bytes // so all memory addresses are nicely byte aligned. @@ -326,20 +325,9 @@ void* BFCArena::AllocateRawInternal(size_t num_bytes, std::lock_guard lock(lock_); // search for a valid chunk - auto* chunk = FindChunkPtr(bin_num, - rounded_bytes, - num_bytes, - stream, - enable_cross_stream_reusing, - wait_fn); + auto* chunk = FindChunkPtr(bin_num, rounded_bytes, num_bytes, stream); if (chunk != nullptr) { - // if it is on default stream (the new allocate chunk), assign to current stream - if (chunk->stream == nullptr) { - chunk->stream = stream; - if (stream) - chunk->stream_timestamp = stream->GetCurrentTimestamp(); - } return chunk->ptr; } @@ -349,12 +337,8 @@ void* BFCArena::AllocateRawInternal(size_t num_bytes, // Try to extend auto status = Extend(rounded_bytes); if (status.IsOK()) { - chunk = FindChunkPtr(bin_num, rounded_bytes, num_bytes, stream, false); + chunk = FindChunkPtr(bin_num, rounded_bytes, num_bytes, stream); if (chunk != nullptr) { - // if it is on default stream (the new allocate chunk), assign to current stream - if (chunk->stream == nullptr && stream) { - chunk->stream = stream; - } return chunk->ptr; } else { status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, @@ -413,11 +397,8 @@ BFCArena::Chunk* BFCArena::SplitFreeChunkFromBin(BFCArena::Bin::FreeChunkSet* fr } BFCArena::Chunk* BFCArena::FindChunkPtr(BinNum bin_num, size_t rounded_bytes, - size_t num_bytes, Stream* stream, - bool allow_chunk_from_different_stream, - WaitNotificationFn wait_fn) { - BFCArena::Chunk* other_stream_candidate = nullptr; - // First identify the first bin that could satisfy rounded_bytes. + size_t num_bytes, Stream* stream) { + // First identify the first bin that could satisfy rounded_bytes. for (; bin_num < kNumBins; bin_num++) { // Start searching from the first bin for the smallest chunk that fits // rounded_bytes. @@ -427,29 +408,27 @@ BFCArena::Chunk* BFCArena::FindChunkPtr(BinNum bin_num, size_t rounded_bytes, BFCArena::Chunk* chunk = ChunkFromHandle(h); ORT_ENFORCE(!chunk->in_use()); if (chunk->size >= rounded_bytes) { - // We found an existing chunk that fits us that wasn't in use, now check the stream + // We found an existing chunk that fits us that wasn't in use, now check the stream. bool safe_to_use = chunk->stream == stream || !chunk->stream || (stream && chunk->stream && - chunk->stream_timestamp < stream->GetLastSyncTimestampWithTargetStream(chunk->stream)); + chunk->stream_sync_id < stream->GetSyncIdForLastWaitOnStream(*chunk->stream)); if (safe_to_use) { // the chunk with same stream has higher priority. - return SplitFreeChunkFromBin(&b->free_chunks, citer, rounded_bytes, num_bytes); - } else if (allow_chunk_from_different_stream && !other_stream_candidate) { - other_stream_candidate = chunk; + chunk = SplitFreeChunkFromBin(&b->free_chunks, citer, rounded_bytes, num_bytes); + + if (stream) { + chunk->stream = stream; + chunk->stream_sync_id = stream->GetSyncId(); + } + + return chunk; } } } } - // if trying to use an unsafe chunk from other streams, secure it. - if (other_stream_candidate) { - SecureTheChunk(other_stream_candidate->stream, stream, wait_fn); - // if find some available chunk, make sure mark it as "being used" before return - other_stream_candidate->allocation_id = next_allocation_id_++; - other_stream_candidate->bin_num = kInvalidBinNum; - } - return other_stream_candidate; + return nullptr; } void BFCArena::SplitChunk(BFCArena::ChunkHandle h, size_t num_bytes) { @@ -463,7 +442,7 @@ void BFCArena::SplitChunk(BFCArena::ChunkHandle h, size_t num_bytes) { BFCArena::Chunk* new_chunk = ChunkFromHandle(h_new_chunk); // set the new chunk's stream and timestamp new_chunk->stream = c->stream; - new_chunk->stream_timestamp = c->stream_timestamp; + new_chunk->stream_sync_id = c->stream_sync_id; new_chunk->ptr = static_cast(static_cast(c->ptr) + num_bytes); region_manager_.set_handle(new_chunk->ptr, h_new_chunk); @@ -608,7 +587,7 @@ void BFCArena::Merge(BFCArena::ChunkHandle h1, // Set the new size c1->size += c2->size; - c1->stream_timestamp = std::max(c1->stream_timestamp, c2->stream_timestamp); + c1->stream_sync_id = std::max(c1->stream_sync_id, c2->stream_sync_id); DeleteChunk(h2); } @@ -815,7 +794,7 @@ void BFCArena::ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_fla Chunk* c = ChunkFromHandle(h); if (c->stream == target_stream) { c->stream = nullptr; - c->stream_timestamp = 0; + c->stream_sync_id = 0; } h = c->next; } @@ -850,24 +829,23 @@ void BFCArena::ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_fla StreamAwareArena::StreamAwareArena(std::unique_ptr resource_allocator, size_t total_memory, - bool enable_cross_stream_sharing, ArenaExtendStrategy arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk, int initial_growth_chunk_size_bytes, - int64_t max_power_of_two_extend_bytes) : BFCArena(std::move(resource_allocator), - total_memory, - arena_extend_strategy, - initial_chunk_size_bytes, - max_dead_bytes_per_chunk, - initial_growth_chunk_size_bytes, - max_power_of_two_extend_bytes), - enable_cross_stream_reusing_(enable_cross_stream_sharing) { + int64_t max_power_of_two_extend_bytes) + : BFCArena(std::move(resource_allocator), + total_memory, + arena_extend_strategy, + initial_chunk_size_bytes, + max_dead_bytes_per_chunk, + initial_growth_chunk_size_bytes, + max_power_of_two_extend_bytes) { arena_type_ = ArenaType::StreamAwareArena; } -void* StreamAwareArena::AllocOnStream(size_t size, Stream* current_stream, WaitNotificationFn wait_fn) { - return AllocateRawInternal(size, false, current_stream, enable_cross_stream_reusing_, wait_fn); +void* StreamAwareArena::AllocOnStream(size_t size, Stream* current_stream) { + return AllocateRawInternal(size, false, current_stream); } void StreamAwareArena::ReleaseStreamBuffers(Stream* stream) { @@ -875,17 +853,5 @@ void StreamAwareArena::ReleaseStreamBuffers(Stream* stream) { ResetChunkOnTargetStream(stream, true); } -void StreamAwareArena::SecureTheChunk(Stream* chunk_stream, Stream* target_stream, WaitNotificationFn wait_fn) const { - if (chunk_stream && target_stream && chunk_stream != target_stream) { - auto notification = chunk_stream->CreateNotification(1); - notification->ActivateAndUpdate(); - if (wait_fn) { - wait_fn(target_stream, *notification); - } - - target_stream->UpdateStreamClock(notification->GetStreamSyncTable()); - // it should be ok to release the notification now, as the wait is already launch to stream. - } -} #endif } // namespace onnxruntime diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index 8081738f2a5dc..f3c0544124241 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -27,7 +27,6 @@ limitations under the License. #include "core/common/logging/severity.h" #include "core/common/safeint.h" -#include #include "core/framework/arena_extend_strategy.h" #include "core/framework/allocator.h" @@ -103,18 +102,13 @@ class BFCArena : public IAllocator { ArenaType GetArenaType() const { return arena_type_; } - virtual void SecureTheChunk(Stream* /*chunk_stream*/, - Stream* /*target_stream*/, - WaitNotificationFn /*wait_fn*/) const {} - protected: void* AllocateRawInternal(size_t num_bytes, bool dump_log_on_failure, - Stream* stream, - bool enable_cross_stream_reusing, - WaitNotificationFn wait_fn); + Stream* stream); + #ifdef ORT_ENABLE_STREAM - // for any chunk that associated with target stream, reset it to default (nullptr in stream, timestamp 0) + // for any chunk that associated with target stream, reset it to default (nullptr in stream, sync id 0) // perform coalesce if coalesce_flag is true void ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_flag); #endif @@ -168,7 +162,7 @@ class BFCArena : public IAllocator { Stream* stream = nullptr; - uint64_t stream_timestamp = 0; + uint64_t stream_sync_id = 0; bool in_use() const { return allocation_id != -1; } @@ -374,9 +368,7 @@ class BFCArena : public IAllocator { BFCArena::Chunk* FindChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes, - Stream* stream, - bool allow_chunk_from_different_stream, - WaitNotificationFn wait_fn = nullptr); + Stream* stream); // Splits the chunk specified by 'h' into two chunks, one at least // of size 'num_bytes'. @@ -516,33 +508,28 @@ class BFCArena : public IAllocator { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(BFCArena); }; + #ifdef ORT_ENABLE_STREAM class StreamAwareArena : public BFCArena { public: StreamAwareArena(std::unique_ptr resource_allocator, size_t total_memory, - bool enable_dynamic_cross_stream_sharing, ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY, int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES, int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK, int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES, int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES); - // If size is 0, then this function returns either NULL, - // or a unique pointer value that can later be successfully - // passed to free(). Whatever, do not dereference that pointer - void* AllocOnStream(size_t size, Stream* current_stream_id, WaitNotificationFn wait_fn); + bool IsStreamAware() const override { return true; } + + // Standard alloc behavior. Returns valid pointer if size > 0 and memory was available. Otherwise returns nullptr. + void* AllocOnStream(size_t size, Stream* current_stream_id) override; void ReleaseStreamBuffers(Stream* stream); static StreamAwareArena* FromBFCArena(BFCArena& arena) { return arena.GetArenaType() == ArenaType::StreamAwareArena ? reinterpret_cast(&arena) : nullptr; } - - virtual void SecureTheChunk(Stream* chunk_stream, Stream* target_stream, WaitNotificationFn wait_fn) const override; - - private: - bool enable_cross_stream_reusing_; }; #endif #ifdef __GNUC__ diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index cfaee309527a6..8030690e7c92d 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -26,17 +26,6 @@ using namespace onnxruntime::common; namespace onnxruntime { -#ifdef ORT_ENABLE_STREAM -static StreamAwareArena* AsStreamBasedAllocator(AllocatorPtr allocator) { - ORT_ENFORCE(allocator.get() != nullptr, "allocator is nullptr"); - if (allocator->Info().alloc_type == OrtArenaAllocator) { - BFCArena* arena_ptr = static_cast(allocator.get()); - return StreamAwareArena::FromBFCArena(*arena_ptr); - } - return nullptr; -} -#endif - IExecutionFrame::IExecutionFrame(const OrtValueNameIdxMap& ort_value_idx_map, const NodeIndexInfo& node_index_info, gsl::span fetch_mlvalue_idxs) @@ -441,13 +430,23 @@ ExecutionFrame::ExecutionFrame(gsl::span feed_mlvalue_idxs, gsl::span #endif // the memory pattern buffer will leave in the whole execution. #ifdef ORT_ENABLE_STREAM - StreamAwareArena* stream_aware_alloc = AsStreamBasedAllocator(alloc); - if (stream_aware_alloc && device_streams_) { + if (alloc->IsStreamAware() && device_streams_) { Stream* mem_pattern_stream = device_streams_->GetRootStream(); - buffer = stream_aware_alloc->AllocOnStream(peak_size, mem_pattern_stream, nullptr); - for (size_t j = 0; j < device_streams_->NumStreams(); j++) { - stream_aware_alloc->SecureTheChunk(mem_pattern_stream, device_streams_->GetStream(j), nullptr); - } + + buffer = alloc->AllocOnStream(peak_size, mem_pattern_stream); + + // this seems unnecessary. any memory pattern buffer would be in use for the entire inference, so + // there's no point at which another stream (as streams are per-inference) would be able to use it. + // given that, it's unclear why we need to update the sync id in all other streams to allow them + // to take this buffer if it was free. + // + // device_stream_collection calls ReleaseStreamBuffers for all streams including the root stream in + // CleanUp, so the chunk will become available to other streams at that point. + // + // Commenting out to verify. + // for (size_t j = 0; j < device_streams_->NumStreams(); j++) { + // stream_aware_arena->WaitOnChunk(mem_pattern_stream, device_streams_->GetStream(j)); + //} } else { buffer = alloc->Alloc(peak_size); } @@ -581,13 +580,9 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va Stream* current_stream = GetValueStream(ort_value_index); if (current_stream) { #ifdef ORT_ENABLE_STREAM - auto stream_aware_alloc = AsStreamBasedAllocator(alloc); - if (stream_aware_alloc) { + if (alloc->IsStreamAware()) { size_t buffer_size = Tensor::CalculateTensorStorageSize(element_type, shape); - // the reused memory must from same EP - auto wait_handle = this->session_state_.GetStreamHandleRegistryInstance().GetWaitHandle( - current_stream->GetDevice(), current_stream->GetDevice()); - void* p_data = stream_aware_alloc->AllocOnStream(buffer_size, current_stream, wait_handle); + void* p_data = alloc->AllocOnStream(buffer_size, current_stream); Tensor::InitOrtValue(element_type, shape, p_data, std::move(alloc), ort_value); } else { Tensor::InitOrtValue(element_type, shape, std::move(alloc), ort_value); diff --git a/onnxruntime/core/framework/execution_steps.cc b/onnxruntime/core/framework/execution_steps.cc index 36f663699be4f..61e26416f2321 100644 --- a/onnxruntime/core/framework/execution_steps.cc +++ b/onnxruntime/core/framework/execution_steps.cc @@ -23,25 +23,25 @@ std::string BarrierStep::ToString() const { return MakeString("Barrier - BarrierId: ", barrier_id_, ", Count: ", 2); } -WaitOnEPStep::WaitOnEPStep(WaitNotificationFn handle, - NotificationIndex idx, NodeIndex node_index) : SequentialExecutionPlan::ExecutionStep(node_index), - wait_handle_(handle), - notification_idx_(idx) {} +WaitOnEPStep::WaitOnEPStep(WaitNotificationFn handle, NotificationIndex idx, NodeIndex node_index) + : SequentialExecutionPlan::ExecutionStep(node_index), + wait_fn_(handle), + notification_idx_(idx) { + ORT_ENFORCE(wait_fn_, "WaitNoficationFn must be provided."); +} Status WaitOnEPStep::Execute(StreamExecutionContext& ctx, size_t stream_idx, SessionScope& /*session_scope*/, const bool& /*terminate_flag*/, bool& continue_flag) { - ORT_ENFORCE(wait_handle_, "WaitOnEPStep.wait_handle is null"); - auto* stream = ctx.GetDeviceStream(stream_idx); auto& notification = *ctx.GetNotification(notification_idx_); - wait_handle_(stream, notification); + wait_fn_(stream, notification); // update the stream's clock status if (stream != nullptr) { - stream->UpdateStreamClock(notification.GetStreamSyncTable()); + stream->UpdateWithAwaitedNotification(notification); } LOGS(ctx.GetLogger(), VERBOSE) << "stream " << stream_idx << " wait on Notification with id: " << notification_idx_; diff --git a/onnxruntime/core/framework/execution_steps.h b/onnxruntime/core/framework/execution_steps.h index 545dabc56b272..b3b3ee6c3ce63 100644 --- a/onnxruntime/core/framework/execution_steps.h +++ b/onnxruntime/core/framework/execution_steps.h @@ -38,7 +38,7 @@ class WaitOnEPStep : public SequentialExecutionPlan::ExecutionStep { std::string ToString() const override; private: - WaitNotificationFn wait_handle_; + WaitNotificationFn wait_fn_; NotificationIndex notification_idx_; }; diff --git a/onnxruntime/core/framework/plugin_ep_stream.h b/onnxruntime/core/framework/plugin_ep_stream.h index 2b89e76e16b76..09938403ad9b5 100644 --- a/onnxruntime/core/framework/plugin_ep_stream.h +++ b/onnxruntime/core/framework/plugin_ep_stream.h @@ -87,8 +87,8 @@ class Stream : public onnxruntime::Stream { return ToStatusAndRelease(ort_status); } - WaitNotificationFn GetWaitNotificationFn() const override { - return Notification::WaitNotificationOnDevice; + const OrtSyncStreamImpl& GetImpl() const { + return impl_; } ~Stream() override { diff --git a/onnxruntime/core/framework/stream_handles.cc b/onnxruntime/core/framework/stream_handles.cc new file mode 100644 index 0000000000000..ab608cdda87c4 --- /dev/null +++ b/onnxruntime/core/framework/stream_handles.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/stream_handles.h" + +#include + +namespace onnxruntime { + +void Stream::UpdateWithAwaitedNotification(const synchronize::Notification& notification) { + const std::unordered_map& stream_sync_info = notification.GetStreamSyncInfo(); + for (const auto& kv : stream_sync_info) { + auto ret = producer_stream_sync_info_.insert(kv); + if (!ret.second) { + // we already have an entry. use the highest value for the producer stream. + ret.first->second = std::max(ret.first->second, kv.second); + } + } +} +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h index 98179b96891b3..414bc1c08adf4 100644 --- a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h +++ b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h @@ -39,14 +39,8 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { } const auto peak_size = mem_patterns_.patterns[i].PeakSize(); - void* buffer; - if (alloc->Info().alloc_type == OrtArenaAllocator) { - // Arena has a specific way to store static memory. - // Arena does not reuse static memory allocated by Reserve. - buffer = static_cast(alloc.get())->Reserve(peak_size); - } else { - buffer = alloc->Alloc(peak_size); - } + // use Reserve for initializers so they don't affect arena growth patterns if an arena is involved. + void* buffer = alloc->Reserve(peak_size); auto buffer_ptr = BufferUniquePtr(buffer, BufferDeleter(std::move(alloc))); auto kvp = buffers_.insert(std::make_pair(location, buffer)); diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index c6bb5d931cbe6..2c0a51f0bfdbc 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -74,28 +74,21 @@ static common::Status AllocateHelper(const AllocatorPtr& allocator, if (source_mlvalue.IsTensor()) { const Tensor& source_tensor = source_mlvalue.Get(); - if (allocator->Info().alloc_type == OrtArenaAllocator) { - void* p_data = nullptr; -#ifdef ORT_ENABLE_STREAM - BFCArena* arena_ptr = static_cast(allocator.get()); - auto* stream_aware_alloc = StreamAwareArena::FromBFCArena(*arena_ptr); - if (stream_aware_alloc && target_stream) { - size_t len = Tensor::CalculateTensorStorageSize(source_tensor.DataType(), source_tensor.Shape()); - p_data = stream_aware_alloc->AllocOnStream(len, target_stream, nullptr); - } -#else - ORT_UNUSED_PARAMETER(target_stream); -#endif // ORT_ENABLE_STREAM - if (p_data == nullptr) { - Tensor::InitOrtValue(source_tensor.DataType(), - source_tensor.Shape(), - allocator, target_mlvalue); - } else { - Tensor::InitOrtValue(source_tensor.DataType(), - source_tensor.Shape(), - p_data, - allocator, target_mlvalue); + void* p_data = nullptr; + if (target_stream && allocator->IsStreamAware()) { + size_t len = Tensor::CalculateTensorStorageSize(source_tensor.DataType(), source_tensor.Shape()); + p_data = allocator->AllocOnStream(len, target_stream); + if (p_data == nullptr && len > 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Allocation failed."); } + } + + if (p_data) { + Tensor::InitOrtValue(source_tensor.DataType(), + source_tensor.Shape(), + p_data, + allocator, target_mlvalue); + } else { Tensor::InitOrtValue(source_tensor.DataType(), source_tensor.Shape(), diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index a691faaffd2a0..4bcf71335d15e 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1488,8 +1488,7 @@ AllocatorPtr CANNExecutionProvider::CreateCannAllocator(OrtDevice::DeviceId devi -1, -1, -1L)}, - true, - false); + true); return CreateAllocator(default_memory_info); } diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.cc b/onnxruntime/core/providers/cann/cann_stream_handle.cc index 041fc54a725a9..cdb727f263480 100644 --- a/onnxruntime/core/providers/cann/cann_stream_handle.cc +++ b/onnxruntime/core/providers/cann/cann_stream_handle.cc @@ -18,7 +18,7 @@ struct CannNotification : public synchronize::Notification { } void Activate() override { - CANN_CALL_THROW(aclrtRecordEvent(event_, static_cast(stream_.GetHandle()))); + CANN_CALL_THROW(aclrtRecordEvent(event_, static_cast(GetStream().GetHandle()))); } void wait_on_device(Stream& device_stream) { diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.h b/onnxruntime/core/providers/cann/cann_stream_handle.h index f20eafb2b4b35..e7a352298b2bd 100644 --- a/onnxruntime/core/providers/cann/cann_stream_handle.h +++ b/onnxruntime/core/providers/cann/cann_stream_handle.h @@ -24,8 +24,6 @@ struct CannStream : Stream { void Flush() override; bool own_stream_{true}; - - WaitNotificationFn GetWaitNotificationFn() const override { return WaitCannNotificationOnDevice; } }; void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 1f4c9fcdbc073..e036c7764d041 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -161,9 +161,7 @@ AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId devi {default_memory_arena_cfg ? *default_memory_arena_cfg : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, // make it stream aware - true, - // enable cross stream sharing? - false); + true); // CUDA malloc/free is expensive so always use an arena return CreateAllocator(default_memory_info); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 054dd9f9da9f3..bcbf1d4a1c800 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -41,8 +41,12 @@ class CudaKernel : public OpKernel { template inline IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes, onnxruntime::Stream* stream) const { - if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), count_or_bytes, false, stream, WaitCudaNotificationOnDevice); + if (count_or_bytes == 0) { + return nullptr; + } + + return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), count_or_bytes, false, + stream); } // Different from GetScratchBuffer which use IAllocator::Alloc() to allocate memory, diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 021a6f1e7e350..e8d133779f33c 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -355,8 +355,9 @@ struct CudaOrtAllocator : OrtAllocator { Alloc = AllocImpl; Free = FreeImpl; Info = InfoImpl; - Reserve = AllocImpl; // no special behavior for Reserve so use AllocImpl - GetStats = nullptr; // GetStatsImpl. The CUDA allocators don't have stats currently so we can skip. + Reserve = AllocImpl; // no special behavior for Reserve so use AllocImpl + GetStats = nullptr; // GetStatsImpl. The CUDA allocators don't have stats currently so we can skip. + AllocOnStream = nullptr; // TODO. Plugin EP arena to provide this. const OrtEpApi& ep_api = *api.GetEpApi(); const OrtMemoryDevice* mem_device = ep_api.MemoryInfo_GetMemoryDevice(mem_info); @@ -679,7 +680,6 @@ struct CudaEpFactory : OrtEpFactory { CreateAllocator = CreateAllocatorImpl; ReleaseAllocator = ReleaseAllocatorImpl; - CreateDataTransfer = CreateDataTransferImpl; IsStreamAware = IsStreamAwareImpl; diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index b6cbffb073774..fbee1841ae8d5 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -38,7 +38,7 @@ struct CudaNotification : public synchronize::Notification { void Activate() override { // record event with cudaEventBlockingSync so we can support sync on host without busy wait. - CUDA_CALL_THROW(cudaEventRecord(event_, static_cast(stream_.GetHandle()))); + CUDA_CALL_THROW(cudaEventRecord(event_, static_cast(GetStream().GetHandle()))); } void wait_on_device(Stream& device_stream) { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index c75cf15f7c2f8..1be7a3d510082 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -48,8 +48,6 @@ struct CudaStream : Stream { onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } - WaitNotificationFn GetWaitNotificationFn() const override { return WaitCudaNotificationOnDevice; } - private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 4f8e6605ce151..b232124dc6b00 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -367,7 +367,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); const CudaT* input_data = reinterpret_cast(input.Data()); if (calculate_sqt) { - input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitCudaNotificationOnDevice); + input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); input_data = reinterpret_cast(input_data_buffer.get()); fast_divmod tmp_div; Impl_Mul(stream, static_cast(SimpleBroadcast::NoBroadcast), nullptr, @@ -384,7 +384,9 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, } break; case ApplicableMatrixReduction::Columns: { const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size(m, n); - auto buffer = buffer_size_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, buffer_size_bytes, false, ort_stream, WaitCudaNotificationOnDevice); + auto buffer = buffer_size_bytes == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, buffer_size_bytes, false, ort_stream); ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data, reinterpret_cast(output.MutableData()), m, n, buffer.get(), buffer_size_bytes)); @@ -421,7 +423,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if ((ReduceTensorIndices == CUDNN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same::value) || (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES && std::is_same::value)) { // ArgMax/ArgMin with FP16 are not supported by cudnn, so convert input to fp32 then call cudnn - temp_X = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitCudaNotificationOnDevice); + temp_X = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); Impl_Cast(stream, reinterpret_cast(input.Data()), temp_X.get(), input_shape.Size()); } else { cudnn_type_X = CudnnTensor::GetDataType(); @@ -444,18 +446,22 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, CudaStream* cuda_stream = static_cast(ort_stream); CUDNN_RETURN_IF_ERROR(cudnnGetReductionWorkspaceSize(CudaKernel::GetCudnnHandle(cuda_stream), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); - auto workspace_cuda = workspace_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, workspace_bytes, false, ort_stream, WaitCudaNotificationOnDevice); + auto workspace_cuda = workspace_bytes == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, workspace_bytes, false, ort_stream); size_t indices_bytes = 0; CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(CudaKernel::GetCudnnHandle(cuda_stream), reduce_desc, input_tensor, output_tensor, &indices_bytes)); - auto indices_cuda = indices_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream, WaitCudaNotificationOnDevice); + auto indices_cuda = indices_bytes == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream); if (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES) { IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); CudaT* input_data = nullptr; if (calculate_sqt) { - input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitCudaNotificationOnDevice); + input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); input_data = reinterpret_cast(input_data_buffer.get()); fast_divmod tmp_div; Impl_Mul(stream, @@ -482,7 +488,9 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, size_t indices_bytes_max = 0; CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(CudaKernel::GetCudnnHandle(cuda_stream), reduce_max_desc, input_tensor, output_tensor, &indices_bytes_max)); - auto indices_cuda_max = indices_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream, WaitCudaNotificationOnDevice); + auto indices_cuda_max = indices_bytes == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream); auto* p_output = reinterpret_cast(output.template MutableData()); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( CudaKernel::GetCudnnHandle(cuda_stream), reduce_max_desc, indices_cuda_max.get(), indices_bytes_max, @@ -493,9 +501,11 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, // Exp(X-ReduceMax) const TensorShape output_shape(output_dims); - auto exp_result_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitCudaNotificationOnDevice); + auto exp_result_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); auto exp_result = exp_result_buffer.get(); - auto log_sum_result_buffer = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitCudaNotificationOnDevice); + auto log_sum_result_buffer = output_count == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); auto log_sum_result = log_sum_result_buffer.get(); BinaryElementwisePreparation prepare; ORT_RETURN_IF_ERROR(prepare.BinaryElementwiseBroadcastPrepareHelper(input_shape, output_shape, input_shape)); @@ -563,7 +573,9 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, } } else { if (temp_X) { - auto temp_output = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitCudaNotificationOnDevice); + auto temp_output = output_count == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( CudaKernel::GetCudnnHandle(cuda_stream), reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, @@ -589,14 +601,18 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.MutableData(), static_cast(0), output_count * sizeof(int64_t), stream)); } else { if (temp_X) { - auto temp_output = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitCudaNotificationOnDevice); + auto temp_output = output_count == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( CudaKernel::GetCudnnHandle(cuda_stream), reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, &one, input_tensor, temp_X.get(), &zero, output_tensor, temp_output.get())); } else { - auto temp_output = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitCudaNotificationOnDevice); + auto temp_output = output_count == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( CudaKernel::GetCudnnHandle(cuda_stream), reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, @@ -605,7 +621,8 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, } // CUDA reduction index is uint32_t for now, cast it to int64_t according to ONNX spec - Impl_Cast(stream, reinterpret_cast(indices_cuda.get()), output.MutableData(), output_count); + Impl_Cast(stream, reinterpret_cast(indices_cuda.get()), + output.MutableData(), output_count); } } diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc index 2df995d6e62ac..f4a33a128608a 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc @@ -120,7 +120,7 @@ IAllocatorUniquePtr CudaTuningContext::GetScratchBuffer( return nullptr; } - return IAllocator::MakeUniquePtr(it->second, num_bytes, false, stream, WaitCudaNotificationOnDevice); + return IAllocator::MakeUniquePtr(it->second, num_bytes, false, stream); } } // namespace tunable diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index aa8b21ea3fe52..41b55e3baf508 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -323,9 +323,7 @@ AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::Devic : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, // make it stream aware - true, - // enable cross stream sharing? - false); + true); // ROCM malloc/free is expensive so always use an arena return CreateAllocator(default_memory_info); diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index 8ed4e4a45a8c4..6e492327a73a3 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -18,11 +18,12 @@ struct MIGraphXNotification : public synchronize::Notification { void Activate() override { // record event with hipEventBlockingSync so we can support sync on host without busy wait. - HIP_CALL_THROW(hipEventRecord(event_, static_cast(stream_.GetHandle()))); + HIP_CALL_THROW(hipEventRecord(event_, static_cast(GetStream().GetHandle()))); } void wait_on_device(Stream& device_stream) { - ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream.GetDevice().ToString()); + ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", + device_stream.GetDevice().ToString()); // launch a wait command to the migraphx stream HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), event_, 0)); }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index d0ef3334b38c9..886103690c661 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -29,8 +29,6 @@ struct MIGraphXStream : Stream { virtual void* GetResource(int version, int id) const; - virtual WaitNotificationFn GetWaitNotificationFn() const { return WaitMIGraphXNotificationOnDevice; } - private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; diff --git a/onnxruntime/core/session/allocator_adapters.cc b/onnxruntime/core/session/allocator_adapters.cc index c6eff29a0bd4f..008d54c44ff70 100644 --- a/onnxruntime/core/session/allocator_adapters.cc +++ b/onnxruntime/core/session/allocator_adapters.cc @@ -3,6 +3,7 @@ #include "allocator_adapters.h" #include "core/framework/error_code_helper.h" +#include "core/framework/plugin_ep_stream.h" #include "core/session/abi_devices.h" #include "core/session/abi_key_value_pairs.h" #include "core/session/environment.h" @@ -21,24 +22,33 @@ namespace { // `IAllocatorImplWrappingOrtAllocator` to ensure compatibility. constexpr uint32_t kOrtAllocatorReserveMinVersion = 18; constexpr uint32_t kOrtAllocatorStatsMinVersion = 23; +constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; } // namespace OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxruntime::AllocatorPtr&& i_allocator) : i_allocator_(std::move(i_allocator)) { OrtAllocator::version = ORT_API_VERSION; - OrtAllocator::Alloc = - [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; - OrtAllocator::Free = - [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; - OrtAllocator::Info = - [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { + return static_cast(this_)->Alloc(size); + }; + + OrtAllocator::Free = [](OrtAllocator* this_, void* p) { + static_cast(this_)->Free(p); + }; + + OrtAllocator::Info = [](const OrtAllocator* this_) { + return static_cast(this_)->Info(); + }; + if (OrtAllocator::version >= kOrtAllocatorReserveMinVersion) { - OrtAllocator::Reserve = - [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Reserve(size); }; + OrtAllocator::Reserve = [](OrtAllocator* this_, size_t size) { + return static_cast(this_)->Reserve(size); + }; } + if (OrtAllocator::version >= kOrtAllocatorStatsMinVersion) { - OrtAllocator::GetStats = - [](const OrtAllocator* this_, OrtKeyValuePairs** stats) noexcept -> OrtStatusPtr { + OrtAllocator::GetStats = [](const OrtAllocator* this_, OrtKeyValuePairs** stats) noexcept -> OrtStatusPtr { API_IMPL_BEGIN auto kvp = std::make_unique(); const auto& stats_map = static_cast(this_)->Stats(); @@ -48,12 +58,22 @@ OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxrunti API_IMPL_END }; } + + if (OrtAllocator::version >= kOrtAllocatorAllocOnStreamMinVersion) { + OrtAllocator::AllocOnStream = [](OrtAllocator* this_, size_t size, OrtSyncStream* stream) { + return static_cast(this_)->AllocOnStream(size, stream); + }; + } } void* OrtAllocatorImplWrappingIAllocator::Alloc(size_t size) { return i_allocator_->Alloc(size); } +void* OrtAllocatorImplWrappingIAllocator::AllocOnStream(size_t size, OrtSyncStream* stream) { + return i_allocator_->AllocOnStream(size, static_cast(stream)); +} + void* OrtAllocatorImplWrappingIAllocator::Reserve(size_t size) { return i_allocator_->Reserve(size); } @@ -105,6 +125,18 @@ void* IAllocatorImplWrappingOrtAllocator::Alloc(size_t size) { return ort_allocator_->Alloc(ort_allocator_.get(), size); } +bool IAllocatorImplWrappingOrtAllocator::IsStreamAware() const { + return ort_allocator_->version >= kOrtAllocatorAllocOnStreamMinVersion && ort_allocator_->AllocOnStream != nullptr; +} + +void* IAllocatorImplWrappingOrtAllocator::AllocOnStream(size_t size, Stream* stream) { + if (ort_allocator_->version >= kOrtAllocatorAllocOnStreamMinVersion && ort_allocator_->AllocOnStream) { + return ort_allocator_->AllocOnStream(ort_allocator_.get(), size, static_cast(stream)); + } + + return ort_allocator_->Alloc(ort_allocator_.get(), size); +} + void* IAllocatorImplWrappingOrtAllocator::Reserve(size_t size) { if (ort_allocator_->version >= kOrtAllocatorReserveMinVersion && ort_allocator_->Reserve) { return ort_allocator_->Reserve(ort_allocator_.get(), size); diff --git a/onnxruntime/core/session/allocator_adapters.h b/onnxruntime/core/session/allocator_adapters.h index 544c7828e46f8..d67eae90985bf 100644 --- a/onnxruntime/core/session/allocator_adapters.h +++ b/onnxruntime/core/session/allocator_adapters.h @@ -25,6 +25,7 @@ struct OrtAllocatorImplWrappingIAllocator final : public OrtAllocatorImpl { ~OrtAllocatorImplWrappingIAllocator() override = default; void* Alloc(size_t size); + void* AllocOnStream(size_t size, OrtSyncStream* stream); void Free(void* p); void* Reserve(size_t size); @@ -56,6 +57,9 @@ class IAllocatorImplWrappingOrtAllocator final : public IAllocator { void Free(void* p) override; void* Reserve(size_t size) override; + bool IsStreamAware() const override; + void* AllocOnStream(size_t size, Stream* stream) override; + const OrtAllocator* GetWrappedOrtAllocator() const { return ort_allocator_.get(); } diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index c5fc4e7ccf76f..2a898a2b0bf9f 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -757,7 +757,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKerne return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); } onnxruntime::Stream* stream = reinterpret_cast(context)->GetComputeStream(); - *out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream, stream->GetWaitNotificationFn()); + *out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream); return nullptr; }; diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 493c0a106074c..450a8bad09392 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -639,6 +639,13 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, bool replace_existing) { // NOTE: memory_info is guaranteed to come from the OrtEpDevice when this is called + if (allocator_type == OrtAllocatorType::OrtArenaAllocator) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "OrtAllocatorType::OrtArenaAllocator is reserved for ONNX Runtime internal usage only. " + "The EP implements arena support internally so please use OrtDeviceAllocator and provide " + "any arena options via the allocator options."); + } + // we need to remove from shared_ort_allocators_ first in case the entry in shared_allocators_ owns the pointer in // shared_ort_allocators_. if (auto it = FindExistingAllocator(shared_ort_allocators_, memory_info, /*match_name*/ true); @@ -669,48 +676,21 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, return ToStatusAndRelease(ort_status); } + if (allocator->Info(allocator)->alloc_type == OrtAllocatorType::OrtArenaAllocator) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "OrtEpFactory returned an allocator with OrtAllocatorType of OrtArenaAllocator. " + "This type is reserved for ONNX Runtime internal usage only, as any arena usage by the " + "EP library should be opaque to ORT"); + } + auto ort_allocator = OrtAllocatorUniquePtr(allocator, [&ep_device](OrtAllocator* allocator) { ep_device.ep_factory->ReleaseAllocator(ep_device.ep_factory, allocator); }); - AllocatorPtr shared_allocator; - - if (allocator_type == OrtArenaAllocator) { - // wrap with ORT arena - OrtArenaCfg arena_cfg; - if (allocator_options != nullptr) { - auto status = OrtArenaCfg::FromKeyValuePairs(*allocator_options, arena_cfg); - } - - bool stream_aware_arena = ep_device.ep_factory->IsStreamAware(ep_device.ep_factory); - - AllocatorCreationInfo alloc_creation_info{ - [&ort_allocator](int) -> std::unique_ptr { - return std::make_unique(std::move(ort_allocator)); - }, - /*unused*/ -1, // arg to the lambda above that is ignored as the device id comes from the allocator - /*create_arena*/ true, - arena_cfg, - stream_aware_arena, - }; - - shared_allocator = CreateAllocator(alloc_creation_info); - - // need an OrtAllocator to return to the user so we need yet another layer. - // we pass in a copy of the AllocatorPtr (which is a shared_ptr) in order to maintain the overall condition that - // shared_allocators_ is the main owner of the allocator and the last place we delete from when removing - // from shared_ort_allocators_, arena_ort_allocators_ and shared_allocators_. - auto arena_ort_allocator = std::make_unique(AllocatorPtr(shared_allocator)); - allocator = arena_ort_allocator.get(); - - // store the entry using the EPs memory info for easier lookup when removing - arena_ort_allocators_.insert({&memory_info, std::move(arena_ort_allocator)}); - } else { - shared_ort_allocators_.insert(allocator); - shared_allocator = std::make_shared(std::move(ort_allocator)); - } + shared_ort_allocators_.insert(allocator); + AllocatorPtr shared_allocator = std::make_shared(std::move(ort_allocator)); shared_allocators_.push_back(std::move(shared_allocator)); if (allocator_out != nullptr) { diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc index 6c57f95719f41..8fd1fc198374f 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/ep_api.cc @@ -180,6 +180,31 @@ ORT_API(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_d return memory_device->Id(); } +ORT_API(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* ort_stream) { + // the EP API should only ever see plugin_ep::Stream instances + const auto& stream = *reinterpret_cast(ort_stream); + return &stream.GetImpl(); +} + +ORT_API(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream) { + return static_cast(stream)->GetSyncId(); +} + +ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, + _In_ const OrtSyncStream* consumer_stream) { + uint64_t id{0}; + if (producer_stream && consumer_stream) { + const auto& producer = *static_cast(producer_stream); + const auto& consumer = *static_cast(consumer_stream); + + // If both streams are valid, we can return the sync id for the last wait on the producer stream. + // This is useful for synchronizing operations between different streams. + id = consumer.GetSyncIdForLastWaitOnStream(producer); + } + + return id; +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -201,6 +226,10 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::MemoryDevice_GetMemoryType, &OrtExecutionProviderApi::MemoryDevice_GetVendorId, &OrtExecutionProviderApi::MemoryDevice_GetDeviceId, + + &OrtExecutionProviderApi::SyncStream_GetImpl, + &OrtExecutionProviderApi::SyncStream_GetSyncId, + &OrtExecutionProviderApi::GetSyncIdForLastWaitOnSyncStream, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/ep_api.h index 1af23664f71eb..c0dc79f3fb333 100644 --- a/onnxruntime/core/session/ep_api.h +++ b/onnxruntime/core/session/ep_api.h @@ -35,4 +35,9 @@ ORT_API(OrtMemoryInfoDeviceType, MemoryDevice_GetDeviceType, _In_ const OrtMemor ORT_API(OrtDeviceMemoryType, MemoryDevice_GetMemoryType, _In_ const OrtMemoryDevice* memory_device); ORT_API(uint32_t, MemoryDevice_GetVendorId, _In_ const OrtMemoryDevice* memory_device); ORT_API(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device); + +ORT_API(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* stream); +ORT_API(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream); +ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, + _In_ const OrtSyncStream* consumer_stream); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc index c776020b037f0..c7d7ea2e8a4ec 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc @@ -550,9 +550,12 @@ Status PluginExecutionProvider::SetEpDynamicOptions(gsl::span } std::unique_ptr PluginExecutionProvider::GetDataTransfer() const { OrtDataTransferImpl* data_transfer_impl = nullptr; - OrtStatus* status = ep_factory_.CreateDataTransfer(&ep_factory_, &data_transfer_impl); - if (status != nullptr) { - ORT_THROW("Error creating data transfer: ", ToStatusAndRelease(status).ToString()); + + if (ep_factory_.CreateDataTransfer != nullptr) { + OrtStatus* status = ep_factory_.CreateDataTransfer(&ep_factory_, &data_transfer_impl); + if (status != nullptr) { + ORT_THROW("Error creating data transfer: ", ToStatusAndRelease(status).ToString()); + } } if (data_transfer_impl == nullptr) { @@ -568,6 +571,11 @@ std::vector PluginExecutionProvider::CreatePreferredAllocators() { for (const auto* memory_info : allocator_mem_infos_) { OrtAllocator* ort_allocator_ptr = nullptr; + + if (!ort_ep_->CreateAllocator && !ep_factory_.CreateAllocator) { + ORT_THROW("The OrtEpDevice requires the EP library to implement an allocator, but none were found."); + } + // prefer OrtEp function if available, otherwise fall back to using the OrtEpFactory implementation. OrtStatus* ort_status = ort_ep_->CreateAllocator ? ort_ep_->CreateAllocator(ort_ep_.get(), memory_info, &ort_allocator_ptr) @@ -579,6 +587,13 @@ std::vector PluginExecutionProvider::CreatePreferredAllocators() { ORT_THROW("Error creating allocator: ", ToStatusAndRelease(ort_status).ToString()); } + if (ort_allocator_ptr->Info(ort_allocator_ptr)->alloc_type == OrtAllocatorType::OrtArenaAllocator) { + ORT_THROW( + "OrtEpFactory returned an allocator with OrtAllocatorType of OrtArenaAllocator. " + "This type is reserved for ONNX Runtime internal usage only, as any arena usage by the " + "EP library should be opaque to ORT"); + } + auto ort_allocator = OrtAllocatorUniquePtr( ort_allocator_ptr, [this](OrtAllocator* allocator) { @@ -592,7 +607,7 @@ std::vector PluginExecutionProvider::CreatePreferredAllocators() { void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& registry, AllocatorMap& /*allocators*/) const { - if (!ep_factory_.IsStreamAware(&ep_factory_)) { + if (ep_factory_.IsStreamAware == nullptr || !ep_factory_.IsStreamAware(&ep_factory_)) { return; } @@ -602,6 +617,10 @@ void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistr continue; } + if (!ort_ep_->CreateSyncStreamForDevice && !ep_factory_.CreateSyncStreamForDevice) { + ORT_THROW("The OrtEpFactory is stream aware, but did not provide CreateSyncStreamForDevice."); + } + auto device_type = mem_info->device.Type(); registry.RegisterCreateStreamFn( diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 25cabd256e318..f4f76a389030e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3571,7 +3571,7 @@ common::Status InferenceSession::ValidateAndParseShrinkArenaString(const std::st ++iter; } - // Shrink if it is an arena based allocator + // Shrink if it is a BFCArena allocator // Iterate through the registered allocators as we could have multiple allocators for the device+type // if they differ by vendor_id. for (const auto& [device, allocator_ptr] : session_state_->GetAllocators()) { diff --git a/onnxruntime/test/autoep/library/ep_allocator.h b/onnxruntime/test/autoep/library/ep_allocator.h index 624b4fcb484cd..e46c03dfc8f14 100644 --- a/onnxruntime/test/autoep/library/ep_allocator.h +++ b/onnxruntime/test/autoep/library/ep_allocator.h @@ -5,17 +5,50 @@ #include "example_plugin_ep_utils.h" +#include + // from onnxruntime/core/framework/allocator_stats.h +// copied from onnxruntime::AllocatorStats struct AllocatorStats { int64_t num_allocs; // Number of allocations. int64_t num_reserves; // Number of reserves. (Number of calls to Reserve() in arena-based allocators) + int64_t num_arena_extensions; // Number of arena extensions (Relevant only for arena based allocators) + int64_t num_arena_shrinkages; // Number of arena shrinkages (Relevant only for arena based allocators) int64_t bytes_in_use; // Number of bytes in use. int64_t total_allocated_bytes; // The total number of allocated bytes by the allocator. int64_t max_bytes_in_use; // The maximum bytes in use. int64_t max_alloc_size; // The max single allocation seen. - int64_t bytes_limit; // The upper limit what the allocator can allocate, if such a limit - // is known. Certain allocator may return 0 to indicate the limit is - // unknown. + // The upper limit what the allocator can allocate, if such a limit + // is known. Certain allocator may return 0 to indicate the limit is unknown. + int64_t bytes_limit; + + void ToKeyValuePairs(const OrtApi& api, OrtKeyValuePairs* kvps) const { + if (num_allocs > 0 || bytes_limit != 0) { + api.AddKeyValuePair(kvps, "Limit", std::to_string(bytes_limit).c_str()); + api.AddKeyValuePair(kvps, "InUse", std::to_string(bytes_in_use).c_str()); + api.AddKeyValuePair(kvps, "TotalAllocated", std::to_string(total_allocated_bytes).c_str()); + api.AddKeyValuePair(kvps, "MaxInUse", std::to_string(max_bytes_in_use).c_str()); + api.AddKeyValuePair(kvps, "NumAllocs", std::to_string(num_allocs).c_str()); + api.AddKeyValuePair(kvps, "NumReserves", std::to_string(num_reserves).c_str()); + api.AddKeyValuePair(kvps, "NumArenaExtensions", std::to_string(num_arena_extensions).c_str()); + api.AddKeyValuePair(kvps, "NumArenaShrinkages", std::to_string(num_arena_shrinkages).c_str()); + api.AddKeyValuePair(kvps, "MaxAllocSize", std::to_string(max_alloc_size).c_str()); + } + } + + std::string DebugString() const { + std::ostringstream ss; + ss << "Limit: " << this->bytes_limit << "\n" + << "InUse: " << this->bytes_in_use << "\n" + << "TotalAllocated: " << this->total_allocated_bytes << "\n" + << "MaxInUse: " << this->max_bytes_in_use << "\n" + << "NumAllocs: " << this->num_allocs << "\n" + << "NumReserves: " << this->num_reserves << "\n" + << "NumArenaExtensions: " << this->num_arena_extensions << "\n" + << "NumArenaShrinkages: " << this->num_arena_shrinkages << "\n" + << "MaxAllocSize: " << this->max_alloc_size << "\n"; + return ss.str(); + } }; struct CustomAllocator : OrtAllocator { @@ -27,6 +60,7 @@ struct CustomAllocator : OrtAllocator { Info = InfoImpl; Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena GetStats = GetStatsImpl; // this can be set to nullptr if you don't want to implement it + AllocOnStream = nullptr; } static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { diff --git a/onnxruntime/test/autoep/library/ep_arena.cc b/onnxruntime/test/autoep/library/ep_arena.cc new file mode 100644 index 0000000000000..aa0db71e97925 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_arena.cc @@ -0,0 +1,778 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_arena.h" + +#include +#include + +namespace { +std::string GetAllocatorName(const OrtApi& api, OrtAllocator& allocator) { + const OrtMemoryInfo* mem_info = allocator.Info(&allocator); + const char* allocator_name; + auto* status = api.MemoryInfoGetName(mem_info, &allocator_name); // never fails + static_cast(status); + return allocator_name; +} +} // namespace + +ArenaImpl::ArenaImpl(AllocatorUniquePtr allocator, const ArenaConfig& config, const OrtApi& api, + const OrtLogger& logger) + : device_allocator_{std::move(allocator)}, + allocator_name_{GetAllocatorName(api, *device_allocator_)}, + config_{config}, + next_allocation_id_(1), + free_chunks_list_(kInvalidChunkHandle), + api_{api}, + ep_api_{*api_.GetEpApi()}, + logger_{logger} { + LOG(INFO, "Creating ArenaImpl for " + << allocator_name_ + << " with following configs: initial_chunk_size_bytes: " << config_.initial_chunk_size_bytes + << " max_dead_bytes_per_chunk: " << config_.max_dead_bytes_per_chunk + << " initial_growth_chunk_size_bytes: " << config_.initial_growth_chunk_size_bytes + << " max_power_of_two_extend_bytes: " << config_.max_power_of_two_extend_bytes + << " memory limit: " << config_.max_mem + << " arena_extend_strategy: " << config_.arena_extend_strategy); + + curr_region_allocation_bytes_ = RoundedBytes( + std::min(config_.max_mem, static_cast(config_.initial_chunk_size_bytes))); + + stats_.bytes_limit = static_cast(config.max_mem); + + // Create a bunch of bins of various good sizes. + + // We create bins to fit all possible ranges that cover the + // config_.max_mem starting from allocations up to 256 bytes to + // allocations up to (and including) the memory limit. + LOG(VERBOSE, "Creating " << kNumBins << " bins of max chunk size " + << BinNumToSize(0) << " to " << BinNumToSize(kNumBins - 1)); + + for (BinNum b = 0; b < kNumBins; b++) { + size_t bin_size = BinNumToSize(b); + new (BinFromIndex(b)) Bin(this, bin_size); + EP_ENFORCE((BinForSize(bin_size) == BinFromIndex(b) && + BinForSize(bin_size + 255) == BinFromIndex(b) && + BinForSize(bin_size * 2 - 1) == BinFromIndex(b)), + "Invalid bin size for bin " << b); + + if (b + 1 < kNumBins) { + EP_ENFORCE(BinForSize(bin_size * 2) != BinFromIndex(b), "Invalid bin size for " << b); + } + } +} + +ArenaImpl::~ArenaImpl() { + for (const auto& region : region_manager_.regions()) { + device_allocator_->Free(device_allocator_.get(), region.ptr()); + } + + for (const auto& reserve_chunk : reserved_chunks_) { + device_allocator_->Free(device_allocator_.get(), reserve_chunk.first); + } + + for (BinNum b = 0; b < kNumBins; b++) { + BinFromIndex(b)->~Bin(); + } +} + +ArenaImpl::Chunk* ArenaImpl::ChunkFromHandle(ChunkHandle h) { + EP_ENFORCE(h < chunks_.size(), "ChunkFromHandle"); + return &(chunks_[h]); +} + +OrtStatus* ArenaImpl::Extend(size_t rounded_bytes) { + size_t available_bytes = config_.max_mem - static_cast(stats_.total_allocated_bytes); + // Rounds available_bytes down to the nearest multiple of kMinAllocationSize. + available_bytes = (available_bytes / kMinAllocationSize) * kMinAllocationSize; + + // Do we have enough space to handle the client's request? + // If not, fail immediately. + if (rounded_bytes > available_bytes) { + RETURN_ERROR(ORT_EP_FAIL, "Available memory of " << available_bytes << " is smaller than requested bytes of " + << rounded_bytes); + } + + auto safe_alloc = [this](size_t alloc_bytes) { + void* new_mem = nullptr; + try { + new_mem = device_allocator_->Alloc(device_allocator_.get(), alloc_bytes); + } catch (const std::bad_alloc&) { + // attempted allocation can throw std::bad_alloc. we want to treat this the same as if it returned nullptr + // so swallow the exception + } + // catch (const MyException& exception) { + // if your implementation threw, consider swallowing the exception to enable attempting a smaller allocation + // if possible + //} + return new_mem; + }; + + auto get_extend_bytes = [this, available_bytes](const size_t bytes, size_t& extend_bytes) -> OrtStatus* { + extend_bytes = 0; + if (config_.arena_extend_strategy == ArenaExtendStrategy::kNextPowerOfTwo) { + // If curr_region_allocation_bytes_ is not enough to satisfy the + // allocation, keep multiplying by a power of two until that is + // sufficient. + bool increased_allocation = false; + while (bytes > curr_region_allocation_bytes_) { + curr_region_allocation_bytes_ *= 2; + increased_allocation = true; + } + + extend_bytes = std::min(static_cast(curr_region_allocation_bytes_), available_bytes); + + // we allocated the same number of bytes as the current region + // the 2x is to double the minimum size of the next amount we'll allocate + if (!increased_allocation) { + if (config_.arena_extend_strategy == ArenaExtendStrategy::kNextPowerOfTwo && + static_cast(curr_region_allocation_bytes_) * 2 < config_.max_power_of_two_extend_bytes) { + curr_region_allocation_bytes_ *= 2; + } else { + curr_region_allocation_bytes_ = config_.max_power_of_two_extend_bytes; + } + } + } else if (config_.arena_extend_strategy == ArenaExtendStrategy::kSameAsRequested) { + // BFC Arena could cause internal and external fragmentation. But, running training with + // big batch size will be very sensitive to fragmentation. So, to avoid fragmentation, + // just extend arena with actual requested size. + extend_bytes = bytes; + } else { + RETURN_ERROR(ORT_INVALID_ARGUMENT, "Invalid arena extend strategy." << config_.arena_extend_strategy); + } + + return nullptr; + }; + + size_t bytes; + RETURN_IF_ERROR(get_extend_bytes(rounded_bytes, bytes)); + + // Try allocating. + void* mem_addr = safe_alloc(bytes); + + static constexpr float kBackpedalFactor = 0.9f; + // Try allocating less memory. + while (mem_addr == nullptr) { + // kBackpedalFactor is float, bytes is size_t. The result of bytes * kBackpedalFactor is float. When we cast it to + // size_t, which is a smaller type, it could loss data. This is what C4244 complains. The "static_cast" here + // is to suppress the warning. C26451 suggest we may change kBackpedalFactor to double to get better accuary. But if + // we do that, AMD GPU CI build pipeline will have an "out-of-memory" error. So I choose to keep this piece of code + // untouched and disable the warning first. +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 26451) +#endif + bytes = RoundedBytes(static_cast(bytes * kBackpedalFactor)); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + // give up if we can't satisfy the requested size, or we're attempting an allocation of less than 8K. + // + // the latter protects against an infinite loop that occurs when bytes is less than 2560. at that point the 10% + // reduction to 2304 bytes is undone by rounding to a 256 boundary in RoundedBytes, leading to an infinite loop. + // the 8K value is just to give up a little earlier vs. getting all the way down to 2560 bytes. + // If we can't allocate 8K, we're pretty much dead. + if (bytes < rounded_bytes || bytes < 8 * 1024) + break; + + mem_addr = safe_alloc(bytes); + } + + if (mem_addr == nullptr) { + RETURN_ERROR(ORT_EP_FAIL, "Failed to allocate memory for requested buffer of size " << rounded_bytes); + } + + LOG(INFO, "Extended allocation by " << bytes << " bytes."); + + stats_.total_allocated_bytes += bytes; + LOG(INFO, "Total allocated bytes: " << stats_.total_allocated_bytes); + + LOG(INFO, "Allocated memory at " << mem_addr << " to " << static_cast(static_cast(mem_addr) + bytes)); + + region_manager_.AddAllocationRegion(mem_addr, bytes, stats_.num_arena_extensions); + stats_.num_arena_extensions += 1; + + // Create one large chunk for the whole memory space that will + // be chunked later. + ChunkHandle h = AllocateChunk(); + ArenaImpl::Chunk* c = ChunkFromHandle(h); + c->ptr = mem_addr; + c->size = bytes; + c->allocation_id = -1; + c->prev = kInvalidChunkHandle; + c->next = kInvalidChunkHandle; + // assign the new created chunk to default stream, so it can be pick up by any stream + c->stream = nullptr; + + region_manager_.set_handle(c->ptr, h); + + // TODO(vrv): Try to merge this new region with an existing region, + // if the address space is contiguous, to avoid fragmentation + // across regions. + + // Insert the chunk into the right bin. + InsertFreeChunkIntoBin(h); + + return nullptr; +} + +ArenaImpl::ChunkHandle +ArenaImpl::AllocateChunk() { + if (free_chunks_list_ != kInvalidChunkHandle) { + ChunkHandle h = free_chunks_list_; + Chunk* c = ChunkFromHandle(h); + free_chunks_list_ = c->next; + return h; + } + ChunkHandle h = chunks_.size(); + chunks_.resize(h + 1); + return h; +} + +void ArenaImpl::DeallocateChunk(ChunkHandle h) { + Chunk* c = ChunkFromHandle(h); + + if (c->stream) { + if (auto it = stream_to_chunks_.find(c->stream); it != stream_to_chunks_.end()) { + size_t result = it->second.erase(h); + static_cast(result); // should always be found + + if (it->second.empty()) { + stream_to_chunks_.erase(it); + impl_to_stream_.erase(ep_api_.SyncStream_GetImpl(c->stream)); + } + } + + c->stream = nullptr; + c->stream_sync_id = 0; + } + + c->next = free_chunks_list_; + free_chunks_list_ = h; +} + +// static +size_t ArenaImpl::RoundedBytes(size_t bytes) { + return (kMinAllocationSize * ((bytes + kMinAllocationSize - 1) / kMinAllocationSize)); +} + +void* ArenaImpl::Alloc(size_t size) { + return AllocateRawInternal(size, nullptr, false); +} + +void* ArenaImpl::AllocOnStream(size_t size, OrtSyncStream* stream) { + return AllocateRawInternal(size, stream, false); +} + +void* ArenaImpl::Reserve(size_t size) { + if (size == 0) + return nullptr; + + std::lock_guard lock(lock_); + + LOG(INFO, "Reserving memory in ArenaImpl for " << allocator_name_ << " size: " << size); + + void* ptr = device_allocator_->Alloc(device_allocator_.get(), size); + EP_ENFORCE(reserved_chunks_.find(ptr) == reserved_chunks_.end(), __FUNCTION__); + reserved_chunks_.insert(std::pair(ptr, size)); + stats_.bytes_in_use += size; + stats_.num_reserves += 1; + stats_.num_allocs += 1; + stats_.max_alloc_size = std::max(static_cast(stats_.max_alloc_size), size); + stats_.max_bytes_in_use = std::max(static_cast(stats_.max_bytes_in_use), stats_.bytes_in_use); + stats_.total_allocated_bytes += size; + return ptr; +} + +size_t ArenaImpl::RequestedSize(const void* ptr) { + std::lock_guard lock(lock_); + ArenaImpl::ChunkHandle h = region_manager_.get_handle(ptr); + EP_ENFORCE(h != kInvalidChunkHandle, __FUNCTION__); + ArenaImpl::Chunk* c = ChunkFromHandle(h); + return c->requested_size; +} + +size_t ArenaImpl::AllocatedSize(const void* ptr) { + std::lock_guard lock(lock_); + ArenaImpl::ChunkHandle h = region_manager_.get_handle(ptr); + EP_ENFORCE(h != kInvalidChunkHandle, __FUNCTION__); + ArenaImpl::Chunk* c = ChunkFromHandle(h); + return c->size; +} + +void* ArenaImpl::AllocateRawInternal(size_t num_bytes, OrtSyncStream* stream, bool dump_log_on_failure) { + if (num_bytes == 0) { + return nullptr; + } + + // Round to multiple of kMinAllocationSize + size_t rounded_bytes = RoundedBytes(num_bytes); + + // The BFC allocator tries to find the best fit first. + BinNum bin_num = BinNumForSize(rounded_bytes); + + std::lock_guard lock(lock_); + + if (stream && stream_to_chunks_.find(stream) == stream_to_chunks_.end()) { + stream_to_chunks_.insert({stream, std::set{}}); + const OrtSyncStreamImpl* stream_impl = ep_api_.SyncStream_GetImpl(stream); + assert(stream_impl); + impl_to_stream_.insert({stream_impl, stream}); + } + + // search for a valid chunk + auto* chunk = FindChunkPtr(bin_num, rounded_bytes, num_bytes, stream); + + if (chunk != nullptr) { + return chunk->ptr; + } + + LOG(INFO, "Extending arena for " << allocator_name_ + << ". bin_num:" << bin_num << " (requested) num_bytes: " << num_bytes + << " (actual) rounded_bytes:" << rounded_bytes); + + // Try to extend + auto status = Extend(rounded_bytes); + if (status == nullptr) { + chunk = FindChunkPtr(bin_num, rounded_bytes, num_bytes, stream); + if (chunk != nullptr) { + return chunk->ptr; + } else { + status = api_.CreateStatus(ORT_EP_FAIL, + ("Failed to find a free memory block despite calling Extend. rounded_bytes=" + + std::to_string(rounded_bytes)) + .c_str()); + } + } + + // We searched all bins for an existing free chunk to use and couldn't find one. Dump the memory log for analysis. + if (dump_log_on_failure) { + LOG(ERROR, "BFC Arena ran out of memory trying to allocate " << num_bytes); + DumpMemoryLog(rounded_bytes); + } + + throw std::runtime_error(api_.GetErrorMessage(status)); +} + +OrtStatus* ArenaImpl::GetStats(OrtKeyValuePairs** stats) { + std::lock_guard lock(lock_); + + api_.CreateKeyValuePairs(stats); + stats_.ToKeyValuePairs(api_, *stats); + + return nullptr; +} + +ArenaImpl::Chunk* ArenaImpl::SplitFreeChunkFromBin(ArenaImpl::Bin::FreeChunkSet* free_chunks, + const ArenaImpl::Bin::FreeChunkSet::iterator& citer, + size_t rounded_bytes, + size_t num_bytes) { + const ArenaImpl::ChunkHandle h = (*citer); + RemoveFreeChunkIterFromBin(free_chunks, citer); + ArenaImpl::Chunk* chunk = ChunkFromHandle(h); + + // If we can break the size of the chunk into two reasonably large pieces, do so. + // In any case don't waste more than max_dead_bytes_per_chunk bytes on padding this alloc. + if (chunk->size >= rounded_bytes * 2 || + static_cast(chunk->size - rounded_bytes) >= config_.max_dead_bytes_per_chunk) { + SplitChunk(h, rounded_bytes); + chunk = ChunkFromHandle(h); // Update chunk pointer in case it moved + } + + // The requested size of the returned chunk is what the user has allocated. + chunk->requested_size = num_bytes; + // Assign a unique id and increment the id counter, marking the chunk as being in use. + chunk->allocation_id = next_allocation_id_++; + + ++stats_.num_allocs; + stats_.bytes_in_use += chunk->size; + stats_.max_bytes_in_use = std::max(stats_.max_bytes_in_use, stats_.bytes_in_use); + stats_.max_alloc_size = std::max(stats_.max_alloc_size, static_cast(chunk->size)); + + return chunk; +} + +ArenaImpl::Chunk* ArenaImpl::FindChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes, + OrtSyncStream* stream) { + // First identify the first bin that could satisfy rounded_bytes. + for (; bin_num < kNumBins; bin_num++) { + // Start searching from the first bin for the smallest chunk that fits rounded_bytes. + Bin* b = BinFromIndex(bin_num); + for (auto citer = b->free_chunks.begin(); citer != b->free_chunks.end(); ++citer) { + const ArenaImpl::ChunkHandle h = (*citer); + ArenaImpl::Chunk* chunk = ChunkFromHandle(h); + EP_ENFORCE(!chunk->in_use(), __FUNCTION__); + + if (chunk->size >= rounded_bytes) { + // We found an existing chunk that fits us that wasn't in use. + // If it's assigned to another stream, and we have synchronized with that stream more recently than it + // was assigned, we can take the chunk. + bool safe_to_use = chunk->stream == stream || + !chunk->stream || + (stream && chunk->stream && + chunk->stream_sync_id < ep_api_.GetSyncIdForLastWaitOnSyncStream(chunk->stream, stream)); + + if (safe_to_use) { + chunk = SplitFreeChunkFromBin(&b->free_chunks, citer, rounded_bytes, num_bytes); + + if (stream) { + chunk->stream = stream; + chunk->stream_sync_id = ep_api_.SyncStream_GetSyncId(stream); + stream_to_chunks_[stream].insert(h); + } + + return chunk; + } + } + } + } + + return nullptr; +} + +void ArenaImpl::SplitChunk(ArenaImpl::ChunkHandle h, size_t num_bytes) { + // Allocate the new chunk before we do any ChunkFromHandle + ChunkHandle h_new_chunk = AllocateChunk(); + + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(!c->in_use() && (c->bin_num == kInvalidBinNum), __FUNCTION__); + + // Create a new chunk starting num_bytes after c + ArenaImpl::Chunk* new_chunk = ChunkFromHandle(h_new_chunk); + new_chunk->stream = c->stream; + new_chunk->stream_sync_id = c->stream_sync_id; + + new_chunk->ptr = static_cast(static_cast(c->ptr) + num_bytes); + region_manager_.set_handle(new_chunk->ptr, h_new_chunk); + + // Set the new sizes of the chunks. + new_chunk->size = c->size - num_bytes; + c->size = num_bytes; + + // The new chunk is not in use. + new_chunk->allocation_id = -1; + + // Maintain the pointers. + // c <-> c_neighbor becomes + // c <-> new_chunk <-> c_neighbor + ArenaImpl::ChunkHandle h_neighbor = c->next; + new_chunk->prev = h; + new_chunk->next = h_neighbor; + c->next = h_new_chunk; + if (h_neighbor != kInvalidChunkHandle) { + Chunk* c_neighbor = ChunkFromHandle(h_neighbor); + c_neighbor->prev = h_new_chunk; + } + + // Add the newly free chunk to the free bin. + InsertFreeChunkIntoBin(h_new_chunk); +} + +void ArenaImpl::Free(void* p) { + if (p == nullptr) { + return; + } + + std::lock_guard lock(lock_); + auto it = reserved_chunks_.find(p); + if (it != reserved_chunks_.end()) { + device_allocator_->Free(device_allocator_.get(), it->first); + stats_.bytes_in_use -= it->second; + stats_.total_allocated_bytes -= it->second; + reserved_chunks_.erase(it); + } else { + DeallocateRawInternal(p); + } +} + +void ArenaImpl::DeallocateRawInternal(void* ptr) { + // Find the chunk from the ptr. + ArenaImpl::ChunkHandle h = region_manager_.get_handle(ptr); + EP_ENFORCE(h != kInvalidChunkHandle, __FUNCTION__); + + // Consider coalescing it. + FreeAndMaybeCoalesce(h); +} + +// Merges Chunk(h2) into Chunk(h1) when Chunk(h1)->next is h2 and Chunk(h2)->prev is h1. +void ArenaImpl::Merge(ArenaImpl::ChunkHandle h1, + ArenaImpl::ChunkHandle h2) { + Chunk* c1 = ChunkFromHandle(h1); + Chunk* c2 = ChunkFromHandle(h2); + // We can only merge chunks that are not in use. + EP_ENFORCE(!c1->in_use() && !c2->in_use() && c1->stream == c2->stream, __FUNCTION__); + + // c1's prev doesn't change, still points to the same ptr, and is + // still not in use. + + // Fix up neighbor pointers + // + // c1 <-> c2 <-> c3 should become + // c1 <-> c3 + + ArenaImpl::ChunkHandle h3 = c2->next; + c1->next = h3; + EP_ENFORCE(c2->prev == h1, __FUNCTION__); + if (h3 != kInvalidChunkHandle) { + ArenaImpl::Chunk* c3 = ChunkFromHandle(h3); + c3->prev = h1; + } + + // Set the new size + c1->size += c2->size; + + // we only merge chunks that have the same stream + assert(c1->stream == c2->stream); + c1->stream_sync_id = std::max(c1->stream_sync_id, c2->stream_sync_id); + + DeleteChunk(h2); +} + +void ArenaImpl::DeleteChunk(ChunkHandle h) { + // Delete h and cleanup all state + Chunk* c = ChunkFromHandle(h); + region_manager_.erase(c->ptr); + DeallocateChunk(h); +} + +void ArenaImpl::InsertFreeChunkIntoBin(ArenaImpl::ChunkHandle h) { + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(!c->in_use() && (c->bin_num == kInvalidBinNum), __FUNCTION__); + BinNum bin_num = BinNumForSize(c->size); + Bin* new_bin = BinFromIndex(bin_num); + c->bin_num = bin_num; + new_bin->free_chunks.insert(h); +} + +void ArenaImpl::RemoveFreeChunkIterFromBin(ArenaImpl::Bin::FreeChunkSet* free_chunks, + const ArenaImpl::Bin::FreeChunkSet::iterator& citer) { + ChunkHandle h = *citer; + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(!c->in_use() && (c->bin_num != kInvalidBinNum), __FUNCTION__); + free_chunks->erase(citer); + c->bin_num = kInvalidBinNum; +} + +void ArenaImpl::RemoveFreeChunkFromBin(ArenaImpl::ChunkHandle h) { + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(!c->in_use() && (c->bin_num != kInvalidBinNum), __FUNCTION__); + EP_ENFORCE(BinFromIndex(c->bin_num)->free_chunks.erase(h) > 0, "Could not find chunk in bin"); + c->bin_num = kInvalidBinNum; +} + +void ArenaImpl::FreeAndMaybeCoalesce(ArenaImpl::ChunkHandle h) { + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(c->in_use() && (c->bin_num == kInvalidBinNum), __FUNCTION__); + + // Mark the chunk as no longer in use + c->allocation_id = -1; + + // Updates the stats. + stats_.bytes_in_use -= c->size; + + // This chunk is no longer in-use, consider coalescing the chunk + // with adjacent chunks. + ChunkHandle chunk_to_reassign = Coalesce(h); + InsertFreeChunkIntoBin(chunk_to_reassign); +} + +ArenaImpl::ChunkHandle ArenaImpl::Coalesce(ChunkHandle h) { + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(!c->in_use(), __FUNCTION__); + + // This chunk is no longer in-use, consider coalescing the chunk with adjacent chunks. + ChunkHandle chunk_to_reassign = h; + + // If the next chunk is free, coalesce the two + if (c->next != kInvalidChunkHandle) { + Chunk* cnext = ChunkFromHandle(c->next); + // only merge the chunks belong to the same stream + if (!cnext->in_use() && cnext->stream == c->stream) { + chunk_to_reassign = h; + + // Deletes c->next + RemoveFreeChunkFromBin(c->next); + Merge(h, ChunkFromHandle(h)->next); + } + } + + // If the previous chunk is free, coalesce the two + c = ChunkFromHandle(h); + if (c->prev != kInvalidChunkHandle) { + Chunk* cprev = ChunkFromHandle(c->prev); + // only merge the chunks belong to the same stream + if (!cprev->in_use() && cprev->stream == c->stream) { + chunk_to_reassign = c->prev; + + RemoveFreeChunkFromBin(c->prev); // this deletes c + Merge(ChunkFromHandle(h)->prev, h); + } + } + + return chunk_to_reassign; +} + +std::array ArenaImpl::GetBinDebugInfo() { + std::array bin_infos; + + for (const auto& region : region_manager_.regions()) { + ChunkHandle h = region_manager_.get_handle(region.ptr()); + while (h != kInvalidChunkHandle) { + const Chunk* c = ChunkFromHandle(h); + BinNum bin_num = BinNumForSize(c->size); + BinDebugInfo& bin_info = bin_infos[bin_num]; + bin_info.total_bytes_in_bin += c->size; + bin_info.total_chunks_in_bin++; + + if (c->in_use()) { + bin_info.total_bytes_in_use += c->size; + bin_info.total_requested_bytes_in_use += c->requested_size; + bin_info.total_chunks_in_use++; + } else { + Bin* bin = BinFromIndex(bin_num); + EP_ENFORCE(bin->free_chunks.count(h) == 1 && c->bin_num == bin_num, __FUNCTION__); + } + + h = c->next; + } + } + return bin_infos; +} + +void ArenaImpl::DumpMemoryLog(size_t num_bytes) { + const std::array bin_infos = GetBinDebugInfo(); + LOG(INFO, "Allocator:" << allocator_name_); + LOG(INFO, "Bin size: Chunks in_use/total (if not zero). Allocated bytes in_use/total. Requested bytes."); + + size_t waste = 0; + for (BinNum bin_num = 0; bin_num < kNumBins; bin_num++) { + Bin* b = BinFromIndex(bin_num); + const BinDebugInfo& bin_info = bin_infos[bin_num]; + EP_ENFORCE(b->free_chunks.size() == bin_info.total_chunks_in_bin - bin_info.total_chunks_in_use, __FUNCTION__); + + if (bin_info.total_chunks_in_bin > 0) { + LOG(INFO, b->bin_size + << ": Chunks " << bin_info.total_chunks_in_use << "/" << bin_info.total_chunks_in_bin + << ". Bytes " + << bin_info.total_bytes_in_use << "/" << bin_info.total_bytes_in_bin << ". " + << "Requested " << bin_info.total_requested_bytes_in_use << "."); + + waste += bin_info.total_bytes_in_use - bin_info.total_requested_bytes_in_use; + } + } + + if (waste > 0) { + LOG(INFO, "Diff between in-use and requested bytes is " << waste); + } + + // Find the bin that we would have liked to allocate in, so we can get some further analysis about fragmentation. + Bin* b = BinForSize(num_bytes); + + LOG(INFO, "Bin for " << num_bytes + << " bytes has max bytes of " << b->bin_size + << ", Chunk State: "); + + for (ChunkHandle h : b->free_chunks) { + Chunk* c = ChunkFromHandle(h); + LOG(INFO, " " << c->DebugString(this, true)); + } + + // Next show the chunks that are in use, and also summarize their number by size. + LOG(INFO, "Overall chunks summary:"); + std::map in_use_by_size; + for (const auto& region : region_manager_.regions()) { + ChunkHandle h = region_manager_.get_handle(region.ptr()); + while (h != kInvalidChunkHandle) { + const Chunk* c = ChunkFromHandle(h); + if (c->in_use()) { + in_use_by_size[c->size]++; + } + LOG(INFO, (c->in_use() ? " Chunk" : " Free ") << " at " << c->ptr + << " of size " << c->size); + h = c->next; + } + } + + LOG(INFO, "Summary of in-use chunks by size: "); + size_t total_bytes = 0; + for (auto& it : in_use_by_size) { + LOG(INFO, " " << it.second << " chunks of size " << it.first + << ". Total " << it.first * it.second); + total_bytes += (it.first * it.second); + } + + LOG(INFO, "Sum Total of in-use chunks: " << total_bytes); + LOG(INFO, "Stats: \n" + << stats_.DebugString()); +} + +OrtStatus* ArenaImpl::ResetChunksUsingStream(const OrtSyncStreamImpl* stream_impl) { + std::lock_guard lock(lock_); + + auto impl_it = impl_to_stream_.find(stream_impl); + if (impl_it == impl_to_stream_.end()) { + return nullptr; // stream hasn't been used with this arena + } + + const OrtSyncStream* stream = impl_it->second; + + auto it = stream_to_chunks_.find(stream); + if (it != stream_to_chunks_.end()) { + const auto& chunk_handles = it->second; + for (size_t handle : chunk_handles) { + Chunk* c = ChunkFromHandle(handle); + assert(c->stream == stream); // something is out of sync if this is not the case + c->stream = nullptr; + } + + stream_to_chunks_.erase(it); + impl_to_stream_.erase(stream_impl); + } + + // It's also possible to find the chunks this way, but that requires iterating every single in-use allocation. + // We also repeat this for every single stream used in a session. + // OTOH there's a cost to create/update keep streams_to_chunks_. + // Using streams_to_chunks_ for now. It also simplifies debugging to have that info. If you're unsure about this + // choice feel free to perf test the two approaches. + // + // for (const auto& region : region_manager_.regions()) { + // ChunkHandle region_begin_chunk = region_manager_.get_handle(region.ptr()); + // ChunkHandle h = region_begin_chunk; + // while (h != kInvalidChunkHandle) { + // Chunk* c = ChunkFromHandle(h); + // if (c->stream == target_stream) { + // c->stream = nullptr; + // c->stream_sync_id = 0; + // } + // h = c->next; + // } + // } + + // coalesce + for (const auto& region : region_manager_.regions()) { + ChunkHandle region_begin_chunk = region_manager_.get_handle(region.ptr()); + ChunkHandle h = region_begin_chunk; + while (h != kInvalidChunkHandle) { + Chunk* c = ChunkFromHandle(h); + if (!c->in_use()) { + RemoveFreeChunkFromBin(h); + ChunkHandle h_next = c->next; + Chunk* c_next = h_next != kInvalidChunkHandle ? ChunkFromHandle(h_next) : nullptr; + + // merge until next chunk is different stream + while (c_next && !c_next->in_use() && c_next->stream == c->stream) { + Coalesce(h); + h_next = c->next; + c_next = h_next != kInvalidChunkHandle ? ChunkFromHandle(h_next) : nullptr; + } + + if (c->bin_num == kInvalidBinNum) { + InsertFreeChunkIntoBin(h); + } + } + h = c->next; + } + } + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/ep_arena.h b/onnxruntime/test/autoep/library/ep_arena.h new file mode 100644 index 0000000000000..641f3ce3f7b17 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_arena.h @@ -0,0 +1,629 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Portions Copyright (c) Microsoft Corporation + +#pragma once +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" +#include "ep_allocator.h" +#include "example_plugin_ep_utils.h" + +#if defined(PLATFORM_WINDOWS) +#include +#endif + +enum ArenaExtendStrategy { + kDefault = -1, + kNextPowerOfTwo = 0, + kSameAsRequested = 1, +}; + +// copied from onnxruntime::OrtArenaCfg so the values and config key names match +struct ArenaConfig { + static const ArenaExtendStrategy DEFAULT_ARENA_EXTEND_STRATEGY = ArenaExtendStrategy::kNextPowerOfTwo; + static const int DEFAULT_INITIAL_CHUNK_SIZE_BYTES = 1 * 1024 * 1024; + static const int DEFAULT_MAX_DEAD_BYTES_PER_CHUNK = 128 * 1024 * 1024; + static const int DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES = 2 * 1024 * 1024; + static const int64_t DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES = 1024 * 1024 * 1024; // 1GB + static const size_t DEFAULT_MAX_MEM = std::numeric_limits::max(); + + ArenaConfig(size_t max_mem = std::numeric_limits::max(), + ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY, + int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES, + int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK, + int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES, + int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES) + : max_mem(max_mem), + arena_extend_strategy(arena_extend_strategy), + initial_chunk_size_bytes(initial_chunk_size_bytes), + max_dead_bytes_per_chunk(max_dead_bytes_per_chunk), + initial_growth_chunk_size_bytes(initial_growth_chunk_size_bytes), + max_power_of_two_extend_bytes(max_power_of_two_extend_bytes) { + if (arena_extend_strategy == ArenaExtendStrategy::kDefault) { + arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo; + } + } + + size_t max_mem; + ArenaExtendStrategy arena_extend_strategy; + int initial_chunk_size_bytes; + int max_dead_bytes_per_chunk; + int initial_growth_chunk_size_bytes; + int64_t max_power_of_two_extend_bytes; + + bool IsValid() { + return initial_chunk_size_bytes > 0 && + max_dead_bytes_per_chunk > 0 && + initial_growth_chunk_size_bytes > 0 && + max_power_of_two_extend_bytes > 0; + } + + // config key names that we parse in FromKeyValuePairs + struct ConfigKeyNames { + static constexpr const char* ArenaExtendStrategy = "arena.extend_strategy"; + static constexpr const char* InitialChunkSizeBytes = "arena.initial_chunk_size_bytes"; + static constexpr const char* MaxDeadBytesPerChunk = "arena.max_dead_bytes_per_chunk"; + static constexpr const char* InitialGrowthChunkSizeBytes = "arena.initial_growth_chunk_size_bytes"; + static constexpr const char* MaxPowerOfTwoExtendBytes = "arena.max_power_of_two_extend_bytes"; + static constexpr const char* MaxMem = "arena.max_mem"; + }; + + static ArenaConfig FromKeyValuePairs(const OrtApi& api, const OrtKeyValuePairs& kvps) { + ArenaConfig config{}; + const char* value = nullptr; + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::ArenaExtendStrategy); value) { + config.arena_extend_strategy = std::string(value) == "1" ? kSameAsRequested : kNextPowerOfTwo; + } + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::InitialChunkSizeBytes); value) { + config.initial_chunk_size_bytes = std::stoi(std::string(value)); + } + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::MaxDeadBytesPerChunk); value) { + config.max_dead_bytes_per_chunk = std::stoi(std::string(value)); + } + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::InitialGrowthChunkSizeBytes); value) { + config.initial_growth_chunk_size_bytes = std::stoi(std::string(value)); + } + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::MaxPowerOfTwoExtendBytes); value) { + config.max_power_of_two_extend_bytes = std::stoll(value); + } + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::MaxMem); value) { + config.max_mem = static_cast(std::stoull(std::string(value))); + } + + return config; + } +}; + +// A memory allocator that implements a 'best-fit with coalescing' algorithm. +// This is essentially a very simple version of Doug Lea's malloc (dlmalloc). +// +// The goal of this allocator is to support defragmentation via coalescing. +// One assumption we make is that the process using this allocator owns pretty much all of the memory, and that nearly +// all requests to allocate memory go through this interface. +class ArenaImpl { + public: + static const ArenaExtendStrategy DEFAULT_ARENA_EXTEND_STRATEGY = ArenaExtendStrategy::kNextPowerOfTwo; + static const int DEFAULT_INITIAL_CHUNK_SIZE_BYTES = 1 * 1024 * 1024; + static const int DEFAULT_MAX_DEAD_BYTES_PER_CHUNK = 128 * 1024 * 1024; + static const int DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES = 2 * 1024 * 1024; + static const int64_t DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES = 1024 * 1024 * 1024; // 1GB + static const size_t DEFAULT_MAX_MEM = std::numeric_limits::max(); + + ArenaImpl(AllocatorUniquePtr allocator, const ArenaConfig& config, const OrtApi& api, + const OrtLogger& logger); + + ~ArenaImpl(); + + void* Alloc(size_t size); + void* AllocOnStream(size_t size, OrtSyncStream* stream); + void Free(void* p); + + // allocate memory directly. this is used for initializers so they don't affect the arena growth patterns + void* Reserve(size_t size); + + OrtStatus* GetStats(OrtKeyValuePairs** stats); + + size_t RequestedSize(const void* ptr); + size_t AllocatedSize(const void* ptr); + + // Un-assign chunks that are currently assigned to the stream. + // + // This should be called from OrtSyncStreamImpl::OnSessionRunEnd. + // A stream is used in one session at a time. When called from OnSessionRunEnd we know that the stream is done and + // will not be performing any more operations on the data. + // + // We don't have a better way to know when it's safe to re-use a chunk in another stream given the actual memory + // usage is asynchronous on the GPU side, and the code assigning memory is running on CPU prior to that. + OrtStatus* ResetChunksUsingStream(const OrtSyncStreamImpl* stream_impl); + + private: + void* AllocateRawInternal(size_t num_bytes, OrtSyncStream* stream, bool dump_log_on_failure); + void DeallocateRawInternal(void* ptr); + + // A ChunkHandle is an index into the chunks_ vector in BFCAllocator + // kInvalidChunkHandle means an invalid chunk + using ChunkHandle = size_t; + static const size_t kInvalidChunkHandle = static_cast(-1); + + using BinNum = int; + static const int kInvalidBinNum = -1; + static const int kNumBins = 21; + + // Chunks point to memory. Their prev/next pointers form a + // doubly-linked list of addresses sorted by base address that + // must be contiguous. Chunks contain information about whether + // they are in use or whether they are free, and contain a pointer + // to the bin they are in. + struct Chunk { + size_t size = 0; // Full size of buffer. + + // We sometimes give chunks that are larger than needed to reduce + // fragmentation. requested_size keeps track of what the client + // actually wanted so we can understand whether our splitting + // strategy is efficient. + size_t requested_size = 0; + + // allocation_id is set to -1 when the chunk is not in use. It is assigned a + // value greater than zero before the chunk is returned from + // AllocateRaw, and this value is unique among values assigned by + // the parent allocator. + int64_t allocation_id = -1; + void* ptr = nullptr; // pointer to granted subbuffer. + + // If not kInvalidChunkHandle, the memory referred to by 'prev' is directly + // preceding the memory used by this chunk. E.g., It should start + // at 'ptr - prev->size' + ChunkHandle prev = kInvalidChunkHandle; + + // If not kInvalidChunkHandle, the memory referred to by 'next' is directly + // following the memory used by this chunk. E.g., It should be at + // 'ptr + next->size' + ChunkHandle next = kInvalidChunkHandle; + + // What bin are we in? + BinNum bin_num = kInvalidBinNum; + + OrtSyncStream* stream = nullptr; + // Current sync id of `stream` when it was assigned the Chunk. + // If the chunk is assigned to a stream and is free, and another Stream wants to use it, that Stream must have + // synchronized with `stream` at a sync id > to stream_sync_id. + // stream_sync_id is set when the chunk is first assigned to `stream`. + // The sync id is incremented at the start of sync, so any chunk with a previous sync id is safe to re-assign. + uint64_t stream_sync_id = 0; + + bool in_use() const { return allocation_id != -1; } + + std::string DebugString(ArenaImpl* a, bool recurse) { + std::ostringstream ss; + ss << " Size: " << size << " | Requested Size: " << requested_size << " | in_use: " << in_use(); + if (recurse && prev != ArenaImpl::kInvalidChunkHandle) { + Chunk* p = a->ChunkFromHandle(prev); + ss << ", prev: " << p->DebugString(a, false); + } + + if (recurse && next != ArenaImpl::kInvalidChunkHandle) { + Chunk* n = a->ChunkFromHandle(next); + ss << ", next: " << n->DebugString(a, false); + } + return ss.str(); + } + }; + + // A Bin is a collection of similar-sized free chunks. + struct Bin { + // All chunks in this bin have >= bin_size memory. + size_t bin_size = 0; + + struct ChunkComparator { + explicit ChunkComparator(ArenaImpl* allocator) + : allocator_(allocator) {} + + // Sort first by size and then use pointer address as a tie breaker. + bool operator()(const ChunkHandle ha, + const ChunkHandle hb) const { + const Chunk* a = allocator_->ChunkFromHandle(ha); + const Chunk* b = allocator_->ChunkFromHandle(hb); + if (a->size != b->size) { + return a->size < b->size; + } + return a->ptr < b->ptr; + } + + private: + ArenaImpl* allocator_; // The parent allocator + }; + + typedef std::set FreeChunkSet; + // List of free chunks within the bin, sorted by chunk size. + // Chunk * not owned. + FreeChunkSet free_chunks; + Bin(ArenaImpl* allocator, size_t bs) + : bin_size(bs), free_chunks(ChunkComparator(allocator)) {} + }; + + static const size_t kMinAllocationBits = 8; + static const size_t kMinAllocationSize = 1 << kMinAllocationBits; + + // AllocationRegion maps pointers to ChunkHandles for a single + // contiguous memory region. + // + // This class is thread-compatible. + class AllocationRegion { + public: + AllocationRegion(void* ptr, size_t memory_size, int64_t id) + : ptr_(ptr), + memory_size_(memory_size), + end_ptr_(static_cast(static_cast(ptr_) + memory_size_)), + id_(id) { + EP_ENFORCE(0 == memory_size % kMinAllocationSize, __FUNCTION__); + + const size_t n_handles = (memory_size + kMinAllocationSize - 1) / kMinAllocationSize; + handles_ = std::make_unique(n_handles); + for (size_t i = 0; i < n_handles; i++) { + handles_[i] = kInvalidChunkHandle; + } + } + + AllocationRegion(AllocationRegion&& other) noexcept { Swap(other); } + AllocationRegion() = default; + ~AllocationRegion() = default; + + AllocationRegion& operator=(AllocationRegion&& other) noexcept { + Swap(other); + return *this; + } + + void* ptr() const { return ptr_; } + void* end_ptr() const { return end_ptr_; } + size_t memory_size() const { return memory_size_; } + int64_t id() const { return id_; } + + ChunkHandle get_handle(const void* p) const { + return handles_[IndexFor(p)]; + } + + void set_handle(const void* p, ChunkHandle h) { + handles_[IndexFor(p)] = h; + } + + void erase(const void* p) { + set_handle(p, kInvalidChunkHandle); + } + + private: + void Swap(AllocationRegion& other) { + std::swap(ptr_, other.ptr_); + std::swap(memory_size_, other.memory_size_); + std::swap(end_ptr_, other.end_ptr_); + std::swap(id_, other.id_); + std::swap(handles_, other.handles_); + } + + int IndexFor(const void* p) const { + std::uintptr_t p_int = reinterpret_cast(p); + std::uintptr_t base_int = reinterpret_cast(ptr_); + EP_ENFORCE(p_int >= base_int, "AllocationRegion::IndexFor"); + EP_ENFORCE(p_int < base_int + memory_size_, "AllocationRegion::IndexFor"); + return static_cast(((p_int - base_int) >> kMinAllocationBits)); + } + + // metadata about the allocation region. + void* ptr_ = nullptr; + size_t memory_size_ = 0; + void* end_ptr_ = nullptr; + // A unique identifier for this allocation region + // (May be used by the client to track which allocation region was allocated first, second, and so on) + int64_t id_ = -1; + + // Array of size "memory_size / kMinAllocationSize". It is + // indexed by (p-base) / kMinAllocationSize, contains ChunkHandle + // for the memory allocation represented by "p" + std::unique_ptr handles_; + + AllocationRegion& operator=(const AllocationRegion&) = delete; + }; + + // RegionManager aggregates one or more "AllocationRegions" and provides + // a layer of indirection from pointers to the underlying ChunkHandle, + // allowing allocation across multiple discontiguous memory regions. + // + // This class is thread-compatible. + class RegionManager { + public: + RegionManager() = default; + ~RegionManager() = default; + + void AddAllocationRegion(void* ptr, size_t memory_size, int64_t id) { + // Insert sorted by end_ptr + auto entry = std::upper_bound(regions_.begin(), regions_.end(), ptr, &Comparator); + regions_.insert(entry, AllocationRegion(ptr, memory_size, id)); + } + + void RemoveAllocationRegion(void* ptr) { + auto entry = std::upper_bound(regions_.begin(), regions_.end(), ptr, &Comparator); + EP_ENFORCE(entry != regions_.end(), "RegionManager::RemoveAllocationRegion Could not find Region for: " << ptr); + regions_.erase(entry); + } + + ChunkHandle get_handle(const void* p) const { + return RegionFor(p)->get_handle(p); + } + + void set_handle(const void* p, ChunkHandle h) { + return MutableRegionFor(p)->set_handle(p, h); + } + void erase(const void* p) { return MutableRegionFor(p)->erase(p); } + + const std::vector& regions() const { return regions_; } + + private: + RegionManager(const RegionManager&) = delete; + RegionManager& operator=(const RegionManager&) = delete; + RegionManager(RegionManager&&) = delete; + RegionManager& operator=(RegionManager&&) = delete; + + static bool Comparator(const void* ptr, const AllocationRegion& other) { + return ptr < other.end_ptr(); + } + + AllocationRegion* MutableRegionFor(const void* p) { + return const_cast(RegionFor(p)); + } + + const AllocationRegion* RegionFor(const void* p) const { + auto entry = std::upper_bound(regions_.begin(), regions_.end(), p, &Comparator); + + if (entry != regions_.end()) { + return &(*entry); + } + + EP_ENFORCE(entry != regions_.end(), "RegionManager::RegionFor Could not find Region for: " << p); + return nullptr; + } + + private: + std::vector regions_; + }; + + // Returns 'bytes' rounded up to the next highest kMinAllocationSize. + size_t RoundedBytes(size_t bytes); + + // Try to add a new memory region that can satisfy an allocation of + // 'rounded_bytes' bytes. + OrtStatus* Extend(size_t rounded_bytes); + + // Returns an underlying allocated chunk of size + // 'rounded_bytes'. + ArenaImpl::Chunk* FindChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes, OrtSyncStream* stream); + + // Splits the chunk specified by 'h' into two chunks, one at least + // of size 'num_bytes'. + void SplitChunk(ChunkHandle h, size_t num_bytes); + + // Merges the two chunk handles. Requires that the chunks are + // contiguous in their allocation. + void Merge(ChunkHandle h, ChunkHandle h2); + + // Frees the memory represented by 'h', coalescing the chunk if + // possible. + void FreeAndMaybeCoalesce(ChunkHandle h); + + ArenaImpl::ChunkHandle Coalesce(ChunkHandle h); + + // Adds the chunk 'h' to the proper free bin. + void InsertFreeChunkIntoBin(ChunkHandle h); + + // Removes the free chunk pointed to by 'c' from the set free_chunks. + void RemoveFreeChunkIterFromBin(Bin::FreeChunkSet* free_chunks, + const Bin::FreeChunkSet::iterator& c); + + // Removes a free chunk from the bin. + void RemoveFreeChunkFromBin(ChunkHandle h); + + ArenaImpl::Chunk* SplitFreeChunkFromBin(ArenaImpl::Bin::FreeChunkSet* free_chunks, + const ArenaImpl::Bin::FreeChunkSet::iterator& citer, + size_t rounded_bytes, + size_t num_bytes); + + // Removes the chunk metadata represented by 'h'. + void DeleteChunk(ChunkHandle h); + + void DumpMemoryLog(size_t num_bytes); + + ChunkHandle AllocateChunk(); + void DeallocateChunk(ChunkHandle h); + + Chunk* ChunkFromHandle(ChunkHandle h); + + // Information about a Bin that is useful for debugging. + struct BinDebugInfo { + size_t total_bytes_in_use = 0; + size_t total_bytes_in_bin = 0; + size_t total_requested_bytes_in_use = 0; + size_t total_chunks_in_use = 0; + size_t total_chunks_in_bin = 0; + }; + + // Computes and returns a BinDebugInfo for each Bin. + std::array GetBinDebugInfo(); + + int Log2FloorNonZeroSlow(uint64_t n) { + int r = 0; + while (n > 0) { + r++; + n >>= 1; + } + return r - 1; + } + + // Returns floor(log2(n)). + int Log2FloorNonZero(uint64_t n) { +#if defined(__GNUC__) + return 63 ^ __builtin_clzll(n); +#elif defined(PLATFORM_WINDOWS) + unsigned long index; +#if defined(_WIN64) + _BitScanReverse64(&index, n); +#else + auto high = static_cast(n >> 32); + if (_BitScanReverse(&index, high) > 0) { + index += 32; + } else { + auto low = static_cast((n << 32) >> 32); + _BitScanReverse(&index, low); + } +#endif + return index; +#else + return Log2FloorNonZeroSlow(n); +#endif + } + + // Map from bin size to Bin + Bin* BinFromIndex(BinNum index) { + return reinterpret_cast(&(bins_space_[index * sizeof(Bin)])); + } + + size_t BinNumToSize(BinNum index) { + return static_cast(256) << index; + } + + BinNum BinNumForSize(size_t bytes) { + uint64_t v = std::max(bytes, 256) >> kMinAllocationBits; + int b = std::min(kNumBins - 1, Log2FloorNonZero(v)); + return b; + } + + Bin* BinForSize(size_t bytes) { + return BinFromIndex(BinNumForSize(bytes)); + } + + alignas(Bin) char bins_space_[sizeof(Bin) * kNumBins]; + + mutable std::mutex lock_; + + AllocatorUniquePtr device_allocator_; + const std::string allocator_name_; + const ArenaConfig config_; + + RegionManager region_manager_; + size_t curr_region_allocation_bytes_; + + // Counter containing the next unique identifier to assign to a newly-created chunk. + int64_t next_allocation_id_; + + std::vector chunks_; + ChunkHandle free_chunks_list_; // Pointer to head of linked list of free Chunks + std::unordered_map reserved_chunks_; + + // chunks being used by a stream + std::unordered_map> stream_to_chunks_; + + // map to connect the OrtSyncStreamImpl the EP library creates to the OrtSyncStream that ORT uses. + // we don't know that it's safe to re-use a chunk until the stream is done with, which is via the call to + // OrtSyncStreamImpl::OnSessionRunEnd. the allocations see OrtSyncStream, so we need to connect things up to + // un-assign chunks when StreamImpl::OnSessionRunEnd is called. + std::unordered_map impl_to_stream_; + + AllocatorStats stats_; + + const OrtApi& api_; + const OrtEpApi& ep_api_; + const OrtLogger& logger_; + + ArenaImpl(const ArenaImpl&) = delete; + ArenaImpl& operator=(const ArenaImpl&) = delete; + ArenaImpl(ArenaImpl&&) = delete; + ArenaImpl& operator=(ArenaImpl&&) = delete; +}; + +struct ArenaAllocator : OrtAllocator { + static OrtStatus* CreateOrtArenaAllocator(AllocatorUniquePtr allocator, + const OrtKeyValuePairs* options, + const OrtApi& api, + const OrtLogger& logger, + std::unique_ptr& arena_allocator) { + ArenaConfig config = options ? ArenaConfig::FromKeyValuePairs(api, *options) : ArenaConfig{}; + const OrtMemoryInfo* mem_info = allocator->Info(allocator.get()); + auto impl = std::make_unique(std::move(allocator), config, api, logger); + + arena_allocator = std::make_unique(std::move(impl), *mem_info); + + return nullptr; + } + + ArenaAllocator(std::unique_ptr implementation, const OrtMemoryInfo& memory_info) + : impl_{std::move(implementation)}, + memory_info_{memory_info} { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Reserve = ReserveImpl; + Free = FreeImpl; + Info = InfoImpl; + GetStats = GetStatsImpl; + AllocOnStream = AllocOnStreamImpl; + } + + // remove the OrtSyncStream* from any chunks that were using the stream + OrtStatus* ResetChunksUsingStream(const OrtSyncStreamImpl* stream_impl) { + impl_->ResetChunksUsingStream(stream_impl); + return nullptr; + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + auto& arena = *static_cast(this_); + return arena.impl_->Alloc(size); + } + + static void* ORT_API_CALL AllocOnStreamImpl(struct OrtAllocator* this_, size_t size, OrtSyncStream* stream) { + auto& arena = *static_cast(this_); + return arena.impl_->AllocOnStream(size, stream); + } + + static void* ORT_API_CALL ReserveImpl(struct OrtAllocator* this_, size_t size) { + auto& arena = *static_cast(this_); + return arena.impl_->Reserve(size); + } + + static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { + auto& arena = *static_cast(this_); + arena.impl_->Free(p); + } + + static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + const auto& arena = *static_cast(this_); + return &arena.memory_info_; + } + + static OrtStatus* ORT_API_CALL GetStatsImpl(const struct OrtAllocator* this_, OrtKeyValuePairs** out) noexcept { + const auto& arena = *static_cast(this_); + return arena.impl_->GetStats(out); + }; + + private: + std::unique_ptr impl_; + const OrtMemoryInfo& memory_info_; +}; diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 1cffb72c84879..4da7d722a5e0b 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -7,6 +7,7 @@ #include "ep.h" #include "ep_allocator.h" +#include "ep_arena.h" #include "ep_data_transfer.h" #include "ep_stream_support.h" @@ -38,6 +39,8 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL /*vendor*/ 0xBE57, /* device_id */ 0, OrtDeviceMemoryType_DEFAULT, /*alignment*/ 0, + // it is invalid to use OrtArenaAllocator as that is reserved for the + // internal ORT Arena implementation OrtAllocatorType::OrtDeviceAllocator, &mem_info); assert(status == nullptr); // should never fail. @@ -208,7 +211,7 @@ void ORT_API_CALL ExampleEpFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, Or /*static*/ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateAllocatorImpl(OrtEpFactory* this_ptr, const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* /*allocator_options*/, + const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator) noexcept { auto& factory = *static_cast(this_ptr); *allocator = nullptr; @@ -226,14 +229,57 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateAllocatorImpl(OrtEpFactory* this // allocator on each call. To do this have an allocator instance as an OrtEpFactory class member and make // ReleaseAllocatorImpl a no-op. // - auto cpu_allocator = std::make_unique(memory_info, factory); - *allocator = cpu_allocator.release(); + // NOTE: EP should implement its own arena logic. ep_arena.cc/h is provided as a reference and we use it here for + // device memory. `allocator_options` can be used for arena configuration and there is a helper in ep_arena.h + // to convert from OrtKeyValuePairs to the same arena config settings that ORT uses. + // You are of course free to have completely different settings. + + // the read-only allocator is used for initializers. we don't need an arena for that. + if (is_readonly_allocator) { + auto read_only_allocator = std::make_unique(memory_info, factory); + *allocator = read_only_allocator.release(); + return nullptr; + } + + // create/use the shared arena based allocator + std::lock_guard lock{factory.mutex_}; + + if (!factory.arena_allocator_) { + std::unique_ptr ep_allocator = std::make_unique(memory_info, factory); + + // initial shared allocator in environment does not have allocator options. + // if the user calls CreateSharedAllocator they can provide options to configure the arena differently. + factory.arena_allocator_using_default_settings_ = allocator_options == nullptr; + RETURN_IF_ERROR(ArenaAllocator::CreateOrtArenaAllocator(std::move(ep_allocator), allocator_options, + factory.ort_api, + factory.default_logger_, factory.arena_allocator_)); + + } else { + if (factory.arena_allocator_using_default_settings_ && allocator_options) { + // potential change in arena settings. up to EP author to determine how to handle this. + // we should not get here if replacing the shared allocator in the environment, as we free the existing one + // before replacing it. i.e. ReleaseAllocatorImpl should have been called, and arena_allocator_ should be null. + } + } + + ++factory.num_arena_users_; + *allocator = factory.arena_allocator_.get(); + return nullptr; } /*static*/ -void ORT_API_CALL ExampleEpFactory::ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept { - delete static_cast(allocator); +void ORT_API_CALL ExampleEpFactory::ReleaseAllocatorImpl(OrtEpFactory* this_ptr, OrtAllocator* allocator) noexcept { + auto& factory = *static_cast(this_ptr); + std::lock_guard lock{factory.mutex_}; + + if (allocator == factory.arena_allocator_.get()) { + if (--factory.num_arena_users_ == 0) { + factory.arena_allocator_ = nullptr; + } + } else { + delete static_cast(allocator); + } } /*static*/ @@ -255,7 +301,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac const OrtMemoryDevice* memory_device, const OrtKeyValuePairs* stream_options, OrtSyncStreamImpl** stream) noexcept { - auto& factory = *static_cast(this_ptr); + auto& factory = *static_cast(this_ptr); *stream = nullptr; // we only need stream synchronization on the device stream diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 60c9f63b78b8c..088deda1fe9d2 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -3,6 +3,9 @@ #pragma once +#include + +#include "ep_arena.h" #include "ep_data_transfer.h" #include "example_plugin_ep_utils.h" @@ -17,6 +20,11 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { return data_transfer_impl_.get(); } + // Get the shared arena allocator if created. + ArenaAllocator* GetArenaAllocator() const { + return arena_allocator_.get(); + } + private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -70,5 +78,10 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { MemoryInfoUniquePtr default_memory_info_; MemoryInfoUniquePtr readonly_memory_info_; // used for initializers + bool arena_allocator_using_default_settings_{true}; + std::unique_ptr arena_allocator_; // shared device allocator that uses an arena + uint32_t num_arena_users_{0}; + std::mutex mutex_; // mutex to protect arena_allocator_ and num_arena_users_ + std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory }; diff --git a/onnxruntime/test/autoep/library/ep_stream_support.cc b/onnxruntime/test/autoep/library/ep_stream_support.cc index a948fe1bfce1e..1f6c16a8cb358 100644 --- a/onnxruntime/test/autoep/library/ep_stream_support.cc +++ b/onnxruntime/test/autoep/library/ep_stream_support.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "ep_stream_support.h" - +#include "ep_factory.h" // // StreamImpl implementation // @@ -27,7 +27,13 @@ OrtStatus* ORT_API_CALL StreamImpl::FlushImpl(_In_ OrtSyncStreamImpl* /*this_ptr } /*static*/ -OrtStatus* ORT_API_CALL StreamImpl::OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* /*this_ptr*/) noexcept { +OrtStatus* ORT_API_CALL StreamImpl::OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + auto* arena = impl.factory_->GetArenaAllocator(); + if (arena) { + arena->ResetChunksUsingStream(this_ptr); + } + return nullptr; } diff --git a/onnxruntime/test/autoep/library/ep_stream_support.h b/onnxruntime/test/autoep/library/ep_stream_support.h index 10c4804722f8b..a825e5afd2250 100644 --- a/onnxruntime/test/autoep/library/ep_stream_support.h +++ b/onnxruntime/test/autoep/library/ep_stream_support.h @@ -4,15 +4,18 @@ #pragma once #include "onnxruntime_c_api.h" +#include "ep_factory.h" #include "example_plugin_ep_utils.h" +class ExampleEpFactory; + // // Class implementing Stream support for synchronization. // class StreamImpl : public OrtSyncStreamImpl, public ApiPtrs { public: - StreamImpl(ApiPtrs apis, const OrtEp* ep, const OrtKeyValuePairs* /*stream_options*/) - : ApiPtrs(apis), ep_{ep} { + StreamImpl(ExampleEpFactory& factory, const OrtEp* ep, const OrtKeyValuePairs* /*stream_options*/) + : ApiPtrs(factory), ep_{ep}, factory_{&factory} { ort_version_supported = ORT_API_VERSION; CreateNotification = CreateNotificationImpl; GetHandle = GetHandleImpl; @@ -34,6 +37,7 @@ class StreamImpl : public OrtSyncStreamImpl, public ApiPtrs { // EP instance if the stream is being created internally for inferencing. // nullptr when the stream is created outside of an inference session for data copies. const OrtEp* ep_; + ExampleEpFactory* factory_{nullptr}; }; // diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h index e107a94410dba..99ebee9ff64de 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h @@ -25,12 +25,53 @@ } \ } while (0) +// see ORT_ENFORCE for implementations that also capture a stack trace and work in builds with exceptions disabled +// NOTE: In this simplistic implementation you must provide an argument, even it if's an empty string +#define EP_ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + std::ostringstream oss; \ + oss << "EP_ENFORCE failed: " << #condition << " "; \ + oss << __VA_ARGS__; \ + throw std::runtime_error(oss.str()); \ + } \ + } while (false) + +#ifdef _WIN32 +#define EP_WSTR(x) L##x +#define EP_FILE_INTERNAL(x) EP_WSTR(x) +#define EP_FILE EP_FILE_INTERNAL(__FILE__) +#else +#define EP_FILE __FILE__ +#endif + +#define LOG(level, ...) \ + do { \ + std::ostringstream ss; \ + ss << __VA_ARGS__; \ + api_.Logger_LogMessage(&logger_, ORT_LOGGING_LEVEL_##level, ss.str().c_str(), EP_FILE, __LINE__, __FUNCTION__); \ + } while (false) + +#define RETURN_ERROR(code, ...) \ + do { \ + std::ostringstream ss; \ + ss << __VA_ARGS__; \ + return api_.CreateStatus(code, ss.str().c_str()); \ + } while (false) + +#define THROW(...) \ + std::ostringstream ss; \ + ss << __VA_ARGS__; \ + throw std::runtime_error(ss.str()) + struct ApiPtrs { const OrtApi& ort_api; const OrtEpApi& ep_api; const OrtModelEditorApi& model_editor_api; }; +using AllocatorUniquePtr = std::unique_ptr>; + // Helper to release Ort one or more objects obtained from the public C API at the end of their scope. template struct DeferOrtRelease { diff --git a/onnxruntime/test/autoep/test_allocators.cc b/onnxruntime/test/autoep/test_allocators.cc index 84b6e284ccb8e..77d2bb24b7d35 100644 --- a/onnxruntime/test/autoep/test_allocators.cc +++ b/onnxruntime/test/autoep/test_allocators.cc @@ -30,8 +30,9 @@ struct DummyAllocator : OrtAllocator { Alloc = AllocImpl; Free = FreeImpl; Info = InfoImpl; - Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena - GetStats = nullptr; // this can be set to nullptr if not implemented + Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena + GetStats = nullptr; // this can be set to nullptr if not implemented + AllocOnStream = nullptr; // optional } static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { @@ -75,9 +76,11 @@ TEST(SharedAllocators, AddArenaToSharedAllocator) { auto initial_chunk_size = "25600"; // arena allocates in 256 byte amounts allocator_options.Add(OrtArenaCfg::ConfigKeyNames::InitialChunkSizeBytes, initial_chunk_size); - ASSERT_ORTSTATUS_OK(c_api.CreateSharedAllocator(*ort_env, example_ep.get(), - OrtDeviceMemoryType_DEFAULT, OrtArenaAllocator, &allocator_options, - &allocator)); + ASSERT_ORTSTATUS_OK(c_api.CreateSharedAllocator(*ort_env, example_ep.get(), OrtDeviceMemoryType_DEFAULT, + // allocator is internally added by EP. + // OrtArenaAllocator can only be used for the internal BFCArena + OrtDeviceAllocator, + &allocator_options, &allocator)); // first allocation should init the arena to the initial chunk size void* mem = allocator->Alloc(allocator, 16); diff --git a/onnxruntime/test/framework/bfc_arena_test.cc b/onnxruntime/test/framework/bfc_arena_test.cc index 670447f2804dc..9ded9d2bfeac0 100644 --- a/onnxruntime/test/framework/bfc_arena_test.cc +++ b/onnxruntime/test/framework/bfc_arena_test.cc @@ -339,81 +339,82 @@ struct StreamMock : public Stream { #ifdef ORT_ENABLE_STREAM TEST(StreamAwareArenaTest, TwoStreamAllocation) { - StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30, false); + StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30); CheckStats(&a, 0, 0, 0, 0); OrtDevice tmp; StreamMock stream1(tmp), stream2(tmp); - auto* stream1_chunk_a = a.AllocOnStream(4096, &stream1, nullptr); - auto* stream2_chunk_a = a.AllocOnStream(4096, &stream2, nullptr); - a.Free(stream1_chunk_a); - auto* stream2_chunk_b = a.AllocOnStream(4096, &stream2, nullptr); + auto* stream1_chunk_a = a.AllocOnStream(4096, &stream1); // 4K chunk on stream 1 + auto* stream2_chunk_a = a.AllocOnStream(4096, &stream2); // 4K chunk on stream 2 + a.Free(stream1_chunk_a); // free but assigned to stream1 + // stream2 can't reuse stream1's chunk + auto* stream2_chunk_b = a.AllocOnStream(4096, &stream2); // 4K chunk on stream 2 EXPECT_NE(stream2_chunk_b, stream1_chunk_a); - a.Free(stream2_chunk_a); - auto* stream1_chunk_c = a.AllocOnStream(4096, &stream1, nullptr); - // it should pick the first chunk + + a.Free(stream2_chunk_a); // free but assigned to stream2 + + // it should pick the first chunk. + auto* stream1_chunk_c = a.AllocOnStream(4096, &stream1); EXPECT_EQ(stream1_chunk_c, stream1_chunk_a); - auto* stream1_chunk_d = a.AllocOnStream(4096, &stream1, nullptr); - // it shouldn't pick stream2_chunk_a's buffer + // it shouldn't pick stream2_chunk_a due to stream mismatch + auto* stream1_chunk_d = a.AllocOnStream(4096, &stream1); EXPECT_NE(stream1_chunk_d, stream2_chunk_a); - a.Free(stream2_chunk_b); + + a.Free(stream2_chunk_b); // still assigned to stream 2. should coalesce with stream1_chunk_a to create 8K buffer + // test clean stream2 - a.ReleaseStreamBuffers(&stream2); - auto stream1_chunk_e = a.AllocOnStream(8192, &stream1, nullptr); - // now it should pick the stream2_chunk_a's buffer - EXPECT_EQ(stream1_chunk_e, stream2_chunk_a); + a.ReleaseStreamBuffers(&stream2); // all stream 2 buffers are now available + + // now it should pick stream2_chunk_a as it is no longer assigned to stream 2 + auto stream1_chunk_e = a.AllocOnStream(8192, &stream1); + EXPECT_EQ(stream1_chunk_e, stream2_chunk_a); // stream1_chunk_e and stream2_chunk_a are assigned to stream1 + a.Free(stream1_chunk_c); a.Free(stream1_chunk_d); - // add stream2 to stream 1 depenency + + // stream 2 wait on stream 1 auto stream1_notification_a = stream1.CreateNotification(1); - stream1_notification_a->ActivateAndUpdate(); - stream2.UpdateStreamClock(stream1_notification_a->GetStreamSyncTable()); - auto* stream2_chunk_c = a.AllocOnStream(4096, &stream2, nullptr); - // it should pick the first chunk - EXPECT_EQ(stream2_chunk_c, stream1_chunk_c); - auto* stream2_chunk_d = a.AllocOnStream(4096, &stream2, nullptr); - // it should pick the third slot - EXPECT_EQ(stream2_chunk_d, stream1_chunk_d); - // continue allocate on stream1 - auto* stream1_chunk_f = a.AllocOnStream(4096, &stream1, nullptr); + stream1_notification_a->ActivateAndUpdate(); // stream 1 sync id 0 -> 1 + stream2.UpdateWithAwaitedNotification(*stream1_notification_a); // stream 2 now has sync id info of stream1:1 + + // stream 2 can now take stream 1 buffers with sync id of 0 + auto* stream2_chunk_c = a.AllocOnStream(4096, &stream2); + EXPECT_EQ(stream2_chunk_c, stream1_chunk_c); // stream2 took a buffer from stream1 with sync id 0 + + // stream 2 can take the remaining free buffer from stream 1 with sync id of 0 + auto* stream2_chunk_d = a.AllocOnStream(4096, &stream2); + EXPECT_EQ(stream2_chunk_d, stream1_chunk_d); // stream2 took the other buffer from stream 1 + + // new buffer required + auto* stream1_chunk_f = a.AllocOnStream(4096, &stream1); // new buffer on stream 1. sync id = 1 a.Free(stream1_chunk_f); - auto* stream2_chunk_e = a.AllocOnStream(4096, &stream2, nullptr); + + // new buffer required + auto* stream2_chunk_e = a.AllocOnStream(4096, &stream2); // new buffer on stream 2 EXPECT_NE(stream2_chunk_e, stream1_chunk_f); + + // free 8K buffer on stream 1 a.Free(stream1_chunk_e); - // test clean stream1 - a.ReleaseStreamBuffers(&stream1); - auto* stream2_chunk_f = a.AllocOnStream(8192, &stream2, nullptr); - // now it should pick stream1_chunk_e + + // can use 8K stream1_chunk_e as it has sync id = 0 and stream 2 has sync id of 1 for stream 1 + auto* stream2_chunk_f = a.AllocOnStream(8192, &stream2); EXPECT_EQ(stream2_chunk_f, stream1_chunk_e); + // remove assignment to stream 1 for free buffers. stream1_chunk_f will become available to stream 2 + a.ReleaseStreamBuffers(&stream1); // stream1 buffers are new available + + auto* stream2_chunk_g = a.AllocOnStream(4096, &stream2); + EXPECT_EQ(stream2_chunk_g, stream1_chunk_f); + // cleanup a.Free(stream2_chunk_d); a.Free(stream2_chunk_e); a.Free(stream2_chunk_f); } - -TEST(StreamAwareArenaTest, TestSecureTheChunk) { - StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30, true); - OrtDevice tmp; - StreamMock stream1(tmp), stream2(tmp); - - void* p1 = a.AllocOnStream(BFCArena::DEFAULT_INITIAL_CHUNK_SIZE_BYTES, &stream1, nullptr); - a.Free(p1); - - bool waitFunctionInvoked = false; - void* p2 = a.AllocOnStream(BFCArena::DEFAULT_INITIAL_CHUNK_SIZE_BYTES, &stream2, - [&waitFunctionInvoked](Stream*, synchronize::Notification&) { waitFunctionInvoked = true; }); - - std::unordered_map syncTable; - stream2.CloneCurrentStreamSyncTable(syncTable); - EXPECT_EQ(syncTable.size(), 1u) << "stream2 has been updated with stream1's nofitication on the clock"; - EXPECT_TRUE(waitFunctionInvoked) << "wait function should be invoked"; - a.Free(p2); -} #endif TEST(BFCArenaTest, TestExtendStrategy) { diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index 9807fcca06ed4..0fe747cdd84e5 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -118,6 +118,9 @@ struct TestAllocator : public OrtAllocator { Reserve = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { throw std::runtime_error("This should not be used"); }; + + GetStats = nullptr; + AllocOnStream = nullptr; } // initializers that are used directly by the model. as there's no copy they must remain valid. From 7e59947dcf8f489f02e1d9dfb7ca0f3de0d677d0 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Fri, 25 Jul 2025 01:02:31 -0700 Subject: [PATCH 09/33] [WebGPU EP] allow concat operator to handle large number of inputs (#25390) ### Description Adjusts concat operator to batch inputs based on maxStorageBuffersPerShaderStage to allow unlimited number of inputs. ### Motivation and Context Fixes patchtst model for transformers.js {31C75CD1-7A7D-48E3-A090-FB153925D165} --- .../core/providers/webgpu/tensor/concat.cc | 114 ++++++++++-------- .../core/providers/webgpu/tensor/concat.h | 5 +- .../providers/cpu/tensor/concat_op_test.cc | 105 ++++++++++++++++ 3 files changed, 171 insertions(+), 53 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 5cfd6c78f8929..283a9e5fe8262 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -38,19 +38,19 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) WEBGPU_CONCAT_KERNEL(13) -void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { - os << "fn calculate_input_index(index: u32) -> u32 {\n" - << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" - << " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n" - << " return i;\n" +void AppendCalculateInputIndexFunction(std::ostream& os, size_t input_count) { + os << "fn calculate_input_index(global_idx: u32) -> u32 {\n" + << " for (var i = 1u; i < " << input_count << "; i = i + 1u) {\n" + << " if (global_idx < " << GetElementAt("uniforms.offsets", "i", input_count) << ") {\n" + << " return i - 1;\n" << " }\n" << " }\n" - << " return " << input_count << ";\n" + << " return " << input_count - 1 << ";\n" << "}\n"; } -void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output) { - os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; +void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output, size_t axis, size_t input_count) { + os << "fn assign_output_data(global_idx: u32, input_index: u32) {\n"; for (size_t i = 0; i < inputs.size(); ++i) { if (i == 0) { os << " if (input_index == 0u) {\n"; @@ -59,7 +59,12 @@ void AppendAssignOutputDataFunction(std::ostream& os, gsl::spanGetByIndices("indices")) << ";\n"; + std::string offset = GetElementAt("uniforms.offsets", "input_index", input_count); + std::string concat_axis_offset = GetElementAt("uniforms.sizes_in_concat_axis", std::to_string(i), input_count); + std::string output_indices_axis = "output_indices" + (inputs[i]->Rank() > 1 ? "[" + std::to_string(axis) + "]" : ""); + os << " var output_indices = " << inputs[i]->OffsetToIndices("global_idx - " + offset) << ";\n" + << " " << output_indices_axis << " += " << concat_axis_offset << ";\n" + << " " << output.SetByIndices("output_indices", inputs[i]->GetByOffset("global_idx - " + offset)) << "\n"; } os << " }\n" "}\n"; @@ -74,27 +79,21 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - // add implementation of fn calculate_input_index - AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); - // add implementation of fn assign_output_data - AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); - const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count); + AppendCalculateInputIndexFunction(shader.AdditionalImplementation(), input_count); + AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output, axis_, input_count); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" - << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" - << " let input_index = calculate_input_index(indices_axis);\n" - << " if (input_index != 0u) {\n" - << " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n" - << " }\n" - " assign_output_data(global_idx, input_index, indices);\n"; + << "let input_index = calculate_input_index(global_idx);\n" + << "assign_output_data(global_idx, input_index);\n"; + return Status::OK(); } Status Concat::ComputeInternal(ComputeContext& context) const { - int input_count = context.InputCount(); + uint32_t input_count = context.InputCount(); InlinedTensorsVector input_tensors; input_tensors.reserve(input_count); - for (int i = 0; i < input_count; ++i) { + for (uint32_t i = 0; i < input_count; ++i) { input_tensors.push_back(context.Input(i)); } @@ -104,42 +103,55 @@ Status Concat::ComputeInternal(ComputeContext& context) const { return Status::OK(); } - uint32_t output_size = onnxruntime::narrow(prepare.output_tensor->Shape().Size()); + uint32_t axis = static_cast(prepare.axis); + uint32_t max_inputs_per_concat = context.DeviceLimits().maxStorageBuffersPerShaderStage - 1; + + uint32_t input_index = 0; + uint32_t cumulative_size_in_concat_axis = 0; + + while (input_index < input_count) { + ConcatProgram program{axis}; + uint32_t num_inputs_this_concat = std::min(max_inputs_per_concat, input_count - input_index); + + std::vector offsets; + offsets.reserve(num_inputs_this_concat + 1); + offsets.push_back(0); - size_t axis = static_cast(prepare.axis); - ConcatProgram program{axis}; + std::vector sizes_in_concat_axis; + sizes_in_concat_axis.reserve(num_inputs_this_concat + 1); + sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); - std::vector sizes_in_concat_axis; - sizes_in_concat_axis.reserve(input_count); - uint32_t sum = 0; - for (int i = 0; i < input_count; ++i) { - const auto& input = prepare.inputs[i]; - if (input.tensor->Shape().Size() == 0) { - continue; + uint32_t output_size = 0; + for (uint32_t i = 0; i < num_inputs_this_concat; i++) { + auto& input = prepare.inputs[input_index + i]; + if (input.tensor->Shape().Size() == 0) { + continue; + } + program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); + + uint32_t size = onnxruntime::narrow(input.tensor->Shape().Size()); + uint32_t axis_size = static_cast(input.tensor->Shape()[axis]); + + output_size += size; + offsets.push_back(output_size); + cumulative_size_in_concat_axis += axis_size; + sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); } - program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); - auto axis_size = input.tensor->Shape()[axis]; - sum += static_cast(axis_size); - sizes_in_concat_axis.push_back(sum); - } + offsets.pop_back(); + sizes_in_concat_axis.pop_back(); - size_t non_empty_input_count = sizes_in_concat_axis.size(); + program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ",")) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::span(offsets.data(), offsets.size()), gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), output_size}); + ORT_RETURN_IF_ERROR(context.RunProgram(program)); - if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { - // TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes. - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=", - input_count, ", output=1) exceeds the limit (", - context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device."); + input_index += num_inputs_this_concat; } - program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ",")) - .AddOutputs({prepare.output_tensor}) - .SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), - output_size}); - return context.RunProgram(program); + return Status::OK(); } } // namespace webgpu -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h index 0f6e6dd327e33..7980556e0a1f4 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.h +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -17,7 +17,8 @@ class ConcatProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32}, + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"offsets", ProgramUniformVariableDataType::Uint32}, + {"sizes_in_concat_axis", ProgramUniformVariableDataType::Uint32}, {"output_size", ProgramUniformVariableDataType::Uint32}); private: @@ -33,4 +34,4 @@ class Concat final : public WebGpuKernel, public ConcatBase { }; } // namespace webgpu -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 9e0fb81cbb0fc..b5e13c6377ccb 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -434,5 +434,110 @@ TEST(ConcatOpTest, Concat4D_2) { test.Run(); } +#ifdef USE_WEBGPU +TEST(ConcatOpTest, Concat1D_int32_4inputs) { + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1}, {1}); + test.AddInput("input2", {2}, {2, 3}); + test.AddInput("input3", {4}, {4, 5, 6, 7}); + test.AddInput("input4", {2}, {8, 9}); + test.AddOutput("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + test.Run(); +} + +TEST(ConcatOpTest, Concat1D_exceed_maxStorageBuffersPerShaderStage) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1}, {1}); + test.AddInput("input2", {1}, {2}); + test.AddInput("input3", {1}, {3}); + test.AddInput("input4", {1}, {4}); + test.AddInput("input5", {1}, {5}); + test.AddInput("input6", {1}, {6}); + test.AddInput("input7", {1}, {7}); + test.AddInput("input8", {1}, {8}); + test.AddInput("input9", {1}, {9}); + test.AddOutput("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + test.Run(); +} + +TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis0) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1, 2}, {1, 2}); + test.AddInput("input2", {1, 2}, {3, 4}); + test.AddInput("input3", {1, 2}, {5, 6}); + test.AddInput("input4", {1, 2}, {7, 8}); + test.AddInput("input5", {1, 2}, {9, 10}); + test.AddInput("input6", {1, 2}, {11, 12}); + test.AddInput("input7", {1, 2}, {13, 14}); + test.AddInput("input8", {1, 2}, {15, 16}); + test.AddInput("input9", {1, 2}, {17, 18}); + test.AddOutput("concat_result", {9, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + test.Run(); +} + +TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis1) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {1, 2}, {1, 2}); + test.AddInput("input2", {1, 2}, {3, 4}); + test.AddInput("input3", {1, 2}, {5, 6}); + test.AddInput("input4", {1, 2}, {7, 8}); + test.AddInput("input5", {1, 2}, {9, 10}); + test.AddInput("input6", {1, 2}, {11, 12}); + test.AddInput("input7", {1, 2}, {13, 14}); + test.AddInput("input8", {1, 2}, {15, 16}); + test.AddInput("input9", {1, 2}, {17, 18}); + test.AddOutput("concat_result", {1, 18}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + test.Run(); +} + +TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {2, 1, 1}, {1, 2}); + test.AddInput("input2", {2, 1, 1}, {3, 4}); + test.AddInput("input3", {2, 1, 1}, {5, 6}); + test.AddInput("input4", {2, 1, 1}, {7, 8}); + test.AddInput("input5", {2, 1, 1}, {9, 10}); + test.AddInput("input6", {2, 1, 1}, {11, 12}); + test.AddInput("input7", {2, 1, 1}, {13, 14}); + test.AddInput("input8", {2, 1, 1}, {15, 16}); + test.AddInput("input9", {2, 1, 1}, {17, 18}); + test.AddOutput("concat_result", {2, 9, 1}, {// batch 0 + 1, 3, 5, 7, 9, 11, 13, 15, 17, + // batch 1 + 2, 4, 6, 8, 10, 12, 14, 16, 18}); + test.Run(); +} + +TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage_mixed_sizes) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {2, 1, 1}, {1, 2}); + test.AddInput("input2", {2, 3, 1}, {3, 4, 5, 6, 7, 8}); + test.AddInput("input3", {2, 2, 1}, {9, 10, 11, 12}); + test.AddInput("input4", {2, 1, 1}, {13, 14}); + test.AddOutput("concat_result", {2, 7, 1}, {// batch 0 + 1, 3, 4, 5, 9, 10, 13, + // batch 1 + 2, 6, 7, 8, 11, 12, 14}); + test.Run(); +} +#endif // USE_WEBGPU + } // namespace test } // namespace onnxruntime From c3499d783635a990c487ba46240bbe0f936893dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 25 Jul 2025 14:58:52 +0200 Subject: [PATCH 10/33] Attention Operator (CPU) (#25156) ### Description Implementation Attention(23) for CPU. The backend tests from onnx were wrong for Attention (see https://github.com/onnx/onnx/pull/7142). The onnx version needs to be updated to make all tests pass. The implementation matches the reference implementation after onnx was fixed. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Ti-Tai Wang Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> --- docs/OperatorKernels.md | 1 + .../providers/cpu/cpu_execution_provider.cc | 4 + .../core/providers/cpu/llm/attention.cc | 736 ++++++++++ .../core/providers/cpu/llm/attention.h | 88 ++ .../providers/cpu/llm/attention_helper.cc | 156 +++ .../core/providers/cpu/llm/attention_helper.h | 73 + onnxruntime/core/providers/cpu/math/gemm.cc | 29 +- onnxruntime/core/providers/cpu/math/gemm.h | 9 + onnxruntime/core/util/math_cpu.cc | 69 + onnxruntime/core/util/math_cpuonly.h | 6 + .../test/contrib_ops/attention_op_test.cc | 88 +- onnxruntime/test/onnx/TestCase.cc | 24 + onnxruntime/test/providers/base_tester.cc | 3 + onnxruntime/test/providers/base_tester.h | 4 + .../providers/cpu/llm/attention_op_test.cc | 1206 +++++++++++++++++ .../onnx_backend_test_series_filters.jsonc | 1 + 16 files changed, 2444 insertions(+), 53 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/llm/attention.cc create mode 100644 onnxruntime/core/providers/cpu/llm/attention.h create mode 100644 onnxruntime/core/providers/cpu/llm/attention_helper.cc create mode 100644 onnxruntime/core/providers/cpu/llm/attention_helper.h create mode 100644 onnxruntime/test/providers/cpu/llm/attention_op_test.cc diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8659c96b540c8..3b70e5da8b3e4 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -43,6 +43,7 @@ Do not modify directly.* |||[7, 21]|**T** = tensor(float)| |Atanh|*in* input:**T**
*out* output:**T**|22+|**T** = tensor(float)| |||[9, 21]|**T** = tensor(float)| +|Attention|*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**|23+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)
**U** = tensor(bool), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[19, 21]|**T** = tensor(float)| |||[11, 18]|**T** = tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 6a89fc6234f0f..5eac0523d953a 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1290,6 +1290,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, #endif // Opset 23 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, float, Attention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Cast); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, ConstantOfShape); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, int32_t, DequantizeLinear); @@ -3254,6 +3256,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 23 + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", BuildKernelDefConstraints()), \ + Attention); + +REGISTER_ONNX_KERNEL_TYPED(float) +REGISTER_ONNX_KERNEL_TYPED(MLFloat16) + +template +void make_copy(T* mask_data, const U* mask_index, size_t size); + +template <> +void make_copy(float* mask_data, const float* mask_index, size_t size) { + memcpy(mask_data, mask_index, size * sizeof(float)); +} + +template <> +void make_copy(MLFloat16* mask_data, const MLFloat16* mask_index, size_t size) { + memcpy(mask_data, mask_index, size * sizeof(MLFloat16)); +} + +template <> +void make_copy(float* mask_data, const bool* mask_index, size_t size) { + for (size_t i = 0; i < size; ++i) { + mask_data[i] = mask_index[i] ? 0.0f : std::numeric_limits::lowest(); + } +} + +template <> +void make_copy(MLFloat16* mask_data, const bool* mask_index, size_t size) { + for (size_t i = 0; i < size; ++i) { + mask_data[i] = mask_index[i] ? MLFloat16(0.f) : std::numeric_limits::lowest(); + } +} + +template +inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp, AllocatorPtr) { + MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp); +} + +template <> +inline void ComputeAttentionSoftmaxInplace(MLFloat16* score, int N, int D, ThreadPool* tp, AllocatorPtr allocator) { + ORT_ENFORCE(tp == nullptr, "No parallelized version of softmax for float16."); + // Mlas Lacks kernels for fp16 softmax, we convert into float32 and call the float32 version. + void* allocated_ptr = allocator->Alloc(static_cast(N * D * sizeof(float))); + BufferUniquePtr float_buffer(allocated_ptr, BufferDeleter(allocator)); + float* ptr = reinterpret_cast(allocated_ptr); + MlasConvertHalfToFloatBuffer(score, ptr, N * D); + MlasComputeSoftmax(ptr, ptr, N, D, false, false, 0.0f, tp); + MlasConvertFloatToHalfBuffer(ptr, score, N * D); +} + +template +inline void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, T softcap) { + MlasComputeSoftcap(scores, scores, sequence_length, softcap); +} + +template <> +inline void ComputeAttentionSoftcapInplace(MLFloat16* scores, int sequence_length, MLFloat16 softcap) { + // Mlas Lacks kernels for fp16 softcap. The code is similar to the softcap implementation in mlas. + float x; + float cap = softcap.ToFloat(); + for (size_t i = 0; i < static_cast(sequence_length); i++) { + x = std::tanh(scores[i].ToFloat() / cap) * cap; + scores[i] = MLFloat16(x); + } +} + +template +Attention::Attention(const OpKernelInfo& info) : AttentionBase(info) { + is_causal_ = static_cast(info.GetAttrOrDefault("is_causal", 0)) == 1; + // kv_num_heads, q_num_head are mandatory for 3D inputs but not used for 4D inputs. + // The dimension is not yet known. If not specified, the inputs is assumed to be 4D. + kv_num_heads_ = static_cast(info.GetAttrOrDefault("kv_num_heads", 0)); + q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0)); + int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0)); + qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists() + ? static_cast(mode) + : QKMatMulOutputMode::kNone; + ORT_ENFORCE(qk_matmul_output_mode_ == QKMatMulOutputMode::kNone || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQK || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQKMask || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftCap || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftMax, + "qk_matmul_output_mode must be 0, 1, 2, or 3."); + // The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed. + scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN()); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + softmax_precision_ = static_cast(info.GetAttrOrDefault("softmax_precision", 0)); + ORT_ENFORCE(scale_ > 0 || std::isnan(scale_), "scale must be greater than 0 if specified"); +} + +template +Status Attention::Compute(OpKernelContext* context) const { + const Tensor* Q = context->Input(0); + const Tensor* K = context->Input(1); + const Tensor* V = context->Input(2); + const Tensor* attn_mask = context->Input(3); + const Tensor* past_key = context->Input(4); + const Tensor* past_value = context->Input(5); + + AttentionParameters parameters; + std::vector y_shape; + std::vector present_key_shape; + std::vector present_value_shape; + std::vector output_qk_shape; + + ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention( + Q, + K, + V, + attn_mask, + past_key, + past_value, + is_causal_, + softcap_, + softmax_precision_, + qk_matmul_output_mode_, + kv_num_heads_, + q_num_heads_, + scale_, + parameters, + y_shape, + present_key_shape, + present_value_shape, + output_qk_shape) + .IsOK(), + "Output shapes for Attention could not be computed."); + + Tensor* Y = context->Output(0, y_shape); + Tensor* present_key = context->Output(1, present_key_shape); + Tensor* present_value = context->Output(2, present_value_shape); + Tensor* output_qk = parameters.qk_matmul_output_mode == QKMatMulOutputMode::kNone + ? nullptr + : context->Output(3, output_qk_shape); + return this->ApplyAttention(context, + Q->Data(), // Q + K->Data(), // K + V->Data(), // V + attn_mask, // const Tensor* mask_index, // mask, nullptr if no mask + past_key, // past K input tensor (if not using past state) + past_value, // past V input tensor (if not using past state) + Y, // first output + present_key, // present K output tensor (if separating present KV) + present_value, // present V output tensor (if separating present KV) + output_qk, // Q*K output tensor (if returning Q*K value) + parameters // attention parameters + ); +} + +template +void AttentionBase::ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const Tensor* mask_index, // mask + const AttentionParameters& parameters, // attention parameters + const T* past_key, // past key only (if not using past state) + T* present_key, // present key only (if not using present state) + T* output_qk, // Q*K output + ThreadPool* tp, + AllocatorPtr allocator) const { + // The case past_key != nullptr and present_key == nullptr is not supported. + // We use the fact present_key is requested to avoid any extra allocation. + // However, if present_key is not requested, we should avoid allocated more memory than needed but that mean + // allocating one buffer per thread. That's why the implementation is not done. + // The user should define a model with a present_key even if not used if past_key is not null. + ORT_ENFORCE((past_key == nullptr) == (present_key == nullptr), + "The implementation only supports past_key and present_key both null or both not null."); + const size_t past_chunk_length = static_cast(parameters.past_sequence_length) * parameters.head_size; // P x H + const size_t q_input_chunk_length = static_cast(parameters.q_sequence_length) * parameters.head_size; // S x H + const size_t k_input_chunk_length = static_cast(parameters.kv_sequence_length) * parameters.head_size; // L x H + const size_t present_chunk_length = past_chunk_length + k_input_chunk_length; // T x H + + TensorOpCost unit_cost; + const ptrdiff_t probs_matrix_size = SafeInt(parameters.q_sequence_length) * + parameters.total_sequence_length; + const ptrdiff_t probs_matrix_bytes = probs_matrix_size * sizeof(T); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * parameters.head_size * probs_matrix_size); + unit_cost.bytes_loaded = static_cast((parameters.q_sequence_length + + parameters.total_sequence_length) * + parameters.head_size * sizeof(T)); + unit_cost.bytes_stored = static_cast(probs_matrix_bytes); + + if (present_key) { + double bytes_to_copy_key = present_chunk_length * static_cast(sizeof(T)); + unit_cost.bytes_loaded += bytes_to_copy_key; + unit_cost.bytes_stored += bytes_to_copy_key; + } + + // Prepare mask + // Merge causal mask with padding mask, and convert values from 0/1 to -inf/0. + int mask_batch_size = static_cast(mask_index == nullptr || mask_index->Shape().NumDimensions() < 4 + ? 1 + : mask_index->Shape().GetDims()[0]); + int mask_num_heads = static_cast(mask_index == nullptr || mask_index->Shape().NumDimensions() < 3 + ? 1 + : (mask_index->Shape().NumDimensions() < 4 + ? mask_index->Shape().GetDims()[0] + : mask_index->Shape().GetDims()[1])); + + T* mask_data = nullptr; + bool delete_mask_data = false; + bool causal = parameters.is_causal && parameters.q_sequence_length > 1; + if (mask_index == nullptr) { + // No mask = null mask. + if (causal) { + size_t mask_data_bytes = SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T); + void* allocated_ptr = allocator->Alloc(mask_data_bytes); + memset(allocated_ptr, 0, mask_data_bytes); + mask_data = static_cast(allocated_ptr); + for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { + for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { + mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits::lowest(); + } + } + delete_mask_data = true; + } + } else if (mask_index->IsDataType() || causal) { + // We need a copy. + size_t mask_data_bytes = SafeInt(mask_index->Shape().Size()) * sizeof(T); + mask_data = static_cast(allocator->Alloc(mask_data_bytes)); + delete_mask_data = true; + + if (mask_index->IsDataType()) { + // Convert bool mask to 0/1 + make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size())); + } else if (mask_index != nullptr) { + // We make a copy because causal is True. + make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size())); + } + if (causal) { + // This loop could be parallelized. + // According to the specifications, this configuration is not supported + // as is_causal=1 or mask is not None (exclusive or). + int n_iter = mask_batch_size * mask_num_heads; + for (int i = 0; i < n_iter; ++i) { + for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { + for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { + mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits::lowest(); + } + } + } + } + } else { + // Nothing to do, no necessary copy. + mask_data = const_cast(mask_index->Data()); + } + + bool transposed_k = parameters.transpose_output && nullptr == present_key; + if (nullptr != present_key && parameters.kv_num_heads != parameters.q_num_heads) { + // This is not part of the main loop because it is not needed at every iteration and + // we cannot ensure the inner body is executed first before getting used in another iteration. + // parameters.batch_size * parameters.q_num_heads + for (std::ptrdiff_t batch_i = 0; batch_i < parameters.batch_size; ++batch_i) { + for (std::ptrdiff_t head_i = 0; head_i < parameters.kv_num_heads; ++head_i) { + ConcatStateChunk(past_key, K, present_key, + past_chunk_length, k_input_chunk_length, present_chunk_length, + parameters.kv_num_heads, parameters.head_size, batch_i, head_i, + parameters.transpose_output); + } + } + } + + // If present_key is not null, it is already initialized to zero. + // Main loop + // With 3D inputs, both Q and K are transposed with permutations (0, 2, 1, 3). + // To avoid expressing the transposition, we use GemmEx with different values for lda, ldb. + // If past_key is not null, then we need to concatenate it with K, the concatenation is not transposed. + const int loop_len = parameters.batch_size * parameters.q_num_heads; + const float alpha = parameters.scale; + + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const ptrdiff_t output_offset = SafeInt(i) * probs_matrix_size; + std::ptrdiff_t batch_i = i / parameters.q_num_heads; + std::ptrdiff_t head_i = i % parameters.q_num_heads; + const ptrdiff_t mask_data_offset = probs_matrix_size * + (head_i % mask_num_heads + (batch_i % mask_batch_size) * mask_num_heads); + + T* output = attention_probs + output_offset; + T* out_qk = output_qk == nullptr ? nullptr : output_qk + output_offset; + float beta; + + if (mask_data != nullptr && + (out_qk == nullptr || parameters.qk_matmul_output_mode != attention_helper::QKMatMulOutputMode::kQK)) { + // Broadcast mask data: SxT -> SxT + memcpy(output, mask_data + mask_data_offset, probs_matrix_bytes); + beta = 1; + } else { + beta = 0; + } + + // handling GQA + std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads; + const T* k = K + k_input_chunk_length * ki; + + if (nullptr != present_key) { + if (parameters.kv_num_heads != parameters.q_num_heads) { + // Already done in a loop before this one. + k = present_key + ki * present_chunk_length; + } else { + k = ConcatStateChunk(past_key, K, present_key, + past_chunk_length, k_input_chunk_length, present_chunk_length, + parameters.kv_num_heads, parameters.head_size, batch_i, head_i, + parameters.transpose_output); + } + } + + // Compute Q*K' + AttentionMask + // original transposed each iteration + // A: Q (B x N x) S x H (B x N x) S x H S x H + // B: K' (B x N x) T x H (B x N x) H x T H x T + // C: attention_probs (B x N x) S x T (B x N x) S x T S x T + if constexpr (std::is_same::value) { + if (parameters.transpose_output) { + math::GemmEx(CblasNoTrans, + CblasTrans, + parameters.q_sequence_length, // M + parameters.total_sequence_length, // N + parameters.head_size, // K + alpha, + Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size, + parameters.head_size * parameters.q_num_heads, // lda + transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size : k, + transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb + beta, + output, + parameters.total_sequence_length, // ldc + nullptr); + } else { + math::Gemm(CblasNoTrans, + CblasTrans, + parameters.q_sequence_length, // M + parameters.total_sequence_length, // N + parameters.head_size, // K + alpha, + Q + q_input_chunk_length * i, + k, + beta, + output, + nullptr); + } + } else if constexpr (std::is_same::value) { + if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) { + MlasGemm(CblasNoTrans, + CblasTrans, + parameters.q_sequence_length, // M + parameters.total_sequence_length, // N + parameters.head_size, // K + parameters.transpose_output + ? Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size + : Q + q_input_chunk_length * i, + parameters.transpose_output + ? parameters.head_size * parameters.q_num_heads + : static_cast(parameters.head_size), // lda + transposed_k + ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size + : k, + transposed_k + ? parameters.head_size * parameters.kv_num_heads + : static_cast(parameters.head_size), // ldb + output, + static_cast(parameters.past_sequence_length + parameters.kv_sequence_length), // ldc + MLFloat16(alpha).val, MLFloat16(beta).val, nullptr); + } else { + if (parameters.transpose_output) { + math::GemmEx(CblasNoTrans, + CblasTrans, + parameters.q_sequence_length, // M + parameters.total_sequence_length, // N + parameters.head_size, // K + MLFloat16(alpha), + Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size, + parameters.head_size * parameters.q_num_heads, // lda + transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size : k, + transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb + MLFloat16(beta), + output, + parameters.total_sequence_length, // ldc + nullptr); + } else { + TensorShape c_shape({parameters.q_sequence_length, parameters.total_sequence_length}); + Gemm_MLFloat16(CblasNoTrans, CblasTrans, + static_cast(parameters.q_sequence_length), // M + static_cast(parameters.total_sequence_length), // N + static_cast(parameters.head_size), // K + MLFloat16(alpha), + Q + q_input_chunk_length * i, + k, + MLFloat16(beta), + output, + &c_shape, + output, + nullptr); + } + } + } else { + ORT_THROW("Unsupported data type for attention Q*K multiplication: ", DataTypeImpl::ToString(DataTypeImpl::GetType())); + } + if (out_qk != nullptr && + (parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKMask || + parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQK)) { + memcpy(out_qk, output, SafeInt(probs_matrix_size) * sizeof(T)); + if (mask_data != nullptr && parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQK) { + // We need to add the bias we could not add because out_qk was requested without the mask. + // This can be optimized with vectorized add using MlasAddFloat32x4. + MlasEltwiseAdd(output, mask_data + mask_data_offset, output, probs_matrix_size); + } + } + if (parameters.softcap > 0.0f) { + if constexpr (std::is_same::value) { + ComputeAttentionSoftcapInplace(output, static_cast(probs_matrix_size), parameters.softcap); + } else if constexpr (std::is_same::value) { + ComputeAttentionSoftcapInplace(output, static_cast(probs_matrix_size), MLFloat16(parameters.softcap)); + } else { + ORT_THROW("Unsupported data type for ComputeAttentionSoftcapInplace: ", + DataTypeImpl::ToString(DataTypeImpl::GetType())); + } + } + if (out_qk != nullptr && parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKSoftCap) { + memcpy(out_qk, output, SafeInt(probs_matrix_size) * sizeof(T)); + } + ComputeAttentionSoftmaxInplace(output, parameters.q_sequence_length, parameters.total_sequence_length, nullptr, allocator); + + if (output_qk != nullptr && parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKSoftMax) { + memcpy(output_qk + output_offset, output, + SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T)); + } + } + }); + if (delete_mask_data) { + allocator->Free(mask_data); + } +} + +template +T* AttentionBase::ConcatStateChunk(const T* past, + const T* base_chunk, // chunk is K or V, it can be transposed or not + T* present, + size_t past_chunk_length, + size_t input_chunk_length, // chunk length of K or V + size_t present_chunk_length, + size_t num_heads, + size_t head_size, + std::ptrdiff_t batch_i, + std::ptrdiff_t head_i, + bool transposed) const { + std::ptrdiff_t i = batch_i * num_heads + head_i % num_heads; + + T* start = present + i * present_chunk_length; + + T* p = start; + if (nullptr != past) { + const T* src_past = past + i * past_chunk_length; + memcpy(p, src_past, past_chunk_length * sizeof(T)); + p += past_chunk_length; + } + + if (transposed) { + ORT_ENFORCE(head_size > 0 && num_heads > 0 && batch_i >= 0 && head_i >= 0, + "Invalid parameters for ConcatStateChunk: head_size=", head_size, ", batch_i=", batch_i, ", head_i=", head_i); + size_t sequence_length = SafeInt(input_chunk_length / head_size); + const T* chunk = base_chunk + head_i * head_size + input_chunk_length * num_heads * batch_i; + for (size_t j = 0; j < sequence_length; ++j) { + memcpy(p, chunk, head_size * sizeof(T)); + p += head_size; + chunk += num_heads * head_size; + } + } else { + const T* chunk = base_chunk + input_chunk_length * i; + memcpy(p, chunk, (present_chunk_length - past_chunk_length) * sizeof(T)); + } + return start; +} + +template +void AttentionBase::ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH_v + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxNxLxH_v + int batch_size, // batch size + int sequence_length, // sequence length + int kv_sequence_length, // sequence length of K or V + int past_sequence_length, // sequence length in past state + int total_sequence_length, // total sequence length = past_sequence_length + kv_sequence_length + int v_head_size, // head size of V (H_v) + int num_heads, // number of attention heads + int kv_num_heads, // number of KV heads + const T* past_value, // past value only (if not using past state) + T* present_value, // present value only (if not using present state) + bool transpose_output, // whether to transpose the output (0, 2, 1, 3) + ThreadPool* tp) const { + ORT_ENFORCE((past_value == nullptr) == (present_value == nullptr), + "The implementation only supports past_value and present_value both null or both not null."); + const ptrdiff_t past_chunk_length = SafeInt(past_sequence_length) * v_head_size; // P x H_v + const ptrdiff_t v_input_chunk_length = SafeInt(kv_sequence_length) * v_head_size; // L x H_v + const ptrdiff_t present_chunk_length = past_chunk_length + v_input_chunk_length; // T x H_v + + // The cost of Gemm + TensorOpCost unit_cost; + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * v_head_size * total_sequence_length); + unit_cost.bytes_loaded = + static_cast(SafeInt(sequence_length + v_head_size) * total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(sequence_length * v_head_size * sizeof(T)); + + const size_t bytes_to_copy_trans = SafeInt(v_head_size) * sizeof(T); + double bytes_to_copy_trans_all = static_cast(sequence_length * bytes_to_copy_trans); + unit_cost.bytes_loaded += bytes_to_copy_trans_all; + unit_cost.bytes_stored += bytes_to_copy_trans_all; + + bool transposed_v = transpose_output && nullptr == present_value; + if (nullptr != present_value && kv_num_heads != num_heads) { + // This is not part of the main loop because it is not needed at every iteration and + // we cannot ensure the inner body is executed first before getting used in another iteration. + // parameters.batch_size * parameters.q_num_heads + for (std::ptrdiff_t batch_i = 0; batch_i < batch_size; ++batch_i) { + for (std::ptrdiff_t head_i = 0; head_i < kv_num_heads; ++head_i) { + ConcatStateChunk(past_value, V, present_value, + past_chunk_length, v_input_chunk_length, present_chunk_length, + kv_num_heads, v_head_size, batch_i, head_i, + transpose_output); + } + } + } + + ThreadPool::TryParallelFor( + tp, SafeInt(batch_size) * num_heads, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + // handling GQA + std::ptrdiff_t batch_i = i / num_heads; + std::ptrdiff_t head_i = i % num_heads; + std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads; + const T* v = V + v_input_chunk_length * vi; + + if (nullptr != present_value) { + if (kv_num_heads != num_heads) { + // Already done in a loop before this one. + v = present_value + vi * present_chunk_length; + } else { + // transposed_v is false here. + v = ConcatStateChunk(past_value, V, present_value, + past_chunk_length, v_input_chunk_length, present_chunk_length, + kv_num_heads, v_head_size, batch_i, head_i, + transpose_output); + } + } + + if (transpose_output) { + // transpose_output is false + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; + + if constexpr (std::is_same::value) { + // V is transposed but not QK. We use GemmEx with a different value for ldb. + math::GemmEx(CblasNoTrans, + CblasNoTrans, + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + 1.f, // alpha + attention_probs + attention_probs_offset, // QK + total_sequence_length, // lda + transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V + transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb + 0.f, // beta + output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), + v_head_size * num_heads, // ldc + nullptr); + } else if constexpr (std::is_same::value) { + // This switch should probably be moved to math_cpu.h. + if (MlasHGemmSupported(CblasNoTrans, CblasNoTrans)) { + MlasGemm(CblasNoTrans, + CblasNoTrans, + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + attention_probs + attention_probs_offset, + total_sequence_length, // lda + transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, + transposed_v ? static_cast(v_head_size * kv_num_heads) : static_cast(v_head_size), // ldb + output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), + v_head_size * num_heads, // ldc + MLFloat16(1.f).val, MLFloat16(0.f).val, nullptr); + } else { + math::GemmEx(CblasNoTrans, + CblasNoTrans, + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + MLFloat16(1.f), // alpha + attention_probs + attention_probs_offset, // QK + total_sequence_length, // lda + transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V + transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb + MLFloat16(0.f), // beta + output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), + v_head_size * num_heads, // ldc + nullptr); + } + } else { + ORT_THROW("Unsupported data type for attention QK*V multiplication: ", + DataTypeImpl::ToString(DataTypeImpl::GetType())); + } + } else { + // transpose_output is false + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; + ptrdiff_t dest_offset = SafeInt(sequence_length) * v_head_size * i; + T* dest = output + dest_offset; + + if constexpr (std::is_same::value) { + math::MatMul(sequence_length, v_head_size, total_sequence_length, + attention_probs + attention_probs_offset, v, dest, nullptr); + } else if constexpr (std::is_same::value) { + if (MlasHGemmSupported(CblasNoTrans, CblasNoTrans)) { + MlasGemm(CblasNoTrans, + CblasNoTrans, + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + attention_probs + attention_probs_offset, + total_sequence_length, // lda + v, + static_cast(v_head_size), // ldb + dest, + static_cast(v_head_size), // ldc + MLFloat16(1.f).val, MLFloat16(0.f).val, nullptr); + } else { + Gemm_MLFloat16(CblasNoTrans, + CblasNoTrans, + static_cast(sequence_length), // M + static_cast(v_head_size), // N + static_cast(total_sequence_length), // K + MLFloat16(1.f), // alpha + attention_probs + attention_probs_offset, + v, + MLFloat16(0.f), // beta + nullptr, + nullptr, + dest, + nullptr); + } + } else { + ORT_THROW("Unsupported data type for attention QK*V multiplication: ", + DataTypeImpl::ToString(DataTypeImpl::GetType())); + } + } + } + }); +} + +template +Status AttentionBase::ApplyAttention(OpKernelContext* context, + const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxLxH + const T* V, // V value with size BxNxLxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past_key, // past K input tensor (if not using past state) + const Tensor* past_value, // past V input tensor (if not using past state) + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor (if separating present KV) + Tensor* present_value, // present V output tensor (if separating present KV) + Tensor* output_qk, // Q*K output tensor (if returning Q*K value) + const AttentionParameters& parameters // attention parameters +) const { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + auto* tp = context->GetOperatorThreadPool(); + + const T* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + T* present_key_data = present_key != nullptr ? present_key->MutableData() : nullptr; + const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; + T* output_qk_data = output_qk != nullptr ? output_qk->MutableData() : nullptr; + + // Compute the attention score. + size_t bytes = SafeInt(parameters.batch_size) * parameters.q_num_heads * + parameters.q_sequence_length * parameters.total_sequence_length * sizeof(T); + auto attention_probs = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); + this->ComputeAttentionProbs(static_cast(attention_probs), + Q, + K, + mask_index, + parameters, + past_key_data, + present_key_data, + output_qk_data, + tp, + allocator); + + this->ComputeVxAttentionScore(output->MutableData(), + static_cast(attention_probs), + V, + parameters.batch_size, + parameters.q_sequence_length, + parameters.kv_sequence_length, + parameters.past_sequence_length, + parameters.total_sequence_length, + parameters.v_head_size, + parameters.q_num_heads, + parameters.kv_num_heads, + past_value_data, + present_value_data, + parameters.transpose_output, + tp); + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/attention.h b/onnxruntime/core/providers/cpu/llm/attention.h new file mode 100644 index 0000000000000..78889e48afb29 --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/attention.h @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/platform/threadpool.h" +#include "core/providers/cpu/llm/attention_helper.h" + +namespace onnxruntime { + +template +class AttentionBase : public OpKernel { + public: + AttentionBase(const OpKernelInfo& info) : OpKernel(info) {} + + Status ApplyAttention(OpKernelContext* context, + const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxLxH + const T* V, // V value with size BxNxLxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past_key, // past K input tensor (if not using past state) + const Tensor* past_value, // past V input tensor (if not using past state) + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor (if separating present KV) + Tensor* present_value, // present V output tensor (if separating present KV) + Tensor* output_qk, // Q*K output tensor (if returning Q*K value) + const attention_helper::AttentionParameters& parameters // attention parameters + ) const; + + protected: + void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH_v + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxNxLxH_v + int batch_size, // batch size + int sequence_length, // sequence length + int kv_sequence_length, // sequence length of K or V + int past_sequence_length, // sequence length in past state + int total_sequence_length, // total sequence length = past_sequence_length + kv_sequence_length + int v_head_size, // head size of V (H_v) + int num_heads, // number of attention heads + int kv_num_heads, // number of KV heads + const T* past_value, // past value only (if not using past state) + T* present_value, // present value only (if not using present state) + bool transpose_output, // whether to transpose the output from BxNxSxH to BxSxNxH + concurrency::ThreadPool* tp) const; + + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const Tensor* mask_index, // mask_index + const attention_helper::AttentionParameters& parameters, // attention parameters + const T* past_key, // past key only (if not using past state) + T* present_key, // present key only (if not using present state) + T* output_qk, // Q*K output + concurrency::ThreadPool* tp, + AllocatorPtr allocator) const; + + T* ConcatStateChunk(const T* past, + const T* chunk, + T* present, + size_t past_chunk_length, + size_t input_chunk_length, + size_t present_chunk_length, + size_t num_heads, + size_t head_size, + std::ptrdiff_t batch_i, + std::ptrdiff_t head_i, + bool transposed) const; +}; + +template +class Attention final : public AttentionBase { + public: + Attention(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + protected: + bool is_causal_; + int kv_num_heads_; + int q_num_heads_; + attention_helper::QKMatMulOutputMode qk_matmul_output_mode_; + float scale_; + float softcap_; + int softmax_precision_; +}; + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.cc b/onnxruntime/core/providers/cpu/llm/attention_helper.cc new file mode 100644 index 0000000000000..9bd954f128454 --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/attention_helper.cc @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/llm/attention_helper.h" +#include "core/util/shape_checker.h" + +namespace onnxruntime { +namespace attention_helper { + +void AttentionParameters::checkParameters() const { + ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0"); + ORT_ENFORCE(q_sequence_length > 0, "Q sequence length must be greater than 0"); + ORT_ENFORCE(kv_sequence_length > 0, "KV sequence length must be greater than 0"); + ORT_ENFORCE(head_size > 0, "Head size must be greater than 0"); + ORT_ENFORCE(v_head_size > 0, "V head size must be greater than 0"); + ORT_ENFORCE(past_sequence_length >= 0, "Past sequence length must be non-negative"); + ORT_ENFORCE(total_sequence_length > 0, "Total sequence length must be greater than 0"); + ORT_ENFORCE(kv_num_heads > 0, "KV number of heads must be greater than 0"); + ORT_ENFORCE(q_num_heads > 0, "Q number of heads must be greater than 0"); + ORT_ENFORCE(total_sequence_length == past_sequence_length + kv_sequence_length, + "Total sequence length must be equal to past sequence length plus KV sequence length"); +} + +Status ComputeOutputShapeForAttention( + const Tensor* Q, + const Tensor* K, + const Tensor* V, + const Tensor* attn_mask, + const Tensor* past_key, + const Tensor* past_value, + bool is_causal, + float softcap, + int softmax_precision, + attention_helper::QKMatMulOutputMode qk_matmul_output_mode, + int kv_num_heads, + int q_num_heads, + float scale, + AttentionParameters& parameters, + std::vector& y_shape, + std::vector& present_key_shape, + std::vector& present_value_shape, + std::vector& output_qk_shape) { + ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr, + "Q, K, and V inputs must not be null"); + int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions()); + int k_dims = onnxruntime::narrow(K->Shape().NumDimensions()); + int v_dims = onnxruntime::narrow(V->Shape().NumDimensions()); + ORT_ENFORCE(q_dims == 3 || q_dims == 4, "Q must be a 3D or 4D tensor"); + ORT_ENFORCE(q_dims == k_dims, "Q and K must have the same rank."); + ORT_ENFORCE(q_dims == v_dims, "Q and V must have the same rank."); + + ORT_ENFORCE((past_key == nullptr) == (past_value == nullptr), "past_key and past_value must be both null or both not null"); + ORT_ENFORCE(Q->Shape()[0] == K->Shape()[0], "inconsistent batch_size (between Q and K)"); + ORT_ENFORCE(Q->Shape()[0] == V->Shape()[0], "inconsistent batch_size (between Q and V)"); + ORT_ENFORCE(past_key == nullptr || Q->Shape()[0] == past_key->Shape()[0], "inconsistent batch_size (between Q and past_key)"); + ORT_ENFORCE(past_value == nullptr || Q->Shape()[0] == past_value->Shape()[0], "inconsistent batch_size (between Q and past_value)"); + ORT_ENFORCE(past_value == nullptr || past_value->Shape()[2] == past_key->Shape()[2], "inconsistent past_sequence_length (between past_key and past_value)"); + + parameters.is_causal = is_causal; + parameters.softcap = softcap; + parameters.softmax_precision = softmax_precision; + parameters.qk_matmul_output_mode = qk_matmul_output_mode; // output mode for Q*K matmul + parameters.batch_size = onnxruntime::narrow(Q->Shape()[0]); // Q.shape[0], K.shape[0], V.shape[0] (4D) + + ORT_ENFORCE(parameters.batch_size > 0, "Batch size must be greater than 0"); + ORT_ENFORCE(attn_mask == nullptr || (attn_mask->Shape().NumDimensions() >= 2 && attn_mask->Shape().NumDimensions() <= 4), "attn_mask must be 2D or 3D or 4D tensor"); + + if (q_dims == 4) { + // 4D + parameters.kv_num_heads = kv_num_heads > 0 ? kv_num_heads : onnxruntime::narrow(K->Shape()[1]); // K.shape[1] or V.shape[1] (4D) + parameters.q_num_heads = q_num_heads > 0 ? q_num_heads : onnxruntime::narrow(Q->Shape()[1]); // Q.shape[1] (4D) + + ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(K->Shape()[1]), "kv_num_heads different from K.shape[1]"); + ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(V->Shape()[1]), "kv_num_heads different from V.shape[1]"); + ORT_ENFORCE(parameters.q_num_heads == onnxruntime::narrow(Q->Shape()[1]), "q_num_heads different from Q.shape[1]"); + ORT_ENFORCE(Q->Shape()[3] == K->Shape()[3], "inconsistent head_size"); + ORT_ENFORCE(K->Shape()[2] == V->Shape()[2], "inconsistent kv_sequence_length"); + ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[2], "inconsistent q_sequence_length (between attn_mask and Q)"); + + // From shapes + parameters.transpose_output = false; // whether to transpose the input/output with permutation (0, 2, 1, 3) + parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[2]); // Q.shape[2] (4D) + parameters.head_size = onnxruntime::narrow(Q->Shape()[3]); // Q.shape[3] (4D) + parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[2]); // K.shape[2] or V.shape[2] (4D) + parameters.v_head_size = onnxruntime::narrow(V->Shape()[3]); // V.shape[3] (4D) + parameters.past_sequence_length = past_key == nullptr // past_key.shape[2] or past_value.shape[2] (4D) or given by the mask + ? 0 + : onnxruntime::narrow(past_key->Shape()[2]); + + y_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_num_heads), + static_cast(parameters.q_sequence_length), + static_cast(parameters.v_head_size)}; + } else { + // 3D + parameters.kv_num_heads = kv_num_heads; + parameters.q_num_heads = q_num_heads; + + // From shapes + ORT_ENFORCE(Q->Shape()[2] % parameters.q_num_heads == 0, "inconsistent q_hidden_size, it should be a multiple of q_num_heads"); + ORT_ENFORCE(V->Shape()[2] % parameters.kv_num_heads == 0, "inconsistent v_hidden_size, it should be a multiple of kv_num_heads"); + + parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3) + parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]); + parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads; + parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]); + parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads; + parameters.past_sequence_length = past_key == nullptr + ? 0 + : onnxruntime::narrow(past_key->Shape()[2]); + + y_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_sequence_length), + static_cast(parameters.q_num_heads * parameters.v_head_size)}; + } + parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length; + + ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads % kv_num_heads == 0 is not verified"); + ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length, + "inconsistent total_sequence_length (between attn_mask and past_key and past_value)"); + ORT_ENFORCE(attn_mask == nullptr || + attn_mask->Shape().NumDimensions() < 3 || + attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == 1 || + attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == parameters.kv_num_heads, + "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with kv_num_heads"); + ORT_ENFORCE(attn_mask == nullptr || + attn_mask->Shape().NumDimensions() < 4 || + attn_mask->Shape()[0] == 1 || + attn_mask->Shape()[0] == parameters.batch_size, + "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with batch_size"); + ASSERT_TENSOR_DIMS(past_key, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.head_size); + ASSERT_TENSOR_DIMS(past_value, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.v_head_size); + + parameters.scale = std::isnan(scale) ? static_cast(1.0 / sqrt(parameters.head_size)) : scale; + parameters.checkParameters(); + + present_key_shape = {static_cast(parameters.batch_size), + static_cast(parameters.kv_num_heads), + static_cast(parameters.total_sequence_length), + static_cast(parameters.head_size)}; + present_value_shape = {static_cast(parameters.batch_size), + static_cast(parameters.kv_num_heads), + static_cast(parameters.total_sequence_length), + static_cast(parameters.v_head_size)}; + if (qk_matmul_output_mode == QKMatMulOutputMode::kNone) { + output_qk_shape.clear(); + } else { + output_qk_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_num_heads), + static_cast(parameters.q_sequence_length), + static_cast(parameters.total_sequence_length)}; + } + return Status::OK(); +} +} // namespace attention_helper +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.h b/onnxruntime/core/providers/cpu/llm/attention_helper.h new file mode 100644 index 0000000000000..1cea27760408f --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/attention_helper.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace attention_helper { + +// enum equivalent to the onnx defintion of qk_matmul_output_mode +enum QKMatMulOutputMode { + kNone = -1, // No output Q*K + kQK = 0, // Output Q*K + kQKMask = 1, // Output Q*K + Mask + kQKSoftCap = 2, // Output SoftCap(Q*K + Mask) + kQKSoftMax = 3, // Output SoftMax(SoftCap(Q*K + Mask)) +}; + +// Parameters deduced from node attributes and inputs/outputs. +struct AttentionParameters { + /* + * Attention Parameters + * MHA: q_num_heads == kv_num_heads -> MHA + * GQA: q_num_heads > kv_num_heads && q_num_heads % kv_num_heads == 0 + * MQA: q_num_heads > kv_num_heads && kv_num_heads == 1 + */ + bool is_causal; + int kv_num_heads; // K.shape[1] or V.shape[1] (4D) + int q_num_heads; // Q.shape[1] (4D) + float scale; + float softcap; + int softmax_precision; + QKMatMulOutputMode qk_matmul_output_mode; + + // From shapes + int batch_size; // Q.shape[0], K.shape[0], V.shape[0] (4D) + int q_sequence_length; // Q.shape[2] (4D) + int head_size; // Q.shape[3] or K.shape[3 (4D) + int kv_sequence_length; // K.shape[2] or V.shape[2] (4D) + int v_head_size; // V.shape[4] (4D) + int past_sequence_length; // pask_key.shape[2] or past_value.shape[2] (4D) + int total_sequence_length; // past_sequence_length + kv_sequence_length + bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH + // This covers the case where the inputs are 3D. + + // Checks the consistency of the parameters. + void checkParameters() const; +}; + +// Computes the output shape for attention based on the input tensors and parameters. +Status ComputeOutputShapeForAttention( + const Tensor* Q, + const Tensor* K, + const Tensor* V, + const Tensor* attn_mask, + const Tensor* past_key, + const Tensor* past_value, + bool is_causal, + float softcap, + int softmax_precision, + attention_helper::QKMatMulOutputMode qk_matmul_output_mode, + int kv_num_heads, + int q_num_heads, + float scale, + AttentionParameters& parameters, + std::vector& y_shape, + std::vector& present_key_shape, + std::vector& present_value_shape, + std::vector& output_qk_shape); + +} // namespace attention_helper +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 5406dd1a40446..65b169355c793 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -174,15 +174,14 @@ void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, thread_pool); } -template <> -void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, - ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, - MLFloat16 alpha, - const MLFloat16* a_data, const MLFloat16* b_data, - MLFloat16 beta, - const MLFloat16* c_data, const TensorShape* c_shape, - MLFloat16* y_data, - concurrency::ThreadPool* thread_pool) { +void Gemm_MLFloat16(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, + ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, + MLFloat16 alpha, + const MLFloat16* a_data, const MLFloat16* b_data, + MLFloat16 beta, + const MLFloat16* c_data, const TensorShape* c_shape, + MLFloat16* y_data, + concurrency::ThreadPool* thread_pool) { // if input is empty tensor, return directly as nothing need to be calculated. if (M == 0 || N == 0) return; @@ -237,6 +236,18 @@ void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans #endif } +template <> +void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, + ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, + MLFloat16 alpha, + const MLFloat16* a_data, const MLFloat16* b_data, + MLFloat16 beta, + const MLFloat16* c_data, const TensorShape* c_shape, + MLFloat16* y_data, + concurrency::ThreadPool* thread_pool) { + Gemm_MLFloat16(trans_a, trans_b, M, N, K, alpha, a_data, b_data, beta, c_data, c_shape, y_data, thread_pool); +} + template void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, float alpha, diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index 953949732560d..9876109c42df1 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -12,6 +12,15 @@ namespace onnxruntime { +void Gemm_MLFloat16(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, // 0, 1 + ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, // 2, 3, 4 + MLFloat16 alpha, // 5 + const MLFloat16* a_data, const MLFloat16* b_data, // 6, 7 + MLFloat16 beta, // 8 + const MLFloat16* c_data, const TensorShape* c_shape, // 9, 10 + MLFloat16* y_data, // 11 + concurrency::ThreadPool* thread_pool); // 12 + template class Gemm : protected GemmBase, public OpKernel { public: diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 983321593a92b..63b647060df3c 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -201,6 +201,75 @@ void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, p MlasGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool); } +template <> +void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, + MLFloat16 alpha, const MLFloat16* A, int lda, const MLFloat16* B, int ldb, MLFloat16 beta, + MLFloat16* C, int ldc, ThreadPool*) { + // The following function is not implemented for MLFloat16 in Mlas. + // MlasGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool); + // Threadpool is not used. + auto C_mat = EigenMatrixMapWithStrides(reinterpret_cast(C), N, M, Eigen::Stride(ldc, 1)); + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + if (beta == MLFloat16(0.f)) { + C_mat.setZero(); + } else { + C_mat *= *reinterpret_cast(&beta); + } + Eigen::half alpha_half = *reinterpret_cast(&alpha); +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + + switch (TransA) { + case CblasNoTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += alpha_half * (ConstEigenMatrixMapWithStrides( + reinterpret_cast(B), N, K, Eigen::Stride(ldb, 1)) * + ConstEigenMatrixMapWithStrides( + reinterpret_cast(A), K, M, Eigen::Stride(lda, 1))); + return; + case CblasTrans: + C_mat.noalias() += alpha_half * (ConstEigenMatrixMapWithStrides( + reinterpret_cast(B), K, N, Eigen::Stride(ldb, 1)) + .transpose() * + ConstEigenMatrixMapWithStrides( + reinterpret_cast(A), K, M, Eigen::Stride(lda, 1))); + return; + default: + ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); + } + } + case CblasTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += alpha_half * (ConstEigenMatrixMapWithStrides( + reinterpret_cast(B), N, K, Eigen::Stride(ldb, 1)) * + ConstEigenMatrixMapWithStrides( + reinterpret_cast(A), M, K, Eigen::Stride(lda, 1)) + .transpose()); + return; + case CblasTrans: + C_mat.noalias() += alpha_half * (ConstEigenMatrixMapWithStrides( + reinterpret_cast(B), K, N, Eigen::Stride(ldb, 1)) + .transpose() * + ConstEigenMatrixMapWithStrides( + reinterpret_cast(A), M, K, Eigen::Stride(lda, 1)) + .transpose()); + return; + default: + ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); + } + } + default: + ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA); + } +} + template void Gemv(CBLAS_TRANSPOSE TransA, int M, diff --git a/onnxruntime/core/util/math_cpuonly.h b/onnxruntime/core/util/math_cpuonly.h index 73caf9f86180d..1b80bfb02c706 100644 --- a/onnxruntime/core/util/math_cpuonly.h +++ b/onnxruntime/core/util/math_cpuonly.h @@ -80,6 +80,12 @@ namespace onnxruntime { template using EigenMatrixMap = Eigen::Map>; +template +using EigenMatrixMapWithStrides = Eigen::Map, 0, Eigen::Stride>; + +template +using ConstEigenMatrixMapWithStrides = Eigen::Map, 0, Eigen::Stride>; + template using EigenArrayMap = Eigen::Map>; diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 61e5fa05c66c1..4245c4bbb1b0a 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -311,7 +311,7 @@ static void RunAttentionTest( kv_sequence_length, past_present_share_buffer, use_scale, do_neox_rotary); } -TEST(AttentionTest, AttentionBatch1) { +TEST(ContribOpAttentionTest, AttentionBatch1) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -340,7 +340,7 @@ TEST(AttentionTest, AttentionBatch1) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionBatch1WithQKVAttr1) { +TEST(ContribOpAttentionTest, AttentionBatch1WithQKVAttr1) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -381,7 +381,7 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr1) { 0, false, false, disable_rocm, false, qkv_sizes); } -TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { +TEST(ContribOpAttentionTest, AttentionBatch1WithQKVAttr2) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -419,7 +419,7 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { 0, false, false, disable_rocm, false, qkv_sizes); } -TEST(AttentionTest, AttentionBatch1AttentionBias) { +TEST(ContribOpAttentionTest, AttentionBatch1AttentionBias) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -460,7 +460,7 @@ TEST(AttentionTest, AttentionBatch1AttentionBias) { 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias); } -TEST(AttentionTest, AttentionBatch2AttentionBias) { +TEST(ContribOpAttentionTest, AttentionBatch2AttentionBias) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -506,7 +506,7 @@ TEST(AttentionTest, AttentionBatch2AttentionBias) { 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias); } -TEST(AttentionTest, AttentionBatch1_Float16) { +TEST(ContribOpAttentionTest, AttentionBatch1_Float16) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -535,7 +535,7 @@ TEST(AttentionTest, AttentionBatch1_Float16) { batch_size, sequence_length, hidden_size, number_of_heads, true); } -TEST(AttentionTest, AttentionBatch2) { +TEST(ContribOpAttentionTest, AttentionBatch2) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -568,7 +568,7 @@ TEST(AttentionTest, AttentionBatch2) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionMaskPartialSequence) { +TEST(ContribOpAttentionTest, AttentionMaskPartialSequence) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -598,7 +598,7 @@ TEST(AttentionTest, AttentionMaskPartialSequence) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionMaskExceedSequence) { +TEST(ContribOpAttentionTest, AttentionMaskExceedSequence) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -628,7 +628,7 @@ TEST(AttentionTest, AttentionMaskExceedSequence) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionNoMaskIndex) { +TEST(ContribOpAttentionTest, AttentionNoMaskIndex) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -658,7 +658,7 @@ TEST(AttentionTest, AttentionNoMaskIndex) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionUnidirectional) { +TEST(ContribOpAttentionTest, AttentionUnidirectional) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -855,7 +855,7 @@ void RawAttentionEmptyPastState(bool past_present_share_buffer) { } } -TEST(AttentionTest, Causal_EmptyPastState) { +TEST(ContribOpAttentionTest, Causal_EmptyPastState) { int batch_size = 1; int sequence_length = 2; int hidden_size = 64; @@ -918,11 +918,11 @@ TEST(AttentionTest, Causal_EmptyPastState) { } } -TEST(AttentionTest, AttentionEmptyPastState) { +TEST(ContribOpAttentionTest, AttentionEmptyPastState) { RawAttentionEmptyPastState(false); } -TEST(AttentionTest, AttentionEmptyPastState_SharedPastPresent) { +TEST(ContribOpAttentionTest, AttentionEmptyPastState_SharedPastPresent) { RawAttentionEmptyPastState(true); } @@ -1037,11 +1037,11 @@ void RawAttentionPastStateBatch1(bool past_present_share_buffer) { } } -TEST(AttentionTest, AttentionPastStateBatch1) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch1) { RawAttentionPastStateBatch1(false); } -TEST(AttentionTest, AttentionPastStateBatch1_SharedPastPresent) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch1_SharedPastPresent) { RawAttentionPastStateBatch1(true); } @@ -1170,11 +1170,11 @@ void RawAttentionPastStateBatch2(bool past_present_share_buffer) { } } -TEST(AttentionTest, AttentionPastStateBatch2) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch2) { RawAttentionPastStateBatch2(false); } -TEST(AttentionTest, AttentionPastStateBatch2_SharedPastPresent) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch2_SharedPastPresent) { RawAttentionPastStateBatch2(true); } @@ -1295,15 +1295,15 @@ void RawAttentionPastStateBatch2WithPadding(bool past_present_share_buffer) { } } -TEST(AttentionTest, AttentionPastStateBatch2WithPadding) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch2WithPadding) { RawAttentionPastStateBatch2WithPadding(false); } -TEST(AttentionTest, AttentionPastStateBatch2WithPadding_SharedPastPresent) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch2WithPadding_SharedPastPresent) { RawAttentionPastStateBatch2WithPadding(true); } -TEST(AttentionTest, AttentionBatch2MaskIndex2) { +TEST(ContribOpAttentionTest, AttentionBatch2MaskIndex2) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1344,7 +1344,7 @@ TEST(AttentionTest, AttentionBatch2MaskIndex2) { AttentionMaskType::MASK_1D_END_START); } -TEST(AttentionTest, AttentionRightPaddingMaskIndex2) { +TEST(ContribOpAttentionTest, AttentionRightPaddingMaskIndex2) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -1382,7 +1382,7 @@ TEST(AttentionTest, AttentionRightPaddingMaskIndex2) { AttentionMaskType::MASK_1D_END_START); } -TEST(AttentionTest, AttentionLeftPaddingMaskIndex2) { +TEST(ContribOpAttentionTest, AttentionLeftPaddingMaskIndex2) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -1420,7 +1420,7 @@ TEST(AttentionTest, AttentionLeftPaddingMaskIndex2) { AttentionMaskType::MASK_1D_END_START); } -TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { +TEST(ContribOpAttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1462,7 +1462,7 @@ TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { AttentionMaskType::MASK_1D_END_START); } -TEST(AttentionTest, Attention3DMask) { +TEST(ContribOpAttentionTest, Attention3DMask) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1508,7 +1508,7 @@ TEST(AttentionTest, Attention3DMask) { AttentionMaskType::MASK_3D_ATTENTION); } -TEST(AttentionTest, AttentionBatch2AttentionMask) { +TEST(ContribOpAttentionTest, AttentionBatch2AttentionMask) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1550,7 +1550,7 @@ TEST(AttentionTest, AttentionBatch2AttentionMask) { AttentionMaskType::MASK_2D_KEY_PADDING); } -TEST(AttentionTest, AttentionUnidirectional3DMask) { +TEST(ContribOpAttentionTest, AttentionUnidirectional3DMask) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1596,7 +1596,7 @@ TEST(AttentionTest, AttentionUnidirectional3DMask) { AttentionMaskType::MASK_3D_ATTENTION); } -TEST(AttentionTest, AttentionUnidirectionalAttentionMask) { +TEST(ContribOpAttentionTest, AttentionUnidirectionalAttentionMask) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1638,7 +1638,7 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) { AttentionMaskType::MASK_2D_KEY_PADDING); } -TEST(AttentionTest, AttentionWithNormFactor) { +TEST(ContribOpAttentionTest, AttentionWithNormFactor) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1683,7 +1683,7 @@ TEST(AttentionTest, AttentionWithNormFactor) { true /*use_scale*/); } -TEST(AttentionTest, AttentionWithNeoXRotaryEmbedding) { +TEST(ContribOpAttentionTest, AttentionWithNeoXRotaryEmbedding) { int batch_size = 2; int sequence_length = 2; int hidden_size = 64; @@ -1717,7 +1717,7 @@ TEST(AttentionTest, AttentionWithNeoXRotaryEmbedding) { true /*use_scale*/, true /*use_neox_rotary_embedding*/); } -TEST(AttentionTest, AttentionMask1DEndNoWord) { +TEST(ContribOpAttentionTest, AttentionMask1DEndNoWord) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1760,7 +1760,7 @@ TEST(AttentionTest, AttentionMask1DEndNoWord) { AttentionMaskType::MASK_1D_KEY_SEQ_LEN); } -TEST(AttentionTest, AttentionMask1DNoWord) { +TEST(ContribOpAttentionTest, AttentionMask1DNoWord) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1803,7 +1803,7 @@ TEST(AttentionTest, AttentionMask1DNoWord) { AttentionMaskType::MASK_1D_END_START); } -TEST(AttentionTest, AttentionMask2DNoWord) { +TEST(ContribOpAttentionTest, AttentionMask2DNoWord) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1846,7 +1846,7 @@ TEST(AttentionTest, AttentionMask2DNoWord) { AttentionMaskType::MASK_2D_KEY_PADDING); } -TEST(AttentionTest, AttentionMask3DNoWord) { +TEST(ContribOpAttentionTest, AttentionMask3DNoWord) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1889,7 +1889,7 @@ TEST(AttentionTest, AttentionMask3DNoWord) { AttentionMaskType::MASK_3D_ATTENTION); } -TEST(AttentionTest, AttentionDummyMask2D) { +TEST(ContribOpAttentionTest, AttentionDummyMask2D) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1931,7 +1931,7 @@ TEST(AttentionTest, AttentionDummyMask2D) { AttentionMaskType::MASK_2D_DUMMY); } -TEST(AttentionTest, Attention4DMask) { +TEST(ContribOpAttentionTest, Attention4DMask) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -1977,7 +1977,7 @@ TEST(AttentionTest, Attention4DMask) { disable_cpu); } -TEST(AttentionTest, AttentionMaskIndexOutOfRange) { +TEST(ContribOpAttentionTest, AttentionMaskIndexOutOfRange) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -2021,7 +2021,7 @@ TEST(AttentionTest, AttentionMaskIndexOutOfRange) { #if !defined(__wasm__) // TODO: fix in web assembly -TEST(AttentionTest, AttentionPastState_dynamic) { +TEST(ContribOpAttentionTest, AttentionPastState_dynamic) { // create rand inputs RandomValueGenerator random{}; @@ -2051,7 +2051,7 @@ TEST(AttentionTest, AttentionPastState_dynamic) { } #endif //! defined(__wasm__) -TEST(AttentionTest, AttentionPrunedModel) { +TEST(ContribOpAttentionTest, AttentionPrunedModel) { int batch_size = 2; int sequence_length = 2; // test input_hidden_size > hidden_size @@ -2174,7 +2174,7 @@ static void RunModelWithRandomInput( } } -TEST(AttentionTest, Attention_Mask2D_Fp32_B2_S32) { +TEST(ContribOpAttentionTest, Attention_Mask2D_Fp32_B2_S32) { constexpr int batch_size = 2; constexpr int sequence_length = 32; @@ -2196,7 +2196,7 @@ TEST(AttentionTest, Attention_Mask2D_Fp32_B2_S32) { false); } -TEST(AttentionTest, Attention_Mask1D_Fp32_B2_S64) { +TEST(ContribOpAttentionTest, Attention_Mask1D_Fp32_B2_S64) { constexpr int batch_size = 2; constexpr int sequence_length = 64; @@ -2217,7 +2217,7 @@ TEST(AttentionTest, Attention_Mask1D_Fp32_B2_S64) { } // This case can be used to test flash attention using Ampere GPU -TEST(AttentionTest, Attention_NoMask_Fp16) { +TEST(ContribOpAttentionTest, Attention_NoMask_Fp16) { constexpr int batch_size = 2; std::vector sequence_lengths{1, 7, 8}; for (const auto& sequence_length : sequence_lengths) { @@ -2236,7 +2236,7 @@ TEST(AttentionTest, Attention_NoMask_Fp16) { } // This test is disabled since it is flaky. -TEST(AttentionTest, DISABLED_Attention_Mask1D_Fp16_B2_FusedNoPadding) { +TEST(ContribOpAttentionTest, DISABLED_Attention_Mask1D_Fp16_B2_FusedNoPadding) { constexpr int batch_size = 2; // Sequence lengths used in TRT fused attention fp16 v2 kernels. @@ -2263,7 +2263,7 @@ TEST(AttentionTest, DISABLED_Attention_Mask1D_Fp16_B2_FusedNoPadding) { #ifndef ENABLE_TRAINING // Prepacking is disabled in full training build so no need to test the feature in a training build. -TEST(AttentionTest, SharedPrepackedWeights) { +TEST(ContribOpAttentionTest, SharedPrepackedWeights) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 20d81824412fa..d279d1fae418a 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1112,6 +1112,30 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"qlinearmatmul_3D_int8_float32", "result diff", {}}, {"qlinearmatmul_3D_uint8_float16", "fp16 type ont supported by CPU EP", {}}}); + // Attention3D examples are wrong with onnx==1.18.0 + broken_tests->insert({"attention_3d", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_attn_mask", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_causal", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_sizes", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_sizes_attn_mask", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_sizes_causal", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_sizes_scaled", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_sizes_softcap", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_with_past_and_present", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa_attn_mask", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa_causal", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa_scaled", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa_softcap", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa_with_past_and_present", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_scaled", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_softcap", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_with_past_and_present", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_with_past_and_present_qk_matmul", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_with_past_and_present_qk_matmul_bias", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_with_past_and_present_qk_matmul_softcap", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_with_past_and_present_qk_matmul_softmax", "wrong expected values (fixed in onnx==1.19.0)"}); + // Some EPs may fail to pass some specific testcases. // For example TenosrRT EP may fail on FLOAT16 related testcases if GPU doesn't support float16. // Instead of list all these testcases, we can use following keyword set to filter out testcases wchich contain diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 7263c435a6a2e..4b37b6c9438aa 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -629,6 +629,7 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, std::unordered_map feeds; std::vector output_names; FillFeedsAndOutputNames(feeds, output_names); + number_of_nodes_ = model.MainGraph().NumberOfNodes(); // Run the model if (ctx_.run_with_specified_eps) { @@ -794,6 +795,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, } } +int BaseTester::GetNumberOfNodesAfterRun() const { return number_of_nodes_; } + void BaseTester::ExecuteModelForEps( std::vector>&& execution_providers, onnxruntime::Model& model, diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h index d39cc3c750dec..182ee4a9550fe 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/providers/base_tester.h @@ -39,6 +39,7 @@ class BaseTester { ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange().Map().at(ONNX_NAMESPACE::ONNX_DOMAIN).second; opset_version_ = latest_onnx_version; } + number_of_nodes_ = 0; } // Derived class to implement to provide the model to test. @@ -621,6 +622,8 @@ class BaseTester { test_allow_released_onnx_opset_only_ = false; } + int GetNumberOfNodesAfterRun() const; + protected: //// if the derived class is caching the model this helper can be called in CreateModelToTest to reset the nodes // static void ClearEpsForAllNodes(Graph& graph); @@ -767,6 +770,7 @@ class BaseTester { std::vector input_data_; std::vector output_data_; std::vector fetches_; + int number_of_nodes_; bool testing_function_called_{}; // has the function that performs the actual testing been called yet? diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc new file mode 100644 index 0000000000000..b4f6d328cacf7 --- /dev/null +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -0,0 +1,1206 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +namespace { +enum class TensorType { + kFloat, + kFloat16, + kBFloat16 +}; +} // anonymous namespace + +static void AddInputs(OpTester& test, + const std::vector& q, + const std::vector& k, + const std::vector& v, + const std::vector& attn_mask, + const std::initializer_list& attn_mask_bool, + const std::vector& past_key, + const std::vector& past_value, + int is_causal, + const std::vector& q_shape, + const std::vector& k_shape, + const std::vector& v_shape, + const std::vector& attn_mask_shape, + const std::vector& past_key_shape, + const std::vector& past_value_shape, + // outputs + const std::vector& y_shape, + const std::vector& present_key_shape, + const std::vector& present_value_shape, + const std::vector& qk_matmul_output_shape, + int kv_num_heads, + int q_num_heads, + int qk_matmul_output_mode, + float scale, + float softcap, + int softmax_precision, + TensorType tensor_type, + const std::vector& y, + const std::vector& present_key, + const std::vector& present_value, + const std::vector& qk_matmul_output) { + if (is_causal >= 0) + test.AddAttribute("is_causal", is_causal); + if (q_shape.size() == 3) { + test.AddAttribute("kv_num_heads", kv_num_heads); + test.AddAttribute("q_num_heads", q_num_heads); + } + if (qk_matmul_output_mode >= 0) + test.AddAttribute("qk_matmul_output_mode", qk_matmul_output_mode); + if (!std::isnan(scale)) + test.AddAttribute("scale", scale); + if (!std::isnan(softcap)) + test.AddAttribute("softcap", softcap); + if (softmax_precision >= 0) + test.AddAttribute("softmax_precision", softmax_precision); + + if (tensor_type == TensorType::kFloat) { + // inputs + test.AddInput("Q", q_shape, q); + test.AddInput("K", k_shape, k); + test.AddInput("V", v_shape, v); + if (!attn_mask.empty()) + test.AddInput("attn_mask", attn_mask_shape, attn_mask); + else if (attn_mask_bool.size() > 0) + test.AddInput("attn_mask", attn_mask_shape, attn_mask_bool); + else + test.AddOptionalInputEdge(); + + if (!past_key.empty()) + test.AddInput("past_key", past_key_shape, past_key); + else + test.AddOptionalInputEdge(); + + if (!past_value.empty()) + test.AddInput("past_value", past_value_shape, past_value); + else + test.AddOptionalInputEdge(); + // outputs + test.AddOutput("Y", y_shape, y, false, 0, 3e-5f); + if (!present_key.empty()) + test.AddOutput("present_key", present_key_shape, present_key); + if (!present_value.empty()) + test.AddOutput("present_value", present_value_shape, present_value); + if (!qk_matmul_output.empty()) + test.AddOutput("qk_matmul_output", qk_matmul_output_shape, qk_matmul_output); + } else if (tensor_type == TensorType::kFloat16) { + // inputs + test.AddInput("Q", q_shape, ToFloat16(q)); + test.AddInput("K", k_shape, ToFloat16(k)); + test.AddInput("V", v_shape, ToFloat16(v)); + if (!attn_mask.empty()) + test.AddInput("attn_mask", attn_mask_shape, ToFloat16(attn_mask)); + else if (attn_mask_bool.size() > 0) + test.AddInput("attn_mask", attn_mask_shape, attn_mask_bool); + else + test.AddOptionalInputEdge(); + + if (!past_key.empty()) + test.AddInput("past_key", past_key_shape, ToFloat16(past_key)); + else + test.AddOptionalInputEdge(); + + if (!past_value.empty()) + test.AddInput("past_value", past_value_shape, ToFloat16(past_value)); + else + test.AddOptionalInputEdge(); + // outputs + test.AddOutput("Y", y_shape, ToFloat16(y), false, 0, 3e-3f); + if (!present_key.empty()) + test.AddOutput("present_key", present_key_shape, ToFloat16(present_key)); + if (!present_value.empty()) + test.AddOutput("present_value", present_value_shape, ToFloat16(present_value)); + if (!qk_matmul_output.empty()) + test.AddOutput("qk_matmul_output", qk_matmul_output_shape, ToFloat16(qk_matmul_output)); + } else { + // inputs + test.AddInput("Q", q_shape, FloatsToBFloat16s(q)); + test.AddInput("K", k_shape, FloatsToBFloat16s(k)); + test.AddInput("V", v_shape, FloatsToBFloat16s(v)); + if (!attn_mask.empty()) + test.AddInput("attn_mask", attn_mask_shape, FloatsToBFloat16s(attn_mask)); + else if (attn_mask_bool.size() > 0) + test.AddInput("attn_mask", attn_mask_shape, attn_mask_bool); + else + test.AddOptionalInputEdge(); + + if (!past_key.empty()) + test.AddInput("past_key", past_key_shape, FloatsToBFloat16s(past_key)); + else + test.AddOptionalInputEdge(); + + if (!past_value.empty()) + test.AddInput("past_value", past_value_shape, FloatsToBFloat16s(past_value)); + else + test.AddOptionalInputEdge(); + // outputs + test.AddOutput("Y", y_shape, FloatsToBFloat16s(y), false, 0, 3e-3f); + if (!present_key.empty()) + test.AddOutput("present_key", present_key_shape, FloatsToBFloat16s(present_key)); + if (!present_value.empty()) + test.AddOutput("present_value", present_value_shape, FloatsToBFloat16s(present_value)); + if (!qk_matmul_output.empty()) + test.AddOutput("qk_matmul_output", qk_matmul_output_shape, FloatsToBFloat16s(qk_matmul_output)); + } +} + +static void SetProviders(std::vector>& execution_providers, bool disable_cpu, bool disable_cuda, bool disable_dml, TensorType tensor_type) { + int min_cuda_architecture = (tensor_type == TensorType::kBFloat16) + ? 800 + : (tensor_type == TensorType::kFloat16) ? 530 + : 0; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); + + if (enable_cuda && !disable_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (enable_dml && !disable_dml) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } + if ((tensor_type == TensorType::kFloat || tensor_type == TensorType::kFloat16) && !disable_cpu) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } +} + +static void RunTest3D( + int batch_size, + int q_num_heads, + int q_sequence_length, + int head_size, + int kv_sequence_length, + int kv_num_heads, + int v_head_size, + int past_sequence_length, + const std::vector& q, + const std::vector& k, + const std::vector& v, + const std::vector& attn_mask, + const std::initializer_list& attn_mask_bool, + const std::vector& past_key, + const std::vector& past_value, + int is_causal, // 0 + // int kv_num_heads, // not needed for 3D + // int q_num_heads, // not needed for 3D + int qk_matmul_output_mode, // 0 + float scale, // 1.0 + float softcap, // 0.0, + int softmax_precision, + TensorType tensor_type, + const std::vector& y, + const std::vector& present_key, + const std::vector& present_value, + const std::vector& qk_matmul_output, + bool disable_cpu, + bool disable_cuda, + bool disable_dml) { + int total_sequence_length = past_sequence_length + kv_sequence_length; + // inputs + int q_hidden_size = q_num_heads * head_size; + int k_hidden_size = kv_num_heads * head_size; + int v_hidden_size = kv_num_heads * v_head_size; + int hidden_size = q_num_heads * v_head_size; + std::vector q_shape = {batch_size, q_sequence_length, q_hidden_size}; + std::vector k_shape = {batch_size, kv_sequence_length, k_hidden_size}; + std::vector v_shape = {batch_size, kv_sequence_length, v_hidden_size}; + + std::vector attn_mask_shape = {q_sequence_length, total_sequence_length}; + if (q_sequence_length * total_sequence_length != attn_mask.size() && attn_mask.size() > 0) { + if (batch_size * q_sequence_length * total_sequence_length == attn_mask.size()) { + attn_mask_shape = {batch_size, 1, q_sequence_length, total_sequence_length}; + } else if (batch_size * q_num_heads * q_sequence_length * total_sequence_length == attn_mask.size()) { + attn_mask_shape = {batch_size, q_num_heads, q_sequence_length, total_sequence_length}; + } else { + ORT_THROW("Invalid attention mask size: ", attn_mask.size(), + " expected ", q_sequence_length, "*", total_sequence_length, " or ", + batch_size, "*", q_sequence_length, "*", total_sequence_length); + } + } + + std::vector past_key_shape = {batch_size, kv_num_heads, past_sequence_length, head_size}; + std::vector past_value_shape = {batch_size, kv_num_heads, past_sequence_length, head_size}; + // outputs + std::vector y_shape = {batch_size, q_sequence_length, hidden_size}; + std::vector present_key_shape = {batch_size, kv_num_heads, total_sequence_length, head_size}; + std::vector present_value_shape = {batch_size, kv_num_heads, total_sequence_length, v_head_size}; + std::vector qk_matmul_output_shape = {batch_size, q_num_heads, q_sequence_length, total_sequence_length}; + + std::vector> execution_providers; + SetProviders(execution_providers, disable_cpu, disable_cuda, disable_dml, tensor_type); + if (execution_providers.size() == 0) { + // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) + return; + } + + for (auto& ep : execution_providers) { + OpTester test("Attention", 23, onnxruntime::kOnnxDomain); + AddInputs(test, q, k, v, attn_mask, attn_mask_bool, past_key, past_value, is_causal, + q_shape, k_shape, v_shape, attn_mask_shape, past_key_shape, past_value_shape, y_shape, present_key_shape, present_value_shape, qk_matmul_output_shape, + kv_num_heads, q_num_heads, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type, y, present_key, present_value, qk_matmul_output); + + std::vector> test_execution_providers; + test_execution_providers.push_back(std::move(ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); + ASSERT_EQ(test.GetNumberOfNodesAfterRun(), 1); // This checks the operator was not inlined. + } +} + +static void RunTest4D( + int batch_size, + int q_num_heads, + int q_sequence_length, + int head_size, + int kv_sequence_length, + int kv_num_heads, + int v_head_size, + int past_sequence_length, + const std::vector& q, + const std::vector& k, + const std::vector& v, + const std::vector& attn_mask, + const std::initializer_list& attn_mask_bool, + const std::vector& past_key, + const std::vector& past_value, + int is_causal, // 0 + // int kv_num_heads, // not needed for 3D + // int q_num_heads, // not needed for 3D + int qk_matmul_output_mode, // 0 + float scale, // 1.0 + float softcap, // 0.0, + int softmax_precision, + TensorType tensor_type, + const std::vector& y, + const std::vector& present_key, + const std::vector& present_value, + const std::vector& qk_matmul_output, + bool disable_cpu, + bool disable_cuda, + bool disable_dml) { + int total_sequence_length = past_sequence_length + kv_sequence_length; + // inputs + std::vector q_shape = {batch_size, q_num_heads, q_sequence_length, head_size}; + std::vector k_shape = {batch_size, kv_num_heads, kv_sequence_length, head_size}; + std::vector v_shape = {batch_size, kv_num_heads, kv_sequence_length, v_head_size}; + + std::vector attn_mask_shape = {q_sequence_length, total_sequence_length}; + if (q_sequence_length * total_sequence_length != attn_mask.size() && attn_mask.size() > 0) { + if (batch_size * q_sequence_length * total_sequence_length == attn_mask.size()) { + attn_mask_shape = {batch_size, 1, q_sequence_length, total_sequence_length}; + } else if (batch_size * q_num_heads * q_sequence_length * total_sequence_length == attn_mask.size()) { + attn_mask_shape = {batch_size, q_num_heads, q_sequence_length, total_sequence_length}; + } else { + ORT_THROW("Invalid attention mask size: ", attn_mask.size(), + " expected ", q_sequence_length, "*", total_sequence_length, " or ", + batch_size, "*", q_sequence_length, "*", total_sequence_length); + } + } + + std::vector past_key_shape = {batch_size, kv_num_heads, past_sequence_length, head_size}; + std::vector past_value_shape = {batch_size, kv_num_heads, past_sequence_length, v_head_size}; + // outputs + std::vector y_shape = {batch_size, q_num_heads, q_sequence_length, v_head_size}; + std::vector present_key_shape = {batch_size, kv_num_heads, total_sequence_length, head_size}; + std::vector present_value_shape = {batch_size, kv_num_heads, total_sequence_length, v_head_size}; + std::vector qk_matmul_output_shape = {batch_size, q_num_heads, q_sequence_length, total_sequence_length}; + + std::vector> execution_providers; + SetProviders(execution_providers, disable_cpu, disable_cuda, disable_dml, tensor_type); + if (execution_providers.size() == 0) { + // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) + return; + } + + for (auto& ep : execution_providers) { + OpTester test("Attention", 23, onnxruntime::kOnnxDomain); + AddInputs(test, q, k, v, attn_mask, attn_mask_bool, past_key, past_value, is_causal, + q_shape, k_shape, v_shape, attn_mask_shape, past_key_shape, past_value_shape, y_shape, present_key_shape, present_value_shape, qk_matmul_output_shape, + kv_num_heads, q_num_heads, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type, y, present_key, present_value, qk_matmul_output); + + std::vector> test_execution_providers; + test_execution_providers.push_back(std::move(ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); + ASSERT_EQ(test.GetNumberOfNodesAfterRun(), 1); // This checks the operator was not inlined. + } +} + +TEST(AttentionTest, Attention3DDefault) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::vector y = {0.231425f, 0.572015f, 0.512671f, 0.279597f, 0.323671f, 0.474956f, 0.344308f, 0.454604f, 0.677763f, 0.427182f, 0.518734f, 0.586593f, 0.366221f, 0.617469f, 0.568592f, 0.711734f, 0.669865f, 0.477629f, 0.443902f, 0.657931f, 0.294461f, 0.444926f, 0.646996f, 0.624016f, 0.230982f, 0.577089f, 0.515905f, 0.281810f, 0.318254f, 0.478419f, 0.341943f, 0.456036f, 0.671153f, 0.419443f, 0.553783f, 0.617598f, 0.405113f, 0.612246f, 0.546371f, 0.691976f, 0.673135f, 0.474435f, 0.440636f, 0.656117f, 0.290562f, 0.437461f, 0.641583f, 0.628633f, 0.213246f, 0.573821f, 0.481404f, 0.314601f, 0.331198f, 0.479336f, 0.334377f, 0.416422f, 0.683961f, 0.438780f, 0.515832f, 0.594131f, 0.421298f, 0.581216f, 0.544020f, 0.665089f, 0.680353f, 0.496091f, 0.458597f, 0.644262f, 0.290254f, 0.439397f, 0.648748f, 0.622587f, 0.215077f, 0.561958f, 0.470216f, 0.315574f, 0.330295f, 0.476255f, 0.346486f, 0.433062f, 0.675563f, 0.430004f, 0.531206f, 0.603125f, 0.392384f, 0.606396f, 0.553218f, 0.688558f, 0.672218f, 0.481904f, 0.442930f, 0.664552f, 0.291008f, 0.447983f, 0.646510f, 0.629446f, 0.684469f, 0.333075f, 0.591230f, 0.723174f, 0.527550f, 0.429390f, 0.379490f, 0.407681f, 0.549282f, 0.325072f, 0.396408f, 0.659680f, 0.252716f, 0.438976f, 0.383743f, 0.537200f, 0.679028f, 0.472077f, 0.522267f, 0.258646f, 0.543009f, 0.648117f, 0.524809f, 0.455668f, 0.679968f, 0.320914f, 0.603929f, 0.720663f, 0.535420f, 0.427747f, 0.365637f, 0.402336f, 0.555204f, 0.329413f, 0.403408f, 0.674143f, 0.257068f, 0.430207f, 0.384353f, 0.534996f, 0.682781f, 0.472336f, 0.532518f, 0.255054f, 0.533888f, 0.631695f, 0.517009f, 0.460408f, 0.676468f, 0.310125f, 0.594133f, 0.720721f, 0.531343f, 0.428411f, 0.383201f, 0.400798f, 0.520066f, 0.313406f, 0.378438f, 0.660871f, 0.236947f, 0.471855f, 0.380046f, 0.533181f, 0.692040f, 0.460203f, 0.533379f, 0.249623f, 0.540433f, 0.638632f, 0.525843f, 0.453184f, 0.678596f, 0.343161f, 0.587705f, 0.727194f, 0.516850f, 0.421908f, 0.366269f, 0.400319f, 0.550307f, 0.323773f, 0.406273f, 0.671064f, 0.258597f, 0.441523f, 0.386403f, 0.537742f, 0.671703f, 0.464797f, 0.523623f, 0.248851f, 0.522889f, 0.644907f, 0.502470f, 0.446048f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention3DDefaultFloat16) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::vector y = {0.231425f, 0.572015f, 0.512671f, 0.279597f, 0.323671f, 0.474956f, 0.344308f, 0.454604f, 0.677763f, 0.427182f, 0.518734f, 0.586593f, 0.366221f, 0.617469f, 0.568592f, 0.711734f, 0.669865f, 0.477629f, 0.443902f, 0.657931f, 0.294461f, 0.444926f, 0.646996f, 0.624016f, 0.230982f, 0.577089f, 0.515905f, 0.281810f, 0.318254f, 0.478419f, 0.341943f, 0.456036f, 0.671153f, 0.419443f, 0.553783f, 0.617598f, 0.405113f, 0.612246f, 0.546371f, 0.691976f, 0.673135f, 0.474435f, 0.440636f, 0.656117f, 0.290562f, 0.437461f, 0.641583f, 0.628633f, 0.213246f, 0.573821f, 0.481404f, 0.314601f, 0.331198f, 0.479336f, 0.334377f, 0.416422f, 0.683961f, 0.438780f, 0.515832f, 0.594131f, 0.421298f, 0.581216f, 0.544020f, 0.665089f, 0.680353f, 0.496091f, 0.458597f, 0.644262f, 0.290254f, 0.439397f, 0.648748f, 0.622587f, 0.215077f, 0.561958f, 0.470216f, 0.315574f, 0.330295f, 0.476255f, 0.346486f, 0.433062f, 0.675563f, 0.430004f, 0.531206f, 0.603125f, 0.392384f, 0.606396f, 0.553218f, 0.688558f, 0.672218f, 0.481904f, 0.442930f, 0.664552f, 0.291008f, 0.447983f, 0.646510f, 0.629446f, 0.684469f, 0.333075f, 0.591230f, 0.723174f, 0.527550f, 0.429390f, 0.379490f, 0.407681f, 0.549282f, 0.325072f, 0.396408f, 0.659680f, 0.252716f, 0.438976f, 0.383743f, 0.537200f, 0.679028f, 0.472077f, 0.522267f, 0.258646f, 0.543009f, 0.648117f, 0.524809f, 0.455668f, 0.679968f, 0.320914f, 0.603929f, 0.720663f, 0.535420f, 0.427747f, 0.365637f, 0.402336f, 0.555204f, 0.329413f, 0.403408f, 0.674143f, 0.257068f, 0.430207f, 0.384353f, 0.534996f, 0.682781f, 0.472336f, 0.532518f, 0.255054f, 0.533888f, 0.631695f, 0.517009f, 0.460408f, 0.676468f, 0.310125f, 0.594133f, 0.720721f, 0.531343f, 0.428411f, 0.383201f, 0.400798f, 0.520066f, 0.313406f, 0.378438f, 0.660871f, 0.236947f, 0.471855f, 0.380046f, 0.533181f, 0.692040f, 0.460203f, 0.533379f, 0.249623f, 0.540433f, 0.638632f, 0.525843f, 0.453184f, 0.678596f, 0.343161f, 0.587705f, 0.727194f, 0.516850f, 0.421908f, 0.366269f, 0.400319f, 0.550307f, 0.323773f, 0.406273f, 0.671064f, 0.258597f, 0.441523f, 0.386403f, 0.537742f, 0.671703f, 0.464797f, 0.523623f, 0.248851f, 0.522889f, 0.644907f, 0.502470f, 0.446048f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DDefaultBasic) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + std::vector k = {1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + std::vector v = {1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + std::vector y = {0.221683f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DDefault) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + std::vector y = {0.501465f, 0.543511f, 0.398088f, 0.474061f, 0.290507f, 0.423018f, 0.447999f, 0.672390f, 0.500878f, 0.545140f, 0.402253f, 0.478354f, 0.278711f, 0.420929f, 0.451124f, 0.682613f, 0.496502f, 0.557356f, 0.419293f, 0.467867f, 0.280946f, 0.422295f, 0.445183f, 0.675748f, 0.498804f, 0.545264f, 0.399543f, 0.471287f, 0.287601f, 0.424845f, 0.443877f, 0.670841f, 0.580098f, 0.450536f, 0.702941f, 0.538382f, 0.329768f, 0.543394f, 0.613723f, 0.562010f, 0.584549f, 0.447129f, 0.673676f, 0.537643f, 0.342950f, 0.515742f, 0.613437f, 0.502951f, 0.585248f, 0.443070f, 0.676620f, 0.549025f, 0.343112f, 0.522440f, 0.611621f, 0.507324f, 0.580745f, 0.461632f, 0.668496f, 0.507376f, 0.336816f, 0.500750f, 0.618162f, 0.500909f, 0.464240f, 0.493342f, 0.380525f, 0.530712f, 0.397056f, 0.582067f, 0.443341f, 0.559227f, 0.467916f, 0.503694f, 0.373170f, 0.549178f, 0.387171f, 0.587037f, 0.448581f, 0.561591f, 0.478681f, 0.496704f, 0.369457f, 0.545459f, 0.392339f, 0.587842f, 0.452645f, 0.576330f, 0.483897f, 0.491793f, 0.360676f, 0.530990f, 0.380686f, 0.603393f, 0.467172f, 0.583590f, 0.642787f, 0.470883f, 0.686034f, 0.642719f, 0.386365f, 0.366454f, 0.467120f, 0.405736f, 0.644347f, 0.466390f, 0.684379f, 0.640710f, 0.385963f, 0.366271f, 0.472645f, 0.403025f, 0.631421f, 0.453237f, 0.677676f, 0.643979f, 0.390879f, 0.377663f, 0.467158f, 0.401772f, 0.637457f, 0.459313f, 0.677889f, 0.659685f, 0.383362f, 0.379251f, 0.453763f, 0.401437f, 0.555998f, 0.186013f, 0.455395f, 0.406430f, 0.395553f, 0.526708f, 0.320193f, 0.484448f, 0.577368f, 0.190770f, 0.462801f, 0.384114f, 0.403607f, 0.534057f, 0.326255f, 0.496504f, 0.563586f, 0.180264f, 0.464196f, 0.384055f, 0.385514f, 0.537212f, 0.338047f, 0.485235f, 0.555800f, 0.177971f, 0.457827f, 0.377928f, 0.372441f, 0.541035f, 0.343750f, 0.483692f, 0.705313f, 0.467049f, 0.389698f, 0.530555f, 0.548003f, 0.637789f, 0.501241f, 0.493046f, 0.692096f, 0.474284f, 0.375588f, 0.530258f, 0.507811f, 0.618987f, 0.468782f, 0.502795f, 0.703758f, 0.479856f, 0.374269f, 0.518477f, 0.518286f, 0.631821f, 0.502535f, 0.509264f, 0.689539f, 0.474638f, 0.374363f, 0.519131f, 0.519441f, 0.644891f, 0.480984f, 0.490645f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DDefaultFloat16) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + std::vector y = {0.501465f, 0.543511f, 0.398088f, 0.474061f, 0.290507f, 0.423018f, 0.447999f, 0.672390f, 0.500878f, 0.545140f, 0.402253f, 0.478354f, 0.278711f, 0.420929f, 0.451124f, 0.682613f, 0.496502f, 0.557356f, 0.419293f, 0.467867f, 0.280946f, 0.422295f, 0.445183f, 0.675748f, 0.498804f, 0.545264f, 0.399543f, 0.471287f, 0.287601f, 0.424845f, 0.443877f, 0.670841f, 0.580098f, 0.450536f, 0.702941f, 0.538382f, 0.329768f, 0.543394f, 0.613723f, 0.562010f, 0.584549f, 0.447129f, 0.673676f, 0.537643f, 0.342950f, 0.515742f, 0.613437f, 0.502951f, 0.585248f, 0.443070f, 0.676620f, 0.549025f, 0.343112f, 0.522440f, 0.611621f, 0.507324f, 0.580745f, 0.461632f, 0.668496f, 0.507376f, 0.336816f, 0.500750f, 0.618162f, 0.500909f, 0.464240f, 0.493342f, 0.380525f, 0.530712f, 0.397056f, 0.582067f, 0.443341f, 0.559227f, 0.467916f, 0.503694f, 0.373170f, 0.549178f, 0.387171f, 0.587037f, 0.448581f, 0.561591f, 0.478681f, 0.496704f, 0.369457f, 0.545459f, 0.392339f, 0.587842f, 0.452645f, 0.576330f, 0.483897f, 0.491793f, 0.360676f, 0.530990f, 0.380686f, 0.603393f, 0.467172f, 0.583590f, 0.642787f, 0.470883f, 0.686034f, 0.642719f, 0.386365f, 0.366454f, 0.467120f, 0.405736f, 0.644347f, 0.466390f, 0.684379f, 0.640710f, 0.385963f, 0.366271f, 0.472645f, 0.403025f, 0.631421f, 0.453237f, 0.677676f, 0.643979f, 0.390879f, 0.377663f, 0.467158f, 0.401772f, 0.637457f, 0.459313f, 0.677889f, 0.659685f, 0.383362f, 0.379251f, 0.453763f, 0.401437f, 0.555998f, 0.186013f, 0.455395f, 0.406430f, 0.395553f, 0.526708f, 0.320193f, 0.484448f, 0.577368f, 0.190770f, 0.462801f, 0.384114f, 0.403607f, 0.534057f, 0.326255f, 0.496504f, 0.563586f, 0.180264f, 0.464196f, 0.384055f, 0.385514f, 0.537212f, 0.338047f, 0.485235f, 0.555800f, 0.177971f, 0.457827f, 0.377928f, 0.372441f, 0.541035f, 0.343750f, 0.483692f, 0.705313f, 0.467049f, 0.389698f, 0.530555f, 0.548003f, 0.637789f, 0.501241f, 0.493046f, 0.692096f, 0.474284f, 0.375588f, 0.530258f, 0.507811f, 0.618987f, 0.468782f, 0.502795f, 0.703758f, 0.479856f, 0.374269f, 0.518477f, 0.518286f, 0.631821f, 0.502535f, 0.509264f, 0.689539f, 0.474638f, 0.374363f, 0.519131f, 0.519441f, 0.644891f, 0.480984f, 0.490645f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DSoftCap) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 10; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + // with softcap=2 + std::vector ys = {0.227656f, 0.365938f, 0.487233f, 0.563168f, 0.314693f, 0.531065f, 0.502050f, 0.532911f, 0.479305f, 0.619133f, 0.230719f, 0.361396f, 0.476682f, 0.566474f, 0.307008f, 0.529635f, 0.503316f, 0.540530f, 0.476847f, 0.620507f, 0.233811f, 0.361041f, 0.472995f, 0.571894f, 0.309176f, 0.536943f, 0.498525f, 0.540409f, 0.475846f, 0.615972f, 0.223131f, 0.365223f, 0.488599f, 0.559249f, 0.315942f, 0.525688f, 0.494637f, 0.539772f, 0.488000f, 0.625606f, 0.539676f, 0.409601f, 0.515692f, 0.453467f, 0.697314f, 0.396105f, 0.298034f, 0.552743f, 0.440534f, 0.843839f, 0.525229f, 0.418362f, 0.546100f, 0.481009f, 0.687614f, 0.414847f, 0.327302f, 0.572564f, 0.461664f, 0.831423f, 0.521430f, 0.418181f, 0.545782f, 0.477744f, 0.687580f, 0.409896f, 0.324292f, 0.565326f, 0.459461f, 0.832106f, 0.542037f, 0.412166f, 0.539834f, 0.486373f, 0.691028f, 0.421836f, 0.330124f, 0.590678f, 0.466584f, 0.831750f, 0.382651f, 0.501226f, 0.660685f, 0.342294f, 0.602060f, 0.492331f, 0.474420f, 0.409177f, 0.518175f, 0.581219f, 0.387046f, 0.503621f, 0.666169f, 0.332572f, 0.596846f, 0.479979f, 0.479994f, 0.413598f, 0.515513f, 0.577655f, 0.398240f, 0.510706f, 0.663548f, 0.331466f, 0.594592f, 0.465828f, 0.485982f, 0.414944f, 0.516808f, 0.588646f, 0.401608f, 0.503138f, 0.664086f, 0.314710f, 0.579984f, 0.448406f, 0.482952f, 0.410394f, 0.515656f, 0.614177f, 0.430626f, 0.390476f, 0.382732f, 0.345745f, 0.361913f, 0.378760f, 0.487068f, 0.359749f, 0.440638f, 0.611671f, 0.434161f, 0.384956f, 0.382824f, 0.347990f, 0.361064f, 0.378348f, 0.483768f, 0.357084f, 0.441993f, 0.612507f, 0.430795f, 0.387191f, 0.392464f, 0.339543f, 0.365489f, 0.373725f, 0.480792f, 0.354801f, 0.428210f, 0.621415f, 0.430196f, 0.387751f, 0.374630f, 0.333935f, 0.363445f, 0.372619f, 0.482465f, 0.350530f, 0.427172f, 0.618986f, 0.529767f, 0.595815f, 0.301624f, 0.397276f, 0.605455f, 0.607591f, 0.617002f, 0.544150f, 0.662428f, 0.510301f, 0.533071f, 0.602211f, 0.278156f, 0.392687f, 0.617217f, 0.593104f, 0.629293f, 0.563362f, 0.682795f, 0.519542f, 0.520110f, 0.607374f, 0.289463f, 0.386297f, 0.609416f, 0.600651f, 0.634780f, 0.553284f, 0.672042f, 0.506020f, 0.514322f, 0.606722f, 0.293574f, 0.377031f, 0.612149f, 0.599634f, 0.640889f, 0.546806f, 0.672437f, 0.505487f, 0.380489f, 0.334473f, 0.554343f, 0.499727f, 0.526942f, 0.558871f, 0.530154f, 0.309413f, 0.555978f, 0.488827f, 0.371393f, 0.341934f, 0.552609f, 0.481362f, 0.537837f, 0.574948f, 0.524870f, 0.312968f, 0.558314f, 0.484292f, 0.382443f, 0.330414f, 0.567252f, 0.481373f, 0.557600f, 0.575927f, 0.536800f, 0.295057f, 0.535626f, 0.488409f, 0.369831f, 0.343157f, 0.554056f, 0.492472f, 0.539300f, 0.565926f, 0.540317f, 0.307066f, 0.560539f, 0.493642f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + ys, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DSoftCapFloat16) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 10; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + // with softcap=2 + std::vector ys = {0.227656f, 0.365938f, 0.487233f, 0.563168f, 0.314693f, 0.531065f, 0.502050f, 0.532911f, 0.479305f, 0.619133f, 0.230719f, 0.361396f, 0.476682f, 0.566474f, 0.307008f, 0.529635f, 0.503316f, 0.540530f, 0.476847f, 0.620507f, 0.233811f, 0.361041f, 0.472995f, 0.571894f, 0.309176f, 0.536943f, 0.498525f, 0.540409f, 0.475846f, 0.615972f, 0.223131f, 0.365223f, 0.488599f, 0.559249f, 0.315942f, 0.525688f, 0.494637f, 0.539772f, 0.488000f, 0.625606f, 0.539676f, 0.409601f, 0.515692f, 0.453467f, 0.697314f, 0.396105f, 0.298034f, 0.552743f, 0.440534f, 0.843839f, 0.525229f, 0.418362f, 0.546100f, 0.481009f, 0.687614f, 0.414847f, 0.327302f, 0.572564f, 0.461664f, 0.831423f, 0.521430f, 0.418181f, 0.545782f, 0.477744f, 0.687580f, 0.409896f, 0.324292f, 0.565326f, 0.459461f, 0.832106f, 0.542037f, 0.412166f, 0.539834f, 0.486373f, 0.691028f, 0.421836f, 0.330124f, 0.590678f, 0.466584f, 0.831750f, 0.382651f, 0.501226f, 0.660685f, 0.342294f, 0.602060f, 0.492331f, 0.474420f, 0.409177f, 0.518175f, 0.581219f, 0.387046f, 0.503621f, 0.666169f, 0.332572f, 0.596846f, 0.479979f, 0.479994f, 0.413598f, 0.515513f, 0.577655f, 0.398240f, 0.510706f, 0.663548f, 0.331466f, 0.594592f, 0.465828f, 0.485982f, 0.414944f, 0.516808f, 0.588646f, 0.401608f, 0.503138f, 0.664086f, 0.314710f, 0.579984f, 0.448406f, 0.482952f, 0.410394f, 0.515656f, 0.614177f, 0.430626f, 0.390476f, 0.382732f, 0.345745f, 0.361913f, 0.378760f, 0.487068f, 0.359749f, 0.440638f, 0.611671f, 0.434161f, 0.384956f, 0.382824f, 0.347990f, 0.361064f, 0.378348f, 0.483768f, 0.357084f, 0.441993f, 0.612507f, 0.430795f, 0.387191f, 0.392464f, 0.339543f, 0.365489f, 0.373725f, 0.480792f, 0.354801f, 0.428210f, 0.621415f, 0.430196f, 0.387751f, 0.374630f, 0.333935f, 0.363445f, 0.372619f, 0.482465f, 0.350530f, 0.427172f, 0.618986f, 0.529767f, 0.595815f, 0.301624f, 0.397276f, 0.605455f, 0.607591f, 0.617002f, 0.544150f, 0.662428f, 0.510301f, 0.533071f, 0.602211f, 0.278156f, 0.392687f, 0.617217f, 0.593104f, 0.629293f, 0.563362f, 0.682795f, 0.519542f, 0.520110f, 0.607374f, 0.289463f, 0.386297f, 0.609416f, 0.600651f, 0.634780f, 0.553284f, 0.672042f, 0.506020f, 0.514322f, 0.606722f, 0.293574f, 0.377031f, 0.612149f, 0.599634f, 0.640889f, 0.546806f, 0.672437f, 0.505487f, 0.380489f, 0.334473f, 0.554343f, 0.499727f, 0.526942f, 0.558871f, 0.530154f, 0.309413f, 0.555978f, 0.488827f, 0.371393f, 0.341934f, 0.552609f, 0.481362f, 0.537837f, 0.574948f, 0.524870f, 0.312968f, 0.558314f, 0.484292f, 0.382443f, 0.330414f, 0.567252f, 0.481373f, 0.557600f, 0.575927f, 0.536800f, 0.295057f, 0.535626f, 0.488409f, 0.369831f, 0.343157f, 0.554056f, 0.492472f, 0.539300f, 0.565926f, 0.540317f, 0.307066f, 0.560539f, 0.493642f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + ys, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnMask) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f}; + std::vector y = {0.478040f, 0.503674f, 0.349552f, 0.475550f, 0.319086f, 0.440731f, 0.452109f, 0.673914f, 0.477799f, 0.522510f, 0.381228f, 0.496104f, 0.239154f, 0.427475f, 0.460164f, 0.727212f, 0.478457f, 0.589145f, 0.456094f, 0.413665f, 0.297445f, 0.419073f, 0.407575f, 0.626054f, 0.503276f, 0.536857f, 0.396718f, 0.495176f, 0.270464f, 0.419459f, 0.466892f, 0.704668f, 0.544710f, 0.446025f, 0.625069f, 0.574330f, 0.337465f, 0.515011f, 0.576166f, 0.495398f, 0.561775f, 0.451492f, 0.656295f, 0.501454f, 0.371102f, 0.511117f, 0.597942f, 0.486135f, 0.613719f, 0.415552f, 0.679385f, 0.545510f, 0.334013f, 0.491561f, 0.634246f, 0.501191f, 0.592514f, 0.421301f, 0.682063f, 0.535644f, 0.365155f, 0.518639f, 0.614815f, 0.501439f, 0.460727f, 0.519269f, 0.348532f, 0.554692f, 0.328284f, 0.619616f, 0.469338f, 0.556237f, 0.442274f, 0.547421f, 0.394879f, 0.609402f, 0.399426f, 0.573414f, 0.435733f, 0.513013f, 0.478210f, 0.470028f, 0.379309f, 0.520524f, 0.393439f, 0.580848f, 0.442115f, 0.602217f, 0.485329f, 0.501646f, 0.370504f, 0.561198f, 0.416058f, 0.567774f, 0.439229f, 0.571259f, 0.674824f, 0.550989f, 0.722801f, 0.662394f, 0.352779f, 0.301575f, 0.454417f, 0.436797f, 0.640218f, 0.464017f, 0.673274f, 0.631072f, 0.416194f, 0.405371f, 0.424135f, 0.380459f, 0.676026f, 0.466017f, 0.693624f, 0.619528f, 0.361035f, 0.314311f, 0.546125f, 0.401422f, 0.634731f, 0.457909f, 0.673249f, 0.669035f, 0.395002f, 0.414838f, 0.422935f, 0.397171f, 0.578772f, 0.171263f, 0.507806f, 0.446147f, 0.431901f, 0.525101f, 0.333084f, 0.473000f, 0.581295f, 0.193171f, 0.470985f, 0.376522f, 0.425847f, 0.546483f, 0.292789f, 0.509355f, 0.590731f, 0.161755f, 0.514375f, 0.380830f, 0.398416f, 0.492429f, 0.361418f, 0.440428f, 0.559340f, 0.167691f, 0.474461f, 0.331081f, 0.368636f, 0.558841f, 0.331704f, 0.485050f, 0.683438f, 0.514064f, 0.339780f, 0.536424f, 0.478815f, 0.654453f, 0.482692f, 0.544422f, 0.718284f, 0.508385f, 0.350896f, 0.561493f, 0.527900f, 0.642672f, 0.514512f, 0.516495f, 0.644405f, 0.441945f, 0.397069f, 0.484688f, 0.496761f, 0.647967f, 0.423362f, 0.480241f, 0.686930f, 0.492126f, 0.344961f, 0.526120f, 0.489709f, 0.638597f, 0.457665f, 0.469929f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * kv_sequence_length); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnMaskBool) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::initializer_list m = {true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true}; + std::vector y = {0.501465f, 0.543511f, 0.398088f, 0.474061f, 0.290507f, 0.423018f, 0.447999f, 0.672390f, 0.500878f, 0.545140f, 0.402253f, 0.478354f, 0.278711f, 0.420929f, 0.451124f, 0.682613f, 0.496502f, 0.557356f, 0.419293f, 0.467867f, 0.280946f, 0.422295f, 0.445183f, 0.675748f, 0.498804f, 0.545264f, 0.399543f, 0.471287f, 0.287601f, 0.424845f, 0.443877f, 0.670841f, 0.580098f, 0.450536f, 0.702941f, 0.538382f, 0.329768f, 0.543394f, 0.613723f, 0.562010f, 0.584549f, 0.447129f, 0.673676f, 0.537643f, 0.342950f, 0.515742f, 0.613437f, 0.502951f, 0.585248f, 0.443070f, 0.676620f, 0.549025f, 0.343112f, 0.522440f, 0.611621f, 0.507324f, 0.580745f, 0.461632f, 0.668496f, 0.507376f, 0.336816f, 0.500750f, 0.618162f, 0.500909f, 0.464240f, 0.493342f, 0.380525f, 0.530712f, 0.397056f, 0.582067f, 0.443341f, 0.559227f, 0.467916f, 0.503694f, 0.373170f, 0.549178f, 0.387171f, 0.587037f, 0.448581f, 0.561591f, 0.478681f, 0.496704f, 0.369457f, 0.545459f, 0.392339f, 0.587842f, 0.452645f, 0.576330f, 0.483897f, 0.491793f, 0.360676f, 0.530990f, 0.380686f, 0.603393f, 0.467172f, 0.583590f, 0.642787f, 0.470883f, 0.686034f, 0.642719f, 0.386365f, 0.366454f, 0.467120f, 0.405736f, 0.644347f, 0.466390f, 0.684379f, 0.640710f, 0.385963f, 0.366271f, 0.472645f, 0.403025f, 0.631421f, 0.453237f, 0.677676f, 0.643979f, 0.390879f, 0.377663f, 0.467158f, 0.401772f, 0.637457f, 0.459313f, 0.677889f, 0.659685f, 0.383362f, 0.379251f, 0.453763f, 0.401437f, 0.555998f, 0.186013f, 0.455395f, 0.406430f, 0.395553f, 0.526708f, 0.320193f, 0.484448f, 0.577368f, 0.190770f, 0.462801f, 0.384114f, 0.403607f, 0.534057f, 0.326255f, 0.496504f, 0.563586f, 0.180264f, 0.464196f, 0.384055f, 0.385514f, 0.537212f, 0.338047f, 0.485235f, 0.555800f, 0.177971f, 0.457827f, 0.377928f, 0.372441f, 0.541035f, 0.343750f, 0.483692f, 0.705313f, 0.467049f, 0.389698f, 0.530555f, 0.548003f, 0.637789f, 0.501241f, 0.493046f, 0.692096f, 0.474284f, 0.375588f, 0.530258f, 0.507811f, 0.618987f, 0.468782f, 0.502795f, 0.703758f, 0.479856f, 0.374269f, 0.518477f, 0.518286f, 0.631821f, 0.502535f, 0.509264f, 0.689539f, 0.474638f, 0.374363f, 0.519131f, 0.519441f, 0.644891f, 0.480984f, 0.490645f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * kv_sequence_length); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), m, std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnPastPresentBasic) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 1; // Q.shape[1] + int q_sequence_length = 3; // Q.shape[2] + int head_size = 2; // Q.shape[3] + int kv_sequence_length = 4; // K.shape[2] and V.shape[2] + int kv_num_heads = 1; // K.shape[1] and V.shape[1] + int v_head_size = 2; // V.shape[3] + int past_sequence_length = 1; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {1, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1}; + std::vector k = {1, 0, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 0, 1, 2}; + std::vector v = {0, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 0, 2}; + std::vector m = {1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1}; + std::vector past_key = {1, 2, 1, 1}; + std::vector past_value = {1, 1, 2, 1}; + std::vector y = {1.2691493034362793, 1.0, 1.0774023532867432, 1.0, 0.9539920091629028, 1.0, 0.4988941252231598, 1.6121423244476318, 0.8137872219085693, 1.3673334121704102, 0.8579846620559692, 1.2801470756530762}; + std::vector present_key = {1.0, 2.0, 1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 2.0}; + std::vector present_value = {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 2.0}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnPastPresent) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + std::vector past_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f}; + std::vector past_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f}; + std::vector y = {0.457694f, 0.455757f, 0.445489f, 0.526766f, 0.477853f, 0.608758f, 0.406654f, 0.519316f, 0.444463f, 0.465842f, 0.428262f, 0.540428f, 0.477282f, 0.638667f, 0.474591f, 0.547811f, 0.457420f, 0.470657f, 0.487116f, 0.542242f, 0.482364f, 0.617841f, 0.476829f, 0.557317f, 0.463370f, 0.432599f, 0.412642f, 0.520960f, 0.479831f, 0.589828f, 0.446331f, 0.612812f, 0.585487f, 0.538315f, 0.504264f, 0.615235f, 0.527800f, 0.515899f, 0.536401f, 0.541573f, 0.578147f, 0.544553f, 0.531175f, 0.583502f, 0.528233f, 0.518028f, 0.562917f, 0.588512f, 0.599006f, 0.525119f, 0.535656f, 0.623945f, 0.521523f, 0.515306f, 0.544257f, 0.592741f, 0.600172f, 0.529797f, 0.490615f, 0.601856f, 0.495671f, 0.500725f, 0.555493f, 0.482300f, 0.538304f, 0.469695f, 0.555198f, 0.489711f, 0.521836f, 0.485628f, 0.493937f, 0.562992f, 0.521894f, 0.489056f, 0.584299f, 0.474376f, 0.493005f, 0.475963f, 0.460919f, 0.567615f, 0.547787f, 0.466202f, 0.536014f, 0.473239f, 0.485554f, 0.498408f, 0.501733f, 0.586437f, 0.517314f, 0.440046f, 0.514271f, 0.545266f, 0.487437f, 0.481043f, 0.518498f, 0.568266f, 0.514357f, 0.572526f, 0.423650f, 0.474643f, 0.492550f, 0.533325f, 0.512998f, 0.452411f, 0.526065f, 0.535346f, 0.407074f, 0.502433f, 0.501283f, 0.528505f, 0.510491f, 0.402870f, 0.516862f, 0.596280f, 0.397160f, 0.469242f, 0.458194f, 0.537358f, 0.510243f, 0.439715f, 0.530736f, 0.580630f, 0.437646f, 0.462414f, 0.484492f, 0.477003f, 0.476393f, 0.431391f, 0.481805f, 0.420751f, 0.544359f, 0.440140f, 0.533953f, 0.453877f, 0.460864f, 0.446440f, 0.454282f, 0.416850f, 0.494072f, 0.462208f, 0.524801f, 0.453293f, 0.493179f, 0.462526f, 0.489181f, 0.452340f, 0.570383f, 0.422193f, 0.524420f, 0.468229f, 0.489729f, 0.444768f, 0.534646f, 0.457197f, 0.522207f, 0.400594f, 0.538509f, 0.489581f, 0.457599f, 0.488340f, 0.549355f, 0.482543f, 0.431908f, 0.352921f, 0.633369f, 0.690998f, 0.314418f, 0.542520f, 0.580878f, 0.489810f, 0.451832f, 0.346453f, 0.599024f, 0.630982f, 0.310195f, 0.532405f, 0.568864f, 0.486514f, 0.432211f, 0.345150f, 0.586195f, 0.659745f, 0.269926f, 0.528033f, 0.509392f, 0.511314f, 0.378251f, 0.319656f, 0.601292f, 0.726670f, 0.338636f, 0.564731f}; + std::vector present_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector present_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} +TEST(AttentionTest, Attention4DAttnIsCausal) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::vector y = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, + 0.466662f, 0.404919f, 0.206397f, 0.494597f, 0.469075f, 0.517016f, 0.457503f, 0.620147f, + 0.455868f, 0.401850f, 0.222910f, 0.498051f, 0.398273f, 0.458905f, 0.484206f, 0.678309f, + 0.428625f, 0.565862f, 0.420294f, 0.361176f, 0.366713f, 0.456673f, 0.367244f, 0.565962f, + 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, + 0.340486f, 0.554859f, 0.357655f, 0.654648f, 0.303360f, 0.468544f, 0.410813f, 0.359175f, 0.539688f, 0.388773f, 0.469414f, 0.709710f, 0.362709f, 0.429548f, 0.533266f, 0.281177f, 0.507994f, 0.419524f, 0.523713f, 0.531125f, 0.334381f, 0.418885f, 0.553995f, 0.441341f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.440199f, 0.552865f, 0.234100f, 0.465348f, 0.108484f, 0.789824f, 0.596633f, 0.505260f, 0.521296f, 0.529090f, 0.243612f, 0.596347f, 0.178938f, 0.704410f, 0.541649f, 0.663573f, 0.447473f, 0.471171f, 0.330193f, 0.440955f, 0.264086f, 0.669717f, 0.497800f, 0.570196f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.666526f, 0.680385f, 0.769414f, 0.846562f, 0.211277f, 0.124523f, 0.362721f, 0.528572f, 0.722160f, 0.763995f, 0.843738f, 0.695165f, 0.266952f, 0.132048f, 0.481567f, 0.579821f, 0.766651f, 0.587935f, 0.750237f, 0.660460f, 0.262872f, 0.142580f, 0.578552f, 0.432957f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.528620f, 0.173138f, 0.496913f, 0.687855f, 0.473097f, 0.565422f, 0.353939f, 0.499403f, 0.683711f, 0.156556f, 0.606089f, 0.441246f, 0.472192f, 0.507007f, 0.441957f, 0.457522f, 0.599108f, 0.136602f, 0.579971f, 0.504480f, 0.443634f, 0.456725f, 0.392707f, 0.395364f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.755483f, 0.623352f, 0.283909f, 0.615250f, 0.377633f, 0.544918f, 0.585578f, 0.822309f, 0.598965f, 0.584465f, 0.234792f, 0.460114f, 0.268955f, 0.677291f, 0.392800f, 0.607946f, 0.577946f, 0.470810f, 0.371437f, 0.510227f, 0.419904f, 0.671214f, 0.345365f, 0.567849f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnIsCausalBasic) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 1; // Q.shape[1] + int q_sequence_length = 3; // Q.shape[2] + int head_size = 2; // Q.shape[3] + int kv_sequence_length = 3; // K.shape[2] and V.shape[2] + int kv_num_heads = 1; // K.shape[1] and V.shape[1] + int v_head_size = 2; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector k = {1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector v = {0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector y = {0.0, 1.0, 0.6697615385055542, 1.0, 0.8022241592407227, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnIsCausalBasicFloat16) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 1; // Q.shape[1] + int q_sequence_length = 3; // Q.shape[2] + int head_size = 2; // Q.shape[3] + int kv_sequence_length = 3; // K.shape[2] and V.shape[2] + int kv_num_heads = 1; // K.shape[1] and V.shape[1] + int v_head_size = 2; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector k = {1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector v = {0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector y = {0.0, 1.0, 0.6697615385055542, 1.0, 0.8022241592407227, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnIsCausalBasicDifferentSequenceLength) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 1; // Q.shape[1] + int q_sequence_length = 3; // Q.shape[2] + int head_size = 2; // Q.shape[3] + int kv_sequence_length = 4; // K.shape[2] and V.shape[2] + int kv_num_heads = 1; // K.shape[1] and V.shape[1] + int v_head_size = 2; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {1.f, 1.f, 0.f, 1.f, 2.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector k = {1.f, 0.f, 1.f, 1.f, 1.f, 2.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 2}; + std::vector v = {0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 2.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 2}; + std::vector y = {0.0, 1.0, 0.6697615385055542, 1.0, 0.85997074842453, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DDiffHeadsWithPastAndPresent) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 10; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + // {2, 3, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + // {2, 3, 6, 8} + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 6, 10} + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + // {4, 18} + std::vector m = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f}; + // {2, 3, 12, 8} + std::vector past_key = {0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f}; + // {2, 3, 12, 10} + std::vector past_value = {0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f}; + // {2, 3, 4, 10} + std::vector y = {0.484245f, 0.491594f, 0.566765f, 0.698646f, 0.412717f, 0.529403f, 0.546576f, 0.477395f, 0.458289f, 0.526034f, 0.507523f, 0.501791f, 0.516438f, 0.666451f, 0.374304f, 0.541111f, 0.568747f, 0.520548f, 0.473141f, 0.519258f, 0.498172f, 0.514510f, 0.527296f, 0.682262f, 0.396020f, 0.501123f, 0.530399f, 0.488510f, 0.446185f, 0.542778f, 0.511414f, 0.485035f, 0.517123f, 0.684857f, 0.389196f, 0.515658f, 0.556560f, 0.526948f, 0.446624f, 0.513224f, 0.518960f, 0.522651f, 0.541202f, 0.520867f, 0.515921f, 0.390582f, 0.438142f, 0.557164f, 0.504964f, 0.579576f, 0.465363f, 0.569218f, 0.532317f, 0.551877f, 0.490628f, 0.361162f, 0.458657f, 0.568250f, 0.511133f, 0.519196f, 0.508355f, 0.532992f, 0.540742f, 0.536218f, 0.491775f, 0.346055f, 0.430588f, 0.545529f, 0.508855f, 0.534426f, 0.477742f, 0.559174f, 0.522186f, 0.518533f, 0.461976f, 0.366468f, 0.455339f, 0.541203f, 0.513318f, 0.516310f, 0.417490f, 0.509893f, 0.590295f, 0.518703f, 0.497346f, 0.569950f, 0.531036f, 0.515108f, 0.551188f, 0.511368f, 0.428004f, 0.470681f, 0.584422f, 0.481287f, 0.526080f, 0.523233f, 0.457405f, 0.481407f, 0.573666f, 0.505292f, 0.455096f, 0.488968f, 0.602769f, 0.494229f, 0.506703f, 0.531687f, 0.494376f, 0.500014f, 0.557185f, 0.516992f, 0.456706f, 0.474918f, 0.604858f, 0.507587f, 0.469668f, 0.505480f, 0.509594f, 0.501727f, 0.579587f, 0.520784f, 0.493654f, 0.421248f, 0.447569f, 0.512260f, 0.385047f, 0.415280f, 0.512025f, 0.438027f, 0.412472f, 0.566399f, 0.521616f, 0.425188f, 0.438491f, 0.497757f, 0.359007f, 0.354674f, 0.526893f, 0.436536f, 0.365545f, 0.598360f, 0.539148f, 0.414424f, 0.449425f, 0.469435f, 0.387864f, 0.398897f, 0.495746f, 0.442739f, 0.325650f, 0.565445f, 0.528260f, 0.427462f, 0.414675f, 0.471898f, 0.383976f, 0.365848f, 0.492247f, 0.412142f, 0.346633f, 0.594105f, 0.607776f, 0.533772f, 0.468197f, 0.372208f, 0.489865f, 0.443200f, 0.545535f, 0.493389f, 0.551969f, 0.423333f, 0.646158f, 0.558704f, 0.439156f, 0.446620f, 0.451905f, 0.487079f, 0.528236f, 0.561621f, 0.598777f, 0.437840f, 0.621812f, 0.514033f, 0.477342f, 0.401848f, 0.471414f, 0.463881f, 0.530019f, 0.506494f, 0.559079f, 0.454743f, 0.645883f, 0.532612f, 0.484295f, 0.429611f, 0.471412f, 0.470437f, 0.545854f, 0.509529f, 0.591309f, 0.463628f, 0.463473f, 0.428821f, 0.487303f, 0.522334f, 0.486353f, 0.659896f, 0.556700f, 0.410148f, 0.569697f, 0.495767f, 0.437882f, 0.420329f, 0.503654f, 0.527284f, 0.465816f, 0.623204f, 0.569190f, 0.413123f, 0.554353f, 0.518062f, 0.492239f, 0.410378f, 0.461884f, 0.498402f, 0.509016f, 0.682983f, 0.535407f, 0.412562f, 0.551318f, 0.498037f, 0.470375f, 0.407394f, 0.460899f, 0.496268f, 0.464923f, 0.672767f, 0.533764f, 0.427543f, 0.577909f, 0.506939f}; + // {2, 3, 18, 8} + std::vector present_key = {0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 18, 10} + std::vector present_value = {0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DGqaAttnMask) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 9; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + // {2, 9, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f}; + // {2, 3, 6, 8} + std::vector k = {0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 3, 6, 8} + std::vector v = {0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f}; + // {4, 6} + std::vector m = {0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f}; + // {2, 9, 4, 8} + std::vector y = {0.641842f, 0.667534f, 0.339592f, 0.480609f, 0.537525f, 0.340368f, 0.752882f, 0.387601f, 0.686814f, 0.643437f, 0.324983f, 0.468788f, 0.539061f, 0.319610f, 0.754181f, 0.373093f, 0.702380f, 0.693136f, 0.318406f, 0.456714f, 0.540838f, 0.315487f, 0.718291f, 0.311025f, 0.681769f, 0.670603f, 0.329705f, 0.456661f, 0.573902f, 0.337385f, 0.700597f, 0.333385f, 0.508992f, 0.253478f, 0.553979f, 0.466355f, 0.398637f, 0.412493f, 0.495810f, 0.677675f, 0.521609f, 0.278997f, 0.564189f, 0.434417f, 0.448085f, 0.467205f, 0.567856f, 0.664713f, 0.490146f, 0.261321f, 0.560582f, 0.424598f, 0.450318f, 0.467336f, 0.520983f, 0.720798f, 0.516095f, 0.264495f, 0.577940f, 0.475340f, 0.444145f, 0.477909f, 0.485663f, 0.672846f, 0.499389f, 0.402198f, 0.520218f, 0.550550f, 0.481065f, 0.730488f, 0.492535f, 0.392315f, 0.436722f, 0.398514f, 0.497457f, 0.502270f, 0.520993f, 0.730472f, 0.565429f, 0.380282f, 0.461226f, 0.392968f, 0.536035f, 0.505191f, 0.446570f, 0.751253f, 0.478584f, 0.389036f, 0.423738f, 0.443828f, 0.554323f, 0.462607f, 0.476656f, 0.733228f, 0.482219f, 0.411910f, 0.620556f, 0.662948f, 0.349409f, 0.482541f, 0.537250f, 0.351544f, 0.734285f, 0.397172f, 0.689500f, 0.637077f, 0.320710f, 0.470914f, 0.526307f, 0.312878f, 0.775762f, 0.384457f, 0.696615f, 0.681034f, 0.324383f, 0.459632f, 0.539497f, 0.317950f, 0.709736f, 0.320698f, 0.671696f, 0.676830f, 0.332387f, 0.453234f, 0.578648f, 0.345084f, 0.685369f, 0.328092f, 0.520830f, 0.251061f, 0.562824f, 0.469184f, 0.393635f, 0.405203f, 0.493565f, 0.668713f, 0.541328f, 0.282797f, 0.577903f, 0.434065f, 0.444664f, 0.460403f, 0.572628f, 0.646402f, 0.493508f, 0.265246f, 0.572078f, 0.418658f, 0.464491f, 0.483746f, 0.516536f, 0.724847f, 0.503705f, 0.270557f, 0.577678f, 0.465114f, 0.468430f, 0.508402f, 0.489087f, 0.689442f, 0.500042f, 0.410507f, 0.521381f, 0.553244f, 0.459062f, 0.719706f, 0.476571f, 0.395052f, 0.429926f, 0.408857f, 0.507006f, 0.493937f, 0.529878f, 0.728873f, 0.571495f, 0.376256f, 0.453676f, 0.380482f, 0.526100f, 0.496696f, 0.457383f, 0.761933f, 0.486657f, 0.396608f, 0.435748f, 0.432822f, 0.531763f, 0.482255f, 0.477046f, 0.726381f, 0.487480f, 0.416572f, 0.626676f, 0.683736f, 0.340657f, 0.475002f, 0.549981f, 0.353311f, 0.740157f, 0.378827f, 0.681403f, 0.636622f, 0.324593f, 0.469088f, 0.537323f, 0.321344f, 0.762506f, 0.384239f, 0.693108f, 0.683351f, 0.329873f, 0.460504f, 0.555115f, 0.325379f, 0.694659f, 0.316422f, 0.677285f, 0.670298f, 0.329724f, 0.456327f, 0.567533f, 0.337560f, 0.701396f, 0.336191f, 0.515940f, 0.251020f, 0.562035f, 0.442479f, 0.405802f, 0.410828f, 0.519841f, 0.686781f, 0.522057f, 0.285013f, 0.562761f, 0.453472f, 0.451971f, 0.481286f, 0.558322f, 0.649971f, 0.486787f, 0.258011f, 0.557963f, 0.426743f, 0.442028f, 0.457034f, 0.510534f, 0.724945f, 0.498901f, 0.272090f, 0.572650f, 0.467930f, 0.465335f, 0.506181f, 0.484559f, 0.690090f, 0.499525f, 0.398443f, 0.522291f, 0.550620f, 0.465209f, 0.731897f, 0.484389f, 0.388997f, 0.411109f, 0.420719f, 0.523354f, 0.478677f, 0.522513f, 0.723052f, 0.587358f, 0.350775f, 0.450881f, 0.384685f, 0.527140f, 0.502089f, 0.438660f, 0.749234f, 0.493312f, 0.377459f, 0.425945f, 0.432397f, 0.544111f, 0.466484f, 0.488077f, 0.738712f, 0.493642f, 0.412262f, 0.565934f, 0.795554f, 0.527262f, 0.295395f, 0.394937f, 0.326235f, 0.457519f, 0.454071f, 0.511390f, 0.753500f, 0.500815f, 0.303925f, 0.403792f, 0.343750f, 0.516333f, 0.463035f, 0.491925f, 0.753119f, 0.503555f, 0.310489f, 0.373396f, 0.334562f, 0.526486f, 0.470500f, 0.495985f, 0.733211f, 0.532951f, 0.342292f, 0.346065f, 0.355272f, 0.479542f, 0.509107f, 0.379088f, 0.582413f, 0.414383f, 0.571800f, 0.613176f, 0.687631f, 0.185596f, 0.656867f, 0.390452f, 0.532452f, 0.407547f, 0.564799f, 0.606499f, 0.653258f, 0.176547f, 0.698038f, 0.410398f, 0.604586f, 0.442972f, 0.497533f, 0.595085f, 0.732265f, 0.187201f, 0.663169f, 0.448716f, 0.590302f, 0.411879f, 0.518449f, 0.636722f, 0.695827f, 0.154292f, 0.666828f, 0.458054f, 0.608582f, 0.430376f, 0.316371f, 0.547620f, 0.542559f, 0.542043f, 0.556297f, 0.468371f, 0.559154f, 0.465195f, 0.344099f, 0.482571f, 0.527115f, 0.527529f, 0.616254f, 0.494566f, 0.605555f, 0.432360f, 0.382197f, 0.466678f, 0.556031f, 0.459313f, 0.588575f, 0.532798f, 0.597684f, 0.412305f, 0.393400f, 0.462773f, 0.491821f, 0.483189f, 0.593919f, 0.569241f, 0.793791f, 0.532988f, 0.300026f, 0.393843f, 0.327085f, 0.448199f, 0.457416f, 0.493302f, 0.725336f, 0.512066f, 0.327500f, 0.404238f, 0.351704f, 0.507818f, 0.477990f, 0.479548f, 0.756083f, 0.511730f, 0.309729f, 0.366024f, 0.338031f, 0.503335f, 0.472352f, 0.473026f, 0.696816f, 0.543129f, 0.374608f, 0.335432f, 0.360978f, 0.486364f, 0.531799f, 0.380422f, 0.599984f, 0.413640f, 0.564090f, 0.607571f, 0.708289f, 0.187551f, 0.671587f, 0.381058f, 0.550543f, 0.422336f, 0.556663f, 0.599418f, 0.666369f, 0.182365f, 0.678737f, 0.423800f, 0.600509f, 0.437094f, 0.494968f, 0.603340f, 0.727226f, 0.179659f, 0.667114f, 0.464399f, 0.563292f, 0.399716f, 0.529198f, 0.655782f, 0.666396f, 0.143497f, 0.659062f, 0.453034f, 0.596627f, 0.417365f, 0.314318f, 0.554269f, 0.518967f, 0.550250f, 0.556252f, 0.494918f, 0.587774f, 0.467566f, 0.350222f, 0.481994f, 0.538857f, 0.525631f, 0.605359f, 0.497486f, 0.608472f, 0.429145f, 0.384532f, 0.466790f, 0.554752f, 0.457698f, 0.586510f, 0.548577f, 0.604359f, 0.398097f, 0.414429f, 0.448200f, 0.485158f, 0.461395f, 0.593015f, 0.563470f, 0.796184f, 0.532783f, 0.293209f, 0.408910f, 0.327450f, 0.438028f, 0.447011f, 0.493041f, 0.739603f, 0.496957f, 0.311881f, 0.389768f, 0.352503f, 0.530113f, 0.476738f, 0.484897f, 0.752985f, 0.511921f, 0.312174f, 0.370408f, 0.339775f, 0.504061f, 0.473793f, 0.487978f, 0.714687f, 0.538817f, 0.358426f, 0.348908f, 0.355820f, 0.481380f, 0.516214f, 0.370872f, 0.602034f, 0.400225f, 0.611090f, 0.630508f, 0.662527f, 0.162489f, 0.658299f, 0.378734f, 0.537283f, 0.412214f, 0.570032f, 0.601452f, 0.653569f, 0.179932f, 0.693105f, 0.411981f, 0.605715f, 0.448022f, 0.481469f, 0.585099f, 0.748463f, 0.195177f, 0.671915f, 0.442141f, 0.581881f, 0.393362f, 0.555388f, 0.650764f, 0.665937f, 0.141141f, 0.675100f, 0.448606f, 0.605061f, 0.412183f, 0.312673f, 0.559178f, 0.530440f, 0.538275f, 0.546820f, 0.494936f, 0.585982f, 0.469875f, 0.355291f, 0.474437f, 0.542980f, 0.518181f, 0.609491f, 0.522046f, 0.618936f, 0.412090f, 0.410711f, 0.452217f, 0.540284f, 0.444109f, 0.585510f, 0.570158f, 0.614413f, 0.415425f, 0.410005f, 0.441791f, 0.491080f, 0.466021f, 0.595833f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DGqaWithPastAndPresent) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 9; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + // {2, 9, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f}; + // {2, 3, 6, 8} + std::vector k = {0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 3, 6, 8} + std::vector v = {0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f}; + // {4, 18} + std::vector m = {0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f}; + // {2, 3, 12, 8} + std::vector past_key = {0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f}; + // {2, 3, 12, 8} + std::vector past_value = {0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.481102f, 0.251523f, 0.876682f, 0.324273f, 0.924623f, 0.974787f, 0.449862f, 0.227129f, 0.291666f, 0.776334f, 0.273350f, 0.380583f, 0.478576f, 0.575111f, 0.996100f, 0.232210f, 0.353424f, 0.262891f, 0.361113f, 0.100805f, 0.359810f, 0.887865f, 0.298590f, 0.371935f}; + // {2, 9, 4, 8} + std::vector y = {0.544462f, 0.617844f, 0.506335f, 0.473482f, 0.606855f, 0.423464f, 0.544771f, 0.450451f, 0.524249f, 0.627160f, 0.497201f, 0.440288f, 0.619110f, 0.437084f, 0.563680f, 0.440037f, 0.516736f, 0.577726f, 0.523888f, 0.493471f, 0.594122f, 0.433401f, 0.585942f, 0.457686f, 0.528512f, 0.604578f, 0.472106f, 0.471486f, 0.600445f, 0.446256f, 0.622393f, 0.435442f, 0.440810f, 0.437705f, 0.476508f, 0.320820f, 0.605191f, 0.640150f, 0.306216f, 0.610947f, 0.485794f, 0.448216f, 0.485639f, 0.323744f, 0.594446f, 0.646597f, 0.321742f, 0.605751f, 0.501858f, 0.445502f, 0.487899f, 0.384660f, 0.597134f, 0.616430f, 0.331401f, 0.566459f, 0.502522f, 0.409965f, 0.526639f, 0.348601f, 0.565200f, 0.586558f, 0.325044f, 0.603422f, 0.450250f, 0.368009f, 0.550911f, 0.460338f, 0.523907f, 0.508816f, 0.575624f, 0.426601f, 0.472310f, 0.372844f, 0.517852f, 0.431688f, 0.551555f, 0.527657f, 0.600578f, 0.473069f, 0.456633f, 0.442035f, 0.539875f, 0.437863f, 0.540202f, 0.499608f, 0.556470f, 0.419831f, 0.463081f, 0.416724f, 0.526389f, 0.458654f, 0.540120f, 0.551554f, 0.569399f, 0.447102f, 0.534296f, 0.597655f, 0.509699f, 0.487167f, 0.607438f, 0.426383f, 0.522794f, 0.458435f, 0.510147f, 0.622761f, 0.501724f, 0.453386f, 0.629671f, 0.434103f, 0.582477f, 0.437681f, 0.520031f, 0.568543f, 0.525216f, 0.490370f, 0.571745f, 0.428629f, 0.572995f, 0.460086f, 0.533607f, 0.614962f, 0.474130f, 0.456345f, 0.576467f, 0.448127f, 0.599211f, 0.432252f, 0.447842f, 0.430169f, 0.480055f, 0.320521f, 0.590915f, 0.627003f, 0.314551f, 0.609320f, 0.499216f, 0.438828f, 0.485519f, 0.322134f, 0.586364f, 0.645824f, 0.326481f, 0.596989f, 0.496362f, 0.442741f, 0.492120f, 0.366111f, 0.601604f, 0.615566f, 0.326354f, 0.567173f, 0.496946f, 0.422179f, 0.533144f, 0.342588f, 0.590482f, 0.605923f, 0.318055f, 0.610401f, 0.452598f, 0.361594f, 0.550919f, 0.455099f, 0.530404f, 0.519313f, 0.588655f, 0.431890f, 0.464325f, 0.389636f, 0.515359f, 0.429087f, 0.540767f, 0.518376f, 0.586627f, 0.471074f, 0.458527f, 0.422216f, 0.537762f, 0.434123f, 0.550956f, 0.507704f, 0.564828f, 0.421548f, 0.463044f, 0.407985f, 0.523093f, 0.473684f, 0.542663f, 0.551348f, 0.576783f, 0.448743f, 0.546208f, 0.621128f, 0.501647f, 0.468191f, 0.612298f, 0.425183f, 0.549241f, 0.447622f, 0.519355f, 0.619636f, 0.487775f, 0.444259f, 0.625749f, 0.430264f, 0.584338f, 0.436887f, 0.521021f, 0.572716f, 0.522539f, 0.486440f, 0.581317f, 0.429079f, 0.579691f, 0.455426f, 0.526431f, 0.604615f, 0.476481f, 0.469814f, 0.588766f, 0.445640f, 0.609160f, 0.437785f, 0.443498f, 0.439338f, 0.487424f, 0.310942f, 0.607341f, 0.630362f, 0.312591f, 0.621999f, 0.483917f, 0.446308f, 0.477454f, 0.331028f, 0.592608f, 0.653297f, 0.322368f, 0.599377f, 0.497354f, 0.443447f, 0.477781f, 0.384002f, 0.591587f, 0.610287f, 0.328537f, 0.567630f, 0.499369f, 0.421961f, 0.536492f, 0.345379f, 0.586450f, 0.600541f, 0.312965f, 0.609437f, 0.451750f, 0.359685f, 0.553321f, 0.464992f, 0.524025f, 0.522507f, 0.582135f, 0.425124f, 0.459696f, 0.394679f, 0.519051f, 0.411226f, 0.539772f, 0.505003f, 0.587681f, 0.469383f, 0.451681f, 0.430062f, 0.541843f, 0.420929f, 0.542240f, 0.487570f, 0.567067f, 0.419708f, 0.456288f, 0.412096f, 0.527592f, 0.467870f, 0.545021f, 0.547842f, 0.573135f, 0.448166f, 0.581220f, 0.559255f, 0.469802f, 0.489935f, 0.557197f, 0.487135f, 0.377325f, 0.425637f, 0.582374f, 0.560738f, 0.425382f, 0.463129f, 0.549939f, 0.481810f, 0.350432f, 0.466049f, 0.593554f, 0.542315f, 0.482597f, 0.496969f, 0.518851f, 0.507807f, 0.366054f, 0.457476f, 0.569468f, 0.565965f, 0.444765f, 0.465404f, 0.515500f, 0.520271f, 0.337845f, 0.448357f, 0.557802f, 0.585925f, 0.426858f, 0.464044f, 0.585251f, 0.557395f, 0.433327f, 0.615342f, 0.534368f, 0.573723f, 0.426393f, 0.518102f, 0.586735f, 0.513129f, 0.371969f, 0.636735f, 0.544166f, 0.588469f, 0.433470f, 0.481894f, 0.595019f, 0.533156f, 0.396519f, 0.608115f, 0.547125f, 0.604473f, 0.441984f, 0.469765f, 0.599107f, 0.561685f, 0.347618f, 0.563457f, 0.507550f, 0.485293f, 0.545846f, 0.408434f, 0.482538f, 0.532314f, 0.498883f, 0.525126f, 0.514603f, 0.471457f, 0.539705f, 0.362410f, 0.490158f, 0.513690f, 0.494170f, 0.496909f, 0.492936f, 0.506153f, 0.565865f, 0.364727f, 0.508899f, 0.516217f, 0.558362f, 0.556920f, 0.530472f, 0.521715f, 0.554673f, 0.363830f, 0.509086f, 0.511590f, 0.552396f, 0.541486f, 0.572145f, 0.551531f, 0.471964f, 0.485188f, 0.555030f, 0.493247f, 0.376875f, 0.429387f, 0.580540f, 0.550944f, 0.435664f, 0.480675f, 0.544997f, 0.488698f, 0.344985f, 0.464878f, 0.593774f, 0.541202f, 0.484834f, 0.497316f, 0.509364f, 0.500045f, 0.357235f, 0.448933f, 0.565242f, 0.546653f, 0.459790f, 0.481954f, 0.514950f, 0.516297f, 0.344285f, 0.454476f, 0.548036f, 0.577907f, 0.427075f, 0.478978f, 0.581563f, 0.553606f, 0.426476f, 0.638442f, 0.498925f, 0.598346f, 0.444106f, 0.536998f, 0.575948f, 0.499260f, 0.371120f, 0.626981f, 0.545949f, 0.586548f, 0.428254f, 0.479753f, 0.596943f, 0.527697f, 0.401418f, 0.613028f, 0.542355f, 0.607063f, 0.447840f, 0.467102f, 0.603496f, 0.549575f, 0.364370f, 0.561534f, 0.507041f, 0.473640f, 0.547768f, 0.413960f, 0.490513f, 0.534377f, 0.497277f, 0.517772f, 0.531394f, 0.489105f, 0.531671f, 0.369343f, 0.486462f, 0.501787f, 0.494220f, 0.493498f, 0.485968f, 0.510301f, 0.559766f, 0.361474f, 0.507888f, 0.518858f, 0.564300f, 0.561990f, 0.537984f, 0.527982f, 0.539571f, 0.366920f, 0.498313f, 0.505709f, 0.538027f, 0.541246f, 0.585733f, 0.565800f, 0.441346f, 0.476255f, 0.556453f, 0.497693f, 0.363246f, 0.426799f, 0.578484f, 0.556489f, 0.436699f, 0.481177f, 0.549473f, 0.484153f, 0.355910f, 0.462010f, 0.590951f, 0.542803f, 0.470954f, 0.488994f, 0.512707f, 0.511876f, 0.358555f, 0.455953f, 0.559449f, 0.546003f, 0.462900f, 0.471080f, 0.517298f, 0.519225f, 0.345016f, 0.449149f, 0.526624f, 0.606761f, 0.427660f, 0.480775f, 0.577420f, 0.538850f, 0.426959f, 0.625509f, 0.530502f, 0.585784f, 0.432234f, 0.516800f, 0.584937f, 0.514154f, 0.373726f, 0.623740f, 0.550470f, 0.585577f, 0.436483f, 0.474799f, 0.594100f, 0.540052f, 0.402520f, 0.607686f, 0.537556f, 0.609680f, 0.439490f, 0.477886f, 0.602656f, 0.542957f, 0.350394f, 0.574553f, 0.506900f, 0.488792f, 0.539037f, 0.403028f, 0.494093f, 0.534739f, 0.494292f, 0.511628f, 0.528192f, 0.480037f, 0.546429f, 0.375120f, 0.484828f, 0.505006f, 0.495786f, 0.497935f, 0.502174f, 0.514122f, 0.541314f, 0.369540f, 0.493985f, 0.508263f, 0.550415f, 0.556157f, 0.543269f, 0.529970f, 0.562027f, 0.376526f, 0.499704f, 0.508621f, 0.536068f, 0.545993f}; + // {2, 3, 18, 8} + std::vector present_key = {0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 3, 18, 8} + std::vector present_value = {0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.481102f, 0.251523f, 0.876682f, 0.324273f, 0.924623f, 0.974787f, 0.449862f, 0.227129f, 0.291666f, 0.776334f, 0.273350f, 0.380583f, 0.478576f, 0.575111f, 0.996100f, 0.232210f, 0.353424f, 0.262891f, 0.361113f, 0.100805f, 0.359810f, 0.887865f, 0.298590f, 0.371935f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DWithPastAndPresentQkMatmul) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + // {2, 3, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + // {2, 3, 6, 8} + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 6, 8} + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {4, 18} + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + // {2, 3, 12, 8} + std::vector past_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f}; + // {2, 3, 12, 8} + std::vector past_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + + // {2, 3, 4, 8} + std::vector y = {0.457694f, 0.455757f, 0.445489f, 0.526766f, 0.477853f, 0.608758f, 0.406654f, 0.519316f, 0.444463f, 0.465842f, 0.428262f, 0.540428f, 0.477282f, 0.638667f, 0.474591f, 0.547811f, 0.457420f, 0.470657f, 0.487116f, 0.542242f, 0.482364f, 0.617841f, 0.476829f, 0.557317f, 0.463370f, 0.432599f, 0.412642f, 0.520960f, 0.479831f, 0.589828f, 0.446331f, 0.612812f, 0.585487f, 0.538315f, 0.504264f, 0.615235f, 0.527800f, 0.515899f, 0.536401f, 0.541573f, 0.578147f, 0.544553f, 0.531175f, 0.583502f, 0.528233f, 0.518028f, 0.562917f, 0.588512f, 0.599006f, 0.525119f, 0.535656f, 0.623945f, 0.521523f, 0.515306f, 0.544257f, 0.592741f, 0.600172f, 0.529797f, 0.490615f, 0.601856f, 0.495671f, 0.500725f, 0.555493f, 0.482300f, 0.538304f, 0.469695f, 0.555198f, 0.489711f, 0.521836f, 0.485628f, 0.493937f, 0.562992f, 0.521894f, 0.489056f, 0.584299f, 0.474376f, 0.493005f, 0.475963f, 0.460919f, 0.567615f, 0.547787f, 0.466202f, 0.536014f, 0.473239f, 0.485554f, 0.498408f, 0.501733f, 0.586437f, 0.517314f, 0.440046f, 0.514271f, 0.545266f, 0.487437f, 0.481043f, 0.518498f, 0.568266f, 0.514357f, 0.572526f, 0.423650f, 0.474643f, 0.492550f, 0.533325f, 0.512998f, 0.452411f, 0.526065f, 0.535346f, 0.407074f, 0.502433f, 0.501283f, 0.528505f, 0.510491f, 0.402870f, 0.516862f, 0.596280f, 0.397160f, 0.469242f, 0.458194f, 0.537358f, 0.510243f, 0.439715f, 0.530736f, 0.580630f, 0.437646f, 0.462414f, 0.484492f, 0.477003f, 0.476393f, 0.431391f, 0.481805f, 0.420751f, 0.544359f, 0.440140f, 0.533953f, 0.453877f, 0.460864f, 0.446440f, 0.454282f, 0.416850f, 0.494072f, 0.462208f, 0.524801f, 0.453293f, 0.493179f, 0.462526f, 0.489181f, 0.452340f, 0.570383f, 0.422193f, 0.524420f, 0.468229f, 0.489729f, 0.444768f, 0.534646f, 0.457197f, 0.522207f, 0.400594f, 0.538509f, 0.489581f, 0.457599f, 0.488340f, 0.549355f, 0.482543f, 0.431908f, 0.352921f, 0.633369f, 0.690998f, 0.314418f, 0.542520f, 0.580878f, 0.489810f, 0.451832f, 0.346453f, 0.599024f, 0.630982f, 0.310195f, 0.532405f, 0.568864f, 0.486514f, 0.432211f, 0.345150f, 0.586195f, 0.659745f, 0.269926f, 0.528033f, 0.509392f, 0.511314f, 0.378251f, 0.319656f, 0.601292f, 0.726670f, 0.338636f, 0.564731f}; + // {2, 3, 18, 8} + std::vector present_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 18, 8} + std::vector present_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {2, 3, 4, 18} + std::vector qk_matmul = {0.820140f, 1.059902f, 0.757718f, 0.881749f, 0.858141f, 1.036822f, 0.884175f, 0.745137f, 0.702161f, 0.857424f, 0.931616f, 0.810373f, 0.765101f, 0.618886f, 0.985434f, 1.031954f, 0.880308f, 0.622930f, 0.763532f, 0.857065f, 0.740183f, 0.789191f, 0.647322f, 0.909152f, 0.686916f, 0.854634f, 0.616661f, 0.909399f, 0.999737f, 0.690372f, 0.633938f, 0.397958f, 0.865367f, 0.924445f, 0.867537f, 0.569419f, 0.980506f, 1.169838f, 1.017614f, 1.046616f, 0.926423f, 1.190621f, 1.081360f, 0.859412f, 0.668530f, 0.881618f, 1.122157f, 0.778354f, 0.913560f, 0.629977f, 1.123444f, 1.261700f, 1.171818f, 0.666636f, 0.732417f, 0.806783f, 0.671492f, 0.704470f, 0.679564f, 0.856373f, 0.747101f, 0.574466f, 0.511335f, 0.570812f, 0.772065f, 0.486530f, 0.626328f, 0.451866f, 0.718409f, 0.895540f, 0.694231f, 0.503419f, 0.531406f, 0.847033f, 0.878291f, 0.737390f, 0.926101f, 1.027148f, 0.731989f, 0.720755f, 0.637853f, 0.523248f, 0.924757f, 0.757182f, 0.669580f, 0.979738f, 0.580251f, 1.052969f, 1.255782f, 0.775240f, 0.284305f, 0.708099f, 0.458294f, 0.381689f, 0.754442f, 0.688000f, 0.675486f, 0.683084f, 0.468356f, 0.518191f, 0.554623f, 0.658507f, 0.571695f, 0.630510f, 0.528123f, 0.531325f, 0.767081f, 0.532916f, 0.348042f, 0.636357f, 0.445687f, 0.399611f, 0.727809f, 0.686446f, 0.593512f, 0.523768f, 0.360500f, 0.423699f, 0.527520f, 0.714839f, 0.553231f, 0.662379f, 0.517964f, 0.485448f, 0.809493f, 0.494930f, 0.274371f, 0.437410f, 0.411925f, 0.342756f, 0.545288f, 0.529269f, 0.533905f, 0.380022f, 0.436475f, 0.301469f, 0.529214f, 0.526297f, 0.502613f, 0.503063f, 0.430358f, 0.614318f, 0.557536f, 0.523195f, 0.627666f, 0.646350f, 0.711912f, 0.578261f, 0.510271f, 0.666607f, 0.609787f, 0.652893f, 0.673018f, 0.618551f, 0.787326f, 1.094408f, 0.693321f, 0.857913f, 0.604598f, 0.781784f, 0.506659f, 0.587050f, 0.797275f, 0.415388f, 0.596291f, 0.560429f, 0.353030f, 0.474825f, 0.499545f, 0.677266f, 0.512789f, 0.749157f, 0.460399f, 0.860298f, 0.559970f, 0.647591f, 0.385551f, 0.412029f, 0.286456f, 0.386895f, 0.466306f, 0.448868f, 0.485777f, 0.485511f, 0.524956f, 0.380963f, 0.659871f, 0.495008f, 0.515935f, 0.440779f, 0.441189f, 0.658574f, 0.476000f, 0.713140f, 0.389744f, 0.417265f, 0.369560f, 0.531347f, 0.798962f, 0.607254f, 0.635098f, 0.675595f, 0.504633f, 0.579773f, 0.825966f, 0.745334f, 0.850824f, 0.713222f, 0.417185f, 0.949167f, 0.538440f, 0.917125f, 0.311825f, 0.475121f, 0.418353f, 0.698230f, 0.553783f, 0.653118f, 0.479333f, 0.683333f, 0.611400f, 0.926136f, 0.937356f, 1.079461f, 0.500571f, 0.941776f, 0.571910f, 0.891547f, 0.471507f, 0.728790f, 0.757396f, 0.784496f, 0.757036f, 0.999690f, 0.542418f, 0.841219f, 0.709393f, 0.945488f, 0.605568f, 1.000231f, 0.913339f, 1.138695f, 0.564313f, 1.077245f, 0.676031f, 0.922692f, 0.458828f, 0.738062f, 0.805418f, 0.864807f, 0.792745f, 1.025324f, 0.755005f, 0.867548f, 0.634732f, 0.905661f, 0.776584f, 1.184950f, 1.140206f, 1.327115f, 0.665969f, 1.196436f, 0.815515f, 1.206247f, 0.621079f, 0.985172f, 0.879408f, 1.054329f, 1.023972f, 1.311348f, 0.430584f, 0.838594f, 0.577089f, 0.887826f, 0.637326f, 0.838023f, 0.852760f, 0.930619f, 0.596678f, 1.004560f, 0.556861f, 0.837758f, 0.499217f, 0.764351f, 0.711010f, 0.774022f, 0.933743f, 0.958043f, 0.587815f, 0.233866f, 0.638163f, 0.785593f, 0.772991f, 0.770025f, 0.862170f, 0.414778f, 0.518855f, 0.729107f, 0.683017f, 0.903488f, 0.660502f, 0.396731f, 0.558027f, 0.342514f, 0.418391f, 0.680441f, 0.667967f, 0.467863f, 0.921835f, 0.926976f, 0.997494f, 1.115404f, 1.154781f, 0.618698f, 0.888651f, 1.045274f, 1.019208f, 1.253905f, 0.983391f, 0.622483f, 0.921609f, 0.369652f, 0.702290f, 1.012872f, 0.884131f, 0.593858f, 0.802401f, 1.081408f, 1.169599f, 1.146572f, 1.132834f, 0.866719f, 1.021105f, 0.884109f, 1.029369f, 1.321895f, 0.973822f, 0.871383f, 1.125121f, 0.518882f, 0.912889f, 0.876105f, 0.555648f, 0.496401f, 0.582726f, 0.730206f, 0.806009f, 0.858020f, 0.827912f, 0.515117f, 0.715055f, 0.533599f, 0.810529f, 0.887599f, 0.607516f, 0.668702f, 0.905358f, 0.279895f, 0.740854f, 0.538839f, 0.824322f, 0.920016f, 0.791579f, 0.844334f, 0.618349f, 0.989377f, 1.120477f, 0.554956f, 0.683589f, 1.280705f, 0.957804f, 0.833027f, 0.763301f, 0.786487f, 0.915324f, 0.941565f, 0.777569f, 1.361176f, 0.508790f, 0.424516f, 0.573465f, 0.405641f, 0.526471f, 0.626492f, 0.534790f, 0.428795f, 0.388423f, 0.689702f, 0.260757f, 0.438301f, 0.479575f, 0.640056f, 0.682344f, 0.519170f, 0.436916f, 0.774498f, 0.534469f, 0.702171f, 0.684503f, 0.648164f, 0.754539f, 0.828688f, 0.623366f, 0.500542f, 0.560133f, 1.098588f, 0.498203f, 0.465793f, 0.656601f, 0.886137f, 0.751770f, 0.533794f, 0.483658f, 1.098963f, 0.733365f, 0.808374f, 0.764603f, 0.755506f, 0.638693f, 0.946285f, 1.001370f, 0.578989f, 0.603487f, 1.074992f, 0.697424f, 0.812599f, 0.717330f, 0.770067f, 1.006811f, 0.783151f, 0.647946f, 1.193171f}; + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + ASSERT_EQ(qk_matmul.size(), batch_size * kv_num_heads * q_sequence_length * (past_sequence_length + kv_sequence_length)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, 0, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); + + qk_matmul = std::vector{1.786287f, 1.851782f, 1.433406f, 1.126638f, 1.074598f, 1.202869f, 1.806932f, 1.039214f, 1.155254f, 1.351381f, 1.709788f, 1.654608f, 0.904174f, 1.045790f, 1.828289f, 1.849986f, 0.982722f, 0.779313f, 1.067731f, 0.932425f, 1.164846f, 0.896809f, 1.215540f, 1.155709f, 1.283348f, 0.972161f, 1.592545f, 1.841960f, 1.391534f, 0.932551f, 0.884336f, 0.881353f, 0.905360f, 1.564150f, 1.275840f, 0.946826f, 1.789871f, 1.878873f, 1.971947f, 1.398552f, 1.823965f, 1.960587f, 1.438784f, 1.481077f, 0.957099f, 1.756017f, 1.234584f, 0.990787f, 1.096593f, 1.033003f, 1.868677f, 1.788607f, 1.659495f, 0.667182f, 1.157819f, 0.870338f, 0.879745f, 1.636864f, 0.894962f, 1.714711f, 1.549994f, 0.733612f, 1.117046f, 0.686474f, 1.499953f, 1.123992f, 1.438267f, 0.931251f, 1.633272f, 0.944889f, 0.987120f, 1.218472f, 1.497553f, 1.638913f, 1.553980f, 0.982279f, 1.142558f, 1.193196f, 1.654746f, 1.014832f, 1.090946f, 1.017206f, 1.702928f, 1.601417f, 0.808653f, 1.406642f, 1.423106f, 1.871002f, 1.358196f, 0.931623f, 0.588504f, 0.783458f, 0.882957f, 0.489307f, 1.322660f, 0.934557f, 1.271919f, 0.800610f, 1.444240f, 1.450752f, 0.946420f, 0.900686f, 0.822093f, 1.113904f, 0.568116f, 1.171030f, 1.175384f, 0.910323f, 1.157407f, 1.345392f, 1.400021f, 0.751548f, 1.625352f, 1.456414f, 0.950937f, 1.145433f, 0.649070f, 1.298100f, 0.639947f, 0.927273f, 0.736265f, 1.065406f, 1.263197f, 1.012355f, 1.297169f, 0.495477f, 0.699773f, 0.500964f, 0.620178f, 1.275150f, 0.760687f, 1.387608f, 1.336798f, 0.539168f, 1.042187f, 0.417132f, 1.257103f, 1.163759f, 1.314552f, 0.982448f, 1.345221f, 0.663667f, 0.850426f, 1.238248f, 1.593812f, 1.438230f, 1.387601f, 0.823150f, 0.726727f, 0.832655f, 1.532544f, 0.946970f, 1.126112f, 1.112509f, 1.565497f, 1.938642f, 0.832394f, 1.284816f, 1.447452f, 1.599816f, 0.609072f, 0.743433f, 1.101475f, 0.490747f, 1.020954f, 0.668047f, 0.921248f, 0.721382f, 1.095978f, 0.794792f, 1.488673f, 1.681718f, 0.852196f, 1.102478f, 0.810369f, 1.130985f, 0.425544f, 1.051735f, 0.694759f, 0.764302f, 1.275671f, 1.157903f, 1.440112f, 0.837447f, 1.422500f, 1.150930f, 1.017296f, 1.116673f, 0.804505f, 1.315179f, 0.553615f, 0.871008f, 0.659033f, 1.116166f, 1.134977f, 0.944172f, 0.857236f, 0.531893f, 1.224364f, 0.670808f, 0.843351f, 1.607988f, 0.720031f, 1.438111f, 1.628858f, 0.904480f, 1.456536f, 0.828884f, 1.145072f, 1.586629f, 1.350379f, 1.396510f, 1.226688f, 0.524469f, 0.711242f, 1.413283f, 1.519931f, 1.444998f, 1.155023f, 0.928222f, 0.827857f, 1.092185f, 1.860113f, 1.373539f, 0.953664f, 1.435734f, 1.350082f, 1.735783f, 0.610580f, 1.155694f, 1.600251f, 1.602529f, 0.859450f, 1.156073f, 0.846617f, 0.916578f, 1.134056f, 1.053106f, 1.173786f, 1.246788f, 1.509772f, 1.256221f, 1.540197f, 2.009806f, 1.067828f, 1.164871f, 0.709226f, 1.221456f, 0.845411f, 1.504512f, 1.201048f, 1.402731f, 1.564370f, 1.576583f, 1.589067f, 1.257597f, 1.674126f, 1.954917f, 1.497631f, 1.948780f, 0.954539f, 2.070836f, 0.927942f, 1.418681f, 0.804113f, 1.388198f, 1.624642f, 1.581236f, 1.511648f, 1.311894f, 0.855986f, 0.902148f, 0.785342f, 1.820220f, 0.852723f, 1.696361f, 1.655653f, 1.089764f, 1.202390f, 1.120222f, 1.284748f, 1.475221f, 1.311156f, 1.243736f, 1.625873f, 0.823371f, 1.226631f, 1.673096f, 1.553962f, 1.025746f, 1.313852f, 1.030482f, 0.989448f, 0.936074f, 1.784927f, 0.708855f, 0.971949f, 1.223065f, 1.461189f, 1.747723f, 0.799575f, 0.823636f, 1.400882f, 1.160547f, 0.520804f, 0.836825f, 0.972166f, 0.543222f, 1.346498f, 1.034594f, 1.565712f, 1.361961f, 1.751214f, 0.736224f, 1.864534f, 1.977835f, 1.411005f, 1.496084f, 1.233789f, 1.105877f, 0.961602f, 1.009357f, 1.110593f, 1.390279f, 1.693497f, 1.302893f, 1.756735f, 1.433344f, 2.067142f, 1.916540f, 1.490259f, 1.488384f, 1.309675f, 1.758509f, 1.141796f, 1.534330f, 1.156855f, 1.274409f, 1.870354f, 1.045789f, 1.400564f, 0.876651f, 0.981051f, 0.559955f, 0.790979f, 1.662600f, 1.021407f, 1.716358f, 1.630805f, 0.674263f, 1.320767f, 0.649261f, 1.538417f, 1.525061f, 1.419455f, 1.148088f, 1.820221f, 0.329244f, 1.033743f, 1.253892f, 1.790469f, 1.711897f, 1.467268f, 1.089224f, 0.834806f, 1.155425f, 2.043234f, 0.849033f, 1.136683f, 1.774663f, 1.735976f, 1.677263f, 0.902375f, 1.213391f, 1.758179f, 1.759598f, 0.879983f, 1.517559f, 0.812989f, 0.499876f, 0.998129f, 0.513259f, 1.094689f, 0.873050f, 1.131224f, 0.546321f, 1.364307f, 1.622263f, 0.652555f, 0.680481f, 0.729973f, 1.123450f, 0.722337f, 1.158875f, 0.845219f, 1.151906f, 1.343835f, 1.411206f, 1.638837f, 1.000100f, 1.652081f, 1.598655f, 0.980791f, 1.122207f, 0.848703f, 1.972988f, 0.610630f, 0.678227f, 0.839634f, 1.289163f, 1.497003f, 1.060701f, 0.971334f, 1.099509f, 1.158767f, 0.871929f, 0.972856f, 1.687900f, 0.854091f, 1.804623f, 1.804263f, 0.738135f, 1.209199f, 1.190654f, 1.425313f, 1.450061f, 1.529269f, 1.249452f, 1.921674f, 0.832500f, 0.940835f, 1.908224f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, 1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, 2, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); + + qk_matmul = std::vector{0.079204f, 0.084565f, 0.055653f, 0.040951f, 0.038874f, 0.044195f, 0.080856f, 0.037523f, 0.042140f, 0.051271f, 0.073371f, 0.069432f, 0.032783f, 0.037770f, 0.082601f, 0.084413f, 0.035462f, 0.028935f, 0.048528f, 0.042386f, 0.053477f, 0.040903f, 0.056258f, 0.052990f, 0.060205f, 0.044104f, 0.082018f, 0.105252f, 0.067083f, 0.042392f, 0.040396f, 0.040276f, 0.041254f, 0.079722f, 0.059754f, 0.043001f, 0.069900f, 0.076406f, 0.083859f, 0.047264f, 0.072324f, 0.082912f, 0.049204f, 0.051330f, 0.030395f, 0.067573f, 0.040116f, 0.031437f, 0.034945f, 0.032792f, 0.075631f, 0.069811f, 0.061356f, 0.022746f, 0.052157f, 0.039125f, 0.039495f, 0.084209f, 0.040101f, 0.091026f, 0.077202f, 0.034126f, 0.050073f, 0.032554f, 0.073434f, 0.050422f, 0.069041f, 0.041583f, 0.083907f, 0.042154f, 0.043972f, 0.055418f, 0.062936f, 0.072492f, 0.066589f, 0.037594f, 0.044129f, 0.046421f, 0.073649f, 0.038838f, 0.041909f, 0.038930f, 0.077284f, 0.069824f, 0.031602f, 0.057467f, 0.058421f, 0.091429f, 0.054749f, 0.035737f, 0.036234f, 0.044034f, 0.048640f, 0.032812f, 0.075502f, 0.051216f, 0.071766f, 0.044795f, 0.085263f, 0.085820f, 0.051827f, 0.049510f, 0.045768f, 0.061277f, 0.035503f, 0.064879f, 0.065162f, 0.049990f, 0.057976f, 0.069967f, 0.073895f, 0.038636f, 0.092571f, 0.078182f, 0.047161f, 0.057286f, 0.034872f, 0.066735f, 0.034556f, 0.046058f, 0.038050f, 0.052880f, 0.064446f, 0.050148f, 0.066673f, 0.029907f, 0.040424f, 0.033136f, 0.037332f, 0.071867f, 0.042963f, 0.080421f, 0.076436f, 0.034427f, 0.056931f, 0.030472f, 0.070581f, 0.064291f, 0.074755f, 0.053630f, 0.077083f, 0.038991f, 0.046997f, 0.069263f, 0.077018f, 0.065921f, 0.062667f, 0.035637f, 0.032361f, 0.035977f, 0.072441f, 0.040334f, 0.048247f, 0.047595f, 0.074868f, 0.108730f, 0.035968f, 0.056545f, 0.066532f, 0.077482f, 0.028769f, 0.032906f, 0.062422f, 0.033892f, 0.057593f, 0.040467f, 0.052127f, 0.042684f, 0.062080f, 0.045935f, 0.091938f, 0.111515f, 0.048649f, 0.062485f, 0.046656f, 0.064291f, 0.031753f, 0.059393f, 0.041563f, 0.044556f, 0.069887f, 0.062123f, 0.082378f, 0.045090f, 0.080940f, 0.061691f, 0.053974f, 0.059613f, 0.043629f, 0.072703f, 0.033948f, 0.046629f, 0.037722f, 0.059583f, 0.060715f, 0.050168f, 0.045991f, 0.033218f, 0.056448f, 0.032452f, 0.038564f, 0.082843f, 0.034089f, 0.069900f, 0.084590f, 0.040994f, 0.071200f, 0.038010f, 0.052145f, 0.081092f, 0.064029f, 0.067052f, 0.056579f, 0.028034f, 0.033791f, 0.068186f, 0.068271f, 0.063343f, 0.047398f, 0.037780f, 0.034172f, 0.044511f, 0.095935f, 0.058974f, 0.038754f, 0.062758f, 0.057607f, 0.084719f, 0.027499f, 0.047430f, 0.073981f, 0.074150f, 0.035269f, 0.047448f, 0.036752f, 0.039415f, 0.048991f, 0.045181f, 0.050976f, 0.054837f, 0.071332f, 0.055356f, 0.073536f, 0.117610f, 0.045851f, 0.050524f, 0.032034f, 0.053465f, 0.036708f, 0.070958f, 0.052385f, 0.064091f, 0.057214f, 0.057917f, 0.058645f, 0.042099f, 0.063851f, 0.084550f, 0.053520f, 0.084033f, 0.031093f, 0.094942f, 0.030276f, 0.049457f, 0.026750f, 0.047972f, 0.060768f, 0.058187f, 0.054276f, 0.044448f, 0.035207f, 0.036870f, 0.032806f, 0.092340f, 0.035092f, 0.081583f, 0.078329f, 0.044479f, 0.049782f, 0.045855f, 0.054055f, 0.065397f, 0.055502f, 0.051883f, 0.076030f, 0.034077f, 0.051003f, 0.079707f, 0.080020f, 0.047184f, 0.062939f, 0.047408f, 0.045502f, 0.043137f, 0.100811f, 0.034370f, 0.044713f, 0.057477f, 0.072930f, 0.097129f, 0.037633f, 0.038550f, 0.068662f, 0.053994f, 0.028478f, 0.039062f, 0.038495f, 0.025068f, 0.055973f, 0.040975f, 0.069692f, 0.056845f, 0.083897f, 0.030405f, 0.093963f, 0.105236f, 0.059703f, 0.065004f, 0.050007f, 0.044003f, 0.038091f, 0.039954f, 0.044211f, 0.058478f, 0.065917f, 0.044603f, 0.070220f, 0.050818f, 0.095779f, 0.082388f, 0.053794f, 0.053693f, 0.044906f, 0.070345f, 0.037966f, 0.056218f, 0.038542f, 0.043350f, 0.078669f, 0.034491f, 0.049179f, 0.029124f, 0.042079f, 0.027618f, 0.034795f, 0.083187f, 0.043812f, 0.087782f, 0.080584f, 0.030962f, 0.059102f, 0.030197f, 0.073473f, 0.072498f, 0.065232f, 0.049729f, 0.097389f, 0.021927f, 0.044356f, 0.055279f, 0.076017f, 0.070273f, 0.055023f, 0.037702f, 0.029233f, 0.040282f, 0.097878f, 0.029652f, 0.039534f, 0.074825f, 0.071985f, 0.067881f, 0.031276f, 0.042686f, 0.073602f, 0.073706f, 0.030584f, 0.057861f, 0.047710f, 0.034884f, 0.057413f, 0.035354f, 0.063233f, 0.050663f, 0.065586f, 0.036542f, 0.082802f, 0.107169f, 0.040638f, 0.041789f, 0.043909f, 0.065079f, 0.043575f, 0.067425f, 0.049272f, 0.066957f, 0.059910f, 0.064085f, 0.080467f, 0.042483f, 0.081539f, 0.077297f, 0.041671f, 0.048000f, 0.036514f, 0.112392f, 0.028779f, 0.030791f, 0.036185f, 0.056722f, 0.069826f, 0.045137f, 0.041278f, 0.046923f, 0.044357f, 0.033296f, 0.036832f, 0.075295f, 0.032707f, 0.084617f, 0.084586f, 0.029126f, 0.046652f, 0.045794f, 0.057906f, 0.059357f, 0.064250f, 0.048568f, 0.095124f, 0.032009f, 0.035671f, 0.093853f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, 3, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); + + y = std::vector{0.466021f, 0.458662f, 0.433769f, 0.544055f, 0.483743f, 0.601701f, 0.452252f, 0.558874f, 0.462717f, 0.462769f, 0.429452f, 0.544879f, 0.480609f, 0.607708f, 0.462766f, 0.570020f, 0.465546f, 0.464215f, 0.442318f, 0.544785f, 0.481242f, 0.599103f, 0.465833f, 0.567976f, 0.466527f, 0.450295f, 0.420681f, 0.541622f, 0.478068f, 0.592818f, 0.453533f, 0.586057f, 0.586788f, 0.542723f, 0.521934f, 0.605385f, 0.523076f, 0.515204f, 0.538008f, 0.539990f, 0.580554f, 0.544345f, 0.524057f, 0.593493f, 0.520281f, 0.513084f, 0.549197f, 0.556567f, 0.590750f, 0.536522f, 0.528383f, 0.608365f, 0.523467f, 0.511267f, 0.533588f, 0.556113f, 0.589547f, 0.537869f, 0.512585f, 0.601047f, 0.507374f, 0.511124f, 0.547465f, 0.512627f, 0.537318f, 0.460441f, 0.540844f, 0.491120f, 0.495359f, 0.476360f, 0.487767f, 0.575867f, 0.522542f, 0.469555f, 0.552479f, 0.488850f, 0.498227f, 0.480921f, 0.484224f, 0.563258f, 0.536463f, 0.455656f, 0.529199f, 0.484251f, 0.487531f, 0.482517f, 0.496116f, 0.576080f, 0.527226f, 0.455449f, 0.525402f, 0.516090f, 0.487896f, 0.477256f, 0.499739f, 0.574474f, 0.520127f, 0.578615f, 0.430572f, 0.471035f, 0.475543f, 0.515079f, 0.488231f, 0.438589f, 0.525065f, 0.569547f, 0.430350f, 0.477609f, 0.478081f, 0.515330f, 0.479993f, 0.427992f, 0.520505f, 0.584227f, 0.430333f, 0.470616f, 0.468772f, 0.517313f, 0.478180f, 0.435562f, 0.527655f, 0.580609f, 0.440415f, 0.475648f, 0.474939f, 0.501466f, 0.474016f, 0.433277f, 0.489508f, 0.425301f, 0.542249f, 0.446878f, 0.532601f, 0.462732f, 0.460696f, 0.462333f, 0.480973f, 0.421038f, 0.522864f, 0.446350f, 0.525882f, 0.466933f, 0.459678f, 0.470179f, 0.485580f, 0.431242f, 0.545418f, 0.440407f, 0.527849f, 0.471587f, 0.464982f, 0.464551f, 0.502461f, 0.437563f, 0.528884f, 0.426691f, 0.531206f, 0.480744f, 0.460218f, 0.480733f, 0.543597f, 0.506559f, 0.419551f, 0.372524f, 0.622818f, 0.678228f, 0.309035f, 0.543150f, 0.561392f, 0.501923f, 0.420097f, 0.368626f, 0.607674f, 0.661294f, 0.315077f, 0.540017f, 0.552392f, 0.506226f, 0.409681f, 0.376208f, 0.608944f, 0.674258f, 0.301188f, 0.537046f, 0.536986f, 0.515894f, 0.402735f, 0.364314f, 0.612694f, 0.684161f, 0.315733f, 0.553979f}; + qk_matmul = std::vector{0.945367f, 0.951913f, 0.892363f, 0.809865f, 0.791187f, 0.834528f, 0.947519f, 0.777578f, 0.819487f, 0.874379f, 0.936622f, 0.929487f, 0.718324f, 0.780164f, 0.949658f, 0.951745f, 0.754242f, 0.652312f, 0.788605f, 0.731722f, 0.822613f, 0.714741f, 0.838334f, 0.819636f, 0.857374f, 0.749652f, 0.920539f, 0.950983f, 0.883508f, 0.731781f, 0.708585f, 0.707096f, 0.718898f, 0.916090f, 0.855373f, 0.738343f, 0.945747f, 0.954392f, 0.961991f, 0.885038f, 0.949232f, 0.961135f, 0.893453f, 0.901670f, 0.742980f, 0.942057f, 0.843904f, 0.757698f, 0.799272f, 0.775110f, 0.953474f, 0.945613f, 0.930149f, 0.583123f, 0.820328f, 0.701546f, 0.706292f, 0.927033f, 0.713836f, 0.937223f, 0.913785f, 0.625270f, 0.806539f, 0.595712f, 0.905140f, 0.808953f, 0.893348f, 0.731177f, 0.926526f, 0.737460f, 0.756132f, 0.839203f, 0.904705f, 0.927320f, 0.914440f, 0.754051f, 0.815274f, 0.831567f, 0.929506f, 0.767753f, 0.797223f, 0.768726f, 0.935774f, 0.921882f, 0.668846f, 0.886779f, 0.890245f, 0.953685f, 0.875974f, 0.731350f, 0.528819f, 0.654687f, 0.707898f, 0.453666f, 0.867444f, 0.732712f, 0.854317f, 0.664378f, 0.894548f, 0.895841f, 0.738158f, 0.716632f, 0.676207f, 0.805438f, 0.513974f, 0.824602f, 0.825990f, 0.721287f, 0.820193f, 0.872961f, 0.885356f, 0.636072f, 0.925397f, 0.896954f, 0.740207f, 0.816236f, 0.571043f, 0.861233f, 0.564864f, 0.729320f, 0.626883f, 0.787724f, 0.851943f, 0.766734f, 0.860993f, 0.458553f, 0.604224f, 0.462875f, 0.551252f, 0.855187f, 0.641481f, 0.882643f, 0.870901f, 0.492358f, 0.778750f, 0.394511f, 0.850263f, 0.822261f, 0.865423f, 0.754124f, 0.872921f, 0.580799f, 0.691292f, 0.844955f, 0.920732f, 0.893341f, 0.882642f, 0.676781f, 0.621059f, 0.681899f, 0.910859f, 0.738408f, 0.809684f, 0.804947f, 0.916307f, 0.959426f, 0.681760f, 0.857763f, 0.895188f, 0.921641f, 0.543474f, 0.631215f, 0.801028f, 0.454809f, 0.770255f, 0.583694f, 0.726487f, 0.617765f, 0.799050f, 0.661115f, 0.903080f, 0.933084f, 0.692215f, 0.801387f, 0.669793f, 0.811356f, 0.401591f, 0.782480f, 0.601031f, 0.643604f, 0.855327f, 0.820355f, 0.893720f, 0.684454f, 0.890119f, 0.818062f, 0.768763f, 0.806408f, 0.666548f, 0.865580f, 0.503225f, 0.701886f, 0.577719f, 0.806231f, 0.812716f, 0.737133f, 0.694831f, 0.486827f, 0.840937f, 0.585511f, 0.687580f, 0.922862f, 0.616929f, 0.893317f, 0.925899f, 0.718472f, 0.896978f, 0.679876f, 0.816115f, 0.919631f, 0.874143f, 0.884595f, 0.841616f, 0.481142f, 0.611455f, 0.888189f, 0.908686f, 0.894699f, 0.819411f, 0.729764f, 0.679323f, 0.797674f, 0.952689f, 0.879496f, 0.741438f, 0.892836f, 0.874073f, 0.939736f, 0.544535f, 0.819632f, 0.921706f, 0.922048f, 0.695974f, 0.819756f, 0.689298f, 0.724275f, 0.812403f, 0.783011f, 0.825482f, 0.847380f, 0.906899f, 0.850019f, 0.912154f, 0.964714f, 0.788641f, 0.822621f, 0.610191f, 0.840083f, 0.688664f, 0.905960f, 0.833974f, 0.885940f, 0.916126f, 0.918067f, 0.920006f, 0.850400f, 0.932095f, 0.960700f, 0.904719f, 0.960224f, 0.741831f, 0.968705f, 0.729633f, 0.889323f, 0.666330f, 0.882774f, 0.925295f, 0.918795f, 0.907231f, 0.864754f, 0.694184f, 0.717342f, 0.655762f, 0.948860f, 0.692490f, 0.934953f, 0.929629f, 0.796792f, 0.834382f, 0.807646f, 0.857745f, 0.900569f, 0.864568f, 0.846518f, 0.925472f, 0.676900f, 0.841599f, 0.931960f, 0.914437f, 0.772197f, 0.865247f, 0.774102f, 0.757127f, 0.733413f, 0.945223f, 0.609958f, 0.749560f, 0.840556f, 0.897883f, 0.941116f, 0.663799f, 0.677044f, 0.885542f, 0.821218f, 0.478321f, 0.684124f, 0.749655f, 0.495423f, 0.873224f, 0.775744f, 0.916341f, 0.876847f, 0.941513f, 0.626858f, 0.953096f, 0.962428f, 0.887707f, 0.904438f, 0.843675f, 0.802600f, 0.744991f, 0.765496f, 0.804272f, 0.883232f, 0.934591f, 0.862466f, 0.942137f, 0.892350f, 0.968477f, 0.957631f, 0.903372f, 0.903027f, 0.864193f, 0.942336f, 0.815018f, 0.911163f, 0.820012f, 0.854988f, 0.953626f, 0.780164f, 0.885474f, 0.704738f, 0.753520f, 0.507944f, 0.658964f, 0.930567f, 0.770439f, 0.937423f, 0.926176f, 0.587777f, 0.866974f, 0.571172f, 0.911854f, 0.909576f, 0.889485f, 0.817120f, 0.948860f, 0.317842f, 0.775405f, 0.849371f, 0.945810f, 0.936880f, 0.899055f, 0.796595f, 0.683048f, 0.819543f, 0.966958f, 0.690564f, 0.813294f, 0.944118f, 0.939758f, 0.932505f, 0.717452f, 0.837694f, 0.942299f, 0.942458f, 0.706411f, 0.908271f, 0.671236f, 0.462019f, 0.760807f, 0.472481f, 0.798583f, 0.702920f, 0.811438f, 0.497758f, 0.877388f, 0.924952f, 0.573387f, 0.591832f, 0.623049f, 0.808766f, 0.618355f, 0.820673f, 0.688564f, 0.818385f, 0.872590f, 0.887750f, 0.927310f, 0.761636f, 0.929143f, 0.921466f, 0.753408f, 0.808335f, 0.690391f, 0.962069f, 0.544571f, 0.590366f, 0.685615f, 0.858907f, 0.904605f, 0.785932f, 0.749290f, 0.800322f, 0.820638f, 0.702353f, 0.749957f, 0.933879f, 0.693201f, 0.947283f, 0.947246f, 0.628017f, 0.836439f, 0.830782f, 0.890702f, 0.895705f, 0.910299f, 0.848130f, 0.958055f, 0.681816f, 0.735606f, 0.956936f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, 2, std::numeric_limits::quiet_NaN(), 1.f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention3DWithPastAndPresentQkMatmul) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + // {2, 4, 24} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + // {2, 6, 24} + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 6, 24} + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {4, 18} + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + // {2, 3, 12, 8} + std::vector past_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f}; + // {2, 3, 12, 8} + std::vector past_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + + // {2, 4, 24} + std::vector y = {0.387434f, 0.451660f, 0.466422f, 0.473844f, 0.487732f, 0.616663f, 0.389945f, 0.474446f, 0.610035f, 0.540721f, 0.465339f, 0.659275f, 0.542400f, 0.558199f, 0.496998f, 0.580479f, 0.608613f, 0.454357f, 0.591427f, 0.539400f, 0.491600f, 0.439752f, 0.574766f, 0.534788f, 0.369295f, 0.476453f, 0.472667f, 0.474934f, 0.484975f, 0.653894f, 0.434421f, 0.507237f, 0.606547f, 0.512561f, 0.492485f, 0.627438f, 0.547220f, 0.559142f, 0.549041f, 0.650326f, 0.576993f, 0.484612f, 0.597630f, 0.527508f, 0.458643f, 0.432526f, 0.522555f, 0.581898f, 0.375984f, 0.479550f, 0.484624f, 0.506722f, 0.499591f, 0.628391f, 0.457767f, 0.484544f, 0.612554f, 0.547468f, 0.485806f, 0.634928f, 0.524544f, 0.542711f, 0.529978f, 0.645564f, 0.613958f, 0.471193f, 0.571000f, 0.499555f, 0.454844f, 0.456024f, 0.567122f, 0.580956f, 0.367353f, 0.449829f, 0.439545f, 0.467891f, 0.516863f, 0.600392f, 0.405625f, 0.505181f, 0.632177f, 0.541634f, 0.449302f, 0.641351f, 0.504706f, 0.533341f, 0.527675f, 0.566799f, 0.572756f, 0.403738f, 0.539009f, 0.570743f, 0.478912f, 0.426711f, 0.567812f, 0.569001f, 0.495478f, 0.510849f, 0.388839f, 0.497814f, 0.545673f, 0.571958f, 0.453011f, 0.440750f, 0.458974f, 0.457386f, 0.506820f, 0.500591f, 0.499766f, 0.469500f, 0.465457f, 0.482146f, 0.581360f, 0.481272f, 0.463336f, 0.277110f, 0.627647f, 0.672684f, 0.342731f, 0.533800f, 0.530251f, 0.504140f, 0.385565f, 0.520337f, 0.548283f, 0.549735f, 0.473426f, 0.404586f, 0.463533f, 0.448576f, 0.497032f, 0.524322f, 0.474570f, 0.430653f, 0.498514f, 0.465629f, 0.578306f, 0.489042f, 0.491176f, 0.239511f, 0.588495f, 0.640517f, 0.319799f, 0.521414f, 0.510868f, 0.564625f, 0.348291f, 0.465071f, 0.498481f, 0.557391f, 0.469662f, 0.433203f, 0.471745f, 0.483765f, 0.520633f, 0.501991f, 0.485003f, 0.471836f, 0.500727f, 0.477256f, 0.574286f, 0.472931f, 0.487446f, 0.259796f, 0.603843f, 0.658305f, 0.303291f, 0.520652f, 0.560815f, 0.513931f, 0.418469f, 0.482361f, 0.535024f, 0.506256f, 0.440027f, 0.428132f, 0.519530f, 0.520400f, 0.482710f, 0.517258f, 0.479400f, 0.442196f, 0.466145f, 0.508808f, 0.534070f, 0.488154f, 0.483878f, 0.234783f, 0.628834f, 0.685886f, 0.369073f, 0.545753f}; + // {2, 3, 18, 8} + std::vector present_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 18, 8} + std::vector present_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {2, 3, 4, 18} + std::vector qk_matmul = {0.820140f, 1.059902f, 0.757718f, 0.881749f, 0.858141f, 1.036822f, 0.884175f, 0.745137f, 0.702161f, 0.857424f, 0.931616f, 0.810373f, 0.765101f, 1.031954f, 0.676118f, 1.049585f, 0.679454f, 0.781211f, 0.732417f, 0.806783f, 0.671492f, 0.704470f, 0.679564f, 0.856373f, 0.747101f, 0.574466f, 0.511335f, 0.570812f, 0.772065f, 0.486530f, 0.626328f, 0.895540f, 0.426428f, 0.830139f, 0.518625f, 0.578420f, 0.491913f, 0.536788f, 0.566909f, 0.660403f, 0.508000f, 0.745048f, 0.542980f, 0.637834f, 0.427056f, 0.598455f, 0.656768f, 0.504709f, 0.485053f, 0.649462f, 0.553231f, 0.485448f, 0.577920f, 0.466000f, 0.399496f, 0.637952f, 0.382979f, 0.665599f, 0.527650f, 0.680828f, 0.511044f, 0.664769f, 0.654046f, 0.736594f, 0.645048f, 0.671768f, 0.524199f, 0.519912f, 0.615914f, 0.647178f, 0.559970f, 0.412029f, 0.492759f, 0.889178f, 0.525811f, 0.479380f, 0.766941f, 0.901303f, 1.087107f, 0.808560f, 0.779749f, 0.609254f, 0.801121f, 0.808370f, 0.397958f, 0.867537f, 0.814879f, 0.981307f, 1.048465f, 0.422327f, 0.531406f, 0.847033f, 0.878291f, 0.737390f, 0.926101f, 1.027148f, 0.731989f, 0.720755f, 0.637853f, 0.523248f, 0.924757f, 0.757182f, 0.588026f, 0.773634f, 0.979738f, 1.255782f, 0.901064f, 0.688140f, 0.274371f, 0.437410f, 0.411925f, 0.342756f, 0.545288f, 0.529269f, 0.533905f, 0.380022f, 0.436475f, 0.301469f, 0.529214f, 0.526297f, 0.395983f, 0.411271f, 0.503063f, 0.557536f, 0.505664f, 0.334459f, 0.348011f, 0.483405f, 0.482135f, 0.438657f, 0.623578f, 0.666952f, 0.527974f, 0.396662f, 0.441010f, 0.322428f, 0.543776f, 0.569352f, 0.341589f, 0.541193f, 0.719589f, 0.825763f, 0.713140f, 0.369560f, 0.925217f, 0.962246f, 0.804315f, 0.969734f, 0.939348f, 0.895554f, 1.240035f, 1.032457f, 1.260824f, 0.838023f, 0.816715f, 1.381388f, 1.123444f, 0.666636f, 0.901369f, 0.880265f, 0.544716f, 0.964444f, 0.610261f, 0.432138f, 0.522623f, 0.616368f, 0.392524f, 0.601866f, 0.610201f, 0.716924f, 0.662694f, 0.625345f, 0.421250f, 0.927903f, 0.710488f, 0.375567f, 0.528123f, 0.532916f, 0.359236f, 0.428232f, 0.627666f, 0.646350f, 0.711912f, 0.578261f, 0.510271f, 0.666607f, 0.609787f, 0.652893f, 0.673018f, 0.618551f, 0.787326f, 1.094408f, 0.787271f, 0.433836f, 0.638263f, 0.836964f, 0.604598f, 0.587050f, 0.798962f, 0.607254f, 0.635098f, 0.675595f, 0.504633f, 0.579773f, 0.825966f, 0.745334f, 0.850824f, 0.713222f, 0.417185f, 0.949167f, 0.715411f, 0.438783f, 0.580263f, 0.596451f, 0.311825f, 0.698230f, 0.553783f, 0.653118f, 0.479333f, 0.683333f, 0.611400f, 0.926136f, 0.937356f, 1.079461f, 0.500571f, 0.941776f, 0.571910f, 0.891547f, 0.471507f, 0.784496f, 0.765230f, 0.316921f, 0.693191f, 0.812555f, 0.430584f, 0.838594f, 0.577089f, 0.887826f, 0.637326f, 0.838023f, 0.852760f, 0.930619f, 0.596678f, 1.004560f, 0.556861f, 0.837758f, 0.499217f, 0.774022f, 0.908813f, 0.359039f, 0.646230f, 0.839435f, 0.724433f, 1.107947f, 0.836124f, 1.043592f, 0.755617f, 1.190845f, 0.927864f, 1.247710f, 0.759936f, 1.199264f, 0.903627f, 0.981243f, 0.477713f, 0.991537f, 0.973822f, 0.518882f, 0.798147f, 0.975918f, 0.343779f, 0.491195f, 0.197678f, 0.348761f, 0.506575f, 0.694266f, 0.570159f, 0.588826f, 0.260686f, 0.583943f, 0.370536f, 0.570071f, 0.363210f, 0.512280f, 0.518522f, 0.260276f, 0.479575f, 0.519170f, 0.649026f, 0.390051f, 0.795750f, 0.920073f, 1.046746f, 0.900276f, 0.940614f, 0.679509f, 0.778774f, 0.792281f, 0.857889f, 1.197963f, 0.738062f, 0.792745f, 0.602892f, 0.687147f, 0.962916f, 0.719326f, 0.587815f, 0.233866f, 0.638163f, 0.785593f, 0.772991f, 0.770025f, 0.862170f, 0.414778f, 0.518855f, 0.729107f, 0.683017f, 0.903488f, 0.620768f, 0.669556f, 0.396731f, 0.418391f, 0.796217f, 0.580872f, 0.555648f, 0.496401f, 0.582726f, 0.730206f, 0.806009f, 0.858020f, 0.827912f, 0.515117f, 0.715055f, 0.533599f, 0.810529f, 0.887599f, 0.629091f, 0.713460f, 0.668702f, 0.740854f, 0.533289f, 0.544756f, 0.500474f, 0.287242f, 0.666506f, 0.805604f, 0.814325f, 0.939329f, 0.784865f, 0.575117f, 0.413632f, 0.650744f, 0.916553f, 0.821434f, 0.634740f, 0.761039f, 0.447249f, 0.427194f, 0.886137f, 0.483658f, 0.957992f, 0.967132f, 0.993273f, 0.791302f, 0.858239f, 1.102870f, 1.073905f, 0.782627f, 0.700627f, 1.402989f, 0.781228f, 0.752175f, 0.879408f, 1.311348f, 0.881165f, 1.044089f, 1.012252f, 1.461238f, 0.731050f, 0.967882f, 0.932687f, 0.778944f, 0.812401f, 0.974234f, 1.130671f, 0.729870f, 0.702872f, 1.304851f, 0.727443f, 0.734453f, 0.899574f, 1.238530f, 0.921609f, 1.012872f, 0.938401f, 1.303568f, 0.824322f, 0.920016f, 0.791579f, 0.844334f, 0.618349f, 0.989377f, 1.120477f, 0.554956f, 0.683589f, 1.280705f, 0.957804f, 0.833027f, 0.791589f, 1.159548f, 1.031220f, 0.951427f, 0.915324f, 1.361176f, 0.733365f, 0.808374f, 0.764603f, 0.755506f, 0.638693f, 0.946285f, 1.001370f, 0.578989f, 0.603487f, 1.074992f, 0.697424f, 0.812599f, 0.708634f, 1.129837f, 0.888077f, 0.835530f, 1.006811f, 1.193171f}; + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + ASSERT_EQ(qk_matmul.size(), batch_size * kv_num_heads * q_sequence_length * (past_sequence_length + kv_sequence_length)); + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmul) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 4; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 4; // V.shape[3] + int past_sequence_length = 7; // past_key.shape[2] and past_value.shape[2] + + // {2, 3, 4, 4} + std::vector q = {-0.454545f, -0.444129f, -0.433712f, -0.423295f, -0.412879f, -0.402462f, -0.392045f, -0.381629f, -0.371212f, -0.360795f, -0.350379f, -0.339962f, -0.329545f, -0.319129f, -0.308712f, -0.298295f, -0.287879f, -0.277462f, -0.267045f, -0.256629f, -0.246212f, -0.235795f, -0.225379f, -0.214962f, -0.204545f, -0.194129f, -0.183712f, -0.173295f, -0.162879f, -0.152462f, -0.142045f, -0.131629f, -0.121212f, -0.110795f, -0.100379f, -0.089962f, -0.079545f, -0.069129f, -0.058712f, -0.048295f, -0.037879f, -0.027462f, -0.017045f, -0.006629f, 0.003788f, 0.014205f, 0.024621f, 0.035038f, 0.045455f, 0.055871f, 0.066288f, 0.076705f, 0.087121f, 0.097538f, 0.107955f, 0.118371f, 0.128788f, 0.139205f, 0.149621f, 0.160038f, 0.170455f, 0.180871f, 0.191288f, 0.201705f, 0.212121f, 0.222538f, 0.232955f, 0.243371f, 0.253788f, 0.264205f, 0.274621f, 0.285038f, 0.295455f, 0.305871f, 0.316288f, 0.326705f, 0.337121f, 0.347538f, 0.357955f, 0.368371f, 0.378788f, 0.389205f, 0.399621f, 0.410038f, 0.420455f, 0.430871f, 0.441288f, 0.451705f, 0.462121f, 0.472538f, 0.482955f, 0.493371f, 0.503788f, 0.514205f, 0.524621f, 0.535038f}; + // {2, 3, 6, 4} + std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 6, 4} + std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 1, 4, 13} + std::vector m = {-0.454545f, -0.444930f, -0.435315f, -0.425699f, -0.416084f, -0.406469f, -0.396853f, -0.387238f, -0.377622f, -0.368007f, -0.358392f, -0.348776f, -0.339161f, -0.329545f, -0.319930f, -0.310315f, -0.300699f, -0.291084f, -0.281469f, -0.271853f, -0.262238f, -0.252622f, -0.243007f, -0.233392f, -0.223776f, -0.214161f, -0.204545f, -0.194930f, -0.185315f, -0.175699f, -0.166084f, -0.156469f, -0.146853f, -0.137238f, -0.127622f, -0.118007f, -0.108392f, -0.098776f, -0.089161f, -0.079545f, -0.069930f, -0.060315f, -0.050699f, -0.041084f, -0.031469f, -0.021853f, -0.012238f, -0.002622f, 0.006993f, 0.016608f, 0.026224f, 0.035839f, 0.045455f, 0.055070f, 0.064685f, 0.074301f, 0.083916f, 0.093531f, 0.103147f, 0.112762f, 0.122378f, 0.131993f, 0.141608f, 0.151224f, 0.160839f, 0.170455f, 0.180070f, 0.189685f, 0.199301f, 0.208916f, 0.218531f, 0.228147f, 0.237762f, 0.247378f, 0.256993f, 0.266608f, 0.276224f, 0.285839f, 0.295455f, 0.305070f, 0.314685f, 0.324301f, 0.333916f, 0.343531f, 0.353147f, 0.362762f, 0.372378f, 0.381993f, 0.391608f, 0.401224f, 0.410839f, 0.420455f, 0.430070f, 0.439685f, 0.449301f, 0.458916f, 0.468531f, 0.478147f, 0.487762f, 0.497378f, 0.506993f, 0.516608f, 0.526224f, 0.535839f}; + // {2, 3, 12, 4} + std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + // {2, 3, 12, 4} + std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), batch_size * 1 * q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + + // {2, 3, 4, 4} + std::vector y = {-0.385197f, -0.378771f, -0.372345f, -0.365919f, -0.385008f, -0.378583f, -0.372157f, -0.365731f, -0.384820f, -0.378394f, -0.371968f, -0.365543f, -0.384632f, -0.378206f, -0.371780f, -0.365354f, -0.217777f, -0.211351f, -0.204925f, -0.198499f, -0.217588f, -0.211163f, -0.204737f, -0.198311f, -0.217400f, -0.210974f, -0.204549f, -0.198123f, -0.217212f, -0.210786f, -0.204360f, -0.197935f, -0.050357f, -0.043931f, -0.037505f, -0.031080f, -0.050169f, -0.043743f, -0.037317f, -0.030891f, -0.049980f, -0.043555f, -0.037129f, -0.030703f, -0.049792f, -0.043366f, -0.036941f, -0.030515f, 0.117063f, 0.123489f, 0.129914f, 0.136340f, 0.117251f, 0.123677f, 0.130102f, 0.136528f, 0.117439f, 0.123865f, 0.130291f, 0.136716f, 0.117628f, 0.124053f, 0.130479f, 0.136904f, 0.284482f, 0.290908f, 0.297334f, 0.303759f, 0.284670f, 0.291096f, 0.297522f, 0.303947f, 0.284859f, 0.291284f, 0.297710f, 0.304135f, 0.285047f, 0.291472f, 0.297898f, 0.304323f, 0.451901f, 0.458327f, 0.464752f, 0.471178f, 0.452089f, 0.458515f, 0.464940f, 0.471366f, 0.452277f, 0.458703f, 0.465128f, 0.471554f, 0.452465f, 0.458890f, 0.465316f, 0.471741f}; + // {2, 3, 13, 4} + std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 18, 8} + std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 13} + std::vector qk_matmul = {0.391336f, 0.370435f, 0.349534f, 0.328633f, 0.307732f, 0.286831f, 0.265930f, 0.390055f, 0.365671f, 0.341286f, 0.316902f, 0.292517f, 0.268133f, 0.354201f, 0.335284f, 0.316367f, 0.297450f, 0.278534f, 0.259617f, 0.240700f, 0.353045f, 0.330975f, 0.308905f, 0.286836f, 0.264766f, 0.242696f, 0.317066f, 0.300134f, 0.283201f, 0.266268f, 0.249335f, 0.232403f, 0.215470f, 0.316034f, 0.296279f, 0.276524f, 0.256769f, 0.237014f, 0.217260f, 0.279932f, 0.264983f, 0.250034f, 0.235086f, 0.220137f, 0.205189f, 0.190240f, 0.279023f, 0.261583f, 0.244143f, 0.226703f, 0.209263f, 0.191823f, 0.152046f, 0.139081f, 0.126117f, 0.113152f, 0.100188f, 0.087223f, 0.074259f, 0.151261f, 0.136136f, 0.121011f, 0.105885f, 0.090760f, 0.075635f, 0.128800f, 0.117819f, 0.106839f, 0.095859f, 0.084878f, 0.073898f, 0.062918f, 0.128139f, 0.115329f, 0.102518f, 0.089708f, 0.076898f, 0.064087f, 0.105554f, 0.096558f, 0.087561f, 0.078565f, 0.069569f, 0.060573f, 0.051577f, 0.105017f, 0.094522f, 0.084026f, 0.073531f, 0.063035f, 0.052539f, 0.082308f, 0.075296f, 0.068284f, 0.061272f, 0.054260f, 0.047248f, 0.040235f, 0.081896f, 0.073715f, 0.065534f, 0.057353f, 0.049172f, 0.040992f, 0.023866f, 0.018838f, 0.013810f, 0.008783f, 0.003755f, -0.001273f, -0.006301f, 0.023578f, 0.017712f, 0.011846f, 0.005980f, 0.000114f, -0.005752f, 0.014509f, 0.011466f, 0.008422f, 0.005378f, 0.002334f, -0.000710f, -0.003754f, 0.014345f, 0.010794f, 0.007243f, 0.003692f, 0.000140f, -0.003411f, 0.005152f, 0.004093f, 0.003033f, 0.001973f, 0.000914f, -0.000146f, -0.001206f, 0.005112f, 0.003876f, 0.002639f, 0.001403f, 0.000167f, -0.001070f, -0.004204f, -0.003280f, -0.002356f, -0.001431f, -0.000507f, 0.000418f, 0.001342f, -0.004121f, -0.003042f, -0.001964f, -0.000885f, 0.000193f, 0.001272f, 0.006798f, 0.009707f, 0.012616f, 0.015524f, 0.018433f, 0.021341f, 0.024250f, 0.007006f, 0.010399f, 0.013793f, 0.017186f, 0.020579f, 0.023973f, 0.011330f, 0.016223f, 0.021116f, 0.026008f, 0.030901f, 0.035794f, 0.040686f, 0.011662f, 0.017370f, 0.023078f, 0.028786f, 0.034494f, 0.040203f, 0.015862f, 0.022739f, 0.029616f, 0.036493f, 0.043369f, 0.050246f, 0.057123f, 0.016318f, 0.024341f, 0.032364f, 0.040387f, 0.048410f, 0.056433f, 0.020394f, 0.029255f, 0.038116f, 0.046977f, 0.055838f, 0.064699f, 0.073560f, 0.020974f, 0.031312f, 0.041649f, 0.051987f, 0.062325f, 0.072663f, 0.100842f, 0.111687f, 0.122532f, 0.133377f, 0.144222f, 0.155067f, 0.165912f, 0.101545f, 0.114198f, 0.126850f, 0.139503f, 0.152155f, 0.164808f, 0.119262f, 0.132092f, 0.144921f, 0.157750f, 0.170579f, 0.183408f, 0.196237f, 0.120090f, 0.135057f, 0.150025f, 0.164992f, 0.179960f, 0.194927f, 0.137683f, 0.152496f, 0.167310f, 0.182123f, 0.196936f, 0.211750f, 0.226563f, 0.138635f, 0.155917f, 0.173199f, 0.190481f, 0.207764f, 0.225046f, 0.156104f, 0.172901f, 0.189699f, 0.206496f, 0.223294f, 0.240091f, 0.256889f, 0.157180f, 0.176777f, 0.196374f, 0.215971f, 0.235568f, 0.255165f, 0.305996f, 0.324777f, 0.343559f, 0.362340f, 0.381122f, 0.399904f, 0.418685f, 0.307195f, 0.329107f, 0.351019f, 0.372931f, 0.394843f, 0.416755f, 0.338305f, 0.359071f, 0.379837f, 0.400603f, 0.421368f, 0.442134f, 0.462900f, 0.339629f, 0.363856f, 0.388082f, 0.412309f, 0.436536f, 0.460762f, 0.370615f, 0.393365f, 0.416115f, 0.438865f, 0.461614f, 0.484364f, 0.507114f, 0.372063f, 0.398604f, 0.425146f, 0.451687f, 0.478229f, 0.504770f, 0.402925f, 0.427659f, 0.452393f, 0.477127f, 0.501861f, 0.526595f, 0.551329f, 0.404497f, 0.433353f, 0.462209f, 0.491065f, 0.519922f, 0.548778f}; + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + ASSERT_EQ(qk_matmul.size(), batch_size * kv_num_heads * q_sequence_length * (past_sequence_length + kv_sequence_length)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 4; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 4; // V.shape[3] + int past_sequence_length = 7; // past_key.shape[2] and past_value.shape[2] + + // {2, 3, 4, 4} + std::vector q = {-0.454545f, -0.444129f, -0.433712f, -0.423295f, -0.412879f, -0.402462f, -0.392045f, -0.381629f, -0.371212f, -0.360795f, -0.350379f, -0.339962f, -0.329545f, -0.319129f, -0.308712f, -0.298295f, -0.287879f, -0.277462f, -0.267045f, -0.256629f, -0.246212f, -0.235795f, -0.225379f, -0.214962f, -0.204545f, -0.194129f, -0.183712f, -0.173295f, -0.162879f, -0.152462f, -0.142045f, -0.131629f, -0.121212f, -0.110795f, -0.100379f, -0.089962f, -0.079545f, -0.069129f, -0.058712f, -0.048295f, -0.037879f, -0.027462f, -0.017045f, -0.006629f, 0.003788f, 0.014205f, 0.024621f, 0.035038f, 0.045455f, 0.055871f, 0.066288f, 0.076705f, 0.087121f, 0.097538f, 0.107955f, 0.118371f, 0.128788f, 0.139205f, 0.149621f, 0.160038f, 0.170455f, 0.180871f, 0.191288f, 0.201705f, 0.212121f, 0.222538f, 0.232955f, 0.243371f, 0.253788f, 0.264205f, 0.274621f, 0.285038f, 0.295455f, 0.305871f, 0.316288f, 0.326705f, 0.337121f, 0.347538f, 0.357955f, 0.368371f, 0.378788f, 0.389205f, 0.399621f, 0.410038f, 0.420455f, 0.430871f, 0.441288f, 0.451705f, 0.462121f, 0.472538f, 0.482955f, 0.493371f, 0.503788f, 0.514205f, 0.524621f, 0.535038f}; + // {2, 3, 6, 4} + std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 6, 4} + std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 1, 4, 13} + std::vector m = {-0.454545f, -0.444930f, -0.435315f, -0.425699f, -0.416084f, -0.406469f, -0.396853f, -0.387238f, -0.377622f, -0.368007f, -0.358392f, -0.348776f, -0.339161f, -0.329545f, -0.319930f, -0.310315f, -0.300699f, -0.291084f, -0.281469f, -0.271853f, -0.262238f, -0.252622f, -0.243007f, -0.233392f, -0.223776f, -0.214161f, -0.204545f, -0.194930f, -0.185315f, -0.175699f, -0.166084f, -0.156469f, -0.146853f, -0.137238f, -0.127622f, -0.118007f, -0.108392f, -0.098776f, -0.089161f, -0.079545f, -0.069930f, -0.060315f, -0.050699f, -0.041084f, -0.031469f, -0.021853f, -0.012238f, -0.002622f, 0.006993f, 0.016608f, 0.026224f, 0.035839f, 0.045455f, 0.055070f, 0.064685f, 0.074301f, 0.083916f, 0.093531f, 0.103147f, 0.112762f, 0.122378f, 0.131993f, 0.141608f, 0.151224f, 0.160839f, 0.170455f, 0.180070f, 0.189685f, 0.199301f, 0.208916f, 0.218531f, 0.228147f, 0.237762f, 0.247378f, 0.256993f, 0.266608f, 0.276224f, 0.285839f, 0.295455f, 0.305070f, 0.314685f, 0.324301f, 0.333916f, 0.343531f, 0.353147f, 0.362762f, 0.372378f, 0.381993f, 0.391608f, 0.401224f, 0.410839f, 0.420455f, 0.430070f, 0.439685f, 0.449301f, 0.458916f, 0.468531f, 0.478147f, 0.487762f, 0.497378f, 0.506993f, 0.516608f, 0.526224f, 0.535839f}; + // {2, 3, 12, 4} + std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + // {2, 3, 12, 4} + std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), batch_size * 1 * q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + + // {2, 3, 4, 4} + std::vector y = {-0.393782f, -0.387694f, -0.381606f, -0.375519f, -0.397492f, -0.391304f, -0.385116f, -0.378928f, -0.397474f, -0.391207f, -0.384941f, -0.378674f, -0.394849f, -0.388519f, -0.382190f, -0.375860f, -0.226271f, -0.220186f, -0.214101f, -0.208016f, -0.230042f, -0.223857f, -0.217672f, -0.211488f, -0.230104f, -0.223841f, -0.217577f, -0.211314f, -0.227525f, -0.221197f, -0.214870f, -0.208543f, -0.058757f, -0.052674f, -0.046592f, -0.040510f, -0.062587f, -0.056406f, -0.050224f, -0.044042f, -0.062730f, -0.056470f, -0.050209f, -0.043949f, -0.060198f, -0.053873f, -0.047548f, -0.041223f, 0.108760f, 0.114840f, 0.120919f, 0.126999f, 0.104873f, 0.111051f, 0.117229f, 0.123408f, 0.104648f, 0.110906f, 0.117163f, 0.123421f, 0.107131f, 0.113454f, 0.119777f, 0.126099f, 0.276279f, 0.282356f, 0.288433f, 0.294510f, 0.272337f, 0.278512f, 0.284687f, 0.290862f, 0.272031f, 0.278286f, 0.284540f, 0.290794f, 0.274463f, 0.280783f, 0.287104f, 0.293424f, 0.443800f, 0.449874f, 0.455949f, 0.462023f, 0.439807f, 0.445978f, 0.452150f, 0.458321f, 0.439418f, 0.445669f, 0.451921f, 0.458172f, 0.441797f, 0.448115f, 0.454433f, 0.460751f}; + // {2, 3, 13, 4} + std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 18, 8} + std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 13} + std::vector qk_matmul = {0.391336f, 0.370435f, 0.349534f, 0.328633f, 0.307732f, 0.286831f, 0.265930f, 0.390055f, 0.365671f, 0.341286f, 0.316902f, 0.292517f, 0.268133f, 0.354201f, 0.335284f, 0.316367f, 0.297450f, 0.278534f, 0.259617f, 0.240700f, 0.353045f, 0.330975f, 0.308905f, 0.286836f, 0.264766f, 0.242696f, 0.317066f, 0.300134f, 0.283201f, 0.266268f, 0.249335f, 0.232403f, 0.215470f, 0.316034f, 0.296279f, 0.276524f, 0.256769f, 0.237014f, 0.217260f, 0.279932f, 0.264983f, 0.250034f, 0.235086f, 0.220137f, 0.205189f, 0.190240f, 0.279023f, 0.261583f, 0.244143f, 0.226703f, 0.209263f, 0.191823f, 0.152046f, 0.139081f, 0.126117f, 0.113152f, 0.100188f, 0.087223f, 0.074259f, 0.151261f, 0.136136f, 0.121011f, 0.105885f, 0.090760f, 0.075635f, 0.128800f, 0.117819f, 0.106839f, 0.095859f, 0.084878f, 0.073898f, 0.062918f, 0.128139f, 0.115329f, 0.102518f, 0.089708f, 0.076898f, 0.064087f, 0.105554f, 0.096558f, 0.087561f, 0.078565f, 0.069569f, 0.060573f, 0.051577f, 0.105017f, 0.094522f, 0.084026f, 0.073531f, 0.063035f, 0.052539f, 0.082308f, 0.075296f, 0.068284f, 0.061272f, 0.054260f, 0.047248f, 0.040235f, 0.081896f, 0.073715f, 0.065534f, 0.057353f, 0.049172f, 0.040992f, 0.023866f, 0.018838f, 0.013810f, 0.008783f, 0.003755f, -0.001273f, -0.006301f, 0.023578f, 0.017712f, 0.011846f, 0.005980f, 0.000114f, -0.005752f, 0.014509f, 0.011466f, 0.008422f, 0.005378f, 0.002334f, -0.000710f, -0.003754f, 0.014345f, 0.010794f, 0.007243f, 0.003692f, 0.000140f, -0.003411f, 0.005152f, 0.004093f, 0.003033f, 0.001973f, 0.000914f, -0.000146f, -0.001206f, 0.005112f, 0.003876f, 0.002639f, 0.001403f, 0.000167f, -0.001070f, -0.004204f, -0.003280f, -0.002356f, -0.001431f, -0.000507f, 0.000418f, 0.001342f, -0.004121f, -0.003042f, -0.001964f, -0.000885f, 0.000193f, 0.001272f, 0.006798f, 0.009707f, 0.012616f, 0.015524f, 0.018433f, 0.021341f, 0.024250f, 0.007006f, 0.010399f, 0.013793f, 0.017186f, 0.020579f, 0.023973f, 0.011330f, 0.016223f, 0.021116f, 0.026008f, 0.030901f, 0.035794f, 0.040686f, 0.011662f, 0.017370f, 0.023078f, 0.028786f, 0.034494f, 0.040203f, 0.015862f, 0.022739f, 0.029616f, 0.036493f, 0.043369f, 0.050246f, 0.057123f, 0.016318f, 0.024341f, 0.032364f, 0.040387f, 0.048410f, 0.056433f, 0.020394f, 0.029255f, 0.038116f, 0.046977f, 0.055838f, 0.064699f, 0.073560f, 0.020974f, 0.031312f, 0.041649f, 0.051987f, 0.062325f, 0.072663f, 0.100842f, 0.111687f, 0.122532f, 0.133377f, 0.144222f, 0.155067f, 0.165912f, 0.101545f, 0.114198f, 0.126850f, 0.139503f, 0.152155f, 0.164808f, 0.119262f, 0.132092f, 0.144921f, 0.157750f, 0.170579f, 0.183408f, 0.196237f, 0.120090f, 0.135057f, 0.150025f, 0.164992f, 0.179960f, 0.194927f, 0.137683f, 0.152496f, 0.167310f, 0.182123f, 0.196936f, 0.211750f, 0.226563f, 0.138635f, 0.155917f, 0.173199f, 0.190481f, 0.207764f, 0.225046f, 0.156104f, 0.172901f, 0.189699f, 0.206496f, 0.223294f, 0.240091f, 0.256889f, 0.157180f, 0.176777f, 0.196374f, 0.215971f, 0.235568f, 0.255165f, 0.305996f, 0.324777f, 0.343559f, 0.362340f, 0.381122f, 0.399904f, 0.418685f, 0.307195f, 0.329107f, 0.351019f, 0.372931f, 0.394843f, 0.416755f, 0.338305f, 0.359071f, 0.379837f, 0.400603f, 0.421368f, 0.442134f, 0.462900f, 0.339629f, 0.363856f, 0.388082f, 0.412309f, 0.436536f, 0.460762f, 0.370615f, 0.393365f, 0.416115f, 0.438865f, 0.461614f, 0.484364f, 0.507114f, 0.372063f, 0.398604f, 0.425146f, 0.451687f, 0.478229f, 0.504770f, 0.402925f, 0.427659f, 0.452393f, 0.477127f, 0.501861f, 0.526595f, 0.551329f, 0.404497f, 0.433353f, 0.462209f, 0.491065f, 0.519922f, 0.548778f}; + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + ASSERT_EQ(qk_matmul.size(), batch_size * kv_num_heads * q_sequence_length * (past_sequence_length + kv_sequence_length)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 4; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 4; // V.shape[3] + int past_sequence_length = 7; // past_key.shape[2] and past_value.shape[2] + + // {2, 3, 4, 4} + std::vector q = {-0.454545f, -0.444129f, -0.433712f, -0.423295f, -0.412879f, -0.402462f, -0.392045f, -0.381629f, -0.371212f, -0.360795f, -0.350379f, -0.339962f, -0.329545f, -0.319129f, -0.308712f, -0.298295f, -0.287879f, -0.277462f, -0.267045f, -0.256629f, -0.246212f, -0.235795f, -0.225379f, -0.214962f, -0.204545f, -0.194129f, -0.183712f, -0.173295f, -0.162879f, -0.152462f, -0.142045f, -0.131629f, -0.121212f, -0.110795f, -0.100379f, -0.089962f, -0.079545f, -0.069129f, -0.058712f, -0.048295f, -0.037879f, -0.027462f, -0.017045f, -0.006629f, 0.003788f, 0.014205f, 0.024621f, 0.035038f, 0.045455f, 0.055871f, 0.066288f, 0.076705f, 0.087121f, 0.097538f, 0.107955f, 0.118371f, 0.128788f, 0.139205f, 0.149621f, 0.160038f, 0.170455f, 0.180871f, 0.191288f, 0.201705f, 0.212121f, 0.222538f, 0.232955f, 0.243371f, 0.253788f, 0.264205f, 0.274621f, 0.285038f, 0.295455f, 0.305871f, 0.316288f, 0.326705f, 0.337121f, 0.347538f, 0.357955f, 0.368371f, 0.378788f, 0.389205f, 0.399621f, 0.410038f, 0.420455f, 0.430871f, 0.441288f, 0.451705f, 0.462121f, 0.472538f, 0.482955f, 0.493371f, 0.503788f, 0.514205f, 0.524621f, 0.535038f}; + // {2, 3, 6, 4} + std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 6, 4} + std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 13} + std::vector m = {-0.454545f, -0.451340f, -0.448135f, -0.444930f, -0.441725f, -0.438520f, -0.435315f, -0.432110f, -0.428904f, -0.425699f, -0.422494f, -0.419289f, -0.416084f, -0.412879f, -0.409674f, -0.406469f, -0.403263f, -0.400058f, -0.396853f, -0.393648f, -0.390443f, -0.387238f, -0.384033f, -0.380828f, -0.377622f, -0.374417f, -0.371212f, -0.368007f, -0.364802f, -0.361597f, -0.358392f, -0.355186f, -0.351981f, -0.348776f, -0.345571f, -0.342366f, -0.339161f, -0.335956f, -0.332751f, -0.329545f, -0.326340f, -0.323135f, -0.319930f, -0.316725f, -0.313520f, -0.310315f, -0.307110f, -0.303904f, -0.300699f, -0.297494f, -0.294289f, -0.291084f, -0.287879f, -0.284674f, -0.281469f, -0.278263f, -0.275058f, -0.271853f, -0.268648f, -0.265443f, -0.262238f, -0.259033f, -0.255828f, -0.252622f, -0.249417f, -0.246212f, -0.243007f, -0.239802f, -0.236597f, -0.233392f, -0.230186f, -0.226981f, -0.223776f, -0.220571f, -0.217366f, -0.214161f, -0.210956f, -0.207751f, -0.204545f, -0.201340f, -0.198135f, -0.194930f, -0.191725f, -0.188520f, -0.185315f, -0.182110f, -0.178904f, -0.175699f, -0.172494f, -0.169289f, -0.166084f, -0.162879f, -0.159674f, -0.156469f, -0.153263f, -0.150058f, -0.146853f, -0.143648f, -0.140443f, -0.137238f, -0.134033f, -0.130828f, -0.127622f, -0.124417f, -0.121212f, -0.118007f, -0.114802f, -0.111597f, -0.108392f, -0.105186f, -0.101981f, -0.098776f, -0.095571f, -0.092366f, -0.089161f, -0.085956f, -0.082751f, -0.079545f, -0.076340f, -0.073135f, -0.069930f, -0.066725f, -0.063520f, -0.060315f, -0.057110f, -0.053904f, -0.050699f, -0.047494f, -0.044289f, -0.041084f, -0.037879f, -0.034674f, -0.031469f, -0.028263f, -0.025058f, -0.021853f, -0.018648f, -0.015443f, -0.012238f, -0.009033f, -0.005828f, -0.002622f, 0.000583f, 0.003788f, 0.006993f, 0.010198f, 0.013403f, 0.016608f, 0.019814f, 0.023019f, 0.026224f, 0.029429f, 0.032634f, 0.035839f, 0.039044f, 0.042249f, 0.045455f, 0.048660f, 0.051865f, 0.055070f, 0.058275f, 0.061480f, 0.064685f, 0.067890f, 0.071096f, 0.074301f, 0.077506f, 0.080711f, 0.083916f, 0.087121f, 0.090326f, 0.093531f, 0.096737f, 0.099942f, 0.103147f, 0.106352f, 0.109557f, 0.112762f, 0.115967f, 0.119172f, 0.122378f, 0.125583f, 0.128788f, 0.131993f, 0.135198f, 0.138403f, 0.141608f, 0.144814f, 0.148019f, 0.151224f, 0.154429f, 0.157634f, 0.160839f, 0.164044f, 0.167249f, 0.170455f, 0.173660f, 0.176865f, 0.180070f, 0.183275f, 0.186480f, 0.189685f, 0.192890f, 0.196096f, 0.199301f, 0.202506f, 0.205711f, 0.208916f, 0.212121f, 0.215326f, 0.218531f, 0.221737f, 0.224942f, 0.228147f, 0.231352f, 0.234557f, 0.237762f, 0.240967f, 0.244172f, 0.247378f, 0.250583f, 0.253788f, 0.256993f, 0.260198f, 0.263403f, 0.266608f, 0.269814f, 0.273019f, 0.276224f, 0.279429f, 0.282634f, 0.285839f, 0.289044f, 0.292249f, 0.295455f, 0.298660f, 0.301865f, 0.305070f, 0.308275f, 0.311480f, 0.314685f, 0.317890f, 0.321096f, 0.324301f, 0.327506f, 0.330711f, 0.333916f, 0.337121f, 0.340326f, 0.343531f, 0.346737f, 0.349942f, 0.353147f, 0.356352f, 0.359557f, 0.362762f, 0.365967f, 0.369172f, 0.372378f, 0.375583f, 0.378788f, 0.381993f, 0.385198f, 0.388403f, 0.391608f, 0.394814f, 0.398019f, 0.401224f, 0.404429f, 0.407634f, 0.410839f, 0.414044f, 0.417249f, 0.420455f, 0.423660f, 0.426865f, 0.430070f, 0.433275f, 0.436480f, 0.439685f, 0.442890f, 0.446096f, 0.449301f, 0.452506f, 0.455711f, 0.458916f, 0.462121f, 0.465326f, 0.468531f, 0.471737f, 0.474942f, 0.478147f, 0.481352f, 0.484557f, 0.487762f, 0.490967f, 0.494172f, 0.497378f, 0.500583f, 0.503788f, 0.506993f, 0.510198f, 0.513403f, 0.516608f, 0.519814f, 0.523019f, 0.526224f, 0.529429f, 0.532634f, 0.535839f, 0.539044f, 0.542249f}; + // {2, 3, 12, 4} + std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + // {2, 3, 12, 4} + std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), batch_size * q_num_heads * q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + + // {2, 3, 4, 4} + std::vector y = {-0.385742f, -0.379327f, -0.372911f, -0.366496f, -0.385554f, -0.379139f, -0.372723f, -0.366308f, -0.385366f, -0.378950f, -0.372535f, -0.366119f, -0.385178f, -0.378762f, -0.372347f, -0.365931f, -0.218323f, -0.211907f, -0.205492f, -0.199076f, -0.218134f, -0.211719f, -0.205304f, -0.198888f, -0.217946f, -0.211531f, -0.205115f, -0.198700f, -0.217758f, -0.211342f, -0.204927f, -0.198512f, -0.050903f, -0.044487f, -0.038072f, -0.031657f, -0.050715f, -0.044299f, -0.037884f, -0.031468f, -0.050526f, -0.044111f, -0.037695f, -0.031280f, -0.050338f, -0.043922f, -0.037507f, -0.031092f, 0.116517f, 0.122932f, 0.129348f, 0.135763f, 0.116705f, 0.123121f, 0.129536f, 0.135952f, 0.116894f, 0.123309f, 0.129724f, 0.136140f, 0.117082f, 0.123497f, 0.129913f, 0.136328f, 0.283937f, 0.290352f, 0.296768f, 0.303183f, 0.284125f, 0.290540f, 0.296956f, 0.303371f, 0.284313f, 0.290729f, 0.297144f, 0.303559f, 0.284501f, 0.290917f, 0.297332f, 0.303747f, 0.451356f, 0.457772f, 0.464187f, 0.470602f, 0.451544f, 0.457960f, 0.464375f, 0.470790f, 0.451732f, 0.458148f, 0.464563f, 0.470978f, 0.451920f, 0.458336f, 0.464751f, 0.471166f}; + // {2, 3, 13, 4} + std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 18, 8} + std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 13} + std::vector qk_matmul = {0.391336f, 0.370435f, 0.349534f, 0.328633f, 0.307732f, 0.286831f, 0.265930f, 0.390055f, 0.365671f, 0.341286f, 0.316902f, 0.292517f, 0.268133f, 0.354201f, 0.335284f, 0.316367f, 0.297450f, 0.278534f, 0.259617f, 0.240700f, 0.353045f, 0.330975f, 0.308905f, 0.286836f, 0.264766f, 0.242696f, 0.317066f, 0.300134f, 0.283201f, 0.266268f, 0.249335f, 0.232403f, 0.215470f, 0.316034f, 0.296279f, 0.276524f, 0.256769f, 0.237014f, 0.217260f, 0.279932f, 0.264983f, 0.250034f, 0.235086f, 0.220137f, 0.205189f, 0.190240f, 0.279023f, 0.261583f, 0.244143f, 0.226703f, 0.209263f, 0.191823f, 0.152046f, 0.139081f, 0.126117f, 0.113152f, 0.100188f, 0.087223f, 0.074259f, 0.151261f, 0.136136f, 0.121011f, 0.105885f, 0.090760f, 0.075635f, 0.128800f, 0.117819f, 0.106839f, 0.095859f, 0.084878f, 0.073898f, 0.062918f, 0.128139f, 0.115329f, 0.102518f, 0.089708f, 0.076898f, 0.064087f, 0.105554f, 0.096558f, 0.087561f, 0.078565f, 0.069569f, 0.060573f, 0.051577f, 0.105017f, 0.094522f, 0.084026f, 0.073531f, 0.063035f, 0.052539f, 0.082308f, 0.075296f, 0.068284f, 0.061272f, 0.054260f, 0.047248f, 0.040235f, 0.081896f, 0.073715f, 0.065534f, 0.057353f, 0.049172f, 0.040992f, 0.023866f, 0.018838f, 0.013810f, 0.008783f, 0.003755f, -0.001273f, -0.006301f, 0.023578f, 0.017712f, 0.011846f, 0.005980f, 0.000114f, -0.005752f, 0.014509f, 0.011466f, 0.008422f, 0.005378f, 0.002334f, -0.000710f, -0.003754f, 0.014345f, 0.010794f, 0.007243f, 0.003692f, 0.000140f, -0.003411f, 0.005152f, 0.004093f, 0.003033f, 0.001973f, 0.000914f, -0.000146f, -0.001206f, 0.005112f, 0.003876f, 0.002639f, 0.001403f, 0.000167f, -0.001070f, -0.004204f, -0.003280f, -0.002356f, -0.001431f, -0.000507f, 0.000418f, 0.001342f, -0.004121f, -0.003042f, -0.001964f, -0.000885f, 0.000193f, 0.001272f, 0.006798f, 0.009707f, 0.012616f, 0.015524f, 0.018433f, 0.021341f, 0.024250f, 0.007006f, 0.010399f, 0.013793f, 0.017186f, 0.020579f, 0.023973f, 0.011330f, 0.016223f, 0.021116f, 0.026008f, 0.030901f, 0.035794f, 0.040686f, 0.011662f, 0.017370f, 0.023078f, 0.028786f, 0.034494f, 0.040203f, 0.015862f, 0.022739f, 0.029616f, 0.036493f, 0.043369f, 0.050246f, 0.057123f, 0.016318f, 0.024341f, 0.032364f, 0.040387f, 0.048410f, 0.056433f, 0.020394f, 0.029255f, 0.038116f, 0.046977f, 0.055838f, 0.064699f, 0.073560f, 0.020974f, 0.031312f, 0.041649f, 0.051987f, 0.062325f, 0.072663f, 0.100842f, 0.111687f, 0.122532f, 0.133377f, 0.144222f, 0.155067f, 0.165912f, 0.101545f, 0.114198f, 0.126850f, 0.139503f, 0.152155f, 0.164808f, 0.119262f, 0.132092f, 0.144921f, 0.157750f, 0.170579f, 0.183408f, 0.196237f, 0.120090f, 0.135057f, 0.150025f, 0.164992f, 0.179960f, 0.194927f, 0.137683f, 0.152496f, 0.167310f, 0.182123f, 0.196936f, 0.211750f, 0.226563f, 0.138635f, 0.155917f, 0.173199f, 0.190481f, 0.207764f, 0.225046f, 0.156104f, 0.172901f, 0.189699f, 0.206496f, 0.223294f, 0.240091f, 0.256889f, 0.157180f, 0.176777f, 0.196374f, 0.215971f, 0.235568f, 0.255165f, 0.305996f, 0.324777f, 0.343559f, 0.362340f, 0.381122f, 0.399904f, 0.418685f, 0.307195f, 0.329107f, 0.351019f, 0.372931f, 0.394843f, 0.416755f, 0.338305f, 0.359071f, 0.379837f, 0.400603f, 0.421368f, 0.442134f, 0.462900f, 0.339629f, 0.363856f, 0.388082f, 0.412309f, 0.436536f, 0.460762f, 0.370615f, 0.393365f, 0.416115f, 0.438865f, 0.461614f, 0.484364f, 0.507114f, 0.372063f, 0.398604f, 0.425146f, 0.451687f, 0.478229f, 0.504770f, 0.402925f, 0.427659f, 0.452393f, 0.477127f, 0.501861f, 0.526595f, 0.551329f, 0.404497f, 0.433353f, 0.462209f, 0.491065f, 0.519922f, 0.548778f}; + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + ASSERT_EQ(qk_matmul.size(), batch_size * kv_num_heads * q_sequence_length * (past_sequence_length + kv_sequence_length)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index aa89ab80bc4e5..23c3a922326cb 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -31,6 +31,7 @@ "current_failing_tests": [ "^test_adagrad", "^test_adagrad_multiple", + "^test_attention_3d.*", // wrong expected values in onnx==1.18.0, fixed in 1.19.0 "^test_batchnorm_epsilon_training_mode", "^test_batchnorm_example_training_mode", "^test_col2im_pads", // still one wrong value coming from the backtest example From 0905c56ee26d1903c27cea1963940cacdb3d541c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Fri, 25 Jul 2025 17:19:51 +0200 Subject: [PATCH 11/33] [TRT EP] Fix `trt_load_user_initializer` for large models where weight are not correctly excluded (#25502) ### Description This change respects initializers that are external but already loaded in memory. This is required due to an optimization that leaves it to the backend to read a mapped memory area. @chilo-ms can you help run the CI and merge this change ? --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../core/graph/graph_proto_serializer.cc | 2 +- .../tensorrt/tensorrt_execution_provider.cc | 140 ++++++++++-------- .../tensorrt/tensorrt_execution_provider.h | 23 ++- .../providers/tensorrt/tensorrt_basic_test.cc | 2 + 4 files changed, 97 insertions(+), 70 deletions(-) diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index 0fbcea2719ce8..9a67796254231 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -95,7 +95,7 @@ void GraphViewerToProto(const GraphViewer& graph_view, auto* p_initializer = graph_proto.add_initializer(); // Do not save raw into the graph, only the metadata - if (!include_initializer_data && init->has_raw_data()) { + if (!include_initializer_data && (init->has_raw_data() || utils::HasExternalDataInMemory(*init))) { // Set datatype if (init->has_data_type()) { p_initializer->set_data_type(init->data_type()); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 64be445b4c15c..b60f64db1734d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2337,11 +2337,14 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (load_user_initializer_) { auto allInitializers = graph_viewer->GetAllInitializedTensors(); - for (auto entry : allInitializers) { + for (auto& entry : allInitializers) { auto* tp = entry.second; if (tp->has_raw_data()) { - userWeights.push_back( - TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()}); + userWeights.emplace_back(tp->name(), tp->raw_data()); + } else if (utils::HasExternalDataInMemory(*tp)) { + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(full_init->name(), full_init->raw_data()); } } } @@ -2378,7 +2381,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (load_user_initializer_) { trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); for (auto const& userWeight : userWeights) { - trt_parser->loadInitializer(userWeight.name.c_str(), static_cast(userWeight.data.c_str()), userWeight.size); + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); } is_model_supported = trt_parser->parseModelProto(); } else { @@ -2862,7 +2865,8 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil if (onnx_model_path.empty()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The ONNX model was not provided as path. " - "Please use provide an ONNX bytestream to enable refitting the weightless engine."); + "Please use provide an ONNX bytestream to enable refitting the weightless engine." + "When providing a bytestream during session initialization, it should also be set as trt_onnx_bytes_stream"); } else { // check if file path to ONNX is legal if (path_check && IsAbsolutePath(onnx_model_path.string())) { @@ -2909,6 +2913,7 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil int required_weights = refitter->getAllWeights(0, nullptr); std::vector refit_names(required_weights); refitter->getAllWeights(required_weights, refit_names.data()); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitter requires " << required_weights << " weights"; // Vectors to keep track of data pointers. std::vector names; @@ -2918,67 +2923,69 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil std::vector sizes; sizes.reserve(required_weights); - if (refit_with_external_data) { - auto onnx_model = ModelProto::Create(); - TensorProtos* allInitializers_byte_stream; + auto onnx_model = ModelProto::Create(); + TensorProtos* allInitializers_byte_stream; - // Reconstruct onnx model view. - const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, - onnx_model_bytestream_size); - if (!onnx_model->ParseFromString(onnx_model_view)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "The provided ONNX bytestream to refit could not be parsed."); - } - - // Extract graph and initializer information. - auto const& graph = onnx_model->mutable_graph(); - allInitializers_byte_stream = graph->mutable_initializer(); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size(); - - // Loop through all initializers - for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { - auto& proto = allInitializers_byte_stream->at(initializer_idx); - auto& proto_name = proto.name(); - bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end(); - if (weight_is_refittable) { - if (proto.has_data_location()) { - if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) { - // Default values for reading into external_data blob. - int64_t offset = 0; - size_t length = 0; - auto external_data = proto.mutable_external_data(); - const std::string kOffset = "offset", kLength = "length"; - for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { - auto current_key = external_data->at(entry_idx).mutable_key(); - auto current_value = external_data->at(entry_idx).mutable_value(); - if (*current_key == kOffset && !current_value->empty()) { - offset = std::stoll(*current_value); - } else if (*current_key == kLength && !current_value->empty()) { - length = std::stoul(*current_value); - } + // Reconstruct onnx model view. + const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, + onnx_model_bytestream_size); + if (!onnx_model->ParseFromString(onnx_model_view)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The provided ONNX bytestream to refit could not be parsed."); + } + + // Extract graph and initializer information. + auto const& graph = onnx_model->mutable_graph(); + allInitializers_byte_stream = graph->mutable_initializer(); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size(); + + // Loop through all initializers + int missing_initializer_data = 0; + for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { + auto& proto = allInitializers_byte_stream->at(initializer_idx); + auto& proto_name = proto.name(); + bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end(); + if (weight_is_refittable) { + if (proto.has_data_location()) { + if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) { + // Default values for reading into external_data blob. + int64_t offset = 0; + size_t length = 0; + auto external_data = proto.mutable_external_data(); + const std::string kOffset = "offset", kLength = "length"; + for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { + auto current_key = external_data->at(entry_idx).mutable_key(); + auto current_value = external_data->at(entry_idx).mutable_value(); + if (*current_key == kOffset && !current_value->empty()) { + offset = std::stoll(*current_value); + } else if (*current_key == kLength && !current_value->empty()) { + length = std::stoul(*current_value); } - names.push_back(proto.name()); - bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); - sizes.push_back(length); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."); } - } else { - if (!proto.has_raw_data()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "[TensorRT EP] Proto: " + proto_name + " has no raw data"); - } - auto& raw_data = proto.raw_data(); names.push_back(proto.name()); - bytes.push_back(raw_data.c_str()); - sizes.push_back(raw_data.size()); + bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); + sizes.push_back(length); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."); } + } else if (proto.has_raw_data()) { + auto& raw_data = proto.raw_data(); + names.push_back(proto.name()); + bytes.push_back(raw_data.c_str()); + sizes.push_back(raw_data.size()); } else { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable"; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Proto: " + proto_name + " has no raw nor external data."; + ++missing_initializer_data; } + } else { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable"; } } + if (missing_initializer_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[TensorRT EP] RefitEngine is missing " + std::to_string(missing_initializer_data) + " initializers."); + } // Load extracted initializers into the parser if (!names.empty()) { @@ -3093,12 +3100,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (load_user_initializer_) { auto allInitializers = graph_body_viewer.GetAllInitializedTensors(); - for (auto entry : allInitializers) { + for (auto& entry : allInitializers) { auto name = entry.first; auto* tp = entry.second; if (tp->has_raw_data()) { - userWeights->push_back( - TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()}); + userWeights->emplace_back( + TensorrtUserWeights(tp->name(), tp->raw_data())); + } else if (utils::HasExternalDataInMemory(*tp)) { + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights->emplace_back( + TensorrtUserWeights(full_init->name(), full_init->raw_data())); } } } @@ -3134,7 +3146,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (load_user_initializer_) { trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); for (auto const& userWeight : *userWeights) { - trt_parser->loadInitializer(userWeight.name.c_str(), static_cast(userWeight.data.c_str()), userWeight.size); + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); } trt_parser->parseModelProto(); } else { @@ -3671,14 +3683,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (weight_stripped_engine_refit_) { LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build"; - char* onnx = string_buf.data(); - size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, - onnx, - onnx_size, + onnx_model_bytestream_, + onnx_model_bytestream_size_, onnx_external_data_bytestream_, onnx_external_data_bytestream_size_, trt_engine.get(), diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index dba17f7822eac..e817fc51237c0 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -158,10 +158,25 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { using ShapeRangesMap = std::unordered_map>>>; // Struct to hold user weights when ModelProtos are serialized with data. -struct TensorrtUserWeights { - std::string name{}; - std::string data{}; - int64_t size{}; +class TensorrtUserWeights { + public: + TensorrtUserWeights(const std::string& name, const std::string& data) : name_(name), data_(data) {}; + + const char* Name() const { + return name_.c_str(); + }; + + const void* Data() const { + return static_cast(data_.data()); + } + + int64_t Size() const { + return static_cast(data_.size()); + } + + private: + std::string name_{}; + std::string data_{}; }; // Information to construct kernel function state. diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 553059932db90..706bd3c0fce62 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -571,6 +571,8 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { params7.trt_dump_ep_context_model = 1; params7.trt_ep_context_embed_mode = 1; params7.trt_weight_stripped_engine_enable = 1; + params7.trt_onnx_bytestream = model_bytes.data(); + params7.trt_onnx_bytestream_size = model_bytes.size(); params7.trt_ep_context_file_path = ctx_model_name_str.c_str(); execution_provider = TensorrtExecutionProviderWithOptions(¶ms7); EXPECT_TRUE(session_object7.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); From ab95c1505c8af7a53cf0d97a82cf9a7ae840e850 Mon Sep 17 00:00:00 2001 From: Ishwar Raut Date: Fri, 25 Jul 2025 21:14:56 +0530 Subject: [PATCH 12/33] [NvTensorRTRTX EP] Add EP factory to Nv TRT RTX EP (#25511) ### Description 1. Implemented the required changes for the EP factory. ### Motivation and Context These changes are required for WinML GA. --- .../nv_tensorrt_rtx/nv_provider_factory.cc | 488 ++++++++++++++++-- .../nv_tensorrt_rtx/nv_basic_test.cc | 157 +++++- .../test_nv_trt_rtx_ep_util.cc | 58 +++ .../nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h | 38 ++ 4 files changed, 708 insertions(+), 33 deletions(-) create mode 100644 onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc create mode 100644 onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index 428d24f2f3df8..e236cccaaaa77 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -7,10 +7,14 @@ #include #include "nv_execution_provider.h" #include "nv_provider_factory_creator.h" +#include "nv_data_transfer.h" +#include "nv_allocator.h" #include "core/framework/provider_options.h" #include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" #include "core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h" #include +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "core/providers/cuda/cuda_stream_handle.h" using namespace onnxruntime; @@ -151,28 +155,385 @@ ORT_API(onnxruntime::Provider*, GetProvider) { } } -#include "core/framework/error_code_helper.h" +// +// Plug-in EP infrastructure +// + +#include "core/session/abi_devices.h" +#include "onnxruntime_config.h" // for ORT_VERSION + +struct ErrorHelper { + static const OrtApi* ort_api; + + static OrtStatus* ToOrtStatus(const Status& status) { + if (status.IsOK()) { + return nullptr; // no error + } + + return ort_api->CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } +}; + +const OrtApi* ErrorHelper::ort_api = nullptr; + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF_STATUS_NOTOK(fn) \ + do { \ + Status _status = (fn); \ + if (!_status.IsOK()) { \ + return ErrorHelper::ToOrtStatus(_status); \ + } \ + } while (0) + +#define CUDA_RETURN_IF_ERROR(expr) RETURN_IF_STATUS_NOTOK(CUDA_CALL(expr)) + +struct NvTrtRtxOrtAllocator : OrtAllocator { + NvTrtRtxOrtAllocator(const OrtMemoryInfo* mem_info, const OrtApi& api) : memory_info_{mem_info} { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = AllocImpl; // no special behavior for Reserve so use AllocImpl + GetStats = nullptr; // GetStatsImpl. The CUDA allocators don't have stats currently so we can skip. + + const OrtEpApi& ep_api = *api.GetEpApi(); + const OrtMemoryDevice* mem_device = ep_api.MemoryInfo_GetMemoryDevice(mem_info); + uint32_t device_id = ep_api.MemoryDevice_GetDeviceId(mem_device); + const char* name = nullptr; + auto* status = api.MemoryInfoGetName(mem_info, &name); + static_cast(status); // GetName never fails + + if (ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_HOST_ACCESSIBLE) { + allocator_ = std::make_unique(device_id, name); + } else { + allocator_ = std::make_unique(device_id, name); + } + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + auto& impl = *static_cast(this_); + return impl.allocator_->Alloc(size); + } + + static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { + auto& impl = *static_cast(this_); + impl.allocator_->Free(p); + } + + static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + const NvTrtRtxOrtAllocator& impl = *static_cast(this_); + return impl.memory_info_; + } + + private: + const OrtMemoryInfo* memory_info_; + std::unique_ptr allocator_; +}; + +struct NvTrtRtxDataTransferImpl : OrtDataTransferImpl { + NvTrtRtxDataTransferImpl(const OrtApi& ort_api_in) + : ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + ort_version_supported = ORT_API_VERSION; + CanCopy = CanCopyImpl; + CopyTensors = CopyTensorsImpl; + Release = ReleaseImpl; + } + + static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept { + const auto& impl = *static_cast(this_ptr); + + // logic copied from GPUDataTransfer::CanCopy + OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device); + OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device); + auto src_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device); + auto dst_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device); + + if ((src_type == OrtDevice::GPU && src_vendor_id != OrtDevice::VendorIds::NVIDIA) || + (dst_type == OrtDevice::GPU && dst_vendor_id != OrtDevice::VendorIds::NVIDIA)) { + return false; + } + + // copy must be GPU to GPU or between GPU and CPU + return (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) || + (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_CPU) || + (src_type == OrtMemoryInfoDeviceType_CPU && dst_type == OrtMemoryInfoDeviceType_GPU); + } + + static OrtStatus* CopyTensorsImpl(OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors, + OrtValue** dst_tensors, + OrtSyncStream** streams, + size_t num_tensors) noexcept { + auto& impl = *static_cast(this_ptr); + bool need_stream_sync = false; + + for (size_t idx = 0; idx < num_tensors; ++idx) { + const OrtValue* src_tensor = src_tensors[idx]; + OrtValue* dst_tensor = dst_tensors[idx]; + OrtSyncStream* stream = streams ? streams[idx] : nullptr; + + const OrtMemoryDevice* src_device = impl.ep_api.Value_GetMemoryDevice(src_tensor); + const OrtMemoryDevice* dst_device = impl.ep_api.Value_GetMemoryDevice(dst_tensor); + + size_t bytes; + RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(src_tensor, &bytes)); + + const void* src_data = nullptr; + void* dst_data = nullptr; + RETURN_IF_ERROR(impl.ort_api.GetTensorData(src_tensor, &src_data)); + RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensor, &dst_data)); + + OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device); + OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device); + OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device); + OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device); + + const bool src_is_gpu_default = src_type == OrtMemoryInfoDeviceType_GPU && + src_mem_type == OrtDeviceMemoryType_DEFAULT; + const bool dst_is_gpu_default = dst_type == OrtMemoryInfoDeviceType_GPU && + dst_mem_type == OrtDeviceMemoryType_DEFAULT; + + cudaStream_t cuda_stream = nullptr; + if (stream) { + cuda_stream = static_cast(impl.ort_api.SyncStream_GetHandle(stream)); + } + + if (dst_is_gpu_default) { + if (src_is_gpu_default) { + // Copy only if the two addresses are different. + if (dst_data != src_data) { + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, cuda_stream)); + + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + need_stream_sync = true; + } + } + } else { + // copy from pinned or non-pinned CPU memory to GPU + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, cuda_stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); + + if (src_mem_type != OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not + // have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + need_stream_sync = true; + } + } + } + } else if (src_is_gpu_default) { + // copying from GPU to CPU memory, this is blocking + + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, cuda_stream)); + + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); + } + } else { + // copying between CPU accessible memory + + if (dst_data != src_data) { + if (cuda_stream) { + if (src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // sync the stream first to make sure the data arrived + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + } + } + + memcpy(dst_data, src_data, bytes); + } + } + } + + if (need_stream_sync) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + + return nullptr; + } + + static void ReleaseImpl(OrtDataTransferImpl* /*this_ptr*/) noexcept { + // no-op as we have a single shared instance in OrtEpFactory which is returned from CreateDataTransferImpl, and is + // owned by and freed by the factory. + } + + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +struct NvTrtRtxSyncNotificationImpl : OrtSyncNotificationImpl { + static OrtStatus* Create(cudaStream_t stream, const OrtApi& ort_api, + std::unique_ptr& notification) { + notification.reset(new NvTrtRtxSyncNotificationImpl(stream, ort_api)); // can't use make_unique with private ctor + CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(¬ification->event_, cudaEventDisableTiming)); + + return nullptr; + } + + static void ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static OrtStatus* ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventRecord(impl.event_, impl.stream_)); + + return nullptr; + } + + static OrtStatus* WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* consumer_stream) noexcept { + auto& impl = *static_cast(this_ptr); + + // setup the consumer stream to wait on our event. + void* consumer_handle = impl.ort_api.SyncStream_GetHandle(consumer_stream); + CUDA_RETURN_IF_ERROR(cudaStreamWaitEvent(static_cast(consumer_handle), impl.event_)); + + return nullptr; + } + + static OrtStatus* WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventSynchronize(impl.event_)); + + return nullptr; + } + + ~NvTrtRtxSyncNotificationImpl() { + cudaEventDestroy(event_); + } + + private: + NvTrtRtxSyncNotificationImpl(cudaStream_t stream, const OrtApi& ort_api_in) + : stream_{stream}, ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + ort_version_supported = ORT_API_VERSION; + Activate = ActivateImpl; + WaitOnDevice = WaitOnDeviceImpl; + WaitOnHost = WaitOnHostImpl; + Release = ReleaseImpl; + } + + cudaStream_t& stream_; + cudaEvent_t event_; + + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +struct NvTrtRtxSyncStreamImpl : OrtSyncStreamImpl { + NvTrtRtxSyncStreamImpl(cudaStream_t&& stream, + const OrtDevice& device, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_cuda_stream, + const OrtApi& ort_api_in) + : stream_{ + stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, /*own*/ true, + /*external_cudnn_handle*/ nullptr, + /*external_cublas_handle*/ nullptr, + // ep_info is used by GetResource which seems to be a somewhat ugly way to make arbitrary info that is + // unrelated to the stream available to a custom op. + // avoiding adding GetResource to OrtSyncStreamImpl as we should have a cleaner setup for custom ops, + // so this argument value isn't used and doesn't matter. + /*ep_info*/ CUDAExecutionProviderInfo{}}, + ort_api{ort_api_in} { + ort_version_supported = ORT_API_VERSION; + GetHandle = GetHandleImpl; + CreateNotification = CreateNotificationImpl; + Flush = FlushImpl; + OnSessionRunEnd = OnSessionRunEndImpl; + Release = ReleaseImpl; + } + + static void ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static void* GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + return impl.stream_.GetHandle(); + } + + static OrtStatus* CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** notification_impl) noexcept { + auto& impl = *static_cast(this_ptr); + *notification_impl = nullptr; + + std::unique_ptr notification; + cudaStream_t* cuda_stream = static_cast(impl.stream_.GetHandle()); + + RETURN_IF_ERROR(NvTrtRtxSyncNotificationImpl::Create(*cuda_stream, impl.ort_api, notification)); + *notification_impl = notification.release(); + + return nullptr; + } + + static OrtStatus* FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + impl.stream_.Flush(); + + return nullptr; + } + + static OrtStatus* OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + RETURN_IF_STATUS_NOTOK(impl.stream_.CleanUpOnRunEnd()); + + return nullptr; + } + + private: + // this is a little onion-ish as CudaStream is a onnxruntime::Stream and this is an OrtSyncStreamImpl that will be + // used via plugin_ep::Stream, which is also an onnxruntime::Stream. in a 'real' plugin EP implementation + // CudaStream would go away and the logic it has would be implemented directly here. + CudaStream stream_; + const OrtApi& ort_api; +}; // OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection. struct NvTensorRtRtxEpFactory : OrtEpFactory { + using MemoryInfoUniquePtr = std::unique_ptr>; + NvTensorRtRtxEpFactory(const OrtApi& ort_api_in, - const OrtLogger& default_logger_in, - OrtHardwareDeviceType hw_type) - : ort_api{ort_api_in}, default_logger{default_logger_in}, ort_hw_device_type{hw_type} { + const OrtLogger& default_logger_in) : ort_api{ort_api_in}, + ep_api{*ort_api_in.GetEpApi()}, + default_logger{default_logger_in}, + data_transfer_impl{ort_api_in} { GetName = GetNameImpl; GetVendor = GetVendorImpl; GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; + GetVendorId = GetVendorIdImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; CreateAllocator = CreateAllocatorImpl; ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + + ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. } // Returns the name for the EP. Each unique factory configuration must have a unique name. @@ -211,18 +572,36 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); + int num_cuda_devices = 0; + cudaGetDeviceCount(&num_cuda_devices); + RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices)); + + int16_t device_id = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type && + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) { OrtKeyValuePairs* ep_options = nullptr; + OrtKeyValuePairs* ep_metadata = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_options); - ORT_API_RETURN_IF_ERROR( - factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options, - &ep_devices[num_ep_devices++])); + factory->ort_api.CreateKeyValuePairs(&ep_metadata); + factory->ort_api.AddKeyValuePair(ep_options, "device_id", std::to_string(device_id).c_str()); + + RETURN_IF_ERROR(factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, + &ep_devices[num_ep_devices])); + factory->ort_api.ReleaseKeyValuePairs(ep_options); + factory->ort_api.ReleaseKeyValuePairs(ep_metadata); + + const OrtMemoryInfo* gpu_mem_info = factory->gpu_memory_infos[device_id].get(); + const OrtMemoryInfo* host_accessible_mem_info = factory->host_accessible_memory_infos[device_id].get(); + + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], gpu_mem_info)); + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], host_accessible_mem_info)); + num_ep_devices++; + device_id++; } } - return nullptr; } @@ -241,50 +620,99 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { } static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, - const OrtMemoryInfo* /*memory_info*/, + const OrtMemoryInfo* memory_info, const OrtKeyValuePairs* /*allocator_options*/, OrtAllocator** allocator) noexcept { - auto* factory = static_cast(this_ptr); - - *allocator = nullptr; - return factory->ort_api.CreateStatus( - ORT_INVALID_ARGUMENT, - "CreateAllocator should not be called as we did not add OrtMemoryInfo to our OrtEpDevice."); + auto& factory = *static_cast(this_ptr); + auto allocator_ = std::make_unique(memory_info, factory.ort_api); + *allocator = allocator_.release(); + return nullptr; } - static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, OrtAllocator* /*allocator*/) noexcept { - // should never be called as we don't implement CreateAllocator + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept { + delete static_cast(allocator); } - static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept { - *data_transfer = nullptr; // not implemented + auto& factory = *static_cast(this_ptr); + *data_transfer = &factory.data_transfer_impl; return nullptr; } static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { - return false; + return true; } static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, - const OrtMemoryDevice* /*memory_device*/, + const OrtMemoryDevice* memory_device, const OrtKeyValuePairs* /*stream_options*/, - OrtSyncStreamImpl** stream) noexcept { - auto* factory = static_cast(this_ptr); + OrtSyncStreamImpl** ort_stream) noexcept { + auto& factory = *static_cast(this_ptr); + + auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); + cudaStream_t stream = nullptr; + CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id)); + CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - *stream = nullptr; - return factory->ort_api.CreateStatus( - ORT_INVALID_ARGUMENT, "CreateSyncStreamForDevice should not be called as IsStreamAware returned false."); + const OrtDevice* ort_device = static_cast(memory_device); + + auto impl = std::make_unique(std::move(stream), *ort_device, nullptr, + /*release_cpu_buffer_on_cuda_stream*/ true, + factory.ort_api); + *ort_stream = impl.release(); + return nullptr; } + OrtStatus* CreateMemoryInfoForDevices(int num_devices) { + gpu_memory_infos.reserve(num_devices); + host_accessible_memory_infos.reserve(num_devices); + + for (int device_id = 0; device_id < num_devices; ++device_id) { + OrtMemoryInfo* mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("NvTensorRTRTX", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ OrtDevice::VendorIds::NVIDIA, + /* device_id */ device_id, + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, + &mem_info)); + gpu_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + + mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("NvTensorRTRTX host accessible", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ OrtDevice::VendorIds::NVIDIA, + /* device_id */ device_id, + OrtDeviceMemoryType_HOST_ACCESSIBLE, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, + &mem_info)); + host_accessible_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + } + return nullptr; + } + + private: const OrtApi& ort_api; + const OrtEpApi& ep_api; const OrtLogger& default_logger; const std::string ep_name{kNvTensorRTRTXExecutionProvider}; const std::string vendor{"NVIDIA"}; // NVIDIA vendor ID. Refer to the ACPI ID registry (search NVIDIA): https://uefi.org/ACPI_ID_List const uint32_t vendor_id{0x10de}; - const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice + + std::vector gpu_memory_infos; + std::vector host_accessible_memory_infos; + + // we use a shared instance for the OrtDataTransferImpl instead of creating a new one on every call to + NvTrtRtxDataTransferImpl data_transfer_impl; + + NvTensorRtRtxEpFactory(const NvTensorRtRtxEpFactory&) = delete; + NvTensorRtRtxEpFactory& operator=(const NvTensorRtRtxEpFactory&) = delete; + + NvTensorRtRtxEpFactory(NvTensorRtRtxEpFactory&&) = default; + NvTensorRtRtxEpFactory& operator=(NvTensorRtRtxEpFactory&&) = default; }; extern "C" { @@ -297,14 +725,14 @@ OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); // Factory could use registration_name or define its own EP name. - auto factory_gpu = std::make_unique(*ort_api, *default_logger, OrtHardwareDeviceType_GPU); + auto factory = std::make_unique(*ort_api, *default_logger); if (max_factories < 1) { return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, "Not enough space to return EP factory. Need at least one."); } - factories[0] = factory_gpu.release(); + factories[0] = factory.release(); *num_factories = 1; return nullptr; diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 8858ae75fb39a..0559699670c4a 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -8,8 +8,13 @@ #include "gtest/gtest.h" #include "test/util/include/scoped_env_vars.h" #include "test/common/trt_op_test_utils.h" +#include "test/common/random_generator.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" #include +#include #include #include #include @@ -20,7 +25,7 @@ using namespace std; using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; - +extern std::unique_ptr ort_env; namespace onnxruntime { namespace test { @@ -410,9 +415,10 @@ static bool SessionHasEp(Ort::Session& session, const char* ep_name) { TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { PathString model_name = ORT_TSTR("nv_execution_provider_data_dyn_test.onnx"); std::string graph_name = "test"; - std::vector dims = {1, -1, -1}; - CreateBaseModel(model_name, graph_name, dims, true); + std::vector dims = {1, 3, 2}; + + CreateBaseModel(model_name, graph_name, dims); auto env = Ort::Env(); auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; @@ -429,6 +435,151 @@ TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { env.UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider); } + +TEST(NvExecutionProviderTest, GetSharedAllocator) { + const OrtApi& c_api = Ort::GetApi(); + RegisteredEpDeviceUniquePtr nv_tensorrt_rtx_ep; + Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, nv_tensorrt_rtx_ep); + + const auto* ep_memory_info = c_api.EpDevice_MemoryInfo(nv_tensorrt_rtx_ep.get(), OrtDeviceMemoryType_DEFAULT); + + // validate there is a shared allocator + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, ep_memory_info, &allocator)); + ASSERT_NE(allocator, nullptr); + + const auto* ep_host_accessible_memory_info = c_api.EpDevice_MemoryInfo(nv_tensorrt_rtx_ep.get(), OrtDeviceMemoryType_HOST_ACCESSIBLE); + OrtAllocator* host_accessible_allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, ep_host_accessible_memory_info, &host_accessible_allocator)); + ASSERT_NE(host_accessible_allocator, nullptr); +} + +TEST(NvExecutionProviderTest, LoadUnloadPluginLibrary) { + const std::filesystem::path& library_path = Utils::nv_tensorrt_rtx_ep_info.library_path; + const std::string& registration_name = Utils::nv_tensorrt_rtx_ep_info.registration_name; + + const OrtApi* c_api = &Ort::GetApi(); + // this should load the library and create OrtEpDevice + ASSERT_ORTSTATUS_OK(Ort::GetApi().RegisterExecutionProviderLibrary(*ort_env, registration_name.c_str(), + library_path.c_str())); + + const OrtEpDevice* const* ep_devices = nullptr; + size_t num_devices = 0; + + ASSERT_ORTSTATUS_OK(Ort::GetApi().GetEpDevices(*ort_env, &ep_devices, &num_devices)); + // should be one device for the example EP + auto num_test_ep_devices = std::count_if(ep_devices, ep_devices + num_devices, + [®istration_name, &c_api](const OrtEpDevice* device) { + // the example uses the registration name for the EP name + // but that is not a requirement and the two can differ. + return c_api->EpDevice_EpName(device) == registration_name; + }); + ASSERT_EQ(num_test_ep_devices, 1) << "Expected an OrtEpDevice to have been created by the test library."; + + // and this should unload it + ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(*ort_env, + registration_name.c_str())); +} + +TEST(NvExecutionProviderTest, LoadUnloadPluginLibraryCxxApi) { + const std::filesystem::path& library_path = Utils::nv_tensorrt_rtx_ep_info.library_path; + const std::string& registration_name = Utils::nv_tensorrt_rtx_ep_info.registration_name; + const OrtApi* c_api = &Ort::GetApi(); + // this should load the library and create OrtEpDevice + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + std::vector ep_devices = ort_env->GetEpDevices(); + + auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), + [®istration_name, &c_api](const Ort::ConstEpDevice& device) { + return device.EpName() == registration_name; + }); + ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by the test library."; + + // test all the C++ getters. expected values are from \onnxruntime\test\autoep\library\example_plugin_ep.cc + ASSERT_STREQ(test_ep_device->EpVendor(), "NVIDIA"); + + auto metadata = test_ep_device->EpMetadata(); + ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), ORT_VERSION); + + // the GPU device info will vary by machine so check for the lowest common denominator values + Ort::ConstHardwareDevice device = test_ep_device->Device(); + ASSERT_EQ(device.Type(), OrtHardwareDeviceType_GPU); + ASSERT_GE(device.VendorId(), 0); + ASSERT_GE(device.DeviceId(), 0); + ASSERT_NE(device.Vendor(), nullptr); + Ort::ConstKeyValuePairs device_metadata = device.Metadata(); + std::unordered_map metadata_entries = device_metadata.GetKeyValuePairs(); + ASSERT_GT(metadata_entries.size(), 0); // should have at least SPDRP_HARDWAREID on Windows + + // and this should unload it without throwing + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); +} + +TEST(NvExecutionProviderTest, DataTransfer) { + const OrtApi& c_api = Ort::GetApi(); + RegisteredEpDeviceUniquePtr nv_tensorrt_rtx_ep; + Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, nv_tensorrt_rtx_ep); + const OrtEpDevice* ep_device = nv_tensorrt_rtx_ep.get(); + + const OrtMemoryInfo* device_memory_info = c_api.EpDevice_MemoryInfo(ep_device, OrtDeviceMemoryType_DEFAULT); + + // create a tensor using the default CPU allocator + Ort::AllocatorWithDefaultOptions cpu_allocator; + std::vector shape{2, 3, 4}; // shape doesn't matter + const size_t num_elements = 2 * 3 * 4; + + RandomValueGenerator random{}; + std::vector input_data = random.Gaussian(shape, 0.0f, 2.f); + Ort::Value cpu_tensor = Ort::Value::CreateTensor(cpu_allocator.GetInfo(), + input_data.data(), input_data.size(), + shape.data(), shape.size()); + + // create an on-device Tensor using the NV TensorRT RTX EP GPU allocator. + + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, device_memory_info, &allocator)); + ASSERT_NE(allocator, nullptr); + Ort::Value device_tensor = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + + std::vector src_tensor_ptrs{cpu_tensor}; + std::vector dst_tensor_ptrs{device_tensor}; + + ASSERT_ORTSTATUS_OK(c_api.CopyTensors(*ort_env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), nullptr, + src_tensor_ptrs.size())); + + // Copy data back from device_tensor to a new CPU tensor and verify the contents + + // Create a new CPU tensor to receive the data + Ort::Value cpu_tensor_copy = Ort::Value::CreateTensor(cpu_allocator, shape.data(), shape.size()); + + std::vector src_tensor_ptrs_back{device_tensor}; + std::vector dst_tensor_ptrs_back{cpu_tensor_copy}; + + ASSERT_ORTSTATUS_OK(c_api.CopyTensors(*ort_env, src_tensor_ptrs_back.data(), dst_tensor_ptrs_back.data(), nullptr, + src_tensor_ptrs_back.size())); + + const float* src_data = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetTensorData(cpu_tensor, reinterpret_cast(&src_data))); + + const float* cpu_copy_data = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetTensorData(cpu_tensor_copy, reinterpret_cast(&cpu_copy_data))); + + ASSERT_NE(src_data, cpu_copy_data) << "Should have copied between two different memory locations"; + + size_t bytes; + ASSERT_ORTSTATUS_OK(c_api.GetTensorSizeInBytes(cpu_tensor, &bytes)); + ASSERT_EQ(bytes, num_elements * sizeof(float)); + + auto src_span = gsl::make_span(src_data, num_elements); + auto cpu_copy_span = gsl::make_span(cpu_copy_data, num_elements); + + EXPECT_THAT(cpu_copy_span, ::testing::ContainerEq(src_span)); + + // must release this before we unload the EP and the allocator is deleted + device_tensor = Ort::Value(); +} + #endif // defined(WIN32) } // namespace test diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc new file mode 100644 index 0000000000000..f0ce5c0b296ca --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "test/util/include/api_asserts.h" + +namespace onnxruntime { +namespace test { + +Utils::NvTensorRtRtxEpInfo Utils::nv_tensorrt_rtx_ep_info; + +void Utils::GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device) { + const OrtApi& c_api = Ort::GetApi(); + const OrtEpDevice* const* ep_devices = nullptr; + size_t num_devices; + ASSERT_ORTSTATUS_OK(c_api.GetEpDevices(env, &ep_devices, &num_devices)); + + auto it = std::find_if(ep_devices, ep_devices + num_devices, + [&c_api, &ep_name](const OrtEpDevice* ep_device) { + // NV TensorRT RTX EP uses registration name as ep name + return c_api.EpDevice_EpName(ep_device) == ep_name; + }); + + if (it == ep_devices + num_devices) { + ep_device = nullptr; + } else { + ep_device = *it; + } +} + +void Utils::RegisterAndGetNvTensorRtRtxEp(Ort::Env& env, RegisteredEpDeviceUniquePtr& registered_ep) { + const OrtApi& c_api = Ort::GetApi(); + // this should load the library and create OrtEpDevice + ASSERT_ORTSTATUS_OK(c_api.RegisterExecutionProviderLibrary(env, + nv_tensorrt_rtx_ep_info.registration_name.c_str(), + nv_tensorrt_rtx_ep_info.library_path.c_str())); + const OrtEpDevice* nv_tensorrt_rtx_ep = nullptr; + GetEp(env, nv_tensorrt_rtx_ep_info.registration_name, nv_tensorrt_rtx_ep); + ASSERT_NE(nv_tensorrt_rtx_ep, nullptr); + + registered_ep = RegisteredEpDeviceUniquePtr(nv_tensorrt_rtx_ep, [&env, c_api](const OrtEpDevice* /*ep*/) { + c_api.UnregisterExecutionProviderLibrary(env, nv_tensorrt_rtx_ep_info.registration_name.c_str()); + }); +} + +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h new file mode 100644 index 0000000000000..ef14d3cb382c0 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/graph/constants.h" + +namespace onnxruntime { +namespace test { + +using RegisteredEpDeviceUniquePtr = std::unique_ptr>; + +struct Utils { + struct NvTensorRtRtxEpInfo { + const std::filesystem::path library_path = +#if _WIN32 + "onnxruntime_providers_nv_tensorrt_rtx.dll"; +#else + "libonnxruntime_providers_nv_tensorrt_rtx.so"; +#endif + const std::string registration_name = kNvTensorRTRTXExecutionProvider; + }; + + static NvTensorRtRtxEpInfo nv_tensorrt_rtx_ep_info; + + // get the OrtEpDevice for the NV TensorRT RTX EP from the environment + static void GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device); + + // Register the NV TensorRT RTX EP library, get the OrtEpDevice for it, and return a unique pointer that will + // automatically unregister the EP library. + static void RegisterAndGetNvTensorRtRtxEp(Ort::Env& env, RegisteredEpDeviceUniquePtr& nv_tensorrt_rtx_ep); +}; +} // namespace test +} // namespace onnxruntime From ea1a44ed7d4f7ecbdd6c0fdd3c6b58750f987fe9 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 25 Jul 2025 08:52:45 -0700 Subject: [PATCH 13/33] [EP ABI] OrtGraphToProto utils fixes (#25531) ### Description Fixes for the OrtGraphToProto utilities that EPs can copy and modify: - When serializing `OrtGraph` to ONNX protobuf, do not set an `onnx::TensorShapeProto` for `onnx::ValueInfo` if the shape has no dimension entries. Otherwise, the shape incorrectly looks like a scalar. - Add `ORT_OP_ATTR_GRAPH` to the enum values returned by the `OpAttr_GetType` C API function. This allows the OrtGraphToProto utilities to skip processing subgraph attributes, which can be retrieved via a different API, but return an error on any unsupported attribute type. ### Motivation and Context --- .../core/providers/utils/ort_graph_to_proto.h | 37 ++++++++++++------- .../core/session/onnxruntime_c_api.h | 1 + onnxruntime/core/session/onnxruntime_c_api.cc | 4 ++ onnxruntime/test/ep_graph/test_ep_graph.cc | 16 ++++++-- onnxruntime/test/util/include/api_asserts.h | 6 +++ 5 files changed, 46 insertions(+), 18 deletions(-) diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index b7311f70cd179..0d920ab7dac89 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -366,13 +366,18 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, for (const OrtOpAttr* ort_attr : ort_attrs) { OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; - if (!status.IsOK()) { - // This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it. + Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { + // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. // Can use Node_GetSubgraphs to get subgraphs. continue; } + if (!attr_type_status.IsOK()) { + // Unsupported attribute type. + return attr_type_status; + } + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); } @@ -622,20 +627,24 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); type_proto_tensor->set_elem_type(ort_elem_type); - onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); + // If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks + // like a scalar value. + if (!ort_dims.empty()) { + onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); - for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { - onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); + for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { + onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); - if (ort_dims[dim_idx] >= 0) { - dim_proto->set_dim_value(ort_dims[dim_idx]); - } else { - const std::string& dim_param = ort_dim_syms[dim_idx]; + if (ort_dims[dim_idx] >= 0) { + dim_proto->set_dim_value(ort_dims[dim_idx]); + } else { + const std::string& dim_param = ort_dim_syms[dim_idx]; - // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, - // which represents an unknown dimension. - if (!dim_param.empty()) { - dim_proto->set_dim_param(dim_param); + // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, + // which represents an unknown dimension. + if (!dim_param.empty()) { + dim_proto->set_dim_param(dim_param); + } } } } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a4cf17845a494..2f0e4aa7ce108 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -274,6 +274,7 @@ typedef enum OrtOpAttrType { ORT_OP_ATTR_FLOATS, ORT_OP_ATTR_STRING, ORT_OP_ATTR_STRINGS, + ORT_OP_ATTR_GRAPH, } OrtOpAttrType; //! @} diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f6b6335dd29c0..27f81b18be0c9 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3048,6 +3048,10 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _O *type = OrtOpAttrType::ORT_OP_ATTR_STRINGS; break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH: { + *type = OrtOpAttrType::ORT_OP_ATTR_GRAPH; + break; + } default: return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type."); } diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index d0f682491e4f9..45314f8f39eea 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -166,7 +166,8 @@ TEST(EpGraphTest, SerializeToProto_InputModelHasExternalIni) { }; ONNX_NAMESPACE::ModelProto model_proto; - OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, handle_initializer_data); + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, + handle_initializer_data)); std::ofstream ofs(serialized_model_path, std::ios::binary); model_proto.SerializeToOstream(&ofs); @@ -257,7 +258,8 @@ TEST(EpGraphTest, SerializeToProto_Mnist) { }; ONNX_NAMESPACE::ModelProto model_proto; - OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, handle_initializer_data); + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, + handle_initializer_data)); std::ofstream ofs(serialized_model_path, std::ios::binary); model_proto.SerializeToOstream(&ofs); @@ -301,7 +303,7 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { }; ONNX_NAMESPACE::GraphProto graph_proto; - OrtEpUtils::OrtGraphToProto(ort_graph, graph_proto, handle_initializer_data); + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(ort_graph, graph_proto, handle_initializer_data)); // Verify that TensorProto objects within GraphProto point to memory owned by OrtValues in the OrtGraph. const OrtApi& ort_api = Ort::GetApi(); @@ -393,7 +395,7 @@ TEST(EpGraphTest, SerializeToProto_3LayerSubgraphs) { // Serialize OrtGraph to ModelProto (all initializers stored within TensorProtos). ONNX_NAMESPACE::ModelProto model_proto; - OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto); + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto)); std::ofstream ofs(serialized_model_path, std::ios::binary); model_proto.SerializeToOstream(&ofs); @@ -848,6 +850,8 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // It's possible that the type is defined in ONNX::AttributeProto_AttributeType but not in OrtOpAttrType, since the two are not in a 1:1 mapping. // In such cases, OpAttr_GetType will return a non-null status, and we simply skip the check here. + // TODO: Once we add support for ORT_OP_ATTR_TENSOR, we should be able to just fail if OpAttr_GetType + // returns an error. OrtStatusPtr status = ort_api.OpAttr_GetType(api_node_attr, &api_node_attr_type); if (status != nullptr) { Ort::GetApi().ReleaseStatus(status); @@ -884,6 +888,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_STRINGS); break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_GRAPH); + break; + } default: // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail. ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit.")); diff --git a/onnxruntime/test/util/include/api_asserts.h b/onnxruntime/test/util/include/api_asserts.h index 9d34be24d5012..423135f96fbcd 100644 --- a/onnxruntime/test/util/include/api_asserts.h +++ b/onnxruntime/test/util/include/api_asserts.h @@ -37,3 +37,9 @@ EXPECT_NE(_tmp_status, nullptr); \ if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ } while (false) + +#define ASSERT_CXX_ORTSTATUS_OK(function) \ + do { \ + Ort::Status _tmp_status = (function); \ + ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ + } while (false) From 01912cebe886d8a2aa24c730a163460c126072d6 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 25 Jul 2025 09:37:06 -0700 Subject: [PATCH 14/33] Upgrade onnxruntime-Ubuntu2204-AMD-CPU machine pool to 24.04 (#25534) ### Description 1. Upgrade onnxruntime-Ubuntu2204-AMD-CPU machine pool to Ubuntu 24.04, which can fix some vulnerability management issues. 2. Fix some packaging pipeline issues and remove some unused code blocks from dml-vs-2022.yml --- .../build-perf-test-binaries-pipeline.yml | 12 ++-- .../c-api-noopenmp-test-pipelines.yml | 10 +++ .../azure-pipelines/dml-nuget-packaging.yml | 2 - .../mac-react-native-ci-pipeline.yml | 2 +- .../azure-pipelines/nodejs/templates/test.yml | 16 ++--- .../nodejs/templates/test_linux.yml | 1 - .../nodejs/templates/test_macos.yml | 2 +- .../npm-packaging-pipeline.yml | 4 +- .../azure-pipelines/nuget-windows-ai.yml | 2 +- .../nuget/templates/dml-vs-2022.yml | 68 ++----------------- .../nuget/templates/test_android.yml | 14 ++-- .../nuget/templates/test_linux.yml | 31 ++++----- .../nuget/templates/test_macos.yml | 16 ++--- .../azure-pipelines/post-merge-jobs.yml | 6 +- .../py-cuda-package-test-pipeline.yml | 2 +- .../py-package-test-pipeline.yml | 4 +- .../stages/c-api-linux-cpu-stage.yml | 2 +- .../stages/download-java-tools-stage.yml | 2 +- .../stages/nodejs-linux-packaging-stage.yml | 2 +- .../nuget-linux-cuda-packaging-stage.yml | 4 +- .../stages/nuget_dml_packaging_stage.yml | 2 +- .../stages/py-cpu-packaging-stage.yml | 4 +- .../stages/py-cuda-publishing-stage.yml | 2 +- .../stages/py-gpu-packaging-stage.yml | 2 +- .../stages/set_packaging_variables_stage.yml | 2 +- .../templates/android-java-api-aar-test.yml | 2 +- .../templates/android-java-api-aar.yml | 2 +- .../azure-pipelines/templates/c-api-cpu.yml | 9 --- .../templates/c-api-linux-cpu.yml | 2 +- .../linux-cpu-packaging-pipeline.yml | 2 +- .../templates/linux-wasm-ci.yml | 2 +- .../azure-pipelines/templates/qnn-ep-win.yml | 2 +- .../azure-pipelines/web-ci-pipeline.yml | 2 +- 33 files changed, 83 insertions(+), 154 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml index 0ce4227c9ef9f..5cf5cd8c936fa 100644 --- a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml @@ -20,13 +20,17 @@ stages: artifactName: 'onnxruntime-android-full-aar' job_name_suffix: 'Full' publish_executables: '1' - pool_name: 'onnxruntime-Ubuntu2204-AMD-CPU' + pool_name: 'onnxruntime-Ubuntu2404-AMD-CPU' enable_code_sign: false # build Python packages # Linux GPU only - ${{ if parameters.BuildPythonPackages }}: - - template: stages/py-gpu-packaging-stage.yml + - template: stages/py-linux-gpu-stage.yml parameters: - enable_linux_cuda: true - cuda_version: 12.2 + arch: 'x86_64' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' + extra_build_arg: '' + cmake_build_type: Release + cuda_version: 12.2 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml index ed6183b3fa6da..64e5661eaf6fe 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml @@ -75,6 +75,16 @@ stages: artifactName: 'onnxruntime-android-full-aar' ReleaseVersionSuffix: $(ReleaseVersionSuffix) +- stage: Final_AAR_Testing_Android_QNN + dependsOn: Setup + jobs: + - template: templates/android-java-api-aar-test.yml + parameters: + artifactName: 'onnxruntime-android-qnn-aar' + packageName: 'onnxruntime-android-qnn' + #TODO: get this information from the setup stage + QnnSDKVersion: '2.36.1.250708' + - template: nuget/templates/test_win.yml parameters: AgentPool: 'onnxruntime-Win-CPU-2022' diff --git a/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml b/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml index 0e0a0632b9b6c..6e196e1f8ffd3 100644 --- a/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml +++ b/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml @@ -68,7 +68,6 @@ extends: ArtifactName: 'drop-nuget-dml' StageName: 'Windows_CI_GPU_DML_Dev' BuildCommand: --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --enable_generic_interface --build_nodejs --cmake_generator "Visual Studio 17 2022" --use_vcpkg --use_vcpkg_ms_internal_asset_cache - BuildArch: 'x64' msbuildArchitecture: 'amd64' EnvSetupScript: 'setup_env.bat' sln_platform: 'x64' @@ -88,7 +87,6 @@ extends: ArtifactName: 'drop-win-dml-arm64-zip' StageName: 'Windows_CI_GPU_DML_Dev_arm64' BuildCommand: --build_dir $(Build.BinariesDirectory) --arm64 --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --enable_generic_interface --build_nodejs --cmake_generator "Visual Studio 17 2022" --use_vcpkg --use_vcpkg_ms_internal_asset_cache - BuildArch: 'x64' EnvSetupScript: 'setup_env.bat' sln_platform: 'arm64' DoDebugBuild: 'false' diff --git a/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml index e72f088cfeb55..c71cd95150aa6 100644 --- a/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml @@ -55,5 +55,5 @@ stages: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} BuildConfig: 'Release' - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' enable_code_sign: false diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml index 28ece85428287..6c998f9c3da13 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml @@ -1,10 +1,12 @@ steps: - checkout: none -- task: DownloadPipelineArtifact@0 - displayName: 'Download NPM packages' - inputs: - artifactName: NPM_packages - targetPath: '$(Build.BinariesDirectory)/nodejs-artifact' +- download: build + displayName: 'Download NPM_packages' + artifact: 'NPM_packages' + +- script: | + mv $(Pipeline.Workspace)/build/NPM_packages '$(Build.BinariesDirectory)/nodejs-artifact' + - script: mkdir e2e_test workingDirectory: '$(Build.BinariesDirectory)' @@ -31,6 +33,4 @@ steps: npm init -y npm install $(NpmPackageFilesForTest) --onnxruntime-node-install-cuda=skip node -p "require('onnxruntime-node')" - workingDirectory: '$(Build.BinariesDirectory)/e2e_test' - -- template: ../../templates/clean-agent-build-directory-step.yml + workingDirectory: '$(Build.BinariesDirectory)/e2e_test' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml index 13516b93db4e0..50121595aed54 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml @@ -13,7 +13,6 @@ stages: timeoutInMinutes: 120 pool: name: ${{ parameters.AgentPool }} - os: 'linux' variables: - name: OnnxRuntimeBuildDirectory diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index 6f51abb761c51..bb4f600395ac9 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -12,7 +12,7 @@ stages: timeoutInMinutes: 120 pool: name: 'Azure Pipelines' - image: 'macOS-14' + image: 'macOS-15' os: 'macOS' variables: diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index f6d404c3bde62..3615f9f7c0960 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -55,7 +55,7 @@ extends: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} IsReleasePipeline: true - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' PackageName: 'onnxruntime-web' ExtraBuildArgs: '' UseWebPoolName: true @@ -69,7 +69,7 @@ extends: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} BuildConfig: 'Release' - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' PackageName: 'onnxruntime-react-native' InitialStageDependsOn: 'Precheck_and_extract_commit' enable_code_sign: false diff --git a/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml b/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml index feffd6b268c17..8e29381bc7eb4 100644 --- a/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml +++ b/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml @@ -100,7 +100,7 @@ extends: - output: pipelineArtifact path: '$(Build.ArtifactStagingDirectory)/merged' artifact: drop_Windows_Build_NuGet_Packaging - - ${{if and(eq(parameters.IsReleaseBuild, false), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/users/snnn/')))}}: + - ${{if and(eq(parameters.IsReleaseBuild, false), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')))}}: - output: nuget useDotNetTask: false # The default is false to use the NuGetCommand task. Set to true to use the DotNetCoreCLI task to publish packages. packagesToPush: '$(Build.ArtifactStagingDirectory)/merged/*.nupkg;!$(Build.ArtifactStagingDirectory)/merged/*.symbols.nupkg' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 71ff44ebb2ae5..757b8ac6e9a16 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -6,11 +6,8 @@ parameters: DoNugetPack: 'false' NuPackScript : '' ArtifactName: 'drop-nuget' - DoNodejsPack: 'false' - BuildNodejs: 'true' DoEsrp: 'false' DoTestCoverage: 'false' - BuildArch: 'x64' # Optional. Options: x86, x64 sln_platform: 'x64' # Options: Win32, x64, arm, arm64 EnvSetupScript: 'setup_env.bat' AgentDemands: [] @@ -40,7 +37,6 @@ stages: variables: buildDirectory: '$(Build.BinariesDirectory)' OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' - runCodesignValidationInjection: and(${{ parameters.DoNodejsPack }},${{ parameters. DoEsrp}}) #For the others, code sign is in a separated job DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] @@ -63,7 +59,7 @@ stages: inputs: versionSpec: '3.12' addToPath: true - architecture: ${{ parameters.BuildArch }} + architecture: x64 - task: PipAuthenticate@1 displayName: 'Pip Authenticate' inputs: @@ -74,13 +70,13 @@ stages: inputs: version: 8.x env: - PROCESSOR_ARCHITECTURE: ${{ parameters.BuildArch }} + PROCESSOR_ARCHITECTURE: x64 - task: BatchScript@1 displayName: 'Setup VS2022 env vars' inputs: filename: 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat' - arguments: ${{ parameters.BuildArch }} + arguments: x64 modifyEnvironment: true - ${{ if notIn(parameters['sln_platform'], 'Win32', 'x64') }}: @@ -114,7 +110,7 @@ stages: inputs: version: 6.x env: - PROCESSOR_ARCHITECTURE: ${{ parameters.BuildArch }} + PROCESSOR_ARCHITECTURE: x64 - template: ../../templates/win-esrp-dll.yml parameters: @@ -148,64 +144,10 @@ stages: ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: symbolExpiryTime: 60 includePublicSymbolServer: true - symbolsArtifactName: onnxruntime-dml-nuget-${{ parameters.BuildArch }} + symbolsArtifactName: onnxruntime-dml-nuget-${{ parameters.sln_platform }} symbolsVersion: $(Build.BuildId) symbolProject: 'ONNX Runtime' subscription: 'OnnxrunTimeCodeSign_20240611' searchPattern: | $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime.pdb $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime_providers_*.pdb - - # Node.js Publish - - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: - - task: BatchScript@1 - displayName: 'Setup VS env vars' - inputs: - filename: 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat' - arguments: ${{ parameters.BuildArch }} - modifyEnvironment: true - - template: ../../templates/win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\x64' - DisplayName: 'ESRP - Sign Node.js binding binaries' - DoEsrp: ${{ parameters.DoEsrp }} - Pattern: '*.dll,*.node' - - - script: | - del /Q $(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\x64\CodeSignSummary-*.* - call npm pack - copy $(Build.SourcesDirectory)\js\node\onnxruntime-*.tgz $(Build.ArtifactStagingDirectory) - xcopy /E /I $(Build.SourcesDirectory)\js\node\prebuilds $(Build.ArtifactStagingDirectory)\prebuilds - workingDirectory: '$(Build.SourcesDirectory)\js\node' - displayName: 'Create NPM Package' - - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Pipeline Artifact: ${{ parameters.ArtifactName }}' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - - # Put an unzipped version there to check if all the binaries are signed. - - script: | - 7z x $(Build.ArtifactStagingDirectory)\prebuilds\onnxruntime-*.tar.gz - 7z x $(Build.ArtifactStagingDirectory)\onnxruntime-*.tar - displayName: 'Unzip package to test' - workingDirectory: '$(Build.ArtifactStagingDirectory)' - - - ${{ if eq(parameters.BuildNodejs, 'true') }}: - - task: CopyFiles@2 - displayName: 'Copy DirectML binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\${{ parameters.sln_platform }}' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' - Contents: 'DirectML.dll' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\${{ parameters.sln_platform }}' - - template: ../../templates/win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\${{ parameters.sln_platform }}' - DisplayName: 'ESRP - Sign Node.js binding binaries' - DoEsrp: ${{ parameters.DoEsrp }} - Pattern: '*.node' - - task: 1ES.PublishPipelineArtifact@1 - inputs: - targetPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\${{ parameters.sln_platform }}' - artifactName: 'drop-onnxruntime-nodejs-win-${{ parameters.sln_platform }}-dml' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_android.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_android.yml index 17ea414152be8..e75804f0b35cb 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_android.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_android.yml @@ -24,17 +24,13 @@ stages: inputs: versionSpec: 6.10.x - - template: ../../templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact' - ArtifactName: drop-signed-nuget-${{ parameters.ArtifactSuffix }} - TargetPath: '$(Build.BinariesDirectory)\nuget-artifact' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} + - download: build + displayName: 'Download Nuget' + artifact: 'drop-signed-nuget-${{ parameters.ArtifactSuffix }}' - template: get-nuget-package-version-as-variable.yml parameters: - packageFolder: '$(Build.BinariesDirectory)\nuget-artifact' + packageFolder: '$(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }}' - task: PowerShell@2 displayName: Install MAUI workloads @@ -49,7 +45,7 @@ stages: inputs: targetType: 'inline' script: | - dotnet nuget add source $(Build.BinariesDirectory)\nuget-artifact --name local-nuget + dotnet nuget add source $(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }} --name local-nuget dotnet publish -c Release --property:UsePrebuiltNativePackage=true --property:CurrentOnnxRuntimeVersion=$(NuGetPackageVersionNumber) -f net8.0-android workingDirectory: '$(Build.SourcesDirectory)\csharp\test\Microsoft.ML.OnnxRuntime.Tests.MAUI' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index 2f4f480eeb122..89ce3f3c86727 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -1,5 +1,5 @@ parameters: - AgentPool: 'onnxruntime-Ubuntu2204-AMD-CPU' + AgentPool: 'onnxruntime-Ubuntu2404-AMD-CPU' ArtifactSuffix: '' NugetPackageName: '' StageSuffix: 'CPU' @@ -30,21 +30,18 @@ stages: value: '$(Build.BinariesDirectory)' steps: - - template: ../../templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Signed NuGet' - ArtifactName: drop-signed-nuget-${{ parameters.ArtifactSuffix }} - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} + - download: build + displayName: 'Download Nuget' + artifact: 'drop-signed-nuget-${{ parameters.ArtifactSuffix }}' + - download: build + displayName: 'Download Linux CustomOp TestData' + artifact: ${{ parameters.CustomOpArtifactName }} - - template: ../../templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Linux CustomOp TestData' - ArtifactName: ${{ parameters.CustomOpArtifactName }} - TargetPath: '$(Build.BinariesDirectory)/testdata' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} + + - script: | + mv $(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }} $(Build.BinariesDirectory)/nuget-artifact + mv $(Pipeline.Workspace)/build/${{ parameters.CustomOpArtifactName }} $(Build.BinariesDirectory)/testdata + - template: get-nuget-package-version-as-variable.yml parameters: @@ -110,6 +107,4 @@ stages: DisableContribOps: $(DisableContribOps) DisableMlOps: $(DisableMlOps) IsReleaseBuild: $(IsReleaseBuild) - PACKAGENAME: ${{ parameters.NugetPackageName }} - - - template: ../../templates/clean-agent-build-directory-step.yml + PACKAGENAME: ${{ parameters.NugetPackageName }} \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml index dcaa8f9381ad4..1d122d64b1211 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml @@ -21,23 +21,19 @@ stages: displayName: 'Download Nuget' artifact: 'drop-signed-nuget-${{ parameters.ArtifactSuffix }}' + - download: build + displayName: 'Download Nuget' + artifact: 'onnxruntime-osx' - script: | mv $(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }} $(Build.BinariesDirectory)/nuget-artifact - - - - task: DownloadPipelineArtifact@0 - displayName: 'Download OsX CustomOp test data' - inputs: - artifactName: 'onnxruntime-osx' - targetPath: '$(Build.BinariesDirectory)/testdata' + mv $(Pipeline.Workspace)/build/onnxruntime-osx $(Build.BinariesDirectory)/testdata - template: get-nuget-package-version-as-variable.yml parameters: packageFolder: '$(Build.BinariesDirectory)/nuget-artifact' - script: | - echo "TODO: Enable this test once fix this nuget test issue" $(Build.SourcesDirectory)/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh \ $(Build.BinariesDirectory)/nuget-artifact \ $(NuGetPackageVersionNumber) \ @@ -52,6 +48,4 @@ stages: OnnxRuntimeBuildDirectory: $(Build.BinariesDirectory) DisableContribOps: $(DisableContribOps) DisableMlOps: $(DisableMlOps) - IsReleaseBuild: $(IsReleaseBuild) - - - template: ../../templates/clean-agent-build-directory-step.yml + IsReleaseBuild: $(IsReleaseBuild) \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 11beafb7c05e1..8647b32962165 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -19,7 +19,7 @@ stages: parameters: NpmPackagingMode: 'dev' IsReleasePipeline: true - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' BuildStaticLib: true ExtraBuildArgs: '' UseWebPoolName: true @@ -340,7 +340,7 @@ stages: timeoutInMinutes: 150 variables: skipComponentGovernanceDetection: true - pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + pool: 'onnxruntime-Ubuntu2404-AMD-CPU' steps: - template: templates/set-version-number-variables-step.yml @@ -383,7 +383,7 @@ stages: - job: AndroidCustomBuildScript workspace: clean: all - pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + pool: 'onnxruntime-Ubuntu2404-AMD-CPU' variables: dockerImageTag: onnxruntime-android-custom-build steps: diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml index b10d15432ed5b..a21c72f5278c0 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml @@ -25,7 +25,7 @@ stages: dependsOn: jobs: - job: Python_Publishing_GPU - pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + pool: 'onnxruntime-Ubuntu2404-AMD-CPU' steps: - checkout: none - download: build diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index 01c1366107292..379b20ce8a0c4 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -11,7 +11,7 @@ stages: - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' - stage: Linux_Test_CPU_aarch64_stage @@ -38,7 +38,7 @@ stages: itemPattern: '*/*manylinux*x86_64.whl' arch: 'x86_64' machine_pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' # ****The following Stage depend on all previous tags. *** diff --git a/tools/ci_build/github/azure-pipelines/stages/c-api-linux-cpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/c-api-linux-cpu-stage.yml index ee46d5dac2ff8..ea706a65fb4c9 100644 --- a/tools/ci_build/github/azure-pipelines/stages/c-api-linux-cpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/c-api-linux-cpu-stage.yml @@ -6,6 +6,6 @@ stages: parameters: OnnxruntimeArch: 'x64' OnnxruntimeNodejsBindingArch: 'x64' - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' PackageJava: false PackageNodeJS: false \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml b/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml index 67fa5dba029b1..949d29d27da9d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml @@ -4,7 +4,7 @@ stages: jobs: - job: Download_Java_Tools pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux steps: - checkout: none diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml index e1247565d8f5b..bca95a4a2fd02 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml @@ -13,7 +13,7 @@ stages: clean: all timeoutInMinutes: 180 pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux variables: - template: ../templates/common-variables.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index e36fe98fe0ac2..4175a339535e4 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -16,7 +16,7 @@ stages: clean: all timeoutInMinutes: 150 pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux variables: - name: CUDA_VERSION_MAJOR @@ -65,7 +65,7 @@ stages: clean: all timeoutInMinutes: 180 pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux variables: - template: ../templates/common-variables.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml index 06b52173b236c..33d656d18928d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml @@ -29,7 +29,7 @@ stages: artifactName: drop-win-dml-arm64-zip targetPath: '$(Build.BinariesDirectory)/nuget-artifact-dml' outputs: - - ${{if and(eq(parameters.IsReleaseBuild, false), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/users/snnn/')))}}: + - ${{if and(eq(parameters.IsReleaseBuild, false), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')))}}: - output: nuget useDotNetTask: false # The default is false to use the NuGetCommand task. Set to true to use the DotNetCoreCLI task to publish packages. packagesToPush: '$(Build.ArtifactStagingDirectory)/*.nupkg' diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index e366dd147b118..c1b83c5e579dc 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -302,7 +302,7 @@ stages: - template: ../templates/py-linux.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} is1ES: true @@ -346,7 +346,7 @@ stages: jobs: - template: ../templates/py-linux-qnn.yml parameters: - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} is1ES: true diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml index fbfbc69bce0a8..25645044c30c3 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml @@ -8,7 +8,7 @@ stages: jobs: - job: Python_Publishing_GPU pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux steps: - checkout: none diff --git a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml index 4058ddfe089c8..202856cddbcd4 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -45,7 +45,7 @@ stages: - template: py-linux-gpu-stage.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} cuda_version: ${{ parameters.cuda_version }} diff --git a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml index 869fe05cb1756..396d37ca9710a 100644 --- a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml @@ -24,7 +24,7 @@ stages: jobs: - job: Set_Variables pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: 'linux' templateContext: sdl: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 5b95e6ff9c89a..6e6fb98e6e68c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -29,7 +29,7 @@ parameters: jobs: - job: Final_AAR_Testing_Android pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux workspace: clean: all diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index bbb84642320fb..e4bfe20238770 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -33,7 +33,7 @@ parameters: - name: pool_name displayName: Pool name type: string - default: 'onnxruntime-Ubuntu2204-AMD-CPU' + default: 'onnxruntime-Ubuntu2404-AMD-CPU' - name: packageName displayName: Package Name diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index f4335df1530cf..bf65b0c54cf27 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -100,15 +100,6 @@ stages: QnnSDKVersion: ${{ parameters.QnnSDKVersion }} is1ES: ${{ parameters.is1ES }} -- stage: Final_AAR_Testing_Android_QNN - dependsOn: Android_Java_API_AAR_Packaging_QNN - jobs: - - template: android-java-api-aar-test.yml - parameters: - artifactName: 'onnxruntime-android-qnn-aar' - packageName: 'onnxruntime-android-qnn' - QnnSDKVersion: ${{ parameters.QnnSDKVersion }} - - stage: iOS_Full_xcframework dependsOn: [] jobs: diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index cd2997cc389e9..aa1e38f8b0159 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -16,7 +16,7 @@ parameters: - name: PoolName type: string - default: 'onnxruntime-Ubuntu2204-AMD-CPU' + default: 'onnxruntime-Ubuntu2404-AMD-CPU' - name: ArtifactNamePrefix type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml index fb1c63e1f8a24..986a384d5197d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml @@ -31,7 +31,7 @@ stages: AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} OnnxruntimeArch: 'x64' OnnxruntimeNodejsBindingArch: 'x64' - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' ArtifactNamePrefix: ${{ parameters.ArtifactNamePrefix }} PackageJava: ${{ parameters.PackageJava }} PackageNodeJS: false diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 9f76c150ca2a4..ef0f4c6e0883c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -13,7 +13,7 @@ parameters: - name: PoolName type: string - default: 'onnxruntime-Ubuntu2204-AMD-CPU' + default: 'onnxruntime-Ubuntu2404-AMD-CPU' - name: SkipPublish type: boolean diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 7375318f3722e..52d9eb139fab7 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -29,7 +29,7 @@ stages: enabled: true scanOutputDirectoryOnly: true outputs: - - ${{if and(and(eq(parameters.PublishNugetToFeed, true), eq(parameters.IsReleaseBuild, false)), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/users/snnn/')))}}: + - ${{if and(and(eq(parameters.PublishNugetToFeed, true), eq(parameters.IsReleaseBuild, false)), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')))}}: - output: nuget # condition: and(succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) # Optional condition useDotNetTask: false # The default is false to use the NuGetCommand task. Set to true to use the DotNetCoreCLI task to publish packages. diff --git a/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml index 01920ad1f7fbb..4399219f3f7d5 100644 --- a/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml @@ -54,7 +54,7 @@ stages: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} IsReleasePipeline: false - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' BuildStaticLib: true ExtraBuildArgs: $(ExtraBuildArgs) WASMTemplate: linux-wasm-ci.yml From f4156640bda5ed465767bad6d639a8bf78231510 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 26 Jul 2025 01:00:46 +0800 Subject: [PATCH 15/33] [WebNN] Respect to the shape of zero_point for ConvInteger (#25484) WebNN requires the shapes of zeroPoint and scale for a qdq op to be same. However the ONNX allows [1] as scalar shape and some models may use [1] as the shape for x_zero_point. We should explicitly set the shape of scale to x_zero_point. --- js/web/test/suite-test-list.jsonc | 6 +++--- .../webnn/builders/impl/conv_op_builder.cc | 16 +++++++++------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index d08a72b922142..3f1face2a043c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1598,7 +1598,7 @@ // "test_averagepool_3d_default", "test_basic_conv_with_padding", "test_basic_conv_without_padding", - // "test_basic_convinteger", + "test_basic_convinteger", "test_batchnorm_epsilon_training_mode", "test_batchnorm_epsilon", "test_batchnorm_example_training_mode", @@ -1686,8 +1686,8 @@ "test_conv_with_strides_and_asymmetric_padding", "test_conv_with_strides_no_padding", "test_conv_with_strides_padding", - // // "test_convinteger_with_padding", - // // "test_convinteger_without_padding", + "test_convinteger_with_padding", + "test_convinteger_without_padding", "test_convtranspose_1d", // // "test_convtranspose_3d", // "test_convtranspose_autopad_same", diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index f75b6f41f7f9c..109228cc60d7d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -317,32 +317,34 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N int32_t x_type; ORT_RETURN_IF_NOT(GetType(*input_defs[0], x_type, logger), "Cannot get data type of input x"); - emscripten::val x_zero_point, w_zero_point, x_scale, w_scale; + emscripten::val x_zero_point, w_zero_point; + std::vector x_zero_point_shape; if (TensorExists(input_defs, 2)) { x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); + ORT_RETURN_IF_NOT(GetShape(*input_defs[2], x_zero_point_shape, logger), "Cannot get shape of x_zero_point"); } else { x_zero_point = model_builder.CreateOrGetConstant(x_type, 0); } - // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to default value 1.0f. // The x_zero_point must be a scalar and the scale input should have the same shape as the zero point input. // So the x_scale must be a scalar too. - x_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); + // ONNX allows 1D tensor of size 1 as scalar. So explicitly set the shape of x_scale to x_zero_point_shape. + emscripten::val x_scale = model_builder.CreateOrGetConstant( + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, GetNarrowedIntFromInt64(x_zero_point_shape)); // Dequantize x to Float32 common_options.set("label", node.Name() + "_dequantized_x"); input = model_builder.GetBuilder().call("dequantizeLinear", input, x_scale, x_zero_point, common_options); + std::vector w_zero_point_shape; if (TensorExists(input_defs, 3)) { w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); - std::vector w_zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[3], w_zero_point_shape, logger), "Cannot get shape of w_zero_point"); - w_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, - GetNarrowedIntFromInt64(w_zero_point_shape)); } else { w_zero_point = model_builder.CreateOrGetConstant(x_type, 0); - w_scale = x_scale; } + emscripten::val w_scale = model_builder.CreateOrGetConstant( + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, GetNarrowedIntFromInt64(w_zero_point_shape)); // Dequantize w to Float32 common_options.set("label", node.Name() + "_dequantized_w"); filter = model_builder.GetBuilder().call("dequantizeLinear", filter, w_scale, w_zero_point, From 11aebeb281950741a0350d0354b0aac9ad032511 Mon Sep 17 00:00:00 2001 From: "microsoft-github-policy-service[bot]" <77245923+microsoft-github-policy-service[bot]@users.noreply.github.com> Date: Fri, 25 Jul 2025 10:11:49 -0700 Subject: [PATCH 16/33] Auto-generated baselines by 1ES Pipeline Templates (#25536) Co-authored-by: microsoft-github-policy-service[bot] <77245923+microsoft-github-policy-service[bot]@users.noreply.github.com> --- .config/1espt/PipelineAutobaseliningConfig.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.config/1espt/PipelineAutobaseliningConfig.yml b/.config/1espt/PipelineAutobaseliningConfig.yml index 68bf98b3a2781..f59528797405e 100644 --- a/.config/1espt/PipelineAutobaseliningConfig.yml +++ b/.config/1espt/PipelineAutobaseliningConfig.yml @@ -133,13 +133,16 @@ pipelines: lastModifiedDate: 2025-04-24 armory: lastModifiedDate: 2025-04-24 + policheck: + lastModifiedDate: 2025-07-25 binary: credscan: lastModifiedDate: 2025-04-25 binskim: - lastModifiedDate: 2025-04-25 + lastModifiedDate: 2025-07-25 spotbugs: lastModifiedDate: 2025-04-25 + usedBinskimScanAllExtensions: true 1757: retail: source: From cd450d1563d65fcf8d1748daad894bc036e9efad Mon Sep 17 00:00:00 2001 From: Damien Dooley Date: Fri, 25 Jul 2025 18:23:19 +0100 Subject: [PATCH 17/33] KleidiAI SGEMM/IGEMM/Quantized MatMul - Modular MLAS API Changes for KleidiAI (#25187) This PR introduces the initial integration of KleidiAI-optimized microkernels into ONNX Runtime's MLAS backend, focusing on support for: - SGEMM - IGEMM - Dynamic Quantized MatMuls Key changes: Implements overrides for MlasGemmBatch, MlasGemmPackBSize, and MlasGemmPackB using KleidiAI where applicable. Applies dispatch logic based on TransA == CblasNoTrans and SME2 availability. Supports float32 and int8 GEMM workloads with conditionally invoked SME2 paths. Maintains fallback paths to default MLAS implementations to ensure coverage and stability. **Known Issues / Next Steps:** Requesting feedback specifically on the API structure: Does the new MLAS interface design align with long-term extensibility? Are the dispatch points and override boundaries well-structured? Indicative Performance figures: The kernels added are particularly effective for Conv2D operators: * Based on KleidiAI SME running mobilenet_v1_ssd_f32 on Mac Mini M4 on a single thread image --------- Signed-off-by: Damien Dooley Co-authored-by: Jonathan Clohessy Co-authored-by: Declan Flavin Co-authored-by: Colm Donelan Co-authored-by: Damien Dooley --- cmake/CMakeLists.txt | 41 +- cmake/deps.txt | 2 +- .../external/onnxruntime_external_deps.cmake | 11 + cmake/onnxruntime_mlas.cmake | 24 +- onnxruntime/contrib_ops/cpu/bert/attention.cc | 4 +- .../quantization/dynamic_quantize_matmul.cc | 271 +++++-- onnxruntime/core/common/cpuid_info.cc | 2 + onnxruntime/core/common/cpuid_info.h | 2 + onnxruntime/core/mlas/inc/mlas.h | 80 ++ onnxruntime/core/mlas/lib/convolve.cpp | 15 +- .../mlas/lib/kleidiai/convolve_kleidiai.cpp | 720 ++++++++++++++++++ .../core/mlas/lib/kleidiai/mlasi_kleidiai.h | 114 +++ .../core/mlas/lib/kleidiai/qgemm_kleidiai.cpp | 116 +++ .../core/mlas/lib/kleidiai/sgemm_kleidiai.cpp | 348 +++++++++ onnxruntime/core/mlas/lib/mlasi.h | 127 ++- onnxruntime/core/mlas/lib/platform.cpp | 13 + onnxruntime/core/mlas/lib/qgemm.cpp | 87 ++- onnxruntime/core/mlas/lib/sgemm.cpp | 38 +- .../core/optimizer/matmul_integer_to_float.cc | 15 +- onnxruntime/core/providers/cpu/math/gemm.cc | 8 +- .../providers/cpu/math/gemm_matmul_common.h | 2 +- onnxruntime/core/providers/cpu/math/matmul.cc | 2 +- .../core/providers/cpu/rnn/deep_cpu_gru.cc | 16 +- .../core/providers/cpu/rnn/deep_cpu_lstm.cc | 4 +- onnxruntime/test/mlas/bench/bench_sgemm.cpp | 7 +- .../test/mlas/unittest/test_dynamic_qgemm.cpp | 165 ++++ onnxruntime/test/mlas/unittest/test_fgemm.h | 4 +- .../test/mlas/unittest/test_fgemm_fixture.h | 1 + .../test/optimizer/graph_transform_test.cc | 72 ++ .../test/testdata/matmul_integer_to_float.py | 63 +- ...to_float_int8_bias_initializer_index0.onnx | Bin 0 -> 472 bytes ...to_float_int8_bias_initializer_index1.onnx | Bin 0 -> 472 bytes tools/ci_build/build.py | 17 +- 33 files changed, 2265 insertions(+), 126 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp create mode 100644 onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h create mode 100644 onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp create mode 100644 onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp create mode 100644 onnxruntime/test/testdata/matmul_integer_to_float_int8_bias_initializer_index0.onnx create mode 100644 onnxruntime/test/testdata/matmul_integer_to_float_int8_bias_initializer_index1.onnx diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 611203f0b3f72..b0941b4d0c922 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -83,6 +83,11 @@ option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_USE_KLEIDIAI "Build with KleidiAI integration in MLAS" OFF) +# iOS simulator build explicitly builds targets with USE_KLEIDIAI=ON so attempting to force override if so +if(APPLE AND CMAKE_OSX_ARCHITECTURES MATCHES "x86_64") + message(WARNING "Disabling KleidiAI: not supported on Apple x86_64 platforms") + set(onnxruntime_USE_KLEIDIAI OFF CACHE BOOL "" FORCE) +endif() option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) option(onnxruntime_BUILD_OBJC "Build Objective-C library" OFF) @@ -275,8 +280,6 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() - - # Single output director for all binaries set(RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin CACHE PATH "Single output directory for all binaries.") @@ -648,17 +651,25 @@ else() endif() endif() -if (onnxruntime_USE_KLEIDIAI AND NOT MSVC AND ( - (onnxruntime_target_platform STREQUAL "aarch64") OR - (onnxruntime_target_platform STREQUAL "ARM64") OR - (APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64"))) - check_cxx_compiler_flag(-march=armv8.2-a+dotprod HAS_ARM64_DOTPROD) - check_cxx_compiler_flag(-march=armv8.2-a+i8mm HAS_ARM64_I8MM) - if (NOT HAS_ARM64_DOTPROD) - message(FATAL_ERROR "The compiler doesn't support dotprod") - endif() - if (NOT HAS_ARM64_I8MM) - message(FATAL_ERROR "The compiler doesn't support i8mm") +if (onnxruntime_USE_KLEIDIAI AND ( + (onnxruntime_target_platform STREQUAL "aarch64") OR + (onnxruntime_target_platform STREQUAL "ARM64") OR + (APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64"))) + + # TODO Add checks for MSVC Compilation + if(NOT MSVC) + check_cxx_compiler_flag(-march=armv8.2-a+dotprod HAS_ARM64_DOTPROD) + check_cxx_compiler_flag(-march=armv8.2-a+i8mm HAS_ARM64_I8MM) + if (NOT HAS_ARM64_DOTPROD) + message(FATAL_ERROR "The compiler doesn't support dotprod") + endif() + if (NOT HAS_ARM64_I8MM) + message(FATAL_ERROR "The compiler doesn't support i8mm") + endif() + else() + message(STATUS "Skipping -march= checks on MSVC (not supported), assuming dotprod/i8mm support manually.") + set(HAS_ARM64_DOTPROD TRUE) + set(HAS_ARM64_I8MM TRUE) endif() endif() @@ -1008,6 +1019,10 @@ function(onnxruntime_set_compile_flags target_name) if (onnxruntime_ENABLE_ATEN) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() + # TODO: Narrow scope for Kleidiai compile + if (onnxruntime_USE_KLEIDIAI) + target_compile_definitions(${target_name} PRIVATE USE_KLEIDIAI) + endif() set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) if (onnxruntime_USE_CUDA) diff --git a/cmake/deps.txt b/cmake/deps.txt index 7089012a65f26..01e5c809640f9 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96 dawn;https://github.com/google/dawn/archive/9733be39e18186961d503e064874afe3e9ceb8d1.zip;2a4017c32892b90d072a9102eba90ae691fae36d -kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.4.0.tar.gz;22d3b57b54a61c194ab256ff11b0353a3b220244 +kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.9.0.tar.gz;a2765979f64efb173a4b8ba4de39dcba9c655786 duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 228906030d14c..f76ad642447ba 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + message(STATUS "Loading Dependencies URLs ...") include(external/helper_functions.cmake) @@ -819,6 +822,14 @@ if(onnxruntime_USE_COREML) endif() +if(onnxruntime_USE_KLEIDIAI) + # Disable the KleidiAI tests + set(KLEIDIAI_BUILD_TESTS OFF) + + onnxruntime_fetchcontent_declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai} EXCLUDE_FROM_ALL) + onnxruntime_fetchcontent_makeavailable(kleidiai) +endif() + set(onnxruntime_LINK_DIRS) if (onnxruntime_USE_CUDA) find_package(CUDAToolkit REQUIRED) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 47e7779d93b33..24cecf07e8e36 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -267,24 +267,23 @@ function(setup_mlas_source_for_windows) endfunction() function(setup_kleidiai) - target_compile_definitions(onnxruntime_mlas PRIVATE USE_KLEIDIAI) - - # Disable the KleidiAI tests - set(KLEIDIAI_BUILD_TESTS OFF) - - # Fetch KleidiAI sources: - if (NOT TARGET kleidiai) - onnxruntime_fetchcontent_declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai} EXCLUDE_FROM_ALL) - endif() - onnxruntime_fetchcontent_makeavailable(kleidiai) - target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/kai_ukernel_interface.cpp + ${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp ) target_link_libraries(onnxruntime_mlas PRIVATE kleidiai) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES kleidiai) set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES} PARENT_SCOPE) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS kleidiai EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() endfunction() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") @@ -311,7 +310,6 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") elseif(MSVC) setup_mlas_source_for_windows() else() - if(APPLE) get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index de23444e95778..d16c55695772b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -71,7 +71,7 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, const T* weights_data, size_t weight_matrix_col_size, /*out*/ PrePackedWeights* prepacked_weights) { - size_t packb_size = MlasGemmPackBSize(head_size, input_hidden_size); + size_t packb_size = MlasGemmPackBSize(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size); if (packb_size == 0) { return false; } @@ -87,7 +87,7 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, memset(packed_weights_data, 0, packed_weights_data_size); for (size_t i = 0; i < loop_len; i++) { - MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data); + MlasGemmPackB(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data); packed_weights_data += packb_size; weights_data += head_size; } diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 69eabcfe2654a..e2bb3b508ca7c 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -11,6 +11,7 @@ #include "core/util/qmath.h" #include +#include namespace onnxruntime { namespace contrib { @@ -65,13 +66,13 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, bool a_is_signed, const Tensor* b_tensor, const Tensor* b_scale_tensor, - const Tensor* b_zp_tensor, + const Tensor* b_zp_constant_tensor, const Tensor* bias_tensor) const { MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a_shape, b_tensor ? b_tensor->Shape() : b_shape_, b_scale_tensor ? &b_scale_tensor->Shape() : nullptr, - b_zp_tensor ? &b_zp_tensor->Shape() : nullptr)); + b_zp_constant_tensor ? &b_zp_constant_tensor->Shape() : nullptr)); Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); // Bail out early if the output is going to be empty @@ -85,12 +86,12 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, bool is_b_zp_per_column = false; uint8_t b_zp_default = 0; const uint8_t* b_zp_ptr = &b_zp_default; - if (nullptr != b_zp_tensor) { - ORT_ENFORCE(IsBQuantParamSupported(b_zp_tensor->Shape(), b_tensor ? b_tensor->Shape() : b_shape_), + if (nullptr != b_zp_constant_tensor) { + ORT_ENFORCE(IsBQuantParamSupported(b_zp_constant_tensor->Shape(), b_tensor ? b_tensor->Shape() : b_shape_), "MatmulInteger : b zero point is not valid"); - is_b_zp_per_column = !IsScalarOr1ElementVector(b_zp_tensor); - b_zp_ptr = static_cast(b_zp_tensor->DataRaw()); + is_b_zp_per_column = !IsScalarOr1ElementVector(b_zp_constant_tensor); + b_zp_ptr = static_cast(b_zp_constant_tensor->DataRaw()); } // process scale of b @@ -161,6 +162,122 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { Status Compute(OpKernelContext* context) const override; +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override { + // only pack Matrix B + if (input_idx == GetBIdx()) { + const Tensor* b_zp_constant_tensor{nullptr}; + bool b_quantization_is_asymmetric = false; + + // zero point tensor could be provided as a direct input to the kernel and not as a constant so this + // test is not sufficient + const OrtValue* b_zp; + if (Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp)) { + b_zp_constant_tensor = &b_zp->Get(); + } + + // MlasDynamicQgemm requires symmetric quantization for B, so no zero point should exist or it should + // have a zero value + if (b_zp_constant_tensor != nullptr) { // Covers the case where tensor is not a constant + const auto& shape = b_zp_constant_tensor->Shape(); + const auto* zp_data = static_cast(b_zp_constant_tensor->DataRaw()); + size_t zp_size = static_cast(shape.Size()); + // MlasDynamicQgemm requires symmetric quantization: zp must be scalar 0 or 1D all-zero + if ((shape.NumDimensions() == 0) && (zp_data[0] == 0)) { + b_quantization_is_asymmetric = false; + } else if (shape.NumDimensions() == 1) { + b_quantization_is_asymmetric = false; + for (size_t i = 0; i < zp_size; ++i) { + if (zp_data[i] != 0) { + b_quantization_is_asymmetric = true; + break; + } + } + } else { + // Unsupported higher-rank zp tensor + b_quantization_is_asymmetric = true; + } + } + + // MlasDynamicQgemm requires scale data to be available at packing stage + const Tensor* b_scale_tensor = nullptr; + const bool b_scale_available = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_tensor); + + can_use_dynamic_quant_mlas_ = (!b_quantization_is_asymmetric && b_scale_available); + + // Only handle the common case of a 2D weight matrix. Additional matrices + // could be handled by stacking the packed buffers. + b_shape_ = tensor.Shape(); + // TO DO: handle b_shape_.NumDimensions() > 2 and all dimension values but the last two being 1. + if (!(b_shape_.NumDimensions() == 2 || (b_shape_.NumDimensions() == 3 && b_shape_[0] == 1))) { + can_use_dynamic_quant_mlas_ = false; + } + + // Can we use the mlas dynamic Q gemm interface supported with float output ? + if (!can_use_dynamic_quant_mlas_) { + // default to piece wise mlas interface with separate int matmul, quantize and float conversion + return MatMulIntegerToFloatBase::PrePack(tensor, input_idx, alloc, is_packed, prepacked_weights); + } + is_packed = false; + + // Default to all zeros for bias + const Tensor* bias_tensor{nullptr}; + const OrtValue* bias; + if (Info().TryGetConstantInput(IN_BIAS, &bias)) { + bias_tensor = &bias->Get(); + dynamic_quant_mlas_bias_data_was_packed_ = true; + } + size_t K = static_cast(b_shape_[0]); + size_t N = static_cast(b_shape_[1]); + + const auto* b_data = static_cast(tensor.DataRaw()); + + std::optional b_trans_buffer; + if (IsBTransposed()) { + std::swap(K, N); + b_data = quantization::TransPoseInputData(b_data, b_trans_buffer, alloc, N, K); + } + + const size_t packed_b_size = MlasDynamicQgemmPackBSize(N, K); + if (packed_b_size == 0) { + return Status::OK(); + } + + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); + // Initialize memory to 0 as there could be some padding associated with pre-packed + // buffer memory and we do not want it uninitialized and generate different hashes + // if and when we try to cache this pre-packed buffer for sharing between sessions. + memset(packed_b_.get(), 0, packed_b_size); + + const auto scales = static_cast(b_scale_tensor->Shape().Size()) == N ? std::vector(&b_scale_tensor->Data()[0], + &b_scale_tensor->Data()[N]) + : + // Broadcast matrix scale to all channels + std::vector(N, b_scale_tensor->Data()[0]); + + const auto biases = bias_tensor != nullptr ? std::vector(&bias_tensor->Data()[0], + &bias_tensor->Data()[N]) + : + // Broadcast zero to all channels - no bias data is available + std::vector(N, 0.f); + + MlasDynamicQgemmPackB(N, K, reinterpret_cast(b_data), scales.data(), biases.data(), + packed_b_.get()); + + bool share_prepacked_weights = (prepacked_weights != nullptr); + if (share_prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size); + } + + is_packed = true; + } + return Status::OK(); + } +#endif + enum InputTensors : int { IN_A = 0, IN_B = 1, @@ -171,6 +288,12 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { protected: int GetBIdx() const override { return IN_B; } + + private: + bool can_use_dynamic_quant_mlas_{false}; +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + bool dynamic_quant_mlas_bias_data_was_packed_{false}; +#endif }; class MatMulIntegerToFloat final : public MatMulIntegerToFloatBase { @@ -199,44 +322,104 @@ class MatMulIntegerToFloat final : public MatMulIntegerToFloatBase { }; Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { - const Tensor* a = ctx->Input(IN_A); - const Tensor* b = packed_b_ ? nullptr : ctx->Input(IN_B); - - const Tensor* b_scale_tensor = ctx->Input(IN_B_SCALE); - const Tensor* b_zp_tensor = ctx->Input(IN_B_ZERO_POINT); - - // calculate quantization parameter of a - const float* a_data = a->Data(); - int64_t num_of_elements = a->Shape().Size(); - - float a_scale; - uint8_t a_zero_point; - GetQuantizationParameter(a_data, num_of_elements, a_scale, a_zero_point, ctx->GetOperatorThreadPool()); - - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); - uint8_t* a_data_quant = static_cast(allocator->Alloc(SafeInt(num_of_elements) * sizeof(uint8_t))); - BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(std::move(allocator))); - - ParQuantizeLinearStd(a_data, a_data_quant, narrow(num_of_elements), a_scale, a_zero_point, ctx->GetOperatorThreadPool()); - - bool is_b_scale_supported = IsBQuantParamSupported(b_scale_tensor->Shape(), b ? b->Shape() : b_shape_); - ORT_RETURN_IF_ERROR(ComputeCommon( - ctx, - a_data_quant, - a->Shape(), - a_scale, - a_zero_point, - false /*a_is_signed*/, - b, - is_b_scale_supported ? b_scale_tensor : nullptr, - b_zp_tensor, - ctx->Input(IN_BIAS))); - - if (!is_b_scale_supported) { - ScaleOutput(*b_scale_tensor, *ctx->Output(0)); + // Can this operation be offloaded to a MLAS specific dynamic quantization matmul ? + if (!can_use_dynamic_quant_mlas_) { + const Tensor* a = ctx->Input(IN_A); + const Tensor* b = packed_b_ ? nullptr : ctx->Input(IN_B); + + const Tensor* b_scale_tensor = ctx->Input(IN_B_SCALE); + const Tensor* b_zp_constant_tensor = ctx->Input(IN_B_ZERO_POINT); + + // calculate quantization parameter of a + const float* a_data = a->Data(); + int64_t num_of_elements = a->Shape().Size(); + + float a_scale; + uint8_t a_zero_point; + GetQuantizationParameter(a_data, num_of_elements, a_scale, a_zero_point, ctx->GetOperatorThreadPool()); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + uint8_t* a_data_quant = static_cast(allocator->Alloc(SafeInt(num_of_elements) * sizeof(uint8_t))); + BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(std::move(allocator))); + + ParQuantizeLinearStd(a_data, a_data_quant, narrow(num_of_elements), a_scale, a_zero_point, ctx->GetOperatorThreadPool()); + + bool is_b_scale_supported = IsBQuantParamSupported(b_scale_tensor->Shape(), b ? b->Shape() : b_shape_); + const bool is_a_signed = false; + ORT_RETURN_IF_ERROR(ComputeCommon( + ctx, + a_data_quant, + a->Shape(), + a_scale, + a_zero_point, + is_a_signed, + b, + is_b_scale_supported ? b_scale_tensor : nullptr, + b_zp_constant_tensor, + ctx->Input(IN_BIAS))); + + if (!is_b_scale_supported) { + ScaleOutput(*b_scale_tensor, *ctx->Output(0)); + } } + // Guard against KleidiAI functions being called in non kleidi builds + // TODO: migrate to a suitable override function call for kleidi dynamic qgemm function calls +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + else { + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(ctx->Input(IN_A)->Shape(), + b_shape_, // ctx->Input(IN_B)->Shape(), this is not available now constant data is + // deleted during session init post prepacking + nullptr, + nullptr)); + + Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) + return Status::OK(); + + auto a_data = static_cast(ctx->Input(IN_A)->DataRaw()); + auto* y_data = y->MutableData(); + + // batch gemm + MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS gemm_shape; + gemm_shape.M = static_cast(helper.M()); + gemm_shape.N = static_cast(helper.N()); + gemm_shape.K = static_cast(helper.K()); + + const size_t num_gemms = helper.OutputOffsets().size(); + std::vector gemm_data_vec(num_gemms); + + for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) { + auto& params = gemm_data_vec[gemm_idx]; + params.A = reinterpret_cast(a_data + helper.LeftOffsets()[gemm_idx]); + params.lda = gemm_shape.K; + params.PackedB = packed_b_.get(); + params.C = y_data + helper.OutputOffsets()[gemm_idx]; + params.ldc = gemm_shape.N; + } + MlasDynamicQGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool()); + // This evaluates to true if bias data was not provided as constant data for prepacking stage + if (!dynamic_quant_mlas_bias_data_was_packed_) { + if (ctx->Input(IN_BIAS) != nullptr) { + const auto biases = std::vector(&ctx->Input(IN_BIAS)->Data()[0], + &ctx->Input(IN_BIAS)->Data()[gemm_shape.N]); + + // deferred adding of bias + for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) { + float* MxN = y_data + helper.OutputOffsets()[gemm_idx]; + for (auto l = gemm_shape.M; l > 0; --l) { + MlasEltwiseAdd(MxN, biases.data(), MxN, gemm_shape.N); + MxN += gemm_shape.N; + } + } + } + } + } +#endif return Status::OK(); } @@ -275,7 +458,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const { a_zero_point = *(static_cast(a_zero_point_tensor->DataRaw())); } - const Tensor* b_zp_tensor = ctx->Input(IN_B_ZERO_POINT); + const Tensor* b_zp_constant_tensor = ctx->Input(IN_B_ZERO_POINT); ORT_RETURN_IF_ERROR(ComputeCommon( ctx, static_cast(a->DataRaw()), @@ -285,7 +468,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const { a->IsDataType(), b, is_b_scale_supported ? b_scale_tensor : nullptr, - b_zp_tensor, + b_zp_constant_tensor, ctx->Input(IN_BIAS))); if (!is_a_scale_scalar) { diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index c4667d53c0674..dccfdbda8971b 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -190,6 +190,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + has_arm_sme_ = cpuinfo_has_arm_sme(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -342,6 +343,7 @@ void CPUIDInfo::ArmAppleInit() { has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + has_arm_sme_ = cpuinfo_has_arm_sme(); // Note: We leave is_armv8_narrow_ld_ unset because it only applies to a limited set of uarchs that we don't expect // to encounter on Apple platforms. diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 9c67ebbffa260..84571fa12e6ea 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -40,6 +40,7 @@ class CPUIDInfo { bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + bool HasArm_SME() const { return has_arm_sme_; } uint32_t GetCurrentCoreIdx() const; @@ -127,6 +128,7 @@ class CPUIDInfo { bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; bool has_arm_neon_bf16_{false}; + bool has_arm_sme_{false}; std::string vendor_; uint32_t vendor_id_; diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 4d85c35461825..22bddf58997bc 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -631,6 +631,52 @@ MlasGemm( { MlasGemmBatch(Shape, &DataParams, 1, ThreadPool); } +/** + * @brief Parameters that define the shape of a dynamically quantized GEMM operation. + * + * The structure holds the dimensions of the matrices involved in the GEMM + * computation: + * C = A * B + */ +struct MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS { + size_t M = 0; /**< Row size of matrix A */ + size_t N = 0; /**< Column size of matrix B */ + size_t K = 0; /**< Column size of matrix A and Row size of matrix B */ +}; +/** + * @brief Parameters that define the data buffers and layout for a dynamic quant GEMM. + * + * This structure provides the memory pointers and strides for matrices + * involved in a dynamically quantized GEMM operation, along with the packed B format. + */ +struct MLAS_GEMM_DYN_QUANT_DATA_PARAMS { + const float* A = nullptr; /**< Pointer to input matrix A in FP32 format**/ + size_t lda = 0; /**< Number of elements between adjecent rows in A*/ + const void* PackedB = 0; /**< Points to packed weight matrix B */ + float *C = nullptr; /**< Points to output Matric C */ + size_t ldc = 0; /**< Number of elements between adjecent rows in Matrix C*/ + void* Workspace = nullptr; /**< Workspace buffer for LHS Packing Allocation */ + size_t WorkspaceSize = 0; /**< Workspace buffer size */ +}; + +void +MLASCALL +MlasDynamicQGemmBatch ( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +); + +inline void +MlasDynamicQGemm ( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +) { + MlasDynamicQGemmBatch(Shape, DataParams, 1, ThreadPool); +} + // // Symmetric QGEMM has limited buffer overrun. @@ -685,6 +731,8 @@ MlasSymmQgemmBatch( size_t MLASCALL MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, size_t N, size_t K ); @@ -692,6 +740,7 @@ MlasGemmPackBSize( void MLASCALL MlasGemmPackB( + CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, @@ -750,6 +799,26 @@ MlasSymmQgemmPackB( void* PackedB ); + +size_t +MLASCALL +MlasDynamicQgemmPackBSize( + size_t N, + size_t K +); + +void +MLASCALL +MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +); + + // // Convolution routines. // @@ -2024,3 +2093,14 @@ MlasFlashAttention( MlasFlashAttentionThreadedArgs* args, MLAS_THREADPOOL* ThreadPool ); + +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +/** + * @brief Function to override the packing mechanism decision if kleidi ai is included + * @param enable enable kleidiai packing (allow or disallow depending on true/false) + * @return +*/ +void +MLASCALL +MlasGemmBatchPackUseKleidi(bool enable); +#endif diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index ec79641559c6b..bc1221475fd90 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -861,6 +861,12 @@ Return Value: --*/ { + // Override + if(GetMlasPlatform().MlasConvOverride != nullptr && + GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){ + return; + } + const size_t FilterCount = Parameters->FilterCount; const size_t OutputSize = Parameters->OutputSize; const size_t K = Parameters->K; @@ -1094,6 +1100,13 @@ Return Value: --*/ { + // Override + if (GetMlasPlatform().MlasConvPrepareOverride != nullptr && + GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels, + InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount, + Activation, WorkingBufferSize, Beta, ThreadPool)){ + return; + } // // Save the convolution parameters. // @@ -1299,4 +1312,4 @@ Return Value: } #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp new file mode 100644 index 0000000000000..9eaf4902f536a --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -0,0 +1,720 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include +#include +#include +#include +#include "mlasi_kleidiai.h" +#include +#include + +#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" + +// Right-hand-side (weights) cache key +struct RhsCacheKey { + size_t co, ci, kh, kw, dilationh, dilationw; + size_t weights_hash; + + bool operator==(const RhsCacheKey& other) const { + return co == other.co && ci == other.ci && + kh == other.kh && kw == other.kw && + dilationh == other.dilationh && dilationw == other.dilationw && + weights_hash == other.weights_hash; + } +}; + + +// Left-hand-side (input indirection) cache key +struct LhsCacheKey { + size_t ci, ih, iw; + size_t padding, sh, sw; + size_t kh, kw; + size_t dilationh, dilationw; + size_t data_hash; + + bool operator==(const LhsCacheKey& other) const { + return ci == other.ci && ih == other.ih && iw == other.iw && + padding == other.padding && sh == other.sh && sw == other.sw && + kh == other.kh && kw == other.kw && + dilationh == other.dilationh && dilationw == other.dilationw && + data_hash == other.data_hash; + } +}; + +// Derived from 2^32 * (sqrt(5) - 1) / 2 ≈ 0.6180339887 (reciprocal of the golden ratio) +// Based on Knuth's multiplicative hashing method +constexpr size_t HASH_GOLDEN_RATIO_CONST = 0x9e3779b9; + +size_t HashWeights(const float* data, size_t count = 16) { + size_t h = 0; + for (size_t i = 0; i < count; ++i) { + h ^= std::hash()(data[i]) + HASH_GOLDEN_RATIO_CONST + (h << 6) + (h >> 2); + } + return h; +} + +namespace std { + // Specialize hash type for cache keys and do it within namespace std. + // Doing this allows standard containers like std::unordered_map to find + // the appropriate hash function via template specialization, as ADL + // (argument-dependent lookup) does not apply to std::hash. + template<> + struct hash { + size_t operator()(const RhsCacheKey& k) const { + return k.weights_hash ^ + (std::hash()(k.co) << 1) ^ + (std::hash()(k.ci) << 2) ^ + (std::hash()(k.kh) << 3) ^ + (std::hash()(k.kw) << 4) ^ + (std::hash()(k.dilationh) << 5) ^ + (std::hash()(k.dilationw) << 6); + } + }; + + template<> + struct hash { + size_t operator()(const LhsCacheKey& k) const { + return k.data_hash ^ + (std::hash()(k.ci) << 1) ^ + (std::hash()(k.ih) << 2) ^ + (std::hash()(k.iw) << 3) ^ + (std::hash()(k.padding) << 4) ^ + (std::hash()(k.sh) << 5) ^ + (std::hash()(k.sw) << 6) ^ + (std::hash()(k.kh) << 7) ^ + (std::hash()(k.kw) << 8) ^ + (std::hash()(k.dilationh) << 9) ^ + (std::hash()(k.dilationw) << 10); + } + }; + +} + + +static constexpr size_t ComputeKernelSize(const size_t D, const size_t K) { + // D - dilation size + // K - kernel size + + // D*S scale 1D kernel dimension by dilation factor + // (D-1) remove affect of dilation scaling at kernel end + return (D*K) - (D - 1); +} + +static constexpr size_t ComputeConvOutSize(const size_t L, const size_t K, const size_t P, const size_t S) { + + //With start + end padding + + //L - input size + //K - kernel size + //P - Padding size + //S - stride size + + //Does the convolution compute one value or less ? + if ( S > 0 && (L + 2*P) >= K) { + // L-(K-1) standard convolution output size is L-(K-1) for a step size of 1 with no padding + // (2*P) 1D start and end padding + // (L+2*P)-(K-1) the 1D length of convolution result for a kernel step size of 1 + // /S apply the kernel step + return (((L - K) + (2 * P)) / S) + 1; + } + return 0; +} + +static size_t ComputeMlasWorkingBufferSize(const size_t co, + const size_t ih, const size_t iw, + const size_t kh, const size_t kw, + const size_t dilationh, const size_t dilationw, + const size_t sh, const size_t sw, + const size_t padding) { + // dimensions of dilated kernel + const auto d_kh = ComputeKernelSize(dilationh, kh); + const auto d_kw = ComputeKernelSize(dilationw, kw); + + const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) * + ComputeConvOutSize(iw, d_kw, padding, sw); + + return m * co; +} + +static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) { + + //functional checks - logically can the conv be performed + if ((Parameters->Dimensions != 2) || + (Parameters->BatchCount != 1) || + (Parameters->Beta != 0.f) || + (Parameters->Padding[0] != Parameters->Padding[1]) || + (Parameters->Padding[0] != Parameters->Padding[2]) || + (Parameters->Padding[0] != Parameters->Padding[3]) || + (ComputeConvOutSize(Parameters->InputShape[0], + ComputeKernelSize(Parameters->DilationShape[0],Parameters->KernelShape[0]), + Parameters->Padding[0], Parameters->StrideShape[0]) * + ComputeConvOutSize(Parameters->InputShape[1], + ComputeKernelSize(Parameters->DilationShape[1],Parameters->KernelShape[1]), + Parameters->Padding[1], Parameters->StrideShape[1]) == 0)) { + return false; + } + + //optimization checks - is the implementation optimal for the conv request + + const auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + auto M = ComputeConvOutSize(Parameters->InputShape[0], ComputeKernelSize(Parameters->DilationShape[0], + Parameters->KernelShape[0]), Parameters->Padding[0], Parameters->StrideShape[0]) * + ComputeConvOutSize(Parameters->InputShape[1], ComputeKernelSize(Parameters->DilationShape[1], + Parameters->KernelShape[1]), Parameters->Padding[1], Parameters->StrideShape[1]); + auto N = Parameters->FilterCount; + auto K = Parameters->InputChannels * Parameters->KernelShape[0] * Parameters->KernelShape[1]; + + //Can use these variables to add other conditions as required + MLAS_UNREFERENCED_PARAMETER(M); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(m_step); + MLAS_UNREFERENCED_PARAMETER(n_step); + + if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) { + return false; + } + return true; +} + +//General purpose axis swapping +static auto Transpose4D(std::array shape_in, + const float* in, + std::array permute) { + + std::array shape_out{shape_in[permute[0]], + shape_in[permute[1]], + shape_in[permute[2]], + shape_in[permute[3]]}; + + assert((shape_in[0] * shape_in[1] * shape_in[2] * shape_in[3]) == + (shape_out[0] * shape_out[1] * shape_out[2] * shape_out[3])); + assert(permute[0] < 4 && permute[1] < 4 && permute[2] < 4 && permute[3] < 4); + + const size_t get_stride[] {shape_in[1] * shape_in[2] * shape_in[3], shape_in[2] * shape_in[3], shape_in[3]}; + auto get = [get_stride,in](const std::array& el) { + return in[el[0] * get_stride[0] + + el[1] * get_stride[1] + + el[2] * get_stride[2] + + el[3]]; + }; + + auto out_ = std::make_unique(shape_in[0] * shape_in[1] * shape_in[2] * shape_in[3]); + auto out = out_.get(); + + const size_t set_stride[]{shape_out[1] * shape_out[2] * shape_out[3], shape_out[2] * shape_out[3], shape_out[3]}; + auto set = [set_stride,out](const std::array& el, float v) { + out[el[0] * set_stride[0] + + el[1] * set_stride[1] + + el[2] * set_stride[2] + + el[3]] = v; + }; + + std::array shape; + for (shape[0] = 0; shape[0] < shape_in[0]; ++shape[0]) { + for (shape[1] = 0; shape[1] < shape_in[1]; ++shape[1]) { + for (shape[2] = 0; shape[2] < shape_in[2]; ++shape[2]) { + for (shape[3] = 0; shape[3] < shape_in[3]; ++shape[3]) { + set({shape[permute[0]], shape[permute[1]], shape[permute[2]], shape[permute[3]]}, get(shape)); + } + } + } + } + + return out_; +} + +//nchw to nhwc specific axis swapping +static std::unique_ptr NChwToNhwc(const size_t n, + const size_t c, + const size_t h, + const size_t w, + const float* RESTRICT in, + const size_t dilationh=1, + const size_t dilationw=1, + const bool zero_fill=false, + MLAS_THREADPOOL* ThreadPool=nullptr) { + + const auto d_h = ComputeKernelSize(dilationh, h); + const auto d_w = ComputeKernelSize(dilationw, w); + + auto t = std::make_unique(n*d_h*d_w*c); + if (zero_fill) { + std::fill(&t.get()[0], &t.get()[n*d_h*d_w*c], 0.f); + } + + if (dilationh > 1 || dilationw > 1 || n > 1) { + const size_t get_strides[] {c*h*w,h*w,w}; + auto get = [get_strides,in](const std::array& el) { + return in[el[0]*get_strides[0] + + el[1]*get_strides[1] + + el[2]*get_strides[2] + + el[3]]; + }; + + const size_t set_strides[] {d_h*d_w*c,dilationh*d_w*c,dilationw*c}; + auto set = [set_strides](const std::array& el, float v, float* out) { + out[el[0]*set_strides[0] + + el[1]*set_strides[1] + + el[2]*set_strides[2] + + el[3]] = v; + }; + + MLAS_UNREFERENCED_PARAMETER(set); + MLAS_UNREFERENCED_PARAMETER(get); + + auto out0 = t.get(); + for (size_t s0 = n; s0 > 0; --s0) { + auto out1 = out0; + for (size_t s1 = c; s1 > 0; --s1) { + auto out2 = out1; + for (size_t s2 = h; s2 > 0; --s2) { + float* RESTRICT out3 = out2; + size_t s3 = w; + for (; s3 > 4; s3 -= 4) { + auto vf32 = MlasLoadFloat32x4(in); + in += 4; + MlasStoreLaneFloat32x4<0>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<1>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<2>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<3>(out3, vf32); + out3 += set_strides[2]; + } + for (; s3 > 0; --s3) { + //set({s0,s2,s3,s1}, get({s0,s1,s2,s3}),t.get()); + *out3 = *in++; + out3 += set_strides[2]; + } + out2 += set_strides[1]; + } + out1++; + } + out0 += set_strides[0]; + } + } else { + MlasTranspose(in, t.get(), c, d_h*d_w, ThreadPool); + } + + return t; +} + +static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci, const size_t m, const size_t kh, + const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data, + const float* in_data, + const float* pad_ptr) { + + auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = MlasDivRoundup(m, m_step); + auto MaxTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), RequiredTiles); + m_step *= MlasDivRoundup(RequiredTiles, MaxTiles); + RequiredTiles = MlasDivRoundup(m, m_step); + + MlasTrySimpleParallel(ThreadPool, static_cast(RequiredTiles), [&](ptrdiff_t tid) { + + auto m_idx = static_cast(tid) * m_step; + auto offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m_idx,kh*kw,ci); + + kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + m < (m_idx + m_step) ? m - m_idx : m_step, kh * kw, ci, + lhs_ptrs + m_idx * kh * kw, + reinterpret_cast(in_data), + reinterpret_cast(pad_ptr), + lhs_data + offset + ); + }); +} + +static std::shared_ptr RhsPackWeightsBiasSme(const size_t co, const size_t ci, + const size_t kh, const size_t kw, + const size_t dilationh, const size_t dilationw, + const float* weights, const float* bias, + MLAS_THREADPOOL* ThreadPool) +{ + //cache of prepacked kai rhs weights and biases + static std::unordered_map> rhs_cache; + + RhsCacheKey key = { co, ci, kh, kw, dilationh, dilationw, HashWeights(weights) }; + + auto found = rhs_cache.find(key); + if (found != rhs_cache.end()) { + return found->second; + } else { + // prepare mlas filter weights for kai rhs packing + // dilated nhwc format + auto nhwc = NChwToNhwc(co, ci, kh, kw, weights, dilationh, dilationw, true, ThreadPool); + + + //dilation, axis swap (n x k -> k x n) where n == co, k == d_kh x d_kw x ci + const auto d_kh = ComputeKernelSize(dilationh,kh); + const auto d_kw = ComputeKernelSize(dilationw,kw); + + //t_weights[d_kh][d_kw][ci][co] = nhwc[co][d_kh][d_kw][ci] + auto t_weights = Transpose4D({co,d_kh,d_kw,ci},&nhwc[0],{1,2,3,0}); + + const auto packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(co,d_kh*d_kw,ci); + auto packed = std::shared_ptr(new std::byte[packed_size], std::default_delete()); + + rhs_cache[key] = packed; + + std::vector bias_copy; + if (bias) { + bias_copy.assign(bias, bias + co); + } else { + bias_copy.resize(co, 0.0f); + } + + kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + co, d_kh*d_kw, ci, co * sizeof(float), &t_weights[0], bias_copy.data(), packed.get() + ); + + return packed; + } +} + +static std::shared_ptr LhsPtrFill(const size_t ci, const size_t ih, const size_t iw, + const size_t kh, const size_t kw, size_t sh, size_t sw, + const size_t padding, + const float* pad_ptr) { + size_t check_filled{0}; + + const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw); + + const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + const auto lhs_ptrs_k = kh * kw; + const auto lhs_ptrs_m = m_step * MlasDivRoundup(m, m_step); + auto lhs_ptrs = std::shared_ptr(new const void*[lhs_ptrs_k * lhs_ptrs_m], + std::default_delete()); + + + auto ih_out_size = ComputeConvOutSize(ih, kh, padding, 1); + auto iw_out_size = ComputeConvOutSize(iw, kw, padding, 1); + + auto ptrs_offset = [lhs_ptrs_m,lhs_ptrs_k, m_step](size_t k, size_t m) { + //(m/m_step,transpose(m_step,k) + auto offset {((m/m_step) * lhs_ptrs_k * m_step) + (k*m_step) + (m%m_step)}; + assert(offset < (lhs_ptrs_k * lhs_ptrs_m)); + + MLAS_UNREFERENCED_PARAMETER(lhs_ptrs_m); + + return offset; + }; + + auto pixel_offset = [ih, iw, ci, pad_ptr, padding](size_t h, size_t w) { + if (h < padding) { + return reinterpret_cast(&pad_ptr[0]); + } + h -= padding; + + if (w < padding) { + return reinterpret_cast(&pad_ptr[0]); + } + w -= padding; + + if ((h >= ih) || (w >= iw)) { + return reinterpret_cast(&pad_ptr[0]); + } + + auto offset{h * iw * ci + w * ci}; + assert(offset < (ih*iw*ci)); + return offset*sizeof(float); + }; + + size_t m_{0}; + auto lhs_ptrs_ = lhs_ptrs.get(); + for (size_t ih_ = 0; ih_ < ih_out_size; ih_ += sh) { + for (size_t iw_ = 0; iw_ < iw_out_size; iw_ += sw, ++m_) { + size_t k_{0}; + for (size_t kh_ = 0; kh_ < kh; ++kh_) { + for (size_t kw_ = 0; kw_ < kw; ++kw_) { + lhs_ptrs_[ptrs_offset(k_, m_)] = reinterpret_cast(pixel_offset(ih_+kh_, iw_+kw_)); + k_++; check_filled++; + } + } + } + } + + assert(check_filled == (lhs_ptrs_k * m)); + MLAS_UNREFERENCED_PARAMETER(check_filled); + + return lhs_ptrs; +} + +static std::unique_ptr LhsPackImageDataSme(const size_t ci, const size_t ih, const size_t iw, + const size_t kh, const size_t kw, const size_t sh, + const size_t sw, const size_t padding, const float* in, + MLAS_THREADPOOL* ThreadPool) +{ + size_t padsize = 256; + if(ci > padsize) + { + // figure out how many blocks needed to correctly fill padding + padsize = ((ci + padsize - 1) / padsize) * padsize; + } + static std::vectorpad_ptr(padsize, 0.f); + + LhsCacheKey key = { + ci, ih, iw, + padding, sh, sw, + kh, kw, + 1, 1, + HashWeights(in) + }; + + //create lhs in format required for imatmul + const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw); + + const auto lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m,kh*kw,ci); + auto lhs = std::make_unique(lhs_size); + + auto nhwc = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool); + + //cache of computed lhs ptr offsets + static std::unordered_map> lhs_ptrs_cache; + + std::shared_ptr lhs_ptrs; + if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) { + lhs_ptrs = found->second; + } else { + lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, &pad_ptr[0]); + lhs_ptrs_cache[key] = lhs_ptrs; + } + + MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], &nhwc[0], &pad_ptr[0]); + + return lhs; +} + +static void ConvolveSme(const size_t co, //channels out + const size_t ci, //channels in + const size_t ih, //image height + const size_t iw, //image width + const size_t kh, //kernel height + const size_t kw, //kernel width + const size_t sh, //kernel stride height + const size_t sw, //kernel stride width + const size_t dilationh, //kernel dilation stride + const size_t dilationw, //kernel dilation stride + const size_t padding, //padding size + const size_t groups, //number of filter groups + const float* weights, //kernel weights [co,ci,ih,iw] + const float* bias, //kernel biases + const float* in, //in image data + float* out, //out image data + float* tmp_mlas_aligned, //intermediate buffer if we need to perform a transpose + MLAS_THREADPOOL* ThreadPool) { + + //RhsPackWeightsBiasSme() - to perform dilation increases kernel size and masks unused weights + //compute corrected dimensions of dilated kernel + const auto d_kh = ComputeKernelSize(dilationh, kh); + const auto d_kw = ComputeKernelSize(dilationw, kw); + + //run igemm based convolution + const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) * + ComputeConvOutSize(iw, d_kw, padding, sw); + + auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + //tile iteration dimensions + std::array dim; + dim[0] = 1; // B + dim[1] = MlasDivRoundup(m, m_step); // M + dim[2] = MlasDivRoundup(co, n_step); // N + + //Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0]*dim[1]*dim[2]); + + //scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + //compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(m, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(co, dim[2]), n_step); + + //update tile iterations + dim[1] = MlasDivRoundup(m, m_step); + dim[2] = MlasDivRoundup(co, n_step); + + for (size_t g = 0; g < groups; ++g) { + + auto result{out}; + //do we require a post matmul transpose ? + //output is m x n or image_data x co or hw x co + //MLAS require it as n x m (or co x hw), transpose required + if (co > 1) { + //intermediate buffer required, pre-transpose + //Note: because we are calling MlasTranspose() need to ensure we use a MLAS aligned buffer + result = tmp_mlas_aligned; + } + + auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, ThreadPool); + auto rhs = RhsPackWeightsBiasSme(co, ci, kh, kw, dilationh, dilationw, weights, bias, ThreadPool); + + + MlasTrySimpleParallel(ThreadPool, + static_cast(dim[0]*dim[1]*dim[2]), + [&](ptrdiff_t tid) + { + //compute B,M,N index from iteration index + //ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx*n_step, + d_kh*d_kw,ci); + + auto BTile = reinterpret_cast( + reinterpret_cast(rhs.get()) + rhs_packed_offset + ); + + // Get lhs tile, A + const size_t lhs_packed_offset = + kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx*m_step, + d_kh*d_kw,ci); + + auto ATile = reinterpret_cast( + reinterpret_cast(lhs.get()) + lhs_packed_offset + ); + + auto TileSizeM = (MIdx + 1) * m_step > m ? (m - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > co ? (co - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = &reinterpret_cast(result)[ + MIdx * m_step * co * sizeof(float) + + NIdx * n_step * sizeof(float)]; + + kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + }); + + if (result == tmp_mlas_aligned) { + //Note: this could be absorbed into post conv activation + MlasTranspose(tmp_mlas_aligned, out, m, co, ThreadPool); + } + + in += ci * ih * iw; + out += m * co; + weights += co * ci * kh * kw; + if(bias){ + bias += co; + } + } +} + +bool MLASCALL +ArmKleidiAI::MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool) +{ + //Check dimensions before accessing + if (Dimensions < 2) { + return false; + } + + Parameters->Activation = Activation; + Parameters->Dimensions = Dimensions; + Parameters->BatchCount = BatchCount; + Parameters->GroupCount = GroupCount; + Parameters->InputChannels = InputChannels; + Parameters->FilterCount = FilterCount; + Parameters->Beta = Beta; + + size_t InputSize = 1; + size_t OutputSize = 1; + size_t K = InputChannels; + + for (size_t dim = 0; dim < Dimensions; dim++) { + + Parameters->InputShape[dim] = size_t(InputShape[dim]); + Parameters->OutputShape[dim] = size_t(OutputShape[dim]); + Parameters->KernelShape[dim] = size_t(KernelShape[dim]); + Parameters->DilationShape[dim] = size_t(DilationShape[dim]); + Parameters->Padding[dim] = size_t(Padding[dim]); + Parameters->Padding[dim + Dimensions] = size_t(Padding[dim + Dimensions]); + Parameters->StrideShape[dim] = size_t(StrideShape[dim]); + + InputSize *= Parameters->InputShape[dim]; + OutputSize *= Parameters->OutputShape[dim]; + K *= Parameters->KernelShape[dim]; + } + + Parameters->InputSize = InputSize; + Parameters->OutputSize = OutputSize; + Parameters->K = K; + + Parameters->ThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if(!CheckCapabilitiesSme(Parameters)){ + return false; + } + + //Allocate an aligned buffer for MlasTranspose() + *WorkingBufferSize = ComputeMlasWorkingBufferSize(Parameters->FilterCount, + Parameters->InputShape[0], Parameters->InputShape[1], + Parameters->KernelShape[0], Parameters->KernelShape[1], + Parameters->DilationShape[0], Parameters->DilationShape[1], + Parameters->StrideShape[0], Parameters->StrideShape[1], + Parameters->Padding[0]); + return true; +} + +bool +MLASCALL +ArmKleidiAI::MlasConv( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ) +{ + if(!CheckCapabilitiesSme(Parameters)){ + //Fallback to Default Mlas + return false; + }; + ConvolveSme(Parameters->FilterCount, Parameters->InputChannels, // channel out, in + Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions + Parameters->KernelShape[0], Parameters->KernelShape[1], // kernel dimensions + Parameters->StrideShape[0], Parameters->StrideShape[1], // kernel stride dimensions + Parameters->DilationShape[0], Parameters->DilationShape[1], // kernel dilation + Parameters->Padding[0], // image padding + Parameters->GroupCount, // filter groups + Filter, Bias, Input, Output, WorkingBuffer, ThreadPool); + + MlasActivation(Parameters->Activation, Output, nullptr, Parameters->FilterCount, Parameters->OutputSize, + Parameters->OutputSize); + return true; +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h new file mode 100644 index 0000000000000..11fd78c261834 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -0,0 +1,114 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "mlasi.h" + +// Fix to ensure compatibility with MSVC build +#if defined(_MSC_VER) + #define RESTRICT __restrict +#else + #define RESTRICT __restrict__ +#endif +namespace ArmKleidiAI { +// +// Buffer packing routines. +// + +size_t +MLASCALL +MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K + ); + +bool +MLASCALL +MlasGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ); + +bool +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + +size_t +MLASCALL +MlasDynamicQgemmPackBSize( + size_t N, + size_t K +); + +void +MLASCALL +MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +); + +//pack symmetric quantized B and dynamic quantized A +void +MLASCALL +MlasDynamicQGemmBatch( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool + ); + +bool +MLASCALL +MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool); + +bool +MLASCALL +MlasConv( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp new file mode 100644 index 0000000000000..fb38f2cef9bf6 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp @@ -0,0 +1,116 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" + +#include "mlasi_kleidiai.h" + +//Matmul with float output of dynamic quantized A and symmetric quantized B. + +size_t +MLASCALL +ArmKleidiAI::MlasDynamicQgemmPackBSize( + size_t N, + size_t K +) { + //Default to sme2_mopa but this may not awalys be the most optimal kernel variant to use + auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + //regardless of kernel variant use neon packing variant + return kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); +} + +void +MLASCALL +ArmKleidiAI::MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +) { + // Default to sme2_mopa but this may not awalys be the most optimal kernel variant to use + auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + // y - float output + // scale_factor_lhs - lhs scaling factor + // scale_factor_rhs - rhs scaling factor + // lhs_q - lhs quantized (asymmetric, so has zero point) + // rhs_q - rhs quantized (symmetric so no zero point) + // lhs_zp - lhs zero point + // y = (1/(scale_factor_lhs * scale_factor_rhs) * sum( (lhs_q + lhs_zp)*rhs_q )) + bias + + // rhs packing requires lhs_zp because it will perform lhs_zp*rhs_q during rhs packing + // because lhs quantization is hidden from us, by lhs quant packing, we don't have a value for lhs_zp it is + // lhs dynamic quantization + + kai_rhs_pack_qsi8cx_params params{ + 1, // lhs_zp - set to 1 so it becomes sum((lhs_q + 1)*rhs_q )), + // the actual lhs_zp is applied during the matmul + 1.f // it is not used + }; + + //regardless of kernel variant use neon packing variant + kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(1, N, K, nr, kr, sr, B, + // N bias values + Bias, + // N scale values + Scales, PackedB, 0, ¶ms); +} + +void +MLASCALL +ArmKleidiAI::MlasDynamicQGemmBatch( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +) { + for (auto b = BatchN; b > 0; --b,++DataParams) { + auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + + //TODO enable multi-threading for lhs packing and matmul + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + //Dynamic Quantize A - lhs + auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); + std::byte* lhs = nullptr; + std::unique_ptr fallback; + + if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) { + lhs = static_cast(DataParams->Workspace); + } else { + fallback = std::make_unique(lhs_size); + lhs = fallback.get(); + } + + kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A, + Shape.K*sizeof(float), lhs); + + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( + Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB, + DataParams->C, + Shape.N * sizeof(float), + sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp new file mode 100644 index 0000000000000..caa445b71e2a5 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -0,0 +1,348 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" +#include "mlasi_kleidiai.h" + +size_t +MLASCALL +ArmKleidiAI::MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K +) +/*++ + +Routine Description: + + This routine computes the length in bytes for the packed matrix B buffer. + +Arguments: + + TransA - Supplies the transpose operation on A matrix + + TransB - Supplies the transpose operation on B matrix + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + +Return Value: + + Returns the size in bytes for the packed matrix B buffer. + +--*/ +{ + if (TransA != CblasNoTrans || N == 0 || K == 0) { + return 0; + } + // + // Compute the number of bytes required to hold the packed buffer. + // + size_t bytes = 0; + + if (TransA == CblasNoTrans) { + switch (TransB) { + case CblasNoTrans: + bytes = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(N, K); + break; + case CblasTrans: + bytes = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(N, K); + break; + default: + return 0; + } + } else { + return 0; + } + + return bytes; +} + +bool +MLASCALL +ArmKleidiAI::MlasGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB +) +/*++ + +Routine Description: + + This routine packs the contents of matrix B to the destination buffer. The + destination buffer should be sized based on MlasGemmPackBSize(). For best + performance, the destination buffer should be aligned to the value returned + from MlasGetPreferredBufferAlignment(). + +Arguments: + + TransA - Supplies the transpose operation for matrix A. + + TransB - Supplies the transpose operation for matrix B. + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + + B - Supplies the address of matrix B. + + ldb - Supplies the first dimension of matrix B. + + PackedB - Supplies the address of packed matrix B. + +Return Value: + + None. + +--*/ +{ + if (N == 0 || K == 0) { + return false; + } + + if (TransA == CblasNoTrans) { + const size_t nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + + // pass zeroed bias values + const std::vector bias(N); + + switch (TransB) { + case CblasNoTrans: + kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr); + break; + case CblasTrans: + kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr); + break; + default: + return false; + } + return true; + } + else{ + return false; + } +} + +bool +MLASCALL +ArmKleidiAI::MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) +{ + if(TransA == CblasTrans) + { + return false; + } + if (TransA == CblasNoTrans && K == 0) { + if (Data->beta != 1.0f) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + Data->C[i * Data->ldc + j] *= Data->beta; + } + } + } + } + if (Data->beta == 0.0f){ + std::fill_n(Data->C, M * Data->ldc, 0.0f); + } + //Fallback in the case of unsupported cases + if (M == 0 || N == 0 || K == 0 || + TransA != CblasNoTrans || + (TransB != CblasNoTrans && !Data[0].BIsPacked)) + { + return false; + } + + if (TransA == CblasNoTrans) { + const size_t mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + + auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + + if (M < m_step || N < n_step) { + if (GetMlasPlatform().MlasGemmBatchOverride != ArmKleidiAI::MlasGemmBatch){ + //Fallback to MLAS + return false; + } + } + + std::vector KaiPackedData; + KaiPackedData.resize(BatchSize); + + size_t LhsPackedStride = 0; + std::byte* LhsPackedData = nullptr; + + LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr); + auto LhsPacked = std::make_unique(LhsPackedStride * BatchSize); + LhsPackedData = LhsPacked.get(); + + std::unique_ptr RhsPacked{nullptr}; + + // It is assumed all B batches require packing or not + if (Data[0].BIsPacked) { + // We have already decided the matmul variant we are using, before having values for M,N,K + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + + KaiPackedData[batch_idx].A = reinterpret_cast(LhsPackedPtr); + KaiPackedData[batch_idx].B = Data[batch_idx].B; + }); + } else { + // Multithread pack lhs and rhs + size_t RhsPackedStride = 0; + std::byte* RhsPackedData = nullptr; + + RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K); + RhsPacked = std::make_unique(RhsPackedStride * BatchSize); + RhsPackedData = RhsPacked.get(); + + MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) { + // lhs odd, rhs even + if (batch_idx & 0x1) { + batch_idx >>= 1; + + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + + KaiPackedData[batch_idx].A = reinterpret_cast(LhsPackedPtr); + } else { + batch_idx >>= 1; + + std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]); + + ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K, reinterpret_cast(Data[batch_idx].B), Data[batch_idx].ldb, RhsPackedPtr); + + KaiPackedData[batch_idx].B = reinterpret_cast(RhsPackedPtr); + } + }); + } + + // tile iteration dimensions + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(M, m_step); // M + dim[2] = MlasDivRoundup(N, n_step); // N + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step); + + // update tile iterations + dim[1] = MlasDivRoundup(M, m_step); + dim[2] = MlasDivRoundup(N, n_step); + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + // compute B,M,N index from iteration index + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K); + + auto BTile = reinterpret_cast( + reinterpret_cast(KaiPackedData[BIdx].B) + rhs_packed_offset + ); + + // Get lhs tile, A + const size_t lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K); + + auto ATile = reinterpret_cast( + reinterpret_cast(KaiPackedData[BIdx].A) + lhs_packed_offset + ); + + auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = reinterpret_cast( + reinterpret_cast(Data[BIdx].C) + + MIdx * m_step * Data[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) + ); + // Allocate temporary buffer for raw A*B result + std::vector OutputTile(TileSizeM * TileSizeN, 0.0f); + float* temp_tile = OutputTile.data(); + + + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( + TileSizeM, + TileSizeN, + K, + ATile, BTile, temp_tile, + TileSizeN * sizeof(float), sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + + // Final output tile pointer + float* dst_tile = reinterpret_cast(CTile); + + // quick copy of data in cases where we are not scaling or accumulating anything + // with bounds checking on tile sizing to ensure the data fits in the memory block + bool can_memcpy = ( + Data[BIdx].alpha == 1.0f && + Data[BIdx].beta == 0.0f && + Data[BIdx].ldc == TileSizeN && + MIdx * m_step + TileSizeM <= M && + NIdx * n_step + TileSizeN <= N && + TileSizeM != 0 && + TileSizeN != 0); + + if (can_memcpy) { + std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); + }else { + // apply alpha scaling and beta to output files + for (size_t i = 0; i < TileSizeM; ++i) { + for (size_t j = 0; j < TileSizeN; ++j) { + const size_t idx = i * TileSizeN + j; + const size_t dst_idx = i * Data[BIdx].ldc + j; + + float ab = temp_tile[idx]; + float c_orig = dst_tile[dst_idx]; + + dst_tile[dst_idx] = Data[BIdx].alpha * ab + Data[BIdx].beta * c_orig; + } + } + } + }); + } + return true; +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0879d1b0ba510..a099bcf8438fe 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -35,7 +35,7 @@ Module Name: #endif #endif // MLAS_NO_EXCEPTION -#include "mlas.h" +#include "core/mlas/inc/mlas.h" #if defined(_WIN32) #ifndef WIN32_LEAN_AND_MEAN @@ -118,9 +118,8 @@ Module Name: #ifdef MLAS_NO_EXCEPTION -MLAS_FORCEINLINE -void -MlasPrintFinalMessage(const std::string& msg) +MLAS_FORCEINLINE void + MlasPrintFinalMessage(const std::string& msg) { #if defined(__ANDROID__) __android_log_print(ANDROID_LOG_ERROR, "mlas", "%s", msg.c_str()); @@ -134,6 +133,7 @@ MlasPrintFinalMessage(const std::string& msg) #endif } + #define MLAS_THROW_EX(ex, what) \ do { \ std::string msg = #ex; \ @@ -781,6 +781,119 @@ struct MLAS_QUANT_KERNEL size_t KernelSize ); }; +typedef +void +(MLASCALL MLAS_CONV_FLOAT_FN)( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +typedef +bool +(MLASCALL MLAS_CONV_FLOAT_OVERRIDE)( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +// TODO: Investigate if overridden typedefs can be removed +typedef +void +(MLASCALL MLAS_CONV_PREPARE_FLOAT_FN)( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool + ); +typedef +bool +(MLASCALL MLAS_CONV_PREPARE_FLOAT_OVERRIDE)( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool + ); + +typedef void (MLASCALL MLAS_GEMM_BATCH)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef bool (MLASCALL MLAS_GEMM_BATCH_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef void (MLASCALL MLAS_GEMM_PACK_B_KERNEL)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); + +typedef bool (MLASCALL MLAS_GEMM_PACK_B_KERNEL_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); extern "C" { @@ -1184,6 +1297,12 @@ struct MLAS_PLATFORM { // TODO: move to cpuinfo bool Avx2Supported_ = false; bool Avx512Supported_ = false; + // Mlas overrides initialisation + MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr; + MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr; + MLAS_GEMM_PACK_B_KERNEL_OVERRIDE* MlasGemmPackBOverride = nullptr; + MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; + MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 45bba5363d4f2..3256dadb856d3 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -17,6 +17,10 @@ Module Name: #include "mlasi.h" +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +#include "kleidiai/mlasi_kleidiai.h" +#endif + #include #include @@ -579,6 +583,15 @@ Return Value: } this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions); +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; + this->MlasGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize; + this->MlasGemmPackBOverride = ArmKleidiAI::MlasGemmPackB; + this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare; + this->MlasConvOverride = ArmKleidiAI::MlasConv; + } +#endif #if defined(__linux__) // diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index f5b33d2a9ad34..4e9a0e27099dc 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -14,10 +14,16 @@ Module Name: operation (QGEMM). --*/ - -#include "mlasi.h" +#include +#include "core/mlas/lib/mlasi.h" #include "qgemm.h" +// TODO: When overrides are implemented, remove this +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +#include "kleidiai/mlasi_kleidiai.h" +#endif + + // // Define the parameters to execute segments of a QGEMM operation on worker // threads. @@ -195,6 +201,26 @@ MlasGemmBatch( }); } +void +MLASCALL +MlasDynamicQGemmBatch ( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +) { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback and putting in guards + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(Shape); + MLAS_UNREFERENCED_PARAMETER(DataParams); + MLAS_UNREFERENCED_PARAMETER(BatchN); + MLAS_UNREFERENCED_PARAMETER(ThreadPool); +} int32_t MlasSymmQgemmGetKernelOutputCnt() @@ -293,10 +319,35 @@ MlasSymmQgemmBatch( }); } + + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif +size_t +MLASCALL +MlasDynamicQgemmPackBSize( + size_t N, + size_t K +) +{ + size_t bytes = 0; +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback available + //TODO: Insert Override + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override + bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + + return bytes; +} + + size_t MLASCALL MlasGemmPackBSize( @@ -354,10 +405,38 @@ Return Value: const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + //If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for + //this use case. Concat both packed representations for later decision. + return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K); +} - return AlignedBytesRequired; +void +MLASCALL +MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +) +{ +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override + ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(Scales); + MLAS_UNREFERENCED_PARAMETER(Bias); + MLAS_UNREFERENCED_PARAMETER(PackedB); } + void MLASCALL MlasGemmPackB( @@ -400,7 +479,6 @@ Return Value: // // Retrieve the packing parameters. // - const auto* GemmQuantDispatch = MlasGemmQuantGetDispatch(AIsSigned, BIsSigned); size_t PackedK = GemmQuantDispatch->PackedK; @@ -515,7 +593,6 @@ MlasSymmQgemmPackBSize( #pragma warning(pop) #endif - void MLASCALL MlasSymmQgemmPackB( diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 616622a8c1f53..65c1ccbadad38 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -1572,7 +1572,13 @@ MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - + // Override + if(GetMlasPlatform().MlasGemmBatchOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans && + GetMlasPlatform().MlasGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchSize, ThreadPool)){ + return; + } // // Compute the number of target threads given the complexity of the SGEMM // operation. Small requests should run using the single threaded path. @@ -1637,6 +1643,8 @@ MlasGemmBatch( size_t MLASCALL MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, size_t N, size_t K ) @@ -1661,6 +1669,22 @@ Return Value: // // Compute the number of bytes required to hold the packed buffer. // + // KleidiAI or other override + #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (GetMlasPlatform().MlasGemmPackBSizeOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans) { + size_t bytes_required; + //TODO pass status by reference to indicate success/fail + bytes_required = GetMlasPlatform().MlasGemmPackBSizeOverride(TransA, TransB, N, K); + if (bytes_required != 0){// If ArmKleidiAI::MlasGemmPackBSize ran to completion + return bytes_required; + } + } + #endif + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + const size_t AlignedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); @@ -1676,6 +1700,7 @@ Return Value: void MLASCALL MlasGemmPackB( + CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, @@ -1712,6 +1737,17 @@ Return Value: --*/ { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (GetMlasPlatform().MlasGemmPackBOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans && + GetMlasPlatform().MlasGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)){ + return; + } +#endif + MLAS_UNREFERENCED_PARAMETER(TransA); + + const size_t AlignedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc index b619efb2f751e..7abd375cda896 100644 --- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc +++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc @@ -170,13 +170,18 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g // Find bias node Node* p_add_node = nullptr; + int idx = 0; if (optimizer_utils::CheckOutputEdges(graph, mul_node, 1)) { const Node* tmp_add_node = graph_utils::FirstChildByType(mul_node, "Add"); if (nullptr != tmp_add_node) { - const NodeArg& tmp_add_node_B = *(tmp_add_node->InputDefs()[1]); - if (graph_utils::IsConstantInitializer(graph, tmp_add_node_B.Name(), true) && - CheckBiasShape(tmp_add_node_B.Shape())) { - p_add_node = graph.GetNode(tmp_add_node->Index()); + // check both "inputs" to find bias, caters for edge case where bias index in InputDefs is not what is expected + for (idx = 0; idx < 2; ++idx) { + const NodeArg& candidate = *(tmp_add_node->InputDefs()[idx]); + if (graph_utils::IsConstantInitializer(graph, candidate.Name(), true) && + CheckBiasShape(candidate.Shape())) { + p_add_node = graph.GetNode(tmp_add_node->Index()); + break; + } } } } @@ -203,7 +208,7 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g } if (p_add_node != nullptr) { - input_defs.push_back(p_add_node->MutableInputDefs()[1]); + input_defs.push_back(p_add_node->MutableInputDefs()[idx]); } std::string op_type = "MatMulIntegerToFloat"; diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 65b169355c793..181d0c5e98dd1 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -102,6 +102,7 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( bool GemmPackBFp32(AllocatorPtr& alloc, const Tensor& tensor_b, + bool trans_a, bool trans_b, IAllocatorUniquePtr& packed_b, size_t& packed_b_size, @@ -116,7 +117,7 @@ bool GemmPackBFp32(AllocatorPtr& alloc, const size_t K = trans_b ? static_cast(b_shape[1]) : static_cast(b_shape[0]); const size_t N = trans_b ? static_cast(b_shape[0]) : static_cast(b_shape[1]); - packed_b_size = MlasGemmPackBSize(N, K); + packed_b_size = MlasGemmPackBSize(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, N, K); if (packed_b_size == 0) { return false; } @@ -129,7 +130,8 @@ bool GemmPackBFp32(AllocatorPtr& alloc, // if and when we try to cache this pre-packed buffer for sharing between sessions. memset(packed_b_data, 0, packed_b_size); - MlasGemmPackB(trans_b ? CblasTrans : CblasNoTrans, + MlasGemmPackB(trans_a ? CblasTrans : CblasNoTrans, + trans_b ? CblasTrans : CblasNoTrans, N, K, tensor_b.Data(), @@ -274,7 +276,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, // only pack Matrix B if (input_idx == 1) { size_t packed_b_size; - is_packed = GemmPackBFp32(alloc, tensor, trans_B_ != CblasNoTrans, packed_b_, packed_b_size, b_shape_); + is_packed = GemmPackBFp32(alloc, tensor, trans_A_ != CblasNoTrans, trans_B_ != CblasNoTrans, packed_b_, packed_b_size, b_shape_); bool share_prepacked_weights = (prepacked_weights != nullptr); if (is_packed && share_prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); diff --git a/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h b/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h index 599847e61a54f..0189edb23dddb 100644 --- a/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h +++ b/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h @@ -9,9 +9,9 @@ namespace onnxruntime { bool GemmPackBFp32(AllocatorPtr& alloc, const Tensor& tensor_b, + bool trans_a, bool trans_b, IAllocatorUniquePtr& packed_b, size_t& packed_b_size, TensorShape& b_shape); - }; // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 2c6d23e4de908..530218db31e3d 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -195,7 +195,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc } else #endif { - is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + is_packed = GemmPackBFp32(alloc, tensor, trans_a_attr_, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); } bool share_prepacked_weights = (prepacked_weights != nullptr); diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index c0171f7728ea8..d781de2eb5541 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -194,7 +194,7 @@ bool DeepCpuGruOp::TryPackInputWeights(const Tensor& weights, AllocatorPtr& allo const size_t N = static_cast(shape[1]); const size_t K = static_cast(shape[2]); - const size_t packed_weights_size = MlasGemmPackBSize(N, K); + const size_t packed_weights_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, N, K); if (packed_weights_size == 0) { return false; } @@ -215,7 +215,7 @@ bool DeepCpuGruOp::TryPackInputWeights(const Tensor& weights, AllocatorPtr& allo const size_t N_x_K = N * K; const auto* weights_data = weights.Data(); for (int64_t dir = 0; dir < num_directions; ++dir) { - MlasGemmPackB(CblasTrans, N, K, weights_data, K, packed_weights_data); + MlasGemmPackB(CblasNoTrans, CblasTrans, N, K, weights_data, K, packed_weights_data); weights_data += N_x_K; packed_weights_data += packed_weights_size; } @@ -244,12 +244,12 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& const auto hidden_size_x_2 = N - hidden_size_; // We are making two packed buffers, one for ZR weights and another for H weights. - const size_t ZR_packed_size = MlasGemmPackBSize(narrow(hidden_size_x_2), narrow(K)); + const size_t ZR_packed_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, narrow(hidden_size_x_2), narrow(K)); if (ZR_packed_size == 0) { return false; } - const size_t H_packed_size = MlasGemmPackBSize(narrow(hidden_size_), narrow(K)); + const size_t H_packed_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, narrow(hidden_size_), narrow(K)); if (H_packed_size == 0) { return false; } @@ -275,18 +275,18 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& const auto hidden_2_step = hidden_size_x_2 * K; const auto hidden_1_step = hidden_size_ * K; // square const auto* weights_data = weights.Data(); - MlasGemmPackB(CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); weights_data += hidden_2_step; - MlasGemmPackB(CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); if (num_directions == 2) { weights_data += hidden_1_step; buffer_ZR = static_cast(buffer_ZR) + ZR_packed_size; - MlasGemmPackB(CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); weights_data += hidden_2_step; buffer_H = static_cast(buffer_H) + H_packed_size; - MlasGemmPackB(CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); } return true; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index e95ad707cf2b0..b38e271fdbe4a 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -196,7 +196,7 @@ Status DeepCpuLstmOp::TryPackWeights(const Tensor& weights, PackedWeights& packe return Status::OK(); } - const size_t packed_weights_size = MlasGemmPackBSize(N, K); + const size_t packed_weights_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, N, K); if (packed_weights_size == 0) { return Status::OK(); } @@ -217,7 +217,7 @@ Status DeepCpuLstmOp::TryPackWeights(const Tensor& weights, PackedWeights& packe const auto* weights_data = weights.Data(); for (int i = 0; i < num_directions_; i++) { - MlasGemmPackB(CblasTrans, N, K, weights_data, K, packed_weights_data); + MlasGemmPackB(CblasNoTrans, CblasTrans, N, K, weights_data, K, packed_weights_data); packed_weights_data = static_cast(packed_weights_data) + packed_weights_size; weights_data += N * K; } diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp index a94d33cd77f63..422fc6fbcadbf 100644 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sgemm.cpp @@ -30,9 +30,12 @@ void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, flo tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); if (pack_b) { - size_t pack_b_size = MlasGemmPackBSize(N, K); + CBLAS_TRANSPOSE transB_enum = trans_b ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE transA_enum = trans_a ? CblasTrans : CblasNoTrans; + + size_t pack_b_size = MlasGemmPackBSize(transA_enum, transB_enum, N, K); std::vector B_packed(pack_b_size); - MlasGemmPackB(CblasNoTrans, N, K, B.data(), N, B_packed.data()); + MlasGemmPackB(transA_enum, transB_enum, N, K, B.data(), N, B_packed.data()); MlasGemm( trans_a ? CblasTrans : CblasNoTrans, diff --git a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp new file mode 100644 index 0000000000000..a048ded8349b8 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp @@ -0,0 +1,165 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include "test_util.h" +// Currently this test only applies to KleidiAI Guard against it running in any other situation +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + +class MlasDynamicQgemmTest { + private: + MatrixGuardBuffer buffer_a; + MatrixGuardBuffer buffer_bf; + MatrixGuardBuffer buffer_bq; + MatrixGuardBuffer buffer_c; + MatrixGuardBuffer buffer_c_ref; + + public: + void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + // Setup buffers for holding various data + + float* A = buffer_a.GetBuffer(M * K * BatchSize); + // Buffer for holding floating point version of weight matrix + float* Bf = buffer_bf.GetBuffer(K * N * BatchSize); + // Buffer for holding quantized version of weight matrix + int8_t* Bq = buffer_bq.GetBuffer(K * N * BatchSize); + float* C = buffer_c.GetBuffer(M * N * BatchSize); + float* CRef = buffer_c_ref.GetBuffer(M * N * BatchSize); + + // Initialize A and Bf + for (size_t i = 0; i < M * K * BatchSize; ++i) + A[i] = static_cast((rand() % 255 - 128) / 16.0f); + for (size_t i = 0; i < K * N * BatchSize; ++i) + Bf[i] = static_cast((rand() % 255 - 128) / 16.0f); + + // Quantize Bf → Bq and compute per-column scale and bias per batch + std::vector> b_scale_batches(BatchSize, std::vector(N)); + std::vector> b_bias_batches(BatchSize, std::vector(N, 0.0f)); + + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t n = 0; n < N; ++n) { + float min_val = Bf[b * K * N + n]; + float max_val = min_val; + for (size_t k = 1; k < K; ++k) { + float v = Bf[b * K * N + k * N + n]; + min_val = std::min(min_val, v); + max_val = std::max(max_val, v); + } + float scale = (max_val - min_val) / 255.0f; + if (scale < 1e-8f) scale = 1.0f; + b_scale_batches[b][n] = scale; + + for (size_t k = 0; k < K; ++k) { + float v = Bf[b * K * N + k * N + n]; + int q = static_cast(std::round(v / scale)); + Bq[b * K * N + k * N + n] = static_cast(std::clamp(q, -128, 127)); + } + } + } + + // Prepare kernel parameters + MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS shape{M, N, K}; + std::vector packed_b_storage(BatchSize * MlasDynamicQgemmPackBSize(N, K)); + std::vector params(BatchSize); + + for (size_t b = 0; b < BatchSize; ++b) { + params[b].A = A + b * M * K; + params[b].lda = K; + params[b].C = C + b * M * N; + params[b].ldc = N; + // Pack b matrix using MlasDynamicQgemmPackBSize & MlasDynamicQgemmPackB + void* packed_b = packed_b_storage.data() + b * MlasDynamicQgemmPackBSize(N, K); + MlasDynamicQgemmPackB(N, K, + Bq + b * K * N, + b_scale_batches[b].data(), + b_bias_batches[b].data(), + packed_b); + params[b].PackedB = packed_b; + } + + // call MlasDynamicQGemmBatch Function + MlasDynamicQGemmBatch(shape, params.data(), BatchSize, nullptr); + + // Compute reference result + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[b * M * K + m * K + k]; + float bval = static_cast(Bq[b * K * N + k * N + n]) * b_scale_batches[b][n]; + sum += a * bval; + } + CRef[b * M * N + m * N + n] = sum; + } + } + } + + // Validate results + for (size_t i = 0; i < M * N * BatchSize; ++i) { + float abs_c_ref = std::abs(CRef[i]); + float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f; + float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f); + float abs_tol = 3.0f; + float allowed = std::max(rel_tol, abs_tol); + float diff = std::abs(C[i] - CRef[i]); + ASSERT_LE(diff, allowed); + } + } + + static const char* GetTestSuiteName() { + return "DynamicQgemm"; + } +}; + +class DynamicQgemmExecuteTest : public MlasTestFixture { + public: + DynamicQgemmExecuteTest(size_t M, size_t N, size_t K, size_t BatchSize) + : M_(M), N_(N), K_(K), BatchSize_(BatchSize) {} + + void TestBody() override { + this->mlas_tester->Test(M_, N_, K_, BatchSize_); + } + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t BatchSize) { + std::stringstream ss; + ss << "M" << M << "_N" << N << "_K" << K << "_B" << BatchSize; + + std::string test_name = ss.str(); + + testing::RegisterTest( + MlasDynamicQgemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture* { + return new DynamicQgemmExecuteTest(M, N, K, BatchSize); + }); + + return 1; + } + + static size_t RegisterAll(bool is_short_execute) { + const std::vector batch_size = is_short_execute ? std::vector{1UL, 2UL, 4UL} + : std::vector{1UL, 2UL, 4UL, 8UL, 16UL, 32UL, 64UL}; + size_t count = 0; + const size_t sizes[] = {1, 4, 8, 16, 32, 64}; + for (size_t M : sizes) + for (size_t N : sizes) + for (size_t K : sizes) + for (size_t B : batch_size) + count += RegisterSingleTest(M, N, K, B); + return count; + } + + private: + size_t M_, N_, K_, BatchSize_; +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); +}); +#endif diff --git a/onnxruntime/test/mlas/unittest/test_fgemm.h b/onnxruntime/test/mlas/unittest/test_fgemm.h index 2bd094152d6f0..e7741fba1c3fb 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm.h @@ -112,11 +112,11 @@ class FgemmPackedContext { float* C, size_t ldc, MLAS_THREADPOOL* threadpool) { - size_t PackedBSize = MlasGemmPackBSize(N, K); + size_t PackedBSize = MlasGemmPackBSize(TransA, TransB, N, K); void* PackedB = BufferBPacked.GetBuffer(PackedBSize * BatchSize, true); std::vector data(BatchSize); for (size_t i = 0; i < BatchSize; i++) { - MlasGemmPackB(TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); + MlasGemmPackB(TransA, TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); data[i].BIsPacked = true; data[i].A = A + M * K * i; data[i].lda = lda; diff --git a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h index 53b3edafdf84f..c832ca69dbb31 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h @@ -70,6 +70,7 @@ class FgemmShortExecuteTest : public MlasTestFixture p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + // check graph structure before applying transformations + const Node* add_node = nullptr; + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Add") { + add_node = &node; + break; + } + } + + ASSERT_NE(add_node, nullptr) << "Expected Add node not found."; + + const auto& inputs = add_node->InputDefs(); + ASSERT_EQ(inputs.size(), 2u); + + // Assert bias is in position 1 + EXPECT_EQ(inputs[1]->Name(), "bias") << "Expected bias in input 1 but found in input 0."; + + // Apply the transformer + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["MatMulInteger"], 0); + EXPECT_EQ(op_to_count["Cast"], 0); + EXPECT_EQ(op_to_count["Mul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["Add"], 0); +} + +TEST_F(GraphTransformationTests, MatMulIntegerToFloatFusion_Int8Bias_Input1) { + constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/matmul_integer_to_float_int8_bias_initializer_index0.onnx"); + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + // check graph structure before applying transformations + const Node* add_node = nullptr; + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Add") { + add_node = &node; + break; + } + } + + ASSERT_NE(add_node, nullptr) << "Expected Add node not found."; + + const auto& inputs = add_node->InputDefs(); + ASSERT_EQ(inputs.size(), 2u); + + // Assert bias is in position 0 + EXPECT_EQ(inputs[0]->Name(), "bias") << "Expected bias in input 0 but found in input 1."; + + // Apply the transformer + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["MatMulInteger"], 0); + EXPECT_EQ(op_to_count["Cast"], 0); + EXPECT_EQ(op_to_count["Mul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["Add"], 0); +} + #ifdef USE_DML TEST_F(GraphTransformationTests, MatMulIntegerToFloat16Test) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_integer_to_float16_int8.onnx"; diff --git a/onnxruntime/test/testdata/matmul_integer_to_float.py b/onnxruntime/test/testdata/matmul_integer_to_float.py index 0c1ea47fff5b1..5e9a1778198ef 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float.py +++ b/onnxruntime/test/testdata/matmul_integer_to_float.py @@ -1,8 +1,11 @@ +import numpy as np import onnx -from onnx import TensorProto, helper +from onnx import TensorProto, helper, numpy_helper -def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bias=False): # noqa: N802 +def generate_model( + model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bias=False, bias_initializer=False, bias_flip=False +): nodes = [ # subgraph helper.make_node( "MatMulInteger", @@ -50,15 +53,22 @@ def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bia ) if bias: - nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")]) + if bias_flip: + nodes.extend([helper.make_node("Add", ["bias", "mul_bottom_output"], ["Y"], "add")]) + else: + nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")]) - inputs.extend( - [ - helper.make_tensor_value_info( - "bias", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["N"] - ) - ] + if bias_initializer: + # Use a constant initializer + bias_vals = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float16 if output_type_fp16 else np.float32) + bias_tensor = numpy_helper.from_array(bias_vals, name="bias") + initializers = [bias_tensor] + else: + # Use runtime input + inputs.append( + helper.make_tensor_value_info("bias", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["N"]) ) + initializers = [] graph = helper.make_graph( nodes, @@ -69,6 +79,7 @@ def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bia "Y", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["M", "N"] ), ], + initializer=initializers, ) model = helper.make_model(graph) @@ -76,10 +87,10 @@ def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bia if __name__ == "__main__": - GenerateModel("matmul_integer_to_float16_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=True) - GenerateModel("matmul_integer_to_float_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=False) - GenerateModel("matmul_integer_to_float_uint8.onnx", sign_i=False, sign_w=False, output_type_fp16=False) - GenerateModel( + generate_model("matmul_integer_to_float16_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=True) + generate_model("matmul_integer_to_float_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=False) + generate_model("matmul_integer_to_float_uint8.onnx", sign_i=False, sign_w=False, output_type_fp16=False) + generate_model( "matmul_integer_to_float_int8_bias.onnx", sign_i=False, sign_w=True, @@ -87,7 +98,7 @@ def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bia has_zp=False, bias=True, ) - GenerateModel( + generate_model( "matmul_integer_to_float_uint8_bias.onnx", sign_i=False, sign_w=False, @@ -95,9 +106,27 @@ def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bia has_zp=False, bias=True, ) - - GenerateModel("matmul_integer_to_float_int8_int8.onnx", sign_i=True, sign_w=True, output_type_fp16=False) - GenerateModel( + generate_model( + "matmul_integer_to_float_int8_bias_initializer_index1.onnx", + sign_i=False, + sign_w=True, + output_type_fp16=False, + has_zp=False, + bias=True, + bias_initializer=True, + ) + generate_model( + "matmul_integer_to_float_int8_bias_initializer_index0.onnx", + sign_i=False, + sign_w=True, + output_type_fp16=False, + has_zp=False, + bias=True, + bias_flip=True, + bias_initializer=True, + ) + generate_model("matmul_integer_to_float_int8_int8.onnx", sign_i=True, sign_w=True, output_type_fp16=False) + generate_model( "matmul_integer_to_float_int8_int8_bias.onnx", sign_i=True, sign_w=True, diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias_initializer_index0.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias_initializer_index0.onnx new file mode 100644 index 0000000000000000000000000000000000000000..841e61cef8fb223aa31cf25f56b4e7b3e3383bb8 GIT binary patch literal 472 zcmZvYyH3L}6o%s@sLrWC!HA0@#E^lZ0|QL3(4;Uxl}e255;>O^OG#WcCj(4<6yAte z!cK}P1}f7-wMwBQy<03>8bCC9QUV%gxb!B|vybM5U%9HXXqsxV*VEC2Tq zuAs-`I^{(Uy`fyrOR@3k2|AlikkqLUQ!%l-KCnBef13Fj1b7 z4`RaOuA;4mt*-1)cTFUsEH|I=^z{aHFS!Ie5xj=KwTw(Xi)Wc{1zE<<*2L_m6p-EwYDwP=9C31dREG2Q(+zc@FQFtR> z2|FpG6eRn~{(OCoeQ5vg(}G(d0g#X#l^n~ah-HJz24hj-&9%25ag3Vcslt4bul(Oz zxPl&!>y#Hc^pbADKxv>@79wAHj@Da=#Vh)jQh@;eGud^X7m}lAr@Thr9;rHTz(jdQ zK8OjAyAo|TY`L;S?V3nLS#Cb#>EjLRUUCV4yTDEauNj$o7SA%@3bKryKfG1E6zNOD zD-`c}7e_pAY9Xt^1trvWN!VU|b`4{FZy3f<4K>%p*|Cb2xo|*D24FD~=`z+hY_Zkh pvBndHztdEhQo?L7DVd%8WZMTv;XU0xN|eg9RUupAlhAZezW^&(eo+7b literal 0 HcmV?d00001 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 87e0ac6a42ea6..561a76be5fa89 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -887,7 +887,22 @@ def generate_build_tree( if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"] - cmake_args += ["-Donnxruntime_USE_KLEIDIAI=" + ("OFF" if args.no_kleidiai else "ON")] + # Set onnxruntime_USE_KLEIDIAI based on: + # * Default value above is NO. + # * Leave disabled if "no_kleidiai" argument was specified. + # * Enable if the target is Android and args.android_abi contains arm64* + # * Enable for a Windows cross compile build if compile target is an Arm one. + # * Finally enable if platform.machine contains "arm64". This should cover the following cases: + # * Linux on Arm + # * MacOs (case must be ignored) + # * TODO Delegate responsibility for Onnxruntime_USE_KLEIDIAI = ON to CMake logic + if not args.no_kleidiai: + if ( + (args.android and "arm64" in args.android_abi.lower()) + or (is_windows() and (args.arm64 or args.arm64ec or args.arm) and platform.architecture()[0] != "AMD64") + or ("arm64" in platform.machine().lower()) + ): + cmake_args += ["-Donnxruntime_USE_KLEIDIAI=ON"] if is_macOS() and (args.macos or args.ios or args.visionos or args.tvos): # Note: Xcode CMake generator doesn't have a good support for Mac Catalyst yet. From ff83f53486e2b9f94b2ed7039efe22f80a25e009 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 25 Jul 2025 12:17:00 -0700 Subject: [PATCH 18/33] [build] upgrade to use Node.js in docker image (#25529) ### Description ### Motivation and Context --- .../c-api-noopenmp-test-pipelines.yml | 2 +- ...-gpu-tensorrt-cuda-minimal-ci-pipeline.yml | 4 +- .../py-cuda-package-test-pipeline.yml | 4 +- .../jobs/py-linux-cuda-package-test-job.yml | 4 +- .../stages/py-gpu-packaging-stage.yml | 2 +- .../linux/docker/Dockerfile.manylinux2_28_cpu | 2 +- .../docker/Dockerfile.manylinux2_28_cuda | 2 +- .../docker/Dockerfile.manylinux2_28_rocm | 3 +- .../docker/Dockerfile.manylinux2_28_webgpu | 2 +- .../inference/aarch64/default/cpu/Dockerfile | 5 +- .../default/cpu/scripts/install_deps.sh | 56 ------------------- .../inference/aarch64/python/cpu/Dockerfile | 3 +- .../inference/x86_64/default/cpu/Dockerfile | 5 +- .../default/cpu/scripts/install_deps.sh | 56 ------------------- .../x86_64/default/cuda12/Dockerfile | 5 +- .../default/cuda12/scripts/install_deps.sh | 55 ------------------ .../inference/x86_64/python/cpu/Dockerfile | 3 +- .../x86_64/python/openvino/Dockerfile | 2 +- 18 files changed, 21 insertions(+), 194 deletions(-) delete mode 100755 tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh delete mode 100755 tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh delete mode 100755 tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml index 64e5661eaf6fe..3772b5e9c4c20 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml @@ -190,7 +190,7 @@ stages: - name: runCodesignValidationInjection value: false - name: docker_base_image - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 timeoutInMinutes: 60 steps: - checkout: self diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml index da2f7b5e01e5f..b304ccdb4c533 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml @@ -39,9 +39,9 @@ variables: - template: templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250714.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250724.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: ${{ variables.linux_trt_version_cuda11 }} diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml index a21c72f5278c0..f1d578b9c86a4 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml @@ -18,7 +18,7 @@ stages: machine_pool: 'Onnxruntime-Linux-GPU' python_wheel_suffix: '_gpu' timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 cuda_version: '12.2' - stage: Republish_Wheels @@ -54,4 +54,4 @@ stages: - publish: $(Pipeline.Workspace)/build/onnxruntime_gpu artifact: whl - displayName: Republish artifacts \ No newline at end of file + displayName: Republish artifacts diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index 890b97cbf889a..858de4d173484 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -44,9 +44,9 @@ jobs: - template: ../../templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250714.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250724.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: ${{ variables.linux_trt_version_cuda11 }} diff --git a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml index 202856cddbcd4..f3d3b2a8ecbf2 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -49,4 +49,4 @@ stages: extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} cuda_version: ${{ parameters.cuda_version }} - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index db8668fa9eafe..177df14d6eaee 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 6552c423617b5..489e4ce9f3913 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -43,4 +43,4 @@ RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER ENV PATH=/usr/local/dotnet:$PATH -ENV CUDA_MODULE_LOADING="LAZY" \ No newline at end of file +ENV CUDA_MODULE_LOADING="LAZY" diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index d20da1867926b..957eef8046eaf 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 ARG ROCM_VERSION=6.2.3 #Add our own dependencies @@ -23,4 +23,3 @@ RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER ENV PATH=/usr/local/dotnet:$PATH - diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu index bd3872b4e88e5..56d67599f0bce 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index e6e362ade897d..c8e164282a2f0 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -2,13 +2,12 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14_dotnet:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14_dotnet:20250724.1 ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && python3 -m pip install flatbuffers && rm -rf /tmp/scripts +RUN python3 -m pip install flatbuffers ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh deleted file mode 100755 index 39d7dcfcb70b8..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/bin/bash -set -e -x - -# Download a file from internet -function GetFile { - local uri=$1 - local path=$2 - local force=${3:-false} - local download_retries=${4:-5} - local retry_wait_time_seconds=${5:-30} - - if [[ -f $path ]]; then - if [[ $force = false ]]; then - echo "File '$path' already exists. Skipping download" - return 0 - else - rm -rf "$path" - fi - fi - - if [[ -f $uri ]]; then - echo "'$uri' is a file path, copying file to '$path'" - cp "$uri" "$path" - return $? - fi - - echo "Downloading $uri" - # Use aria2c if available, otherwise use curl - if command -v aria2c > /dev/null; then - aria2c -q -d "$(dirname $path)" -o "$(basename $path)" "$uri" - else - curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail - fi - - return $? -} -mkdir -p /tmp/src - -cd /tmp/src - -CPU_ARCH=$(uname -m) - -echo "Installing Node.js" - -if [[ "$CPU_ARCH" = "x86_64" ]]; then - NODEJS_ARCH=x64 -elif [[ "$CPU_ARCH" = "aarch64" ]]; then - NODEJS_ARCH=arm64 -else - NODEJS_ARCH=$CPU_ARCH -fi -GetFile https://nodejs.org/dist/v22.17.1/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -tar --strip 1 -xf /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -C /usr - -cd / -rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile index 267fc1e661242..31bd41226263f 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20250724.1 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts @@ -8,4 +8,3 @@ ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER - diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index 7981210af14a1..461464093688a 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -2,13 +2,12 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20250724.1 ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && python3 -m pip install flatbuffers && rm -rf /tmp/scripts +RUN python3 -m pip install flatbuffers ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh deleted file mode 100755 index 8a5348f3ef995..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/bin/bash -set -e -x - -# Download a file from internet -function GetFile { - local uri=$1 - local path=$2 - local force=${3:-false} - local download_retries=${4:-5} - local retry_wait_time_seconds=${5:-30} - - if [[ -f $path ]]; then - if [[ $force = false ]]; then - echo "File '$path' already exists. Skipping download" - return 0 - else - rm -rf $path - fi - fi - - if [[ -f $uri ]]; then - echo "'$uri' is a file path, copying file to '$path'" - cp $uri $path - return $? - fi - - echo "Downloading $uri" - # Use aria2c if available, otherwise use curl - if command -v aria2c > /dev/null; then - aria2c -q -d $(dirname $path) -o $(basename $path) "$uri" - else - curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail - fi - - return $? -} -mkdir -p /tmp/src - -cd /tmp/src -CPU_ARCH=$(uname -m) - - -echo "Installing Node.js" -CPU_ARCH=`uname -m` -if [[ "$CPU_ARCH" = "x86_64" ]]; then - NODEJS_ARCH=x64 -elif [[ "$CPU_ARCH" = "aarch64" ]]; then - NODEJS_ARCH=arm64 -else - NODEJS_ARCH=$CPU_ARCH -fi -GetFile https://nodejs.org/dist/v22.17.1/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -tar --strip 1 -xf /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -C /usr - -cd / -rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index 894802dfc8675..043291065736d 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12_dotnet:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12_dotnet:20250724.1 ARG TRT_VERSION #Install TensorRT only if TRT_VERSION is not empty @@ -37,8 +37,7 @@ ENV LC_ALL=en_US.UTF-8 ENV CUDAHOSTCXX=/opt/rh/gcc-toolset-12/root/usr/bin/g++ ADD scripts /tmp/scripts -RUN sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/almalinux.repo && \ - cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +RUN sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/almalinux.repo ENV PATH=/usr/lib/jvm/msopenjdk-17/bin:$PATH ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 ARG BUILD_UID=1001 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh deleted file mode 100755 index f55c017eb8393..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -set -e -x - -# Download a file from internet -function GetFile { - local uri=$1 - local path=$2 - local force=${3:-false} - local download_retries=${4:-5} - local retry_wait_time_seconds=${5:-30} - - if [[ -f $path ]]; then - if [[ $force = false ]]; then - echo "File '$path' already exists. Skipping download" - return 0 - else - rm -rf $path - fi - fi - - if [[ -f $uri ]]; then - echo "'$uri' is a file path, copying file to '$path'" - cp $uri $path - return $? - fi - - echo "Downloading $uri" - # Use aria2c if available, otherwise use curl - if command -v aria2c > /dev/null; then - aria2c -q -d $(dirname $path) -o $(basename $path) "$uri" - else - curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail - fi - - return $? -} -mkdir -p /tmp/src - -cd /tmp/src - - -echo "Installing Node.js" -CPU_ARCH=`uname -m` -if [[ "$CPU_ARCH" = "x86_64" ]]; then - NODEJS_ARCH=x64 -elif [[ "$CPU_ARCH" = "aarch64" ]]; then - NODEJS_ARCH=arm64 -else - NODEJS_ARCH=$CPU_ARCH -fi -GetFile https://nodejs.org/dist/v22.17.1/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -tar --strip 1 -xf /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -C /usr - -cd / -rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile index fc376e33d6d10..43da13df2fe8b 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts @@ -8,4 +8,3 @@ ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER - diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile index fe6c00f99323f..f3341f32a768d 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile @@ -1,5 +1,5 @@ # Use the specified UBI8 base image with GCC 14 -ARG BASEIMAGE="onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2" +ARG BASEIMAGE="onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1" FROM ${BASEIMAGE} ARG BUILD_UID=1000 From ecc358f069488a79c5abc16c5ddfbc4bd5b3c771 Mon Sep 17 00:00:00 2001 From: quic-tirupath Date: Fri, 25 Jul 2025 14:05:58 -0700 Subject: [PATCH 19/33] [QNN EP] Add LPBQ encoding support for MatMul operator (#25539) ### Description - LPBQ encoding is Qualcomm's alternative quantization encoding format for Block Quantization - Add translation logic to read LPBQ pattern on MatMul weights in an QDQ ONNX model exported by AIMET Quantizer - Prepare the corresponding QNN Quantization param for applying LowPowerBlockQuantization on MatMul weights - Apply LPBQ Fusions only for NPU Backend as currently only NPU backend supports LPBQ encoding format ### Motivation and Context - This requires accelerate accuracy sensitive large language models like Phi-3.5 efficiently on Qualcomm's NPU accelerator. --- .../builder/qnn_node_group/lpbqgemm_fusion.cc | 13 +- .../builder/qnn_node_group/lpbqgemm_fusion.h | 2 +- .../qnn_node_group/lpbqmatmul_fusion.cc | 365 ++++++++++++++++++ .../qnn_node_group/lpbqmatmul_fusion.h | 49 +++ .../builder/qnn_node_group/qnn_node_group.cc | 6 +- .../qnn/builder/qnn_node_group/utils.cc | 21 +- 6 files changed, 448 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc index ca15f861f4596..99ea79e028b0c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc @@ -32,8 +32,9 @@ std::unique_ptr LowPowerBlockQuantizedGemmFusion::TryFusion( const logging::Logger& logger) { ORT_UNUSED_PARAMETER(logger); + // Only HTP supports LPBQ encoding format // Looking for a Gemm to start search for Gemm w/ LPBQ encodings pattern. - if (gemm_node_unit.OpType() != "Gemm") { + if (!IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()) || gemm_node_unit.OpType() != "Gemm") { return nullptr; } @@ -236,18 +237,22 @@ Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4); } + std::vector weight_shape; + std::string weight_tensor_name = w_ql_input_1_def.node_arg.Name(); + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(w_ql_input_1_def.node_arg, weight_shape), "Failed to get weight shape"); + // Get attributes like axis, block_size from QuantizeLinear NodeAttrHelper helper(w_ql_node_unit.GetNode()); auto input_channel_axis = helper.Get("axis", static_cast(0)); + if (input_channel_axis < 0) { + input_channel_axis = weight_shape.size() + input_channel_axis; + } auto block_size = helper.Get("block_size", static_cast(0)); size_t output_channel_axis = 0; // Current LowPowerBlockQuantize() support output_channel_axis at index=0; weight_qparams = QnnQuantParamsWrapper(per_channel_float_scale, per_block_int_scale, weight_offset, output_channel_axis, block_size, is_int4_type); - std::vector weight_shape; std::vector unpacked_tensor; - std::string weight_tensor_name = w_ql_input_1_def.node_arg.Name(); - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(w_ql_input_1_def.node_arg, weight_shape), "Failed to get weight shape"); Qnn_DataType_t weight_data_type = is_int4_type ? QNN_DATATYPE_SFIXED_POINT_4 : QNN_DATATYPE_SFIXED_POINT_8; const auto& weight_tensor_proto = qnn_model_wrapper.GetConstantTensor(weight_tensor_name); ORT_RETURN_IF_ERROR(UnpackWeightTensorData(qnn_model_wrapper, weight_tensor_proto, weight_shape, input_channel_axis, unpacked_tensor)); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h index 9dcf07fa863d2..374df8b346e8d 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h @@ -16,7 +16,7 @@ namespace qnn { class QnnModelWrapper; /// -/// Represents a fusion of a {DQ, DQ->Q->DQ} -> Gemm -> DQ sequence. +/// Represents a fusion of a {DQ, DQ->Q->DQ} -> Gemm -> Q sequence. /// This is translated into a QNN's FC w/ LPBQ encodings. /// The contained NodeUnits are of type SingleNode since they are not part of a QDQ node unit. /// diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc new file mode 100644 index 0000000000000..92e0f28b0307c --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc @@ -0,0 +1,365 @@ +#include +#include +#include +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h" + +namespace onnxruntime { +namespace qnn { + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& scale_dql_node_unit, + const NodeUnit& w_ql_node_unit, + const NodeUnit& matmul_node_unit, + const logging::Logger& logger, + bool validate); + +std::unique_ptr LowPowerBlockQuantizedMatMulFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& matmul_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + + // Only HTP supports LPBQ encoding format + // Looking for a MatMul to start search for MatMul w/ LPBQ encodings pattern. + if (!IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()) || matmul_node_unit.OpType() != "MatMul") { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Get QuantizeLinear on Weight (input 1) of MatMul node + const NodeUnit* p_w_ql_node_unit = GetParentOfInput(graph_viewer, + matmul_node_unit, + matmul_node_unit.Inputs()[1], + node_to_node_unit, + node_unit_to_qnn_node_group); + if (p_w_ql_node_unit == nullptr || p_w_ql_node_unit->OpType() != "QuantizeLinear") { + return nullptr; + } + + // Check if input of QuantizeLinear is constant initializer + if (!qnn_model_wrapper.IsConstantInput(p_w_ql_node_unit->Inputs()[0].node_arg.Name())) { + return nullptr; + } + + // Get DequantizeLinear node unit contains per-block int scales and per-channel float scales + const std::array w_ql_parent_types = {"DequantizeLinear"}; + const NodeUnit* p_scale_dql_node_unit = GetParentOfType(graph_viewer, + *p_w_ql_node_unit, + w_ql_parent_types, + node_to_node_unit, + node_unit_to_qnn_node_group); + if (p_scale_dql_node_unit == nullptr) { + return nullptr; + } + + TensorInfo pc_scales_tensor_info = {}; + if (Status status = qnn_model_wrapper.GetTensorInfo(p_scale_dql_node_unit->Inputs()[0], pc_scales_tensor_info); + !status.IsOK()) { + return nullptr; + } + // Check if input 0 of DequantizeLinear is constant initializer and has per-channel float scales + if (!pc_scales_tensor_info.is_initializer || !pc_scales_tensor_info.quant_param.IsPerChannel()) { + return nullptr; + } + + if (Status status = CreateOrValidateOnQnn(qnn_model_wrapper, + *p_scale_dql_node_unit, + *p_w_ql_node_unit, + matmul_node_unit, + logger, + true); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(*p_scale_dql_node_unit, + *p_w_ql_node_unit, + matmul_node_unit); +} + +LowPowerBlockQuantizedMatMulFusion::LowPowerBlockQuantizedMatMulFusion(const NodeUnit& Scale_DQL_node_unit, + const NodeUnit& W_QL_node_unit, + const NodeUnit& MatMul_node_unit) + : node_units_{&Scale_DQL_node_unit, + &W_QL_node_unit, + &MatMul_node_unit} { +} + +Status LowPowerBlockQuantizedMatMulFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], *node_units_[2], logger, true); +} + +Status LowPowerBlockQuantizedMatMulFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], *node_units_[2], logger, false); +} + +gsl::span LowPowerBlockQuantizedMatMulFusion::GetNodeUnits() const { + return node_units_; +} + +const NodeUnit* LowPowerBlockQuantizedMatMulFusion::GetTargetNodeUnit() const { + return node_units_[2]; +} + +namespace { +// Process input[0] for ONNX MatMul that can be translated to either a QNN MatMul. +Status ProcessInput0(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& input_def, + const std::string& original_input_0_name, + std::vector& input_names, + const logging::Logger& logger, + bool do_op_validation) { + TensorInfo input_0_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_def, input_0_info)); + bool reshape_input_0 = input_0_info.shape.size() == 1; + std::string actual_input_0_name = original_input_0_name; + + if (reshape_input_0) { + actual_input_0_name = original_input_0_name + "_ort_qnn_ep_reshape"; + std::vector shape_2d{1, input_0_info.shape[0]}; + QnnQuantParamsWrapper quant_param_2d = input_0_info.quant_param.Copy(); + ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze(input_0_info.shape, shape_2d)); + + // If input_0 is initializer, unpack it and add the tensor with new quantization parameter and shape. + // Otherwise, add a Reshape node. + if (input_0_info.is_initializer) { + std::vector unpacked_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_0_info.initializer_tensor, unpacked_tensor)); + QnnTensorWrapper input_tensorwrapper(actual_input_0_name, QNN_TENSOR_TYPE_STATIC, input_0_info.qnn_data_type, + std::move(quant_param_2d), std::move(shape_2d), std::move(unpacked_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } else { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(original_input_0_name, actual_input_0_name, + input_0_info.shape, shape_2d, + input_0_info.qnn_data_type, input_0_info.quant_param, + quant_param_2d, do_op_validation, + qnn_model_wrapper.IsGraphInput(original_input_0_name), false)); + } + } else { + if (qnn_model_wrapper.IsQnnTensorWrapperExist(actual_input_0_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << actual_input_0_name; + } else { + QnnTensorWrapper input_0_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_0_info, actual_input_0_name, input_0_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_0_tensor)), "Failed to add tensor."); + } + } + input_names.emplace_back(actual_input_0_name); + + return Status::OK(); +} + +// Utility function to unpack weight tensor and transpose to shape [out_channels][in_channels] +Status UnpackWeightTensorData(const QnnModelWrapper& qnn_model_wrapper, + const onnx::TensorProto* weight_tensor_proto, + std::vector& weight_shape, + int64_t& input_channel_axis, + std::vector& unpacked_tensor) { + ORT_RETURN_IF_NOT(weight_tensor_proto != nullptr, "Weight tensor proto is null"); + + if (input_channel_axis == 0) { + // Transpose to keep output_channel at index 0; + // The current logic that quantizes with LPBQ encodings requires out_channels at index 0 + input_channel_axis = weight_shape.size() - 1; + return utils::TwoDimensionTranspose(qnn_model_wrapper, weight_shape, *weight_tensor_proto, unpacked_tensor); + } else { + // No transpose needed, just unpack the initializer data + return qnn_model_wrapper.UnpackInitializerData(*weight_tensor_proto, unpacked_tensor); + } +} + +// A utility function to transpose a 2D data +Status TwoDimensionTranspose(std::vector& data, + std::vector& data_shape, + const Qnn_DataType_t element_type) { + ORT_RETURN_IF_NOT(data_shape.size() == 2, "Expected shape of rank 2"); + + std::array perm = {1, 0}; + std::vector output_shape(data_shape.size()); + ORT_RETURN_IF_ERROR((qnn::utils::PermuteShape(data_shape, perm, output_shape))); + + const size_t elem_byte_size = qnn::utils::GetElementSizeByType(element_type); + ORT_RETURN_IF_NOT(elem_byte_size != 0, "Can't get element byte size from given QNN type"); + + std::vector transposed_data(data.size()); + + for (size_t row = 0; row < data_shape[0]; row++) { + for (size_t col = 0; col < data_shape[1]; col++) { + const size_t src_elem_index = (row * data_shape[1] + col); + const size_t dst_elem_index = (col * output_shape[1] + row); + const size_t src_byte_index = src_elem_index * elem_byte_size; + const size_t dst_byte_index = dst_elem_index * elem_byte_size; + assert(src_byte_index < data.size()); + assert(dst_byte_index < transposed_data.size()); + + std::memcpy(&transposed_data[dst_byte_index], &data[src_byte_index], elem_byte_size); + } + } + + data = std::move(transposed_data); // Update data with transposed data + data_shape = std::move(output_shape); // Update parameter with final transposed shape + return Status::OK(); +} + +// Process LPBQWeight for ONNX MatMul that can be translated to either a QNN MatMul. +Status ProcessLPBQWeight(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& scale_dql_node_unit, + const NodeUnit& w_ql_node_unit, + const NodeUnit& matmul_node_unit, + std::vector& input_names, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + const NodeUnitIODef& mm_input_1_def = matmul_node_unit.Inputs()[1]; + const NodeUnitIODef& w_ql_input_1_def = w_ql_node_unit.Inputs()[0]; + + // get per_channel_float_scale value from Quant param of input[0] of DequantizeLinear + std::vector per_channel_float_scale; + const NodeUnitIODef& per_channel_float_def = scale_dql_node_unit.Inputs()[0]; + const std::optional& scale_dql_quant_param = per_channel_float_def.quant_param; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackScales(scale_dql_quant_param->scale.Name(), per_channel_float_scale)); + + // get per_block_int_scale value from input[0] of DequantizeLinear + std::vector per_block_int_scale; + const NodeUnitIODef& per_block_int_def = scale_dql_node_unit.Inputs()[0]; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackScales(per_block_int_def.node_arg.Name(), per_block_int_scale)); + std::vector weight_offset(per_channel_float_scale.size(), 0); + std::vector block_scales_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(per_block_int_def.node_arg, block_scales_shape), "Failed to get block_scales shape"); + + // Read axis of channels in per-block-int-scales data + NodeAttrHelper scales_node_helper(scale_dql_node_unit.GetNode()); + auto block_scales_axis = scales_node_helper.Get("axis", static_cast(0)); + + // Transpose per-block-int-scales to keep channels at index-0 (QNN LPBQ format requires shape [axis_size][blocks-per-axis]) + if (block_scales_axis == 1) { + ORT_RETURN_IF_ERROR(TwoDimensionTranspose(per_block_int_scale, block_scales_shape, QNN_DATATYPE_UFIXED_POINT_8)); + block_scales_axis = 0; + } + + // Extract weight datatype from zeropoint (aka offset) of Input1 Quant param + const std::optional& mm_input_1_quant_param = mm_input_1_def.quant_param; + bool is_int4_type = false; + if (mm_input_1_quant_param->zero_point != nullptr) { + int32_t elem_data_type = 0; + ORT_RETURN_IF_ERROR(utils::GetOnnxTensorElemDataType(*mm_input_1_quant_param->zero_point, elem_data_type)); + is_int4_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) || + (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4); + } + + std::vector weight_shape; + std::string weight_tensor_name = w_ql_input_1_def.node_arg.Name(); + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(w_ql_input_1_def.node_arg, weight_shape), "Failed to get weight shape"); + + // Get attributes like weight data axis, block_size from QuantizeLinear + NodeAttrHelper helper(w_ql_node_unit.GetNode()); + auto input_channel_axis = helper.Get("axis", static_cast(0)); + if (input_channel_axis < 0) { + input_channel_axis = weight_shape.size() + input_channel_axis; // QNN requires positive axis value + } + auto block_size = helper.Get("block_size", static_cast(0)); + + std::vector unpacked_tensor; + const auto& weight_tensor_proto = qnn_model_wrapper.GetConstantTensor(weight_tensor_name); + // if input_channel_axis = 0, UnpackWeightTensorData will transpose and keep output_channel at 0 + ORT_RETURN_IF_ERROR(UnpackWeightTensorData(qnn_model_wrapper, weight_tensor_proto, weight_shape, input_channel_axis, unpacked_tensor)); + + // Quantize weight tensor + size_t weight_elements = unpacked_tensor.size() / sizeof(float); + auto float_data = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), weight_elements); + std::vector quant_data(weight_elements); + + // weight_data_type = 4 but store in int8 buffer + size_t output_channel_axis = 0; // MatMul requires axis to be rank-1 + Qnn_DataType_t weight_data_type = is_int4_type ? QNN_DATATYPE_SFIXED_POINT_4 : QNN_DATATYPE_SFIXED_POINT_8; + ORT_RETURN_IF_ERROR(qnn::utils::LowPowerBlockQuantizeData(float_data, + weight_shape, + per_channel_float_scale, + per_block_int_scale, + weight_offset, + quant_data, + weight_data_type, + output_channel_axis, + block_scales_axis, + block_size, + block_scales_shape)); + + // MatMul w/ LPBQ requies MatMul(MxK, KxN) and axis = rank-1 (out channels) + // Transpose Weight to KxN, output_channel_axis is modified to rank-1; + if (input_channel_axis == 1) { + ORT_RETURN_IF_ERROR(TwoDimensionTranspose(quant_data, weight_shape, QNN_DATATYPE_SFIXED_POINT_8)); + input_channel_axis = 0; + output_channel_axis = weight_shape.size() - 1; + } + + // Construct Quant params for Weight + QnnQuantParamsWrapper weight_qparams; + weight_qparams = QnnQuantParamsWrapper(per_channel_float_scale, per_block_int_scale, weight_offset, output_channel_axis, block_size, is_int4_type); + + // Get weight tensor type from input of w_dql_tensor or output_dql_tensor + Qnn_TensorType_t weight_tensor_type = qnn_model_wrapper.GetTensorType(weight_tensor_name); + QnnTensorWrapper weight_tensor(weight_tensor_name, weight_tensor_type, QNN_DATATYPE_SFIXED_POINT_8, + std::move(weight_qparams), std::move(weight_shape), + std::move(quant_data)); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor)), "Failed to add weight"); + input_names.emplace_back(weight_tensor_name); + return Status::OK(); +} +} // namespace + +Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& scale_dql_node_unit, + const NodeUnit& w_ql_node_unit, + const NodeUnit& matmul_node_unit, + const logging::Logger& logger, + bool validate) { + assert(scale_dql_node_unit.OpType() == "DequantizeLinear" && + w_ql_node_unit.OpType() == "QuantizeLinear" && + matmul_node_unit.OpType() == "MatMul"); + + const auto& node_name = utils::GetNodeName(matmul_node_unit); + + std::vector input_names; + + // prepare input tensor + const NodeUnitIODef& input_def = matmul_node_unit.Inputs()[0]; + const std::string& input_tensor_name = input_def.node_arg.Name(); + ORT_RETURN_IF_ERROR(ProcessInput0(qnn_model_wrapper, input_def, input_tensor_name, input_names, + logger, validate)); + + // Prepare LowPowerBlockQuantized(LPBQ) Weight + ORT_RETURN_IF_ERROR(ProcessLPBQWeight(qnn_model_wrapper, scale_dql_node_unit, w_ql_node_unit, + matmul_node_unit, input_names, logger)); + + // Prepare Output + const NodeUnitIODef& output_def = matmul_node_unit.Outputs()[0]; + const std::string& op_output_name = output_def.node_arg.Name(); + QnnTensorWrapper output_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + + // Create QNN Node and Validate if require. + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_MAT_MUL, + std::move(input_names), + {op_output_name}, + {}, + validate), + "Failed to add fused Matmul node."); + + return Status(); +} +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h new file mode 100644 index 0000000000000..0d8967de5ace3 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h @@ -0,0 +1,49 @@ +// Copyright (c) Qualcomm. All rights reserved. +// Licensed under the MIT License + +#pragma once + +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a DQ -> Q -> MatMul (+ DQ, DQ, Q). +/// This is translated into a QNN's MatMul w/ LPBQ encodings. +/// The contained NodeUnits are of type SingleNode since they are not part of a QDQ node unit. +/// + +class LowPowerBlockQuantizedMatMulFusion : public IQnnNodeGroup { + public: + LowPowerBlockQuantizedMatMulFusion(const NodeUnit& Scale_DQL_node_unit, + const NodeUnit& W_QL_node_unit, + const NodeUnit& MatMul_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(LowPowerBlockQuantizedMatMulFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "LowPowerBlockQuantizedMatMulFusion"; } + + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& matmul_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index b0f0b4c0ff48a..5f33b639ce613 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -19,6 +19,7 @@ #include "core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/udo_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/ort_api.h" @@ -76,6 +77,7 @@ using FusionFunc = std::function(QnnModelWrapper& static std::unordered_map> fusions = { {"DequantizeLinear", {DQQFusion::TryFusion}}, {"HardSigmoid", {HardSigmoidMulFusion::TryFusion}}, + {"MatMul", {LowPowerBlockQuantizedMatMulFusion::TryFusion}}, {"Gemm", {LowPowerBlockQuantizedGemmFusion::TryFusion, ReshapeGemmFusion::TryFusion}}, {"Mul", {ScaleSoftmaxFusion::TryFusion}}, {"Transpose", {ChannelShuffleFusion::TryFusion}}}; @@ -113,8 +115,8 @@ static std::unique_ptr TryQnnFusions( const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). - if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except MatMul w/ LPBQ encodings + if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode && starting_node_unit.OpType() != "MatMul") { return nullptr; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index 92478e0db7795..10e1633e4b57d 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -177,7 +177,26 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer, const NodeUnitIODef& input, const std::unordered_map& node_unit_map, const std::unordered_map& qnn_node_group_map) { - const Node& child_node = node_unit.GetNode(); + const Node* p_child_node = nullptr; + + for (auto node : node_unit.GetAllNodesInGroup()) { + for (auto node_input : node->InputDefs()) { + if (node_input->Name() == input.node_arg.Name()) { + p_child_node = node; + break; + } + + if (p_child_node != nullptr) { + break; + } + } + } + + if (p_child_node == nullptr) { + return nullptr; + } + + const Node& child_node = *p_child_node; for (auto edge = child_node.InputEdgesBegin(); edge != child_node.InputEdgesEnd(); ++edge) { const Node& parent_node = edge->GetNode(); From 6f4bb5156c417add39ee3bca37d5bad1fe6d9bbf Mon Sep 17 00:00:00 2001 From: Joe Yearsley Date: Fri, 25 Jul 2025 22:16:00 +0100 Subject: [PATCH 20/33] Update nv_basic_test.cc (#24983) ### Description Corrected dtype_name for the respective float16 implementations, previously MLFloat16 would return bf16 rather than fp16, and vice-versa. ### Motivation and Context It looked wrong but passed the tests, I don't fully comprehend what the test suite is doing to try and improve it. I'd be willing to implement any pointers. From 29c20cb78d86f52117c7cfe431d3fc6e1b5d4c18 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 25 Jul 2025 14:36:40 -0700 Subject: [PATCH 21/33] Split windows_tensorrt.yml to two parts (#25528) It reduces the pipeline time for about 30 minutes. The tests still take about 1 hour, which should be reduced. --- .github/workflows/windows_tensorrt.yml | 217 +++++++++++++++++++++---- 1 file changed, 187 insertions(+), 30 deletions(-) diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index e65d23069ad32..dbc138e57a3ec 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -15,14 +15,15 @@ concurrency: group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true +#TODO: enable --build_nodejs jobs: - Windows_GPU_TensorRT_CI_Pipeline: + build: name: Windows GPU TensorRT CI Pipeline - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] steps: - uses: actions/checkout@v4 with: - fetch-depth: 0 # Fetch all history for all tags and branches + fetch-depth: 0 submodules: 'none' - uses: actions/setup-python@v5 @@ -36,29 +37,20 @@ jobs: architecture: x64 - name: Install python modules - run: python -m pip install -r ${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt + run: python -m pip install -r .\tools\ci_build\github\windows\python\requirements.txt + working-directory: ${{ github.workspace }} shell: cmd - - name: Download Primary CUDA SDK v12.2 - run: 'azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v12.2" ${{ runner.temp }}' + - name: Download CUDA SDK v12.2 + working-directory: ${{ runner.temp }} + run: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v12.2" . + dir shell: pwsh - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 - name: Download TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8 run: 'azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" ${{ runner.temp }}' shell: pwsh - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Add CUDA to PATH shell: powershell @@ -69,33 +61,198 @@ jobs: Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\extras\CUPTI\lib64" Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib" - - name: Generate sln - working-directory: ${{ runner.temp }} + - uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - uses: actions/cache@v4 + id: onnx-node-tests-cache + with: + path: ${{ github.workspace }}/js/test/ + key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} + + - name: API Documentation Check and generate run: | - python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + set ORT_DOXY_SRC=${{ github.workspace }} + set ORT_DOXY_OUT=${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo + mkdir %ORT_DOXY_SRC% + mkdir %ORT_DOXY_OUT% + "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg + working-directory: ${{ github.workspace }} shell: cmd - - name: Build + - uses: actions/setup-dotnet@v4 + env: + PROCESSOR_ARCHITECTURE: x64 + with: + dotnet-version: '8.x' + + - name: Use Nuget 6.x + uses: nuget/setup-nuget@v2 + with: + nuget-version: '6.x' + + - name: NuGet restore + run: nuget restore ${{ github.workspace }}\packages.config -ConfigFile ${{ github.workspace }}\NuGet.config -PackagesDirectory ${{ runner.temp }}\build\RelWithDebInfo + shell: cmd + + - name: Set OnnxRuntimeBuildDirectory + shell: pwsh + run: | + $buildDir = Join-Path ${{ runner.temp }} "build" + echo "OnnxRuntimeBuildDirectory=$buildDir" >> $env:GITHUB_ENV + + - name: Build and Clean Binaries working-directory: ${{ runner.temp }} run: | - python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --build --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + npm install -g typescript + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + # Execute the build process + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --build --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + # Clean up the output directory before uploading artifacts + $outputDir = "${{ runner.temp }}\build\RelWithDebInfo" + Write-Host "Cleaning up files from $outputDir..." + + Remove-Item -Path "$outputDir\onnxruntime" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\pybind11" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\models" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\vcpkg_installed" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\_deps" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\CMakeCache.txt" -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\CMakeFiles" -Recurse -Force -ErrorAction SilentlyContinue + # Remove intermediate object files as in the original script + Remove-Item -Path $outputDir -Include "*.obj" -Recurse + shell: pwsh + + - name: Upload build artifacts + uses: actions/upload-artifact@v4 + with: + name: build-artifacts + path: ${{ runner.temp }}\build + env: + OrtPackageId: Microsoft.ML.OnnxRuntime.Gpu + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + setVcvars: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + DocUpdateNeeded: false + ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' + AZCOPY_AUTO_LOGIN_TYPE: MSI + AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + + test: + name: Windows GPU TensorRT CI Pipeline Test Job + needs: build + timeout-minutes: 300 + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + submodules: 'none' + + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + name: build-artifacts + path: ${{ runner.temp }}\build + + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + run: python -m pip install -r .\tools\ci_build\github\windows\python\requirements.txt + working-directory: ${{ github.workspace }} shell: cmd - - name: Add build dir to PATH + - name: Download CUDA SDK v12.2 + working-directory: ${{ runner.temp }} + run: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v12.2" . + dir + shell: pwsh + + - name: Download TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8 + run: 'azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" ${{ runner.temp }}' + shell: pwsh + + - name: Add CUDA to PATH shell: powershell run: | Write-Host "Adding CUDA to PATH" - Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\build\RelWithDebInfo\RelWithDebInfo" + Write-Host "CUDA Path: $env:RUNNER_TEMP\v12.2\bin" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\bin" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\extras\CUPTI\lib64" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib" + + - name: Set OnnxRuntimeBuildDirectory + shell: pwsh + run: | + $buildDir = Join-Path ${{ runner.temp }} "build" + echo "OnnxRuntimeBuildDirectory=$buildDir" >> $env:GITHUB_ENV - name: Install ONNX Runtime Wheel uses: ./.github/actions/install-onnxruntime-wheel with: whl-directory: ${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo\dist - - name: Run tests + - name: Run Tests working-directory: ${{ runner.temp }} run: | - mklink /D /J ${{ github.workspace }}\RelWithDebInfo\models ${{ github.workspace }}\models - python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + npm install -g typescript + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + python.exe ${{ github.workspace }}\tools\python\update_ctest_path.py "${{ runner.temp }}\build\RelWithDebInfo\CTestTestfile.cmake" "${{ runner.temp }}\build\RelWithDebInfo" + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + shell: pwsh + + - name: Validate C# native delegates + run: python tools\ValidateNativeDelegateAttributes.py + working-directory: ${{ github.workspace }}\csharp shell: cmd - timeout-minutes: 180 + env: + OrtPackageId: Microsoft.ML.OnnxRuntime.Gpu + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + setVcvars: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + DocUpdateNeeded: false + ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' + AZCOPY_AUTO_LOGIN_TYPE: MSI + AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 From 9aad21c7f84d56eb7464fc9100f437b140193f92 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 25 Jul 2025 21:39:40 -0700 Subject: [PATCH 22/33] DynamicQuantizeMatMul - handle case where B zero point input is provided but not constant. (#25544) ### Description In DynamicQuantizeMatMul KleidiAI-specific prepacking logic, handle case where B zero point input is provided but not constant. In this case, we should not prepack. Add some unit tests that test the prepacking code path. Add check for ARM SME instructions in DynamicQuantizeMatMul before calling `MlasDynamicQGemmBatch()` and associated functions. ### Motivation and Context Follow up to #25187 --- .../quantization/dynamic_quantize_matmul.cc | 49 +++-- .../dynamic_quantize_matmul_test.cc | 182 +++++++++++------- 2 files changed, 138 insertions(+), 93 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index e2bb3b508ca7c..85a2cbaea0e44 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME() #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/mlas/inc/mlas.h" @@ -10,6 +11,7 @@ #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" +#include #include #include @@ -169,43 +171,40 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { // only pack Matrix B if (input_idx == GetBIdx()) { const Tensor* b_zp_constant_tensor{nullptr}; - bool b_quantization_is_asymmetric = false; + bool b_quantization_might_be_asymmetric = false; - // zero point tensor could be provided as a direct input to the kernel and not as a constant so this - // test is not sufficient const OrtValue* b_zp; if (Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp)) { b_zp_constant_tensor = &b_zp->Get(); } - // MlasDynamicQgemm requires symmetric quantization for B, so no zero point should exist or it should - // have a zero value - if (b_zp_constant_tensor != nullptr) { // Covers the case where tensor is not a constant - const auto& shape = b_zp_constant_tensor->Shape(); - const auto* zp_data = static_cast(b_zp_constant_tensor->DataRaw()); - size_t zp_size = static_cast(shape.Size()); - // MlasDynamicQgemm requires symmetric quantization: zp must be scalar 0 or 1D all-zero - if ((shape.NumDimensions() == 0) && (zp_data[0] == 0)) { - b_quantization_is_asymmetric = false; - } else if (shape.NumDimensions() == 1) { - b_quantization_is_asymmetric = false; - for (size_t i = 0; i < zp_size; ++i) { - if (zp_data[i] != 0) { - b_quantization_is_asymmetric = true; - break; - } - } - } else { - // Unsupported higher-rank zp tensor - b_quantization_is_asymmetric = true; - } + // MlasDynamicQgemm requires symmetric quantization for B, so the B zero point value should either be all zeros + // or not provided. + if (b_zp_constant_tensor != nullptr) { + // B zero point is constant. Check if it is all zeros. + assert(b_zp_constant_tensor->IsDataType() || b_zp_constant_tensor->IsDataType()); + const auto* zp_bytes = static_cast(b_zp_constant_tensor->DataRaw()); + const size_t zp_size_in_bytes = b_zp_constant_tensor->SizeInBytes(); + b_quantization_might_be_asymmetric = std::any_of(zp_bytes, zp_bytes + zp_size_in_bytes, + [](std::byte v) { return v != std::byte{0}; }); + } else { + // B zero point input is not constant. If it exists, we can't assume symmetric quantization. + const auto input_defs = Info().node().InputDefs(); + const bool b_zp_input_exists = input_defs.size() > IN_B_ZERO_POINT && input_defs[IN_B_ZERO_POINT]->Exists(); + b_quantization_might_be_asymmetric = b_zp_input_exists; } // MlasDynamicQgemm requires scale data to be available at packing stage const Tensor* b_scale_tensor = nullptr; const bool b_scale_available = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_tensor); - can_use_dynamic_quant_mlas_ = (!b_quantization_is_asymmetric && b_scale_available); + can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available); + + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + // We check that here too before attempting to use them. + if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { + can_use_dynamic_quant_mlas_ = false; + } // Only handle the common case of a 2D weight matrix. Additional matrices // could be handled by stacking the packed buffers. diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc index c9a7116bf8052..2918e4baf86a4 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc @@ -82,21 +82,48 @@ static void CalculateDynamicQuantizeMatMul(const int64_t M, const int64_t N, con } } +struct TestDynamicQuantizeMatMulOptions { + bool is_matrix_b_constant = true; + + bool per_column = false; + + bool is_scale_constant = false; + + bool has_zp = true; + bool is_zp_constant = false; + bool is_zp_zero = false; + + bool has_bias = false; + bool is_bias_constant = false; + + bool empty_input = false; +}; + template -void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, - bool per_column = false, - bool has_zp = true, - bool has_bias = false, - bool empty_input = false) { +void TestDynamicQuantizeMatMul(const TestDynamicQuantizeMatMulOptions& opts) { + static_assert(std::is_same_v || std::is_same_v); + + SCOPED_TRACE(MakeString( + "b data type:", (std::is_same_v ? "uint8" : "int8"), + ", is_matrix_b_constant:", opts.is_matrix_b_constant, + ", per_column:", opts.per_column, + ", is_scale_constant:", opts.is_scale_constant, + ", has_zp:", opts.has_zp, + ", is_zp_constant:", opts.is_zp_constant, + ", is_zp_zero:", opts.is_zp_zero, + ", has_bias:", opts.has_bias, + ", is_bias_constant:", opts.is_bias_constant, + ", empty_input:", opts.empty_input)); + // create rand inputs RandomValueGenerator random{1668426375}; - int64_t M = empty_input ? 1 : 4; + int64_t M = opts.empty_input ? 1 : 4; int64_t N = 128; int64_t K = 128; - std::vector A_dims{empty_input ? 0 : M, K}; + std::vector A_dims{opts.empty_input ? 0 : M, K}; std::vector B_dims{K, N}; - std::vector Y_dims{empty_input ? 0 : M, K}; + std::vector Y_dims{opts.empty_input ? 0 : M, K}; std::vector A_data = random.Uniform(A_dims, -1.0f, 1.0f); std::vector B_data; std::vector tmp_B_data = random.Uniform(B_dims, @@ -106,101 +133,120 @@ void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, return static_cast(v); }); - int64_t b_scale_zp_size = per_column ? B_dims.back() : 1; + int64_t b_scale_zp_size = opts.per_column ? B_dims.back() : 1; std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); std::vector B_zero_point(b_scale_zp_size); - std::for_each(B_zero_point.begin(), - B_zero_point.end(), - [&random](T& zp) { - zp = static_cast(random.Uniform(std::array{1}, - std::numeric_limits::min(), - std::numeric_limits::max())[0]); - }); + if (!opts.is_zp_zero) { + std::for_each(B_zero_point.begin(), + B_zero_point.end(), + [&random](T& zp) { + zp = static_cast(random.Uniform(std::array{1}, + std::numeric_limits::min(), + std::numeric_limits::max())[0]); + }); + } std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); OpTester test("DynamicQuantizeMatMul", 1, onnxruntime::kMSDomain); test.AddInput("A", A_dims, A_data); - test.AddInput("B", B_dims, B_data, is_matrix_b_constant); - test.AddInput("b_scale", {b_scale_zp_size}, B_scale); + test.AddInput("B", B_dims, B_data, opts.is_matrix_b_constant); + test.AddInput("b_scale", {b_scale_zp_size}, B_scale, opts.is_scale_constant); - if (has_zp) { - test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point); + if (opts.has_zp) { + test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point, opts.is_zp_constant); } else { test.AddOptionalInputEdge(); } - if (has_bias) { - test.AddInput("bias", {B_dims.back()}, Bias); + if (opts.has_bias) { + test.AddInput("bias", {B_dims.back()}, Bias, opts.is_bias_constant); } else { test.AddOptionalInputEdge(); } std::vector Y_data(M * N); CalculateDynamicQuantizeMatMul(M, N, K, A_data, B_data, B_scale, B_zero_point, Bias, Y_data, - per_column, has_zp, has_bias); + opts.per_column, opts.has_zp, opts.has_bias); test.AddOutput("Y", Y_dims, Y_data); test.SetOutputRelErr("Y", 0.02f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } -template -void RunDynamicQuantizeMatMulTest() { - TestDynamicQuantizeMatMul(false, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(true, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(false, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(true, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); +template +void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, + bool per_column = false, + bool has_zp = true, + bool has_bias = false, + bool empty_input = false) { + TestDynamicQuantizeMatMulOptions opts{}; + opts.is_matrix_b_constant = is_matrix_b_constant; + opts.per_column = per_column; + opts.has_zp = has_zp; + opts.has_bias = has_bias; + opts.empty_input = empty_input; + + TestDynamicQuantizeMatMul(opts); } -TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_S8) { - RunDynamicQuantizeMatMulTest(); +template +void RunDynamicQuantizeMatMulTest() { + for (bool is_matrix_b_constant : {false, true}) { + for (bool per_column : {false, true}) { + for (bool has_zp : {false, true}) { + for (bool has_bias : {false, true}) { + TestDynamicQuantizeMatMul(is_matrix_b_constant, + per_column, + has_zp, + has_bias); + } + } + } + } } -TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_U8) { - RunDynamicQuantizeMatMulTest(); +TEST(DynamicQuantizeMatMul, Int8) { + RunDynamicQuantizeMatMulTest(); } -TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_S8) { - RunDynamicQuantizeMatMulTest(); +TEST(DynamicQuantizeMatMul, UInt8) { + RunDynamicQuantizeMatMulTest(); } -TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_U8) { - RunDynamicQuantizeMatMulTest(); -} +TEST(DynamicQuantizeMatMul, WithConstantBInputs) { + TestDynamicQuantizeMatMulOptions base_opts{}; + base_opts.is_matrix_b_constant = true; + base_opts.is_scale_constant = true; + base_opts.is_zp_constant = true; -TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_S8) { - RunDynamicQuantizeMatMulTest(); -} + { + // no zp + auto opts = base_opts; + opts.has_zp = false; -TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_U8) { - RunDynamicQuantizeMatMulTest(); -} + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } -TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_S8) { - RunDynamicQuantizeMatMulTest(); -} + { + // zp that is zero (symmetric quantization) + auto opts = base_opts; + opts.has_zp = true; + opts.is_zp_zero = true; -TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_U8) { - RunDynamicQuantizeMatMulTest(); + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } + + { + // zp that is non-zero + auto opts = base_opts; + opts.has_zp = true; + opts.is_zp_zero = false; + + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } } TEST(DynamicQuantizeMatMul, UInt8_test_with_empty_input) { From b214da55634d08cf61796c56be288878f4755586 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 26 Jul 2025 14:25:57 -0700 Subject: [PATCH 23/33] upgrade emsdk to v4.0.11 (#25477) ### Description ### Motivation and Context Fix the build break on Windows+Ninja --- .gitmodules | 2 +- cmake/CMakeLists.txt | 1 + cmake/external/emsdk | 2 +- cmake/onnxruntime_unittests.cmake | 6 ++ cmake/onnxruntime_webassembly.cmake | 9 ++- onnxruntime/wasm/pre-async.js | 70 ++----------------- tools/ci_build/build_args.py | 2 +- .../templates/linux-wasm-ci.yml | 8 +-- 8 files changed, 22 insertions(+), 78 deletions(-) diff --git a/.gitmodules b/.gitmodules index b5bff01d89850..a48c4062a90fe 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,4 +7,4 @@ [submodule "cmake/external/emsdk"] path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git - branch = 4.0.8 + branch = 4.0.11 diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index b0941b4d0c922..a76be16572a03 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -562,6 +562,7 @@ else() check_cxx_compiler_flag(-Wcast-function-type HAS_CAST_FUNCTION_TYPE) check_cxx_compiler_flag(-Wcatch-value HAS_CATCH_VALUE) check_cxx_compiler_flag(-Wclass-memaccess HAS_CLASS_MEMACCESS) + check_cxx_compiler_flag(-Wcharacter-conversion HAS_CHARACTER_CONVERSION) check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) check_cxx_compiler_flag(-Wdeprecated-anon-enum-enum-conversion HAS_DEPRECATED_ANON_ENUM_ENUM_CONVERSION) check_cxx_compiler_flag(-Wdeprecated-builtins HAS_DEPRECATED_BUILTINS) diff --git a/cmake/external/emsdk b/cmake/external/emsdk index 419021fa04042..d49219d03a41c 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit 419021fa040428bc69ef1559b325addb8e10211f +Subproject commit d49219d03a41cd12f95a33ba84273c20d41fd350 diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 96e513c8a7bc9..c3bebba3bab54 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -120,6 +120,9 @@ function(AddTest) if (${HAS_NOERROR}) target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=uninitialized>") endif() + if (${HAS_CHARACTER_CONVERSION}) + target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=character-conversion>") + endif() endif() set(TEST_ARGS ${_UT_TEST_ARGS}) @@ -787,6 +790,9 @@ if(MSVC) "$<$>:/wd6326>") else() target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) + if (HAS_CHARACTER_CONVERSION) + target_compile_options(onnxruntime_test_utils PRIVATE "$<$:-Wno-error=character-conversion>") + endif() endif() if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS}) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index ffe866164a411..e2d04843d858e 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -175,9 +175,9 @@ else() "${ONNXRUNTIME_ROOT}/wasm/api.cc" "${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc" ) - set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0") message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set") - set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING}) + set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS "-sDISABLE_EXCEPTION_CATCHING=0") + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_CATCHING=0") endif() target_link_libraries(onnxruntime_webassembly PRIVATE @@ -241,11 +241,10 @@ else() "SHELL:-s FILESYSTEM=0" "SHELL:-s INCOMING_MODULE_JS_API=[locateFile,instantiateWasm,wasmBinary]" "SHELL:-s WASM_BIGINT=1" - ${WASM_API_EXCEPTION_CATCHING} --no-entry "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) - + if (onnxruntime_USE_JSEP) # NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU # This flag allows async functions to be called from sync functions, in the cost of binary size and @@ -256,7 +255,7 @@ else() "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" ) list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js") - + endif() if (onnxruntime_USE_WEBGPU) diff --git a/onnxruntime/wasm/pre-async.js b/onnxruntime/wasm/pre-async.js index 8c75dc7c5cf1e..1f8f17535e7d4 100644 --- a/onnxruntime/wasm/pre-async.js +++ b/onnxruntime/wasm/pre-async.js @@ -15,78 +15,20 @@ let initAsyncImpl = () => { // It removes some overhead in cwarp() and ccall() that we don't need. // // Currently in ASYNCIFY build, we only use this for the following functions: + // - OrtAppendExecutionProvider() // - OrtCreateSession() // - OrtRun() // - OrtRunWithBinding() // - OrtBindInput() // - // Note: about parameters "getFunc" and "setFunc": - // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. + // We need to wrap these functions with an async wrapper so that they can be called in an async context. // - // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a - // wrapper for OrtRun() like this (minified): - // ``` - // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); - // ``` - // - // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates - // a wrapper for OrtRun() like this (minified): - // ``` - // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // - // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once - // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will - // reset d._OrtRun to J.ka when the first time it is called. - // - // The difference is important because we need to design the async wrapper in a way that it can handle both cases. - // - // Now, let's look at how the async wrapper is designed to work for both cases: - // - // - Debug build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. - // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // - Release build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. - // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this - // function: - // ``` - // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). - // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored - // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release - // build. - // - // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an - // exported function and set the new value of an exported function. - // - const wrapAsync = (func, getFunc, setFunc) => { + const wrapAsync = (func) => { return (...args) => { // cache the async data before calling the function. const previousAsync = Asyncify.currData; - const previousFunc = getFunc?.(); const ret = func(...args); - const newFunc = getFunc?.(); - if (previousFunc !== newFunc) { - // The exported function has been updated. - // Set the sync function reference to the new function. - func = newFunc; - // Set the exported function back to the async wrapper. - setFunc(previousFunc); - // Remove getFunc and setFunc. They are no longer needed. - setFunc = null; - getFunc = null; - } // If the async data has been changed, it means that the function started an async operation. if (Asyncify.currData != previousAsync) { @@ -101,11 +43,7 @@ let initAsyncImpl = () => { // replace the original functions with asyncified versions const wrapAsyncAPIs = (funcNames) => { for (const funcName of funcNames) { - Module[funcName] = wrapAsync( - Module[funcName], - () => Module[funcName], - (v) => (Module[funcName] = v) - ); + Module[funcName] = wrapAsync(Module[funcName]); } }; diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index c42f8e3219da4..82118148d35f9 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -342,7 +342,7 @@ def add_webassembly_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for WebAssembly (WASM) platform builds.""" parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly.") parser.add_argument("--build_wasm_static_lib", action="store_true", help="Build WebAssembly static library.") - parser.add_argument("--emsdk_version", default="4.0.8", help="Specify version of emsdk.") + parser.add_argument("--emsdk_version", default="4.0.11", help="Specify version of emsdk.") parser.add_argument("--enable_wasm_simd", action="store_true", help="Enable WebAssembly SIMD.") parser.add_argument("--enable_wasm_relaxed_simd", action="store_true", help="Enable WebAssembly Relaxed SIMD.") parser.add_argument("--enable_wasm_threads", action="store_true", help="Enable WebAssembly multi-threading.") diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index ef0f4c6e0883c..e08de4be17574 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -88,15 +88,15 @@ jobs: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.8 ccache-git-emscripten-64bit - ./emsdk activate 4.0.8 ccache-git-emscripten-64bit + ./emsdk install 4.0.11 ccache-git-emscripten-64bit + ./emsdk activate 4.0.11 ccache-git-emscripten-64bit displayName: 'emsdk install and activate ccache for emscripten' - ${{if eq(parameters.WithCache, false)}}: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.8 - ./emsdk activate 4.0.8 + ./emsdk install 4.0.11 + ./emsdk activate 4.0.11 displayName: 'emsdk install and activate ccache for emscripten' - template: build-linux-wasm-step.yml From 7c0c29d6de1957360a8efbaa766ff3961a49be98 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 26 Jul 2025 16:39:49 -0700 Subject: [PATCH 24/33] [build] Fix the file copy in get_docker_image.py (#25548) ### Description Fixes the packaging pipeline. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tools/ci_build/get_docker_image.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/ci_build/get_docker_image.py b/tools/ci_build/get_docker_image.py index e656cedae5916..90947e534918d 100755 --- a/tools/ci_build/get_docker_image.py +++ b/tools/ci_build/get_docker_image.py @@ -71,11 +71,14 @@ def main(): log.info(f"Image: {full_image_name}") - dst_deps_file = Path(args.context) / "scripts" / "deps.txt" + dst_scripts_dir = Path(args.context) / "scripts" + dst_deps_file = dst_scripts_dir / "deps.txt" # The docker file may provide a special deps.txt in its docker context dir and uses that one. # Otherwise, copy a generic one from this repo's cmake dir. if not dst_deps_file.exists(): log.info(f"Copy deps.txt to : {dst_deps_file}") + if not dst_scripts_dir.exists(): + dst_scripts_dir.mkdir(parents=True, exist_ok=True) shutil.copyfile(Path(REPO_DIR) / "cmake" / "deps.txt", str(dst_deps_file)) if "manylinux" in args.dockerfile and args.multiple_repos: From 1b584c192cc7c016c8af652f7d525c221f0b019a Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Sun, 27 Jul 2025 08:57:42 +0800 Subject: [PATCH 25/33] [webgpu] Enable per-run control for graph capture (#25367) This PR uses the existed RunOption `gpu_graph_id` to control whether to skip the graph capture. When the webgpu ep option `enableGraphCapture` is enabled, in RunOption, gpu_graph_id = -1 means skipping graph capture. Otherwise, go to the graph capture path for each session.run. If gpu_graph_id is not specified in RunOption, it will respect `enableGraphCapture `'s value to see whether to go to graph capture path. --- .../webgpu/webgpu_execution_provider.cc | 36 +++++++++++++------ .../webgpu/webgpu_execution_provider.h | 1 + 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 6e09f494f4a8d..bca41b7851c28 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -18,8 +18,11 @@ #include "core/framework/data_transfer_manager.h" #include "core/framework/fallback_cpu_capability.h" #include "core/framework/kernel_registry.h" +#include "core/framework/run_options.h" #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" +#include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/common/parse_string.h" #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/data_transfer.h" @@ -692,7 +695,6 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -907,7 +909,7 @@ Status WebGpuExecutionProvider::OnSessionInitializationEnd() { return Status::OK(); } -Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { if (context_.ValidationMode() >= ValidationMode::Basic) { context_.PushErrorScope(); } @@ -916,20 +918,32 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_ context_.StartProfiling(); } - if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); + if (IsGraphCaptureEnabled()) { + auto graph_annotation_str = run_options.config_options.GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + int graph_annotation_id = 0; + if (graph_annotation_str.has_value()) { + ORT_ENFORCE(onnxruntime::TryParseStringWithClassicLocale(*graph_annotation_str, graph_annotation_id), + "Failed to parse the graph annotation id: ", + *graph_annotation_str); + } + + if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) { + context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); + } + m_current_graph_annotation_id = graph_annotation_id; } return Status::OK(); } -Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { +Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /* run_options */) { context_.Flush(BufferManager()); - if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) { + if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed()) { context_.CaptureEnd(); is_graph_captured_ = true; + ORT_RETURN_IF_ERROR(ReplayGraph(m_current_graph_annotation_id)); } else { IncrementRegularRunCountBeforeGraphCapture(); } @@ -952,12 +966,12 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const { return enable_graph_capture_; } -bool WebGpuExecutionProvider::IsGraphCaptured(int) const { - return is_graph_captured_; +bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + return is_graph_captured_ && graph_annotation_id != -1; } -Status WebGpuExecutionProvider::ReplayGraph(int) { - ORT_ENFORCE(IsGraphCaptured(0)); +Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { + ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); context_.Replay(captured_commands_, *graph_buffer_mgr_); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 2567be2a1eb18..3bbec164a0190 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -99,6 +99,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + int m_current_graph_annotation_id = 0; webgpu::GpuBufferAllocator* allocator_ = nullptr; // Buffer manager specifically for graph capture mode From 51d3198c4cd3e8dcabb3990df60033031128dca2 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Mon, 28 Jul 2025 16:27:06 +1000 Subject: [PATCH 26/33] Refactor plugin EP support (#25541) ### Description Refactor to split out classes and make things easier to find. ### Motivation and Context Cleanup --- cmake/onnxruntime_session.cmake | 4 +- .../onnxruntime/core/session/environment.h | 2 +- onnxruntime/core/session/environment.cc | 8 +- .../core/session/ep_library_internal.cc | 281 ------------------ .../session/ep_library_provider_bridge.cc | 140 --------- onnxruntime/core/session/onnxruntime_c_api.cc | 4 +- .../core/session/{ => plugin_ep}/ep_api.cc | 2 +- .../core/session/{ => plugin_ep}/ep_api.h | 0 .../core/session/plugin_ep/ep_factory_cpu.cc | 56 ++++ .../core/session/plugin_ep/ep_factory_cpu.h | 31 ++ .../core/session/plugin_ep/ep_factory_dml.cc | 113 +++++++ .../core/session/plugin_ep/ep_factory_dml.h | 40 +++ .../{ => plugin_ep}/ep_factory_internal.cc | 40 +-- .../{ => plugin_ep}/ep_factory_internal.h | 101 ++----- .../plugin_ep/ep_factory_internal_impl.cc | 32 ++ .../plugin_ep/ep_factory_internal_impl.h | 86 ++++++ .../plugin_ep/ep_factory_provider_bridge.cc | 44 +++ .../plugin_ep/ep_factory_provider_bridge.h | 66 ++++ .../session/plugin_ep/ep_factory_webgpu.cc | 76 +++++ .../session/plugin_ep/ep_factory_webgpu.h | 35 +++ .../core/session/{ => plugin_ep}/ep_library.h | 0 .../session/plugin_ep/ep_library_internal.cc | 52 ++++ .../{ => plugin_ep}/ep_library_internal.h | 4 +- .../{ => plugin_ep}/ep_library_plugin.cc | 2 +- .../{ => plugin_ep}/ep_library_plugin.h | 2 +- .../plugin_ep/ep_library_provider_bridge.cc | 58 ++++ .../ep_library_provider_bridge.h | 4 +- .../ep_plugin_provider_interfaces.cc | 2 +- .../ep_plugin_provider_interfaces.h | 0 .../forward_to_factory_impl.h} | 2 +- .../core/session/provider_policy_context.cc | 4 +- onnxruntime/core/session/utils.cc | 8 +- .../python/onnxruntime_pybind_state.cc | 2 +- .../test/framework/ep_plugin_provider_test.cc | 2 +- 34 files changed, 736 insertions(+), 567 deletions(-) delete mode 100644 onnxruntime/core/session/ep_library_internal.cc delete mode 100644 onnxruntime/core/session/ep_library_provider_bridge.cc rename onnxruntime/core/session/{ => plugin_ep}/ep_api.cc (99%) rename onnxruntime/core/session/{ => plugin_ep}/ep_api.h (100%) create mode 100644 onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc create mode 100644 onnxruntime/core/session/plugin_ep/ep_factory_cpu.h create mode 100644 onnxruntime/core/session/plugin_ep/ep_factory_dml.cc create mode 100644 onnxruntime/core/session/plugin_ep/ep_factory_dml.h rename onnxruntime/core/session/{ => plugin_ep}/ep_factory_internal.cc (58%) rename onnxruntime/core/session/{ => plugin_ep}/ep_factory_internal.h (50%) create mode 100644 onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc create mode 100644 onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h create mode 100644 onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc create mode 100644 onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h create mode 100644 onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc create mode 100644 onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h rename onnxruntime/core/session/{ => plugin_ep}/ep_library.h (100%) create mode 100644 onnxruntime/core/session/plugin_ep/ep_library_internal.cc rename onnxruntime/core/session/{ => plugin_ep}/ep_library_internal.h (94%) rename onnxruntime/core/session/{ => plugin_ep}/ep_library_plugin.cc (98%) rename onnxruntime/core/session/{ => plugin_ep}/ep_library_plugin.h (96%) create mode 100644 onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc rename onnxruntime/core/session/{ => plugin_ep}/ep_library_provider_bridge.h (95%) rename onnxruntime/core/session/{ => plugin_ep}/ep_plugin_provider_interfaces.cc (99%) rename onnxruntime/core/session/{ => plugin_ep}/ep_plugin_provider_interfaces.h (100%) rename onnxruntime/core/session/{ep_api_utils.h => plugin_ep/forward_to_factory_impl.h} (99%) diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 3ec3c6ee1d5ae..f81a7a9726b76 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -5,6 +5,8 @@ file(GLOB onnxruntime_session_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_INCLUDE_DIR}/core/session/*.h" "${ONNXRUNTIME_ROOT}/core/session/*.h" "${ONNXRUNTIME_ROOT}/core/session/*.cc" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.h" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.cc" ) if (onnxruntime_ENABLE_TRAINING_APIS) @@ -22,7 +24,7 @@ endif() # which is not enabled for any minimal builds. if (onnxruntime_MINIMAL_BUILD) file(GLOB autoep_srcs - "${ONNXRUNTIME_ROOT}/core/session/ep_*.*" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.*" ) set(onnxruntime_session_src_exclude diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 7e49275e59b8b..306f81df38e48 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -20,7 +20,7 @@ #include "core/platform/threadpool.h" #include "core/session/abi_devices.h" -#include "core/session/ep_library.h" +#include "core/session/plugin_ep/ep_library.h" #include "core/session/onnxruntime_c_api.h" struct OrtThreadingOptions; diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 450a8bad09392..2b553aecbca6c 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -16,10 +16,10 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" #include "core/session/inference_session.h" -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_library_internal.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library_internal.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/ort_apis.h" #include "core/session/utils.h" diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc deleted file mode 100644 index 986ccb1fa17fc..0000000000000 --- a/onnxruntime/core/session/ep_library_internal.cc +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library_internal.h" - -#include "core/framework/error_code_helper.h" -#include "core/framework/ortmemoryinfo.h" -#include "core/framework/session_options.h" -#include "core/providers/cpu/cpu_execution_provider.h" -#include "core/session/abi_devices.h" -#include "core/session/abi_logger.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/ep_api.h" -#include "core/session/ort_apis.h" - -#if defined(USE_DML) -#include "core/providers/dml/dml_provider_factory_creator.h" -#endif - -#if defined(USE_WEBGPU) -#include "core/providers/webgpu/webgpu_provider_factory_creator.h" -#endif - -namespace onnxruntime { - -class CpuEpFactory : public EpFactoryInternalImpl { - public: - CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - ORT_API_RETURN_IF_ERROR( - OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "CPU EP factory currently only supports one device at a time."); - } - - CPUExecutionProviderInfo epi{session_options->value.enable_cpu_mem_arena}; - *ep = std::make_unique(epi); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } -}; - -std::unique_ptr EpLibraryInternal::CreateCpuEp() { - auto cpu_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} - -#if defined(USE_DML) -class DmlEpFactory : public EpFactoryInternalImpl { - public: - DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - num_ep_devices = 0; - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - std::unique_ptr ep_options; - - // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is - // associated with a specific device. - // How would we know what options should not allow user overrides if set in OrtEpDevice? - int32_t device_id = 0; // If no device_id was found default to 0 - if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { - ep_options = std::make_unique(); - device_id = std::stoi(it->second); - } - - ep_options->Add("device_id", std::to_string(device_id)); - - auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, - &device, nullptr, ep_options.get(), - &ep_devices[num_ep_devices]); - - if (device_memory_infos.size() < device_id + 1) { - device_memory_infos.resize(device_id + 1); - device_allocators.resize(device_id + 1); - } - - if (device_memory_infos[device_id] == nullptr) { - // Create memory info for the device if it doesn't already exist - device_memory_infos[device_id] = std::make_unique( - "DML", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, - narrow(device_id))); - } - - // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. - // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], - // device_memory_infos[device_id].get()); - - if (api_status != nullptr) { - return api_status; - } - - ++num_ep_devices; - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "DML EP factory currently only supports one device at a time."); - } - - auto ep_options = GetOptionsFromSessionOptions(session_options->value); - auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, - ep_options); - - *ep = dml_ep_factory->CreateProvider(); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } - - OrtStatus* CreateAllocator(const OrtMemoryInfo* /*memory_info*/, - const OrtKeyValuePairs* /*allocator_options*/, - OrtAllocator** allocator) noexcept override { - // TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That - // requires pulling lots of things out of the DML EP to get the D3D12 device and create a - // BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp - //*allocator = device_allocators[memory_info->device.Id()].get(); - *allocator = nullptr; - return nullptr; - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. - *data_transfer = nullptr; - return nullptr; - } - - std::vector> device_memory_infos; // memory info for each device - std::vector> device_allocators; // allocators for each device -}; - -std::unique_ptr EpLibraryInternal::CreateDmlEp() { - auto dml_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(dml_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} -#endif - -#if defined(USE_WEBGPU) -class WebGpuEpFactory : public EpFactoryInternalImpl { - public: - WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - num_ep_devices = 0; - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // TODO: any metadata or options to add? - ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, - &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "WebGPU EP factory currently only supports one device at a time."); - } - - auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(session_options->value.config_options); - *ep = webgpu_ep_factory->CreateProvider(); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } - - /* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of - an InferenceSession. - OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* allocator_options, - OrtAllocator** allocator) noexcept override { - *allocator = device_allocators[memory_info->device.Id()].get(); - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. - *data_transfer = nullptr; - return nullptr; - } - */ -}; - -std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { - auto webgpu_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} -#endif - -std::vector> EpLibraryInternal::CreateInternalEps() { - std::vector> internal_eps; - internal_eps.reserve(4); - - // CPU EP - internal_eps.push_back(CreateCpuEp()); - -#if defined(USE_WEBGPU) - internal_eps.push_back(CreateWebGpuEp()); -#endif - -#if defined(USE_DML) - internal_eps.push_back(CreateDmlEp()); -#endif - - return internal_eps; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc deleted file mode 100644 index ae553891beaa7..0000000000000 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library_provider_bridge.h" - -#include "core/common/status.h" -#include "core/framework/error_code_helper.h" -#include "core/framework/session_options.h" -#include "core/providers/cuda/cuda_provider_options.h" -#include "core/providers/shared_library/provider_host_api.h" -#include "core/session/abi_devices.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/ep_factory_internal.h" - -namespace onnxruntime { -class ProviderBridgeEpFactory : public EpFactoryInternalImpl { - public: - ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) - : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), - ep_factory.GetVendor(&ep_factory), - ep_factory.GetVendorId(&ep_factory)), - ep_factory_{ep_factory}, - provider_library_{provider_library} { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* num_ep_devices) noexcept override { - ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, - max_ep_devices, num_ep_devices)); - - // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. - for (size_t i = 0; i < *num_ep_devices; ++i) { - auto* ep_device = ep_devices[i]; - if (ep_device) { - ep_device->ep_factory = &ep_factory; - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, - const OrtKeyValuePairs* const* ep_metadata_pairs, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - // get the provider specific options - auto ep_options = GetOptionsFromSessionOptions(session_options->value); - auto& provider = provider_library_.Get(); - - auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, - ep_options, *session_options, *session_logger, *ep); - - return ToOrtStatus(status); - } - - OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* allocator_options, - OrtAllocator** allocator) noexcept override { - return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); - } - - void ReleaseAllocator(OrtAllocator* allocator) noexcept override { - ep_factory_.ReleaseAllocator(&ep_factory_, allocator); - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); - } - - bool IsStreamAware() const noexcept override { - return ep_factory_.IsStreamAware(&ep_factory_); - } - - OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, - const OrtKeyValuePairs* stream_options, - OrtSyncStreamImpl** stream) noexcept override { - return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); - } - - OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP - ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP -}; - -Status EpLibraryProviderBridge::Load() { - std::lock_guard lock{mutex_}; - - if (!factories_.empty()) { - // already loaded - return Status::OK(); - } - - // if we have been unloaded we can't just be reloaded. - if (!ep_library_plugin_ || !provider_library_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "EpLibraryProviderBridge has been unloaded. " - "Please create a new instance using LoadPluginOrProviderBridge."); - } - - // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. - // use GetSupportedDevices from the library's factory. - // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. - // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can - // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. - for (const auto& factory : ep_library_plugin_->GetFactories()) { - auto factory_impl = std::make_unique(*factory, *provider_library_); - auto internal_factory = std::make_unique(std::move(factory_impl)); - - factory_ptrs_.push_back(internal_factory.get()); - internal_factory_ptrs_.push_back(internal_factory.get()); - factories_.push_back(std::move(internal_factory)); - } - - return Status::OK(); -} - -Status EpLibraryProviderBridge::Unload() { - std::lock_guard lock{mutex_}; - - internal_factory_ptrs_.clear(); - factory_ptrs_.clear(); - factories_.clear(); - - // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. - ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); - ep_library_plugin_ = nullptr; - - provider_library_->Unload(); - provider_library_ = nullptr; - - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 27f81b18be0c9..37f4fe7312bb4 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -38,8 +38,8 @@ #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" -#include "core/session/ep_api.h" -#include "core/session/ep_library_internal.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/IOBinding.h" diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc similarity index 99% rename from onnxruntime/core/session/ep_api.cc rename to onnxruntime/core/session/plugin_ep/ep_api.cc index 8fd1fc198374f..cae0b086af66c 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_api.h" +#include "core/session/plugin_ep/ep_api.h" #include #include diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h similarity index 100% rename from onnxruntime/core/session/ep_api.h rename to onnxruntime/core/session/plugin_ep/ep_api.h diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc new file mode 100644 index 0000000000000..7e6d0dd2ae5df --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_cpu.h" + +#include "core/framework/error_code_helper.h" +#include "core/graph/constants.h" +#include "core/providers/cpu/cpu_execution_provider.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* CpuEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + ORT_API_RETURN_IF_ERROR( + OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; +} + +OrtStatus* CpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "CPU EP factory currently only supports one device at a time."); + } + + CPUExecutionProviderInfo epi{session_options->value.enable_cpu_mem_arena}; + *ep = std::make_unique(epi); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h new file mode 100644 index 0000000000000..fba9bac976bb2 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class CpuEpFactory : public EpFactoryInternalImpl { + public: + CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc b/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc new file mode 100644 index 0000000000000..2f12ffa394537 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_DML) + +#include "core/session/plugin_ep/ep_factory_dml.h" + +#include "core/framework/error_code_helper.h" +#include "core/providers/dml/dml_provider_factory_creator.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* DmlEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + auto ep_options = std::make_unique(); + + // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is + // associated with a specific device. + // How would we know what options should not allow user overrides if set in OrtEpDevice? + int32_t device_id = 0; // If no device_id was found default to 0 + if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { + device_id = std::stoi(it->second); + } + + ep_options->Add("device_id", std::to_string(device_id)); + + auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, ep_options.get(), + &ep_devices[num_ep_devices]); + + if (device_memory_infos.size() < device_id + 1) { + device_memory_infos.resize(device_id + 1); + device_allocators.resize(device_id + 1); + } + + if (device_memory_infos[device_id] == nullptr) { + // Create memory info for the device if it doesn't already exist + device_memory_infos[device_id] = std::make_unique( + "DML", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, + narrow(device_id))); + } + + // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. + // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], + // device_memory_infos[device_id].get()); + + if (api_status != nullptr) { + return api_status; + } + + ++num_ep_devices; + } + } + + return nullptr; +} + +OrtStatus* DmlEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + *ep = nullptr; + + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "DML EP factory currently only supports one device at a time."); + } + + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, + ep_options); + + *ep = dml_ep_factory->CreateProvider(); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} + +/* +// TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That +// requires pulling lots of things out of the DML EP to get the D3D12 device and create a +// BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp +OrtStatus* DmlEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept { +} + +// TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. +OrtStatus* DmlEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { +} +*/ +} // namespace onnxruntime + +#endif // USE_DML diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_dml.h b/onnxruntime/core/session/plugin_ep/ep_factory_dml.h new file mode 100644 index 0000000000000..1cdd172901942 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_dml.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_DML) + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class DmlEpFactory : public EpFactoryInternalImpl { + public: + DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; + + std::vector> device_memory_infos; // memory info for each device + std::vector> device_allocators; // allocators for each device +}; + +} // namespace onnxruntime + +#endif // USE_DML diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc similarity index 58% rename from onnxruntime/core/session/ep_factory_internal.cc rename to onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 9804aa6a5c42d..3610b0f797a46 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -1,18 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" -#include "core/session/ep_api_utils.h" +#include "core/session/plugin_ep/forward_to_factory_impl.h" #include "core/session/ort_apis.h" -#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { - -using Forward = ForwardToFactory; +using Forward = ForwardToFactoryImpl; EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl) : impl_{std::move(impl)} { @@ -32,38 +30,6 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; } -const char* EpFactoryInternal::GetVersion() const noexcept { - return ORT_VERSION; -} - -OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t /*num_devices*/, - const OrtSessionOptions* /*api_session_options*/, - const OrtLogger* /*api_logger*/, - OrtEp** /*ep*/) { - ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); -} - -// Prior to addition to SessionOptions the EP options do not have a prefix. -// They are prefixed with 'ep..' when added to SessionOptions. -// -// Use this function to get the options without the prefix from SessionOptions. -// Required by the option parsing for multiple existing EPs. -ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { - const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); - ProviderOptions ep_options; - - for (const auto& [key, value] : session_options.config_options.configurations) { - if (key.find(option_prefix) == 0) { - // remove the prefix and add - ep_options[key.substr(option_prefix.length())] = value; - } - } - - return ep_options; -} - InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, gsl::span ep_devices) : ep_factory_{ep_factory} { diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h similarity index 50% rename from onnxruntime/core/session/ep_factory_internal.h rename to onnxruntime/core/session/plugin_ep/ep_factory_internal.h index ae450efa394e8..0e34fef0ff74c 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -7,85 +7,16 @@ #include #include "core/common/common.h" -#include "core/framework/execution_provider.h" #include "core/providers/providers.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { -class EpFactoryInternal; -class EpLibraryInternal; struct SessionOptions; - -// class with virtual methods that are implemented for each internal EP -class EpFactoryInternalImpl { - public: - EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) - : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { - } - - const char* GetName() const noexcept { return ep_name_.c_str(); } - const char* GetVendor() const noexcept { return vendor_.c_str(); } - uint32_t GetVendorId() const noexcept { return vendor_id_; } - const char* GetVersion() const noexcept; - - virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices) noexcept = 0; - - virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, - _Out_ std::unique_ptr* ep) = 0; - - virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, - _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, - _Outptr_ OrtAllocator** allocator) noexcept { - // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned - // so this should never be called - *allocator = nullptr; - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); - } - - virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { - // we don't create any allocators so we don't need to release any - } - - virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { - *data_transfer = nullptr; - return nullptr; // Default implementation does nothing - } - - virtual bool IsStreamAware() const { - return false; - } - - virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, - _In_opt_ const OrtKeyValuePairs* /*stream_options*/, - _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { - *stream = nullptr; - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, - "CreateSyncStreamForDevice is not implemented for this EP factory."); - } - - // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* ep); - - virtual ~EpFactoryInternalImpl() = default; - - protected: - ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; - - private: - const std::string ep_name_; // EP name library was registered with - const std::string vendor_; // EP vendor name - const uint32_t vendor_id_; // EP vendor ID -}; +class EpFactoryInternalImpl; // this class can't have any virtual methods as they break using it as an OrtEpFactory* in OrtEpDevice. class EpFactoryInternal : public OrtEpFactory { @@ -95,7 +26,7 @@ class EpFactoryInternal : public OrtEpFactory { const char* GetName() const noexcept { return impl_->GetName(); } const char* GetVendor() const noexcept { return impl_->GetVendor(); } uint32_t GetVendorId() const noexcept { return impl_->GetVendorId(); } - const char* GetVersion() const noexcept; + const char* GetVersion() const noexcept { return ORT_VERSION; } OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, @@ -106,11 +37,14 @@ class EpFactoryInternal : public OrtEpFactory { } // we don't implement this. CreateIExecutionProvider should be used. - OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Out_ OrtEp** ep); + OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) { + ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); + } // same input args as CreateEp in case we need something from device or ep_metadata_pairs in the future. OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -132,24 +66,23 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->ReleaseAllocator(allocator); } - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { return impl_->CreateDataTransfer(data_transfer); } - bool IsStreamAware() const { + bool IsStreamAware() const noexcept { return impl_->IsStreamAware(); } OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* memory_device, _In_opt_ const OrtKeyValuePairs* stream_options, - _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); } // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* /*ep*/) { + void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one - ORT_THROW("Internal error. No ReleaseEp call is required for EpFactoryInternal."); } private: diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc new file mode 100644 index 0000000000000..e61804d842859 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_factory_internal.h" + +namespace onnxruntime { + +// Prior to addition to SessionOptions the EP options do not have a prefix. +// They are prefixed with 'ep..' when added to SessionOptions. +// +// Use this function to get the options without the prefix from SessionOptions. +// Required by the option parsing for multiple existing EPs. +ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { + const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); + ProviderOptions ep_options; + + for (const auto& [key, value] : session_options.config_options.configurations) { + if (key.find(option_prefix) == 0) { + // remove the prefix and add + ep_options[key.substr(option_prefix.length())] = value; + } + } + + return ep_options; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h new file mode 100644 index 0000000000000..bd0b76b21511f --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/execution_provider.h" +#include "core/framework/provider_options.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { +class EpFactoryInternal; +struct SessionOptions; + +// class with virtual methods that are implemented for each internal EP +class EpFactoryInternalImpl { + public: + EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) + : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { + } + + const char* GetName() const noexcept { return ep_name_.c_str(); } + const char* GetVendor() const noexcept { return vendor_.c_str(); } + uint32_t GetVendorId() const noexcept { return vendor_id_; } + const char* GetVersion() const noexcept; + + virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices) noexcept = 0; + + virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, + _Out_ std::unique_ptr* ep) = 0; + + virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, + _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, + _Outptr_ OrtAllocator** allocator) noexcept { + // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned + // so this should never be called + *allocator = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); + } + + virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { + // we don't create any allocators so we don't need to release any + } + + virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; + return nullptr; // Default implementation does nothing + } + + virtual bool IsStreamAware() const noexcept { + return false; + } + + virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, + _In_opt_ const OrtKeyValuePairs* /*stream_options*/, + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { + *stream = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "CreateSyncStreamForDevice is not implemented for this EP factory."); + } + + // Function ORT calls to release an EP instance. + void ReleaseEp(OrtEp* ep); + + virtual ~EpFactoryInternalImpl() = default; + + protected: + ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; + + private: + const std::string ep_name_; // EP name library was registered with + const std::string vendor_; // EP vendor name + const uint32_t vendor_id_; // EP vendor ID +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc new file mode 100644 index 0000000000000..d6e51a44c1c69 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_provider_bridge.h" + +#include "core/providers/shared_library/provider_host_api.h" + +namespace onnxruntime { +OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept { + ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, + max_ep_devices, num_ep_devices)); + + // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. + for (size_t i = 0; i < *num_ep_devices; ++i) { + auto* ep_device = ep_devices[i]; + if (ep_device) { + ep_device->ep_factory = &ep_factory; + } + } + + return nullptr; +} + +OrtStatus* ProviderBridgeEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + // get the provider specific options + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto& provider = provider_library_.Get(); + + auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, + ep_options, *session_options, *session_logger, *ep); + + return ToOrtStatus(status); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h new file mode 100644 index 0000000000000..437af62dc2c0c --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/provider_bridge_library.h" + +namespace onnxruntime { +class ProviderBridgeEpFactory : public EpFactoryInternalImpl { + public: + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) + : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), + ep_factory.GetVendor(&ep_factory), + ep_factory.GetVendorId(&ep_factory)), + ep_factory_{ep_factory}, + provider_library_{provider_library} { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; + + OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); + } + + void ReleaseAllocator(OrtAllocator* allocator) noexcept override { + ep_factory_.ReleaseAllocator(&ep_factory_, allocator); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept override { + return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); + } + + bool IsStreamAware() const noexcept override { + return ep_factory_.IsStreamAware(&ep_factory_); + } + + OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept override { + return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); + } + + OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP + ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc new file mode 100644 index 0000000000000..0f955e0bab248 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_WEBGPU) +#include "core/session/plugin_ep/ep_factory_webgpu.h" + +#include "core/framework/error_code_helper.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* WebGpuEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // TODO: any metadata or options to add? + ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; +} + +OrtStatus* WebGpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + *ep = nullptr; + + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "WebGPU EP factory currently only supports one device at a time."); + } + + auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(session_options->value.config_options); + *ep = webgpu_ep_factory->CreateProvider(); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} + +/* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of + an InferenceSession. +OrtStatus* WebGpuEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + *allocator = device_allocators[memory_info->device.Id()].get(); +} + +OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. + *data_transfer = nullptr; + return nullptr; +} +*/ +} // namespace onnxruntime + +#endif // USE_WEBGPU diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h new file mode 100644 index 0000000000000..06ecfa744bbda --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_WEBGPU) +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class WebGpuEpFactory : public EpFactoryInternalImpl { + public: + WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; +}; +} // namespace onnxruntime + +#endif // USE_WEBGPU diff --git a/onnxruntime/core/session/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h similarity index 100% rename from onnxruntime/core/session/ep_library.h rename to onnxruntime/core/session/plugin_ep/ep_library.h diff --git a/onnxruntime/core/session/plugin_ep/ep_library_internal.cc b/onnxruntime/core/session/plugin_ep/ep_library_internal.cc new file mode 100644 index 0000000000000..d4015e0bbd366 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_library_internal.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_library_internal.h" +#include "core/session/plugin_ep/ep_factory_cpu.h" +#include "core/session/plugin_ep/ep_factory_dml.h" +#include "core/session/plugin_ep/ep_factory_webgpu.h" + +namespace onnxruntime { + +std::unique_ptr EpLibraryInternal::CreateCpuEp() { + auto cpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} + +#if defined(USE_DML) + +std::unique_ptr EpLibraryInternal::CreateDmlEp() { + auto dml_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(dml_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} +#endif + +#if defined(USE_WEBGPU) +std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { + auto webgpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} +#endif + +std::vector> EpLibraryInternal::CreateInternalEps() { + std::vector> internal_eps; + internal_eps.reserve(4); + + // CPU EP + internal_eps.push_back(CreateCpuEp()); + +#if defined(USE_WEBGPU) + internal_eps.push_back(CreateWebGpuEp()); +#endif + +#if defined(USE_DML) + internal_eps.push_back(CreateDmlEp()); +#endif + + return internal_eps; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_internal.h b/onnxruntime/core/session/plugin_ep/ep_library_internal.h similarity index 94% rename from onnxruntime/core/session/ep_library_internal.h rename to onnxruntime/core/session/plugin_ep/ep_library_internal.h index ab529edc2507f..1587f01360e26 100644 --- a/onnxruntime/core/session/ep_library_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_internal.h @@ -4,8 +4,8 @@ #pragma once #include "core/common/common.h" -#include "core/session/ep_library.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/provider_bridge_library.h" diff --git a/onnxruntime/core/session/ep_library_plugin.cc b/onnxruntime/core/session/plugin_ep/ep_library_plugin.cc similarity index 98% rename from onnxruntime/core/session/ep_library_plugin.cc rename to onnxruntime/core/session/plugin_ep/ep_library_plugin.cc index 32ddd8a765b4c..ebfa364f4f1df 100644 --- a/onnxruntime/core/session/ep_library_plugin.cc +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_plugin.h" #include "core/common/logging/logging.h" #include "core/framework/error_code_helper.h" diff --git a/onnxruntime/core/session/ep_library_plugin.h b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h similarity index 96% rename from onnxruntime/core/session/ep_library_plugin.h rename to onnxruntime/core/session/plugin_ep/ep_library_plugin.h index e2b02ccc654da..e044e91b61e37 100644 --- a/onnxruntime/core/session/ep_library_plugin.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h @@ -6,7 +6,7 @@ #include #include -#include "core/session/ep_library.h" +#include "core/session/plugin_ep/ep_library.h" namespace onnxruntime { /// diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc new file mode 100644 index 0000000000000..06cf54aea4071 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_library_provider_bridge.h" + +#include "core/session/plugin_ep/ep_factory_provider_bridge.h" + +namespace onnxruntime { +Status EpLibraryProviderBridge::Load() { + std::lock_guard lock{mutex_}; + + if (!factories_.empty()) { + // already loaded + return Status::OK(); + } + + // if we have been unloaded we can't just be reloaded. + if (!ep_library_plugin_ || !provider_library_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EpLibraryProviderBridge has been unloaded. " + "Please create a new instance using LoadPluginOrProviderBridge."); + } + + // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. + // use GetSupportedDevices from the library's factory. + // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. + // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can + // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. + for (const auto& factory : ep_library_plugin_->GetFactories()) { + auto factory_impl = std::make_unique(*factory, *provider_library_); + auto internal_factory = std::make_unique(std::move(factory_impl)); + + factory_ptrs_.push_back(internal_factory.get()); + internal_factory_ptrs_.push_back(internal_factory.get()); + factories_.push_back(std::move(internal_factory)); + } + + return Status::OK(); +} + +Status EpLibraryProviderBridge::Unload() { + std::lock_guard lock{mutex_}; + + internal_factory_ptrs_.clear(); + factory_ptrs_.clear(); + factories_.clear(); + + // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. + ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); + ep_library_plugin_ = nullptr; + + provider_library_->Unload(); + provider_library_ = nullptr; + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h similarity index 95% rename from onnxruntime/core/session/ep_library_provider_bridge.h rename to onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h index 0717ccd957de7..c7e8ebefc3785 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h @@ -5,8 +5,8 @@ #include #include -#include "core/session/ep_library.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_bridge_library.h" namespace onnxruntime { diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc similarity index 99% rename from onnxruntime/core/session/ep_plugin_provider_interfaces.cc rename to onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index c7d7ea2e8a4ec..2aac1e1c21cc7 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include #include diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h similarity index 100% rename from onnxruntime/core/session/ep_plugin_provider_interfaces.h rename to onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h similarity index 99% rename from onnxruntime/core/session/ep_api_utils.h rename to onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 77528565eced7..67b22779395ec 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -7,7 +7,7 @@ namespace onnxruntime { // helper to forward a call from the C API to an instance of the factory implementation. // used by EpFactoryInternal and EpFactoryProviderBridge. template -struct ForwardToFactory { +struct ForwardToFactoryImpl { static const char* ORT_API_CALL GetFactoryName(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetName(); } diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 211bf8b2d15a4..6bcbda0f13b92 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -11,8 +11,8 @@ #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_logger.h" -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_c_api.h" diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 69039beb49363..f90ace95d6e58 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -19,10 +19,10 @@ #include "core/session/ort_env.h" #if !defined(ORT_MINIMAL_BUILD) -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_plugin_provider_interfaces.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/model_compilation_options.h" #include "core/session/provider_policy_context.h" #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index ec4d8c6330c8d..acf0681cf8752 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -46,7 +46,7 @@ #if !defined(ORT_MINIMAL_BUILD) #include "core/session/abi_devices.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_policy_context.h" #include "core/session/utils.h" #endif diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 4c5dcd2bd7580..35f7d06fb0912 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "gsl/gsl" #include "gtest/gtest.h" From 2e0f717b12bb57de4de78861a324815b0559c24c Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 28 Jul 2025 09:17:14 -0700 Subject: [PATCH 27/33] Remove the python installation steps from win-qnn-arm64-ci-pipeline.yml (#25552) ### Description Yesterday I updated the machine images. Now they already have python preinstalled. We don't need to do this anymore. Remove the steps to avoid conflicts. Also, refactor the yaml file a little bit. Refactors templates to use parameterized Python versions instead of matrix strategy. --- .../stages/py-cpu-packaging-stage.yml | 17 +++++++- .../templates/py-win-arm64-qnn.yml | 39 +++++-------------- .../templates/py-win-arm64ec-qnn.yml | 5 --- .../win-qnn-arm64-ci-pipeline.yml | 10 ----- 4 files changed, 24 insertions(+), 47 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index c1b83c5e579dc..f4a62208059c8 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -316,7 +316,21 @@ stages: MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true + PYTHON_VERSION: '3.11' + + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + PYTHON_VERSION: '3.12' + + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + PYTHON_VERSION: '3.13' - ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: - stage: Python_Packaging_Windows_arm64ec_QNN @@ -327,7 +341,6 @@ stages: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: - stage: Python_Packaging_Windows_x64_QNN diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 761c551e9f4d9..3c2ef4741f049 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -4,6 +4,10 @@ parameters: type: string default: 'onnxruntime-qnn-windows-vs-2022-arm64' +- name: PYTHON_VERSION + type: string + default: '3.11' + - name: QNN_SDK displayName: QNN SDK Version type: string @@ -19,13 +23,8 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: -- job: Win_py_arm64_qnn_Wheels +- job: Win_py_arm64_qnn_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 210 workspace: clean: all @@ -48,41 +47,21 @@ jobs: outputs: - output: pipelineArtifact targetPath: $(Build.ArtifactStagingDirectory) - artifactName: onnxruntime_qnn_arm64_$(PythonVersion) - - strategy: - matrix: - Python311_arm64: - PythonVersion: '3.11.0' - LocalPythonDir: 'C:\Python\Python311' - Python312_arm64: - PythonVersion: '3.12.6' - LocalPythonDir: 'C:\Python\Python312' - Python313_arm64: - PythonVersion: '3.13.2' - LocalPythonDir: 'C:\Python\Python313' + artifactName: onnxruntime_qnn_arm64_${{ parameters.PYTHON_VERSION }} + variables: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - checkout: self clean: true - submodules: recursive + submodules: none - template: telemetry-steps.yml - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - displayName: Copy python $(PythonVersion) version to agent tools directory - - task: UsePythonVersion@0 inputs: - versionSpec: $(PythonVersion) + versionSpec: ${{ parameters.PYTHON_VERSION }} addToPath: true architecture: 'arm64' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 74cae38393ea6..c8d37457a1034 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -19,11 +19,6 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 7ebf5394e4530..66d1cd1687d99 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -61,16 +61,6 @@ jobs: # because the python bindings also use the USE__PROVIDER_INTERFACE preprocessor macros. ExtraQnnBuildArgs: '--enable_generic_interface --build_wheel' steps: - - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 - XCOPY /s /y /h /e /c /q "C:\Python\Python311\*.*" $(Agent.ToolsDirectory)\Python\3.11.0\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\3.11.0\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\3.11.0 - DIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 - displayName: Copy python 3.11.0 version to agent tools directory - - task: UsePythonVersion@0 inputs: versionSpec: '3.x' From 413d38d0b0a72ca87fbcc3b6d7a26bef11910c4b Mon Sep 17 00:00:00 2001 From: qti-yuduo Date: Mon, 28 Jul 2025 12:59:36 -0700 Subject: [PATCH 28/33] [QNN EP] Support more Einsum equation: bhwc,hkc->bhwk (#25518) Additional equation support for QNN EP on einsum op. --- .../builder/opbuilder/einsum_op_builder.cc | 61 ++++++++++++++++--- .../test/providers/qnn/einsum_op_test.cc | 52 ++++++++++++++++ 2 files changed, 104 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc index 9db0b5202dcd4..7e17addf2f577 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc @@ -45,13 +45,7 @@ std::optional ParseEquation(std::string_view equation_string) { if (term_1.empty() || term_2.empty()) { return std::nullopt; } - if (term_1.size() < 2) { - return std::nullopt; - } - if (term_1.size() != term_2.size()) { - return std::nullopt; - } - if (term_1.size() != result.size()) { + if (term_1.size() < 2 || term_2.size() < 2 || result.size() < 2) { return std::nullopt; } if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::islower(c); })) { @@ -154,6 +148,50 @@ bool IsEquationMatMulTransposeAll(const Equation& equation) { return true; } +bool IsEquationMatMulBroadcastTransposeY(const Equation& equation) { + // E.g., bhwc,hkc->bhwk + const auto& [term_1, term_2, result] = equation; + const size_t term1_dims = term_1.size(); + if (term1_dims != 4) { + return false; + } + const size_t term2_dims = term_2.size(); + if (term2_dims != 3) { + return false; + } + const size_t result_dims = result.size(); + if (result_dims != 4) { + return false; + } + // Check matrix multiplication dimensions + char term_1_m = term_1[term1_dims - 2]; + char term_1_k = term_1[term1_dims - 1]; + char term_2_k = term_2[term2_dims - 1]; + char term_2_n = term_2[term2_dims - 2]; + char result_m = result[result_dims - 2]; + char result_n = result[result_dims - 1]; + if (term_1_m != result_m) { + return false; + } + if (term_1_k != term_2_k) { + return false; + } + if (term_2_n != result_n) { + return false; + } + // Check batch dimensions + if (term_1[0] != result[0]) { + return false; + } + if (term_1[1] != result[1]) { + return false; + } + if (term_2[0] != result[1]) { + return false; + } + return true; +} + /** * @brief Sets the parameter tensor names for a MatMul op. * @@ -317,6 +355,7 @@ Status EinsumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, } if (!IsEquationMatMul(parsed_equation.value()) && !IsEquationMatMulTransposeY(parsed_equation.value()) && + !IsEquationMatMulBroadcastTransposeY(parsed_equation.value()) && !IsEquationMatMulTransposeAll(parsed_equation.value())) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); } @@ -353,7 +392,8 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w /*logger=*/logger, /*do_op_validation=*/do_op_validation, /*qnn_op_type=*/QNN_OP_MAT_MUL)); - } else if (IsEquationMatMulTransposeY(parsed_equation.value())) { + } else if (IsEquationMatMulTransposeY(parsed_equation.value()) || + IsEquationMatMulBroadcastTransposeY(parsed_equation.value())) { std::vector param_tensor_names = SetMatMulParamTensorNames( &qnn_model_wrapper, node_unit, /*transpose_in0=*/false, /*transpose_in1=*/true); ORT_RETURN_IF_ERROR(ProcessOutputs(/*qnn_model_wrapper=*/qnn_model_wrapper, @@ -364,7 +404,10 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w /*do_op_validation=*/do_op_validation, /*qnn_op_type=*/QNN_OP_MAT_MUL)); } else if (IsEquationMatMulTransposeAll(parsed_equation.value())) { - ORT_RETURN_IF_ERROR(CreateMatMulTransposeAll(&qnn_model_wrapper, node_unit, std::move(input_names), do_op_validation)); + ORT_RETURN_IF_ERROR(CreateMatMulTransposeAll(/*qnn_model_wrapper=*/&qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_names=*/std::move(input_names), + /*do_op_validation=*/do_op_validation)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); } diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc index a2a0ce485bb35..d8dbbd799a427 100644 --- a/onnxruntime/test/providers/qnn/einsum_op_test.cc +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -189,6 +189,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) { /*tolerance=*/1e-4f); } +TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/1e-4f); +} + TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { const std::vector shape0{1, 7, 1, 7}; const std::vector shape1{1, 9, 1, 7}; @@ -273,6 +286,19 @@ TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll2) { /*tolerance=*/1e-2f); } +TEST_F(QnnHTPBackendTests, EinsumF16MatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/1e-2f); +} + // // QNN HTP QDQ // @@ -337,6 +363,18 @@ TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll2) { /*tolerance=*/QDQTolerance()); } +TEST_F(QnnHTPBackendTests, EinsumQdqMatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/QDQTolerance()); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #if defined(_M_ARM64) @@ -422,6 +460,20 @@ TEST_F(QnnGPUBackendTests, EinsumRank4MatMulTransposeAll2) { /*tolerance=*/1e-4f); } +// Numeric instability in GPU backend, see also MatMul tests. +TEST_F(QnnGPUBackendTests, DISABLED_EinsumMatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/1e-4f); +} + #endif // defined(_M_ARM64) GPU tests } // namespace test From f80697a79e2313f46b4cf580cbc40e72dc882fb4 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 28 Jul 2025 14:07:33 -0700 Subject: [PATCH 29/33] Cherry-pick round 1 (#25563) - **DynamicQuantizeMatMul - handle case where B zero point input is provided but not constant. (#25544)** - **Refactor plugin EP support (#25541)** - **Remove the python installation steps from win-qnn-arm64-ci-pipeline.yml (#25552)** From bac8af3631f8fc3830a77c7da6c2e17fc6d5c144 Mon Sep 17 00:00:00 2001 From: Fanchen Kong Date: Tue, 29 Jul 2025 05:14:59 +0800 Subject: [PATCH 30/33] Upgrade xnnpack to latest (#25275) ### Description This change is based on #25135. Upgrade xnnpack and several related third-party dependencies, including pthreadpool, cpuinfo, and kleidiai. This change also updates the xnnpack execution provider code to accommodate changes in the xnnpack api. Average pooling qu8 is removed as the corresponding microkernel seems no longer exist in xnnpack. --- cmake/deps.txt | 8 +-- .../external/onnxruntime_external_deps.cmake | 2 +- cmake/external/xnnpack.cmake | 2 +- .../xnnpack/AddEmscriptenAndIosSupport.patch | 31 +++++++---- cmake/vcpkg-ports/cpuinfo/portfile.cmake | 4 +- .../pthreadpool/fix-cmakelists.patch | 15 +++--- cmake/vcpkg-ports/pthreadpool/portfile.cmake | 4 +- cmake/vcpkg-ports/xnnpack/fix-build.patch | 28 +++++----- cmake/vcpkg-ports/xnnpack/portfile.cmake | 4 +- cmake/vcpkg-ports/xnnpack/vcpkg.json | 2 +- .../core/providers/xnnpack/math/gemm.cc | 5 +- .../core/providers/xnnpack/math/matmul.cc | 4 -- .../core/providers/xnnpack/nn/average_pool.cc | 51 +++---------------- onnxruntime/core/providers/xnnpack/nn/conv.cc | 3 +- .../core/providers/xnnpack/nn/conv_base.cc | 13 +++-- .../core/providers/xnnpack/tensor/resize.cc | 36 +++++-------- .../core/providers/xnnpack/xnnpack_kernel.h | 7 --- 17 files changed, 84 insertions(+), 135 deletions(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index 01e5c809640f9..ed1de06f33dcb 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -27,8 +27,8 @@ fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177 googletest;https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip;f638fa0e724760e2ba07ff8cfba32cd644e1ce28 -#xnnpack 2024.09.04 -googlexnnpack;https://github.com/google/XNNPACK/archive/fe98e0b93565382648129271381c14d6205255e3.zip;14f61dcf17cec2cde34ba2dcf61d6f24bf6059f3 +#xnnpack 2025.06.22 +googlexnnpack;https://github.com/google/XNNPACK/archive/3cf85e705098622d59056dcb8f5f963ea7bb0a00.zip;6f6bbba627241f89463ca845febaf063982b34fe json;https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip;5e88795165cc8590138d1f47ce94ee567b85b4d6 microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 @@ -45,9 +45,9 @@ protoc_linux_x86;https://github.com/protocolbuffers/protobuf/releases/download/v protoc_linux_aarch64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-linux-aarch_64.zip;df9d45470b0b8cf939dd2f0ec6b88e9cafc4d617 protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-osx-universal_binary.zip;23710c3d1c2036d8d65a6a22234372fa2d7af9ef psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013 -pthreadpool;https://github.com/google/pthreadpool/archive/4e80ca24521aa0fb3a746f9ea9c3eaa20e9afbb0.zip;bd4ea65c8292801e9555b527a0ecbb2e0092c917 +pthreadpool;https://github.com/google/pthreadpool/archive/dcc9f28589066af0dbd4555579281230abbf74dd.zip;533a77943203ef15ca608bcd9dbe2c94da7451d2 pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f780292da9db273c8ef06ccf5fd4b623624143e9 -pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/8a1772a0c5c447df2d18edf33ec4603a8c9c04a6.zip;85bf8a60dae026b99b6ccd78606c85ed83bfb2cd +pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/de0ce7c7251372892e53ce9bc891750d2c9a4fd8.zip;c45b8d3619b9bccbd26dc5f657959aee38b18b7a re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index f76ad642447ba..0d1f47f195ba5 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -570,7 +570,7 @@ if (onnxruntime_USE_XNNPACK) ENDIF() ADD_LIBRARY(xnnpack STATIC IMPORTED) find_library(xnnpack_LIBRARY NAMES XNNPACK) - find_library(microkernels_prod_LIBRARY NAMES microkernels-prod) + find_library(microkernels_prod_LIBRARY NAMES xnnpack-microkernels-prod) find_package(unofficial-pthreadpool CONFIG REQUIRED) target_include_directories(xnnpack INTERFACE "${XNNPACK_HDR}") diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index d0ab770053be1..c994e7e15aac4 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -90,7 +90,7 @@ onnxruntime_fetchcontent_makeavailable(googlexnnpack) set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR}) set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include) -set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK microkernels-prod pthreadpool) +set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK xnnpack-microkernels-prod pthreadpool) if(ORT_TARGET_PROCESSOR MATCHES "^arm64.*" AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC") list(APPEND onnxruntime_EXTERNAL_LIBRARIES_XNNPACK kleidiai) endif() diff --git a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch index c9cb4bcad9e20..ea0bb61274f84 100644 --- a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch +++ b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f0b3410ae..1e3cb8178 100644 +index 94bcad92e3..be7dfe95fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -337,7 +337,7 @@ ENDIF() +@@ -360,7 +360,7 @@ ENDIF() # ---[ Build flags IF(NOT CMAKE_SYSTEM_NAME) MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined") @@ -11,21 +11,30 @@ index f0b3410ae..1e3cb8178 100644 MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"") ENDIF() IF(CMAKE_SYSTEM_NAME MATCHES "Windows") -@@ -848,7 +848,12 @@ IF(XNNPACK_BUILD_LIBRARY) - TARGET_LINK_LIBRARIES(operator-utils PRIVATE xnnpack-base logging) - TARGET_LINK_LIBRARIES(reference-ukernels PRIVATE xnnpack-base) - TARGET_LINK_LIBRARIES(subgraph PRIVATE xnnpack-base allocator logging memory mutex operators operator-run datatype) -- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base allocator cache hardware-config indirection memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph datatype reference-ukernels) +@@ -903,10 +903,18 @@ IF(XNNPACK_BUILD_LIBRARY) + TARGET_LINK_LIBRARIES(xnnpack-operator-utils PRIVATE xnnpack-base xnnpack-logging) + TARGET_LINK_LIBRARIES(xnnpack-reference-ukernels PRIVATE xnnpack-base xnnpack-datatype) + TARGET_LINK_LIBRARIES(xnnpack-subgraph PRIVATE xnnpack-base xnnpack-allocator xnnpack-logging xnnpack-memory xnnpack-mutex xnnpack-operators xnnpack-operator-run xnnpack-datatype) +- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache +- xnnpack-hardware-config xnnpack-indirection xnnpack-memory xnnpack-microkernel-utils xnnpack-microparams-init +- xnnpack-mutex xnnpack-normalization xnnpack-operators xnnpack-operator-run xnnpack-operator-utils xnnpack-pack-lh xnnpack-packing +- xnnpack-microkernels-prod xnnpack-subgraph xnnpack-datatype xnnpack-reference-ukernels) + IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake -+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base allocator cache hardware-config indirection memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing subgraph datatype reference-ukernels) ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache ++ xnnpack-hardware-config xnnpack-indirection xnnpack-memory xnnpack-microkernel-utils xnnpack-microparams-init ++ xnnpack-mutex xnnpack-normalization xnnpack-operators xnnpack-operator-run xnnpack-operator-utils xnnpack-pack-lh xnnpack-packing ++ xnnpack-subgraph xnnpack-datatype xnnpack-reference-ukernels) + ELSE() -+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base allocator cache hardware-config indirection memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph datatype reference-ukernels) ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache ++ xnnpack-hardware-config xnnpack-indirection xnnpack-memory xnnpack-microkernel-utils xnnpack-microparams-init ++ xnnpack-mutex xnnpack-normalization xnnpack-operators xnnpack-operator-run xnnpack-operator-utils xnnpack-pack-lh xnnpack-packing ++ xnnpack-microkernels-prod xnnpack-subgraph xnnpack-datatype xnnpack-reference-ukernels) + ENDIF() - TARGET_LINK_LIBRARIES(XNNPACK PUBLIC pthreadpool logging) + TARGET_LINK_LIBRARIES(XNNPACK PUBLIC pthreadpool xnnpack-logging) SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES) ENDIF() -@@ -857,7 +862,8 @@ IF(NOT MSVC) +@@ -915,7 +923,8 @@ IF(NOT MSVC) ENDIF() IF(XNNPACK_TARGET_PROCESSOR STREQUAL "arm") SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index e61308bf643b4..6722f10a72857 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -6,8 +6,8 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO pytorch/cpuinfo - REF 8a1772a0c5c447df2d18edf33ec4603a8c9c04a6 - SHA512 b94ccbfa886221d6bb16513d074675af0a72928a9dd9485dcacdc1124a8a60aacbbe91913a1579e766dfb024f0be1d52eeead40342004ff0238a8b94a095ed08 + REF de0ce7c7251372892e53ce9bc891750d2c9a4fd8 + SHA512 0fde9210b700d2648d37c8deeb0d5c0d007d8ca5689578dd3bce4c460886b20d7649f0194d2ea06b02238fe9d4f06193599ec3ab5cafb19f1f860b00404264fa HEAD_REF master ) diff --git a/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch b/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch index 97fd1ac7a2bb1..cf7df0ea22980 100644 --- a/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch +++ b/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f06aada..3c6c6e2 100644 +index efff8cc..1a0f7e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -31,8 +31,6 @@ IF(CCACHE_BINARY) +@@ -41,8 +41,6 @@ IF(CMAKE_C_COMPILER_ID STREQUAL "MSVC") ENDIF() # ---[ Options. @@ -11,7 +11,7 @@ index f06aada..3c6c6e2 100644 OPTION(PTHREADPOOL_ALLOW_DEPRECATED_API "Enable deprecated API functions" ON) SET(PTHREADPOOL_SYNC_PRIMITIVE "default" CACHE STRING "Synchronization primitive (condvar, futex, gcd, event, or default) for worker threads") SET_PROPERTY(CACHE PTHREADPOOL_SYNC_PRIMITIVE PROPERTY STRINGS default condvar futex gcd event) -@@ -41,7 +39,7 @@ IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86(_64)?)$") +@@ -51,7 +49,7 @@ IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86(_64)?)$") ELSE() OPTION(PTHREADPOOL_ENABLE_FASTPATH "Enable fast path using atomic decrement instead of atomic compare-and-swap" OFF) ENDIF() @@ -20,8 +20,8 @@ index f06aada..3c6c6e2 100644 OPTION(PTHREADPOOL_BUILD_TESTS "Build pthreadpool unit tests" ON) OPTION(PTHREADPOOL_BUILD_BENCHMARKS "Build pthreadpool micro-benchmarks" ON) ELSE() -@@ -67,7 +65,8 @@ MACRO(PTHREADPOOL_TARGET_ENABLE_CXX11 target) - ENDMACRO() +@@ -71,7 +69,8 @@ IF(PTHREADPOOL_BUILD_TESTS) + ENDIF() # ---[ Download deps -IF(NOT DEFINED FXDIV_SOURCE_DIR) @@ -30,7 +30,7 @@ index f06aada..3c6c6e2 100644 MESSAGE(STATUS "Downloading FXdiv to ${CMAKE_BINARY_DIR}/FXdiv-source (define FXDIV_SOURCE_DIR to avoid it)") CONFIGURE_FILE(cmake/DownloadFXdiv.cmake "${CMAKE_BINARY_DIR}/FXdiv-download/CMakeLists.txt") EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . -@@ -118,21 +117,13 @@ ELSE() +@@ -122,21 +121,13 @@ ELSE() ENDIF() ADD_LIBRARY(pthreadpool_interface INTERFACE) @@ -54,7 +54,7 @@ index f06aada..3c6c6e2 100644 IF(PTHREADPOOL_SYNC_PRIMITIVE STREQUAL "condvar") TARGET_COMPILE_DEFINITIONS(pthreadpool PRIVATE PTHREADPOOL_USE_FUTEX=0) -@@ -181,18 +172,22 @@ IF(CMAKE_SYSTEM_NAME STREQUAL "Linux") +@@ -182,18 +173,22 @@ IF(CMAKE_SYSTEM_NAME STREQUAL "Linux") ENDIF() # ---[ Configure FXdiv @@ -80,3 +80,4 @@ index f06aada..3c6c6e2 100644 IF(PTHREADPOOL_BUILD_TESTS) # ---[ Build google test + diff --git a/cmake/vcpkg-ports/pthreadpool/portfile.cmake b/cmake/vcpkg-ports/pthreadpool/portfile.cmake index 9400e5e886639..449459feb33cc 100644 --- a/cmake/vcpkg-ports/pthreadpool/portfile.cmake +++ b/cmake/vcpkg-ports/pthreadpool/portfile.cmake @@ -5,8 +5,8 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO google/pthreadpool - REF 4e80ca24521aa0fb3a746f9ea9c3eaa20e9afbb0 - SHA512 776017cc5d2aa94337292f2f4fbd54d099ef29abf736ab8147f07f98f12b7654cbd2fe38d34646a479a519c261ac253bbaf19c6dcbb0ec4cc0859de70f7e6472 + REF dcc9f28589066af0dbd4555579281230abbf74dd + SHA512 61853fa8f6c3297d8760be3af1df3f2a00583c1e0e58bdd03cd9cb915e8660a4f2817b22e6463cf53f10de902a1c6204ec6054fcbeada72eeee9e44baeb97178 PATCHES fix-cmakelists.patch ) diff --git a/cmake/vcpkg-ports/xnnpack/fix-build.patch b/cmake/vcpkg-ports/xnnpack/fix-build.patch index b867377d2ff9e..3da8825e2b57d 100644 --- a/cmake/vcpkg-ports/xnnpack/fix-build.patch +++ b/cmake/vcpkg-ports/xnnpack/fix-build.patch @@ -1,21 +1,17 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f0b3410ae..ba54c3bfe 100644 +index 9f6fb5e256..4387298e59 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -1047,9 +1047,11 @@ ENDIF() - IF(XNNPACK_BUILD_ALL_MICROKERNELS) - TARGET_INCLUDE_DIRECTORIES(microkernels-all PRIVATE include src) +@@ -1125,7 +1125,7 @@ ELSE() ENDIF() -+ - TARGET_INCLUDE_DIRECTORIES(datatype PRIVATE include src) - TARGET_INCLUDE_DIRECTORIES(microkernels-prod PRIVATE include src) --TARGET_INCLUDE_DIRECTORIES(hardware-config PRIVATE include src ${CPUINFO_SOURCE_DIR}/include) -+TARGET_INCLUDE_DIRECTORIES(hardware-config PRIVATE include src) -+ - TARGET_INCLUDE_DIRECTORIES(indirection PRIVATE include src) - TARGET_INCLUDE_DIRECTORIES(microparams-init PRIVATE include src) - TARGET_INCLUDE_DIRECTORIES(normalization PRIVATE include src) -@@ -1104,14 +1106,9 @@ IF(NOT TARGET cpuinfo) + + INCLUDE_DIRECTORIES(.) +-TARGET_INCLUDE_DIRECTORIES(xnnpack-hardware-config PRIVATE include src ${CPUINFO_SOURCE_DIR}/include) ++TARGET_INCLUDE_DIRECTORIES(xnnpack-hardware-config PRIVATE include src) + IF(XNNPACK_BUILD_LIBRARY) + TARGET_INCLUDE_DIRECTORIES(XNNPACK PUBLIC include) + IF(WIN32) +@@ -1164,14 +1164,9 @@ IF(NOT TARGET cpuinfo) "${CPUINFO_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/cpuinfo") ELSE() @@ -33,7 +29,7 @@ index f0b3410ae..ba54c3bfe 100644 ENDIF() ENDIF() IF(XNNPACK_BUILD_LIBRARY) -@@ -1129,16 +1126,12 @@ IF(NOT TARGET pthreadpool) +@@ -1189,16 +1184,12 @@ IF(NOT TARGET pthreadpool) "${PTHREADPOOL_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/pthreadpool") ELSE() @@ -53,7 +49,7 @@ index f0b3410ae..ba54c3bfe 100644 ENDIF() ENDIF() TARGET_LINK_LIBRARIES(xnnpack-base INTERFACE pthreadpool) -@@ -1152,12 +1145,12 @@ IF(NOT TARGET fxdiv) +@@ -1212,12 +1203,12 @@ IF(NOT TARGET fxdiv) "${FXDIV_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/FXdiv") ELSE() diff --git a/cmake/vcpkg-ports/xnnpack/portfile.cmake b/cmake/vcpkg-ports/xnnpack/portfile.cmake index d63ad0fbd0cce..60b3566629e10 100644 --- a/cmake/vcpkg-ports/xnnpack/portfile.cmake +++ b/cmake/vcpkg-ports/xnnpack/portfile.cmake @@ -5,8 +5,8 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO google/XNNPACK - REF 953dcb96cc1b21b4b966952f8ee67a9e1f0d3e71 - SHA512 8c12930ef3b2f832962682d73c362518c014bb4e56d0c5cad2b8b63a03c91dccf6e6a3fd0eb91931fc5872c7df9773e76bf08553fc9c3cc22c94636c74815e94 + REF 3cf85e705098622d59056dcb8f5f963ea7bb0a00 + SHA512 af10afde80def08dc3b20a35bd38e84f9f749865ecc4bc9733b5d99d8a2f0f30c19c3f23472d65462a907b3a58226e3b254354a92a6baa31031824f68012a055 HEAD_REF master PATCHES fix-build.patch diff --git a/cmake/vcpkg-ports/xnnpack/vcpkg.json b/cmake/vcpkg-ports/xnnpack/vcpkg.json index e0d0600902f36..643b5c4abe166 100644 --- a/cmake/vcpkg-ports/xnnpack/vcpkg.json +++ b/cmake/vcpkg-ports/xnnpack/vcpkg.json @@ -1,6 +1,6 @@ { "name": "xnnpack", - "version-date": "2025-01-23", + "version-date": "2025-06-22", "description": "High-efficiency floating-point neural network inference operators for mobile, server, and Web", "homepage": "https://github.com/google/XNNPACK", "license": "BSD-3-Clause", diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.cc b/onnxruntime/core/providers/xnnpack/math/gemm.cc index a3ff3b585ae45..9b78e943122de 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.cc +++ b/onnxruntime/core/providers/xnnpack/math/gemm.cc @@ -139,7 +139,6 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, // flags - 1 - for no transpose - 0 for transpose uint32_t flags = trans_B_ == CblasTrans ? 0 : XNN_FLAG_TRANSPOSE_WEIGHTS; - auto code_cache = GetCodeCache(); auto weights_cache = GetWeightsCache(); xnn_status status = xnn_status::xnn_status_uninitialized; struct xnn_operator* p = nullptr; @@ -159,7 +158,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, bias_data, // const float* bias, foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (op_compute_type_ == OpComputeType::op_compute_type_fp16) { const MLFloat16* bias_data = nullptr; @@ -175,7 +174,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, bias_data, // const float* bias, foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.cc b/onnxruntime/core/providers/xnnpack/math/matmul.cc index 9083b9c22f64a..7870bcff298f2 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.cc +++ b/onnxruntime/core/providers/xnnpack/math/matmul.cc @@ -102,10 +102,8 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, } #ifdef XNN_CACHE_ENABLE - xnn_code_cache_t code_cache = GetCodeCache(); xnn_weights_cache_t weight_cache = GetWeightsCache(); #else - xnn_code_cache_t code_cache = nullptr; xnn_weights_cache_t weight_cache = nullptr; #endif @@ -122,7 +120,6 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, foutput_min, foutput_max, flags, - code_cache, weight_cache, &p); } else if (op_type_ == OpComputeType::op_compute_type_fp16) { @@ -136,7 +133,6 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, foutput_min, foutput_max, flags, - code_cache, weight_cache, &p); } diff --git a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc index f320274f65db3..963dfa5fa26d7 100644 --- a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc @@ -17,7 +17,6 @@ namespace { Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, const std::optional>& clip_min_max, struct xnn_operator*& p, - const OpQuantParam& quant_param, OpComputeType avgpool_type) { uint32_t input_padding_top = narrow(pool_attrs.pads[0]); uint32_t input_padding_left = narrow(pool_attrs.pads[1]); @@ -48,20 +47,6 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, pooling_height, pooling_width, stride_height, stride_width, foutput_min, foutput_max, flags, &p); - } else if (avgpool_type == OpComputeType::op_compute_type_qu8) { - const float output_scale = quant_param[1].first[0]; - const uint8_t output_zero_point = quant_param[1].second; - const uint8_t output_min = xnn_u8s8_quantize(foutput_min, output_scale, output_zero_point); - const uint8_t output_max = xnn_u8s8_quantize(foutput_max, output_scale, output_zero_point); - status = xnn_create_average_pooling2d_nhwc_qu8(input_padding_top, input_padding_right, - input_padding_bottom, input_padding_left, - pooling_height, pooling_width, - stride_height, stride_width, - quant_param[0].second, - quant_param[0].first[0], - quant_param[1].second, - quant_param[1].first[0], - output_min, output_max, flags, &p); } if (status != xnn_status_success) { @@ -72,9 +57,9 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, } bool IsQuantAvgPoolSupported(const NodeUnit& node_unit, const GraphViewer& graph) { - TensorQuantType x_input_type = GetTensorQuantType(node_unit, 0, false, graph); - TensorQuantType output_type = GetTensorQuantType(node_unit, 0, true, graph); - return (x_input_type == TensorTypeUint8 && output_type == TensorTypeUint8); + (void)node_unit; + (void)graph; + return false; } bool IsQuantizedAvgPool(QuantizedOpType quant_op_type) { @@ -209,14 +194,10 @@ AveragePool::AveragePool(const OpKernelInfo& info) avgpool_type_ = OpComputeType::op_compute_type_fp32; } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { avgpool_type_ = OpComputeType::op_compute_type_fp16; - } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { - // the order of input tensor, x,x_scale, x_zp, y_scale, y_zp - quant_param = ParseQuantParamForOp(info, input_dtype, 1); - avgpool_type_ = OpComputeType::op_compute_type_qu8; } struct xnn_operator* p; auto ret = CreateXnnpackKernel(pool_attrs_, clip_min_max_, p, - quant_param, avgpool_type_); + avgpool_type_); ORT_ENFORCE(ret.IsOK(), ret.ErrorMessage()); op0_.reset(p); } @@ -242,23 +223,12 @@ Status AveragePool::Compute(OpKernelContext* context) const { pthreadpool_t threadpool = GetThreadPool(); - // setup allocator/automated dellocate for workspace - size_t workspace_size = 0; - size_t workspace_alignment = 0; - xnn_allocator* allocator = GetStoredAllocator().second; - auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; - - std::unique_ptr workspace(nullptr, deallocator); - auto reshape_fn = xnn_reshape_average_pooling2d_nhwc_f32; if (avgpool_type_ == OpComputeType::op_compute_type_fp16) { reshape_fn = xnn_reshape_average_pooling2d_nhwc_f16; - } else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) { - reshape_fn = xnn_reshape_average_pooling2d_nhwc_qu8; } auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, - &workspace_size, &workspace_alignment, /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, threadpool); @@ -267,17 +237,12 @@ Status AveragePool::Compute(OpKernelContext* context) const { " returned ", status); } - workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); - if (avgpool_type_ == OpComputeType::op_compute_type_fp32) { - status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), workspace.get(), - X.Data(), Y.MutableData()); + status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), X.Data(), + Y.MutableData()); } else if (avgpool_type_ == OpComputeType::op_compute_type_fp16) { - status = xnn_setup_average_pooling2d_nhwc_f16(op0_.get(), workspace.get(), - X.Data(), Y.MutableData()); - } else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_average_pooling2d_nhwc_qu8(op0_.get(), workspace.get(), - X.Data(), Y.MutableData()); + status = xnn_setup_average_pooling2d_nhwc_f16(op0_.get(), X.Data(), + Y.MutableData()); } if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 4e6b308e28ae5..3ef0c1a7cf495 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -91,7 +91,6 @@ Status Conv::Compute(OpKernelContext* context) const { // setup allocator/automated dellocate for workspace size_t workspace_size = 0; - size_t workspace_alignment = 0; xnn_allocator* allocator = GetStoredAllocator().second; auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; std::unique_ptr workspace(nullptr, deallocator); @@ -108,7 +107,7 @@ Status Conv::Compute(OpKernelContext* context) const { } auto status = reshape_fn(op0_.get(), N, H, W, - &workspace_size, &workspace_alignment, + &workspace_size, /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, threadpool); if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index 44962c1796631..9742f397315a7 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -24,7 +24,6 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, const std::optional>& clip_min_max, const Tensor& Weight, const Tensor* Bias, XnnpackOperator& op_uptr, - xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, const OpQuantParam& quant_param, OpComputeType conv_type, @@ -79,7 +78,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, C, M, // input channel stride, output channel stride Weight.Data(), B_data, foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_fp16) { const auto* B_data = Bias ? Bias->Data() : nullptr; @@ -97,7 +96,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, Weight.Data(), B_data, // kernel, bias foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qs8) { const float output_scale = quant_param[2].first[0]; @@ -121,7 +120,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qs8_per_channel) { auto* B_data = Bias ? Bias->Data() : nullptr; @@ -145,7 +144,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qu8) { const auto* B_data = Bias ? Bias->Data() : nullptr; @@ -170,7 +169,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - code_cache, weights_cache, + weights_cache, &p); } @@ -521,7 +520,7 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) Status ConvBase::CreateKernel() { auto ret = CreateXnnpackKernel(convbase_attrs_ref_, C_, M_, kernel_shape_, clip_min_max_, packed_w_, B_, op0_, - GetCodeCache(), GetWeightsCache(), + GetWeightsCache(), quant_param_, conv_type_, is_transpose_); return ret; } diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index 0bb1194643743..32d91084d3507 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -228,13 +228,13 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf auto out_h = output_dims_[1]; auto out_w = output_dims_[2]; if (op_type_ == OpComputeType::op_compute_type_fp32) { - xstatus = xnn_create_resize_bilinear2d_nhwc_f32(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_fp32, out_h, out_w, flags, &p); } else if (op_type_ == OpComputeType::op_compute_type_fp16) { - xstatus = xnn_create_resize_bilinear2d_nhwc_f16(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_fp16, out_h, out_w, flags, &p); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - xstatus = xnn_create_resize_bilinear2d_nhwc_u8(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_quint8, out_h, out_w, flags, &p); } else { - xstatus = xnn_create_resize_bilinear2d_nhwc_s8(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_qint8, out_h, out_w, flags, &p); } ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " failed. Status:", @@ -257,22 +257,14 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, // setup allocator/automated dellocate for workspace size_t workspace_size = 0; - size_t workspace_alignment = 0; xnn_allocator* allocator = GetStoredAllocator().second; auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; std::unique_ptr workspace(nullptr, deallocator); - auto reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f32; - if (op_type_ == OpComputeType::op_compute_type_fp16) { - reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f16; - } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_u8; - } else if (op_type_ == OpComputeType::op_compute_type_qs8) { - reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_s8; - } + auto reshape_fn = xnn_reshape_resize_bilinear2d_nhwc; auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, - &workspace_size, &workspace_alignment, threadpool); + &workspace_size, threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " returned ", status); @@ -281,17 +273,17 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); if (op_type_ == OpComputeType::op_compute_type_fp32) { - status = xnn_setup_resize_bilinear2d_nhwc_f32(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else if (op_type_ == OpComputeType::op_compute_type_fp16) { - status = xnn_setup_resize_bilinear2d_nhwc_f16(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_resize_bilinear2d_nhwc_u8(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else { - status = xnn_setup_resize_bilinear2d_nhwc_s8(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h index 31512586be19d..1779f51046c59 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h @@ -24,8 +24,6 @@ class XnnpackKernel : public OpKernel { } // see comment below about enabling code cache - // xnn_code_cache_t GetCodeCache() { return caches_.auto_code_cache.get();} - xnn_code_cache_t GetCodeCache() { return nullptr; } xnn_weights_cache_t GetWeightsCache() { return caches_.auto_weights_cache.get(); } private: @@ -42,11 +40,6 @@ class XnnpackKernel : public OpKernel { if (enable) { #ifdef XNN_CACHE_ENABLE xnn_status status = xnn_status_success; -#if XNN_PLATFORM_JIT - // status = xnn_init_code_cache(&code_cache_); - // ORT_ENFORCE(status == xnn_status_success, "Failed to initialize XNNPACK code cache");) - // auto_code_cache.reset(&code_cache_); -#endif // status = xnn_init_weights_cache(&weights_cache_); xnn_weights_cache_t weights_cache_provider = nullptr; status = xnn_create_weights_cache(&weights_cache, 0); From 38e660c31bf3837838d922410f442ee94f8498e5 Mon Sep 17 00:00:00 2001 From: shaoboyan091 Date: Tue, 29 Jul 2025 06:08:44 +0800 Subject: [PATCH 31/33] Fix webgpu_pix_frame_generator by adding missing present mode attribute (#25553) This PR fixed webgpu_fix_frame_generator by adding present mode to the surface configuration. This new attribute is required by laste Dawn to rendering frames. --- onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc index 9b287b7b7df99..bc5c755160cb5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc @@ -36,6 +36,7 @@ WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu:: format = capabilities.formats[0]; wgpu::SurfaceConfiguration config; + config.presentMode = capabilities.presentModes[0]; config.device = device; config.format = format; config.width = kWidth; From a2b4546c4449e153de236d8070984e4bbc526216 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 28 Jul 2025 15:27:57 -0700 Subject: [PATCH 32/33] [CUDA] Support SwiGlu in MoE and qMoE (#25530) ### Description This implements the SwiGLU activation for MoE and qMoE. The activation is corresponding to https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/swiglu.py. Also update test_parity_moe.py to enable test for qMoE in CI pipelines. ### Motivation and Context This is naive implementation of the activation. Since the activation will reduce each row length to half, we cannot directly use epilogue. Current implementations need an extra buffer to run SwiGLU kernel. In the future, we might take a look at other alternatives that does not need extra buffer. --- docs/ContribOperators.md | 14 +- .../cuda/collective/sharded_moe.cc | 7 +- .../cuda/moe/ft_moe/moe_gemm_kernels.h | 1 + .../moe/ft_moe/moe_gemm_kernels_template.h | 5 +- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 165 ++++- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 11 +- onnxruntime/contrib_ops/cuda/moe/moe.cc | 7 +- onnxruntime/contrib_ops/cuda/moe/moe_base.h | 29 +- .../cuda/quantization/moe_quantization.cc | 3 +- .../core/graph/contrib_ops/contrib_defs.cc | 14 +- .../python/transformers/test_parity_moe.py | 659 ++++++++++++++---- 11 files changed, 725 insertions(+), 190 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f3dcde1abe37a..b59ff63ea8260 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3089,7 +3089,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation_type : string
-
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
k : int
Number of top experts to select from expert pool
normalize_routing_weights : int
@@ -3106,9 +3106,9 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T
-
3D input tensor with shape (num_experts, hidden_size, inter_size)
+
3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
-
2D optional input tensor with shape (num_experts, inter_size)
+
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T
3D input tensor with shape (num_experts, inter_size, hidden_size)
fc2_experts_bias (optional) : T
@@ -4523,7 +4523,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation_type : string
-
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
expert_weight_bits : int
Number of bits used in quantized weights. Default is 4 bits
k : int
@@ -4542,11 +4542,11 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T1
-
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
+
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).
fc1_scales : T
-
2D input tensor with shape (num_experts, inter_size)
+
2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
-
2D optional input tensor with shape (num_experts, inter_size)
+
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T1
3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
fc2_scales : T
diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 1a4a63de38790..e8cdc50ed4ca7 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -78,8 +78,11 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_, use_sparse_mixer_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index 36127054cfd5e..d5ad8161e100e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -52,6 +52,7 @@ enum class ActivationType { Gelu, GeGLU, ReGLU, SiGLU, + SwiGLU, Identity, InvalidType }; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index ef1f97b9e57a2..8b8f45e77ab9d 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -391,12 +391,10 @@ void MoeGemmRunner::dispatch_to_arch(const T* A, con dispatch_moe_gemm_to_cutlass( A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else if (sm_ >= 80 && sm_ < 90) { + } else if (sm_ >= 80) { // Hopper and Blackwell will fallback to use Ampere kernels. dispatch_moe_gemm_to_cutlass( A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else { - ORT_THROW("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); } } @@ -478,6 +476,7 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream) { + // Swiglu will use Identity to call this function so we not need to handle it here. switch (activation_type) { case ActivationType::Relu: run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index bfbe1d81b1c15..4268b79e1e4f8 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -44,6 +44,72 @@ namespace ort_fastertransformer { static constexpr int WARP_SIZE = 32; + +// SwiGLU with interleaved is like the following python code using PyTorch: +// dim = x.shape[-1] +// x = x.view(-1, dim // 2, 2) +// x_glu, x_linear = x[..., 0], x[..., 1] +// y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) +template +__global__ void swiglu_kernel_interleaved(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha) { + int const row = blockIdx.x; + if (row >= num_rows) { + return; + } + + T const* row_input = input + row * 2 * intermediate_size; + T* row_output = output + row * intermediate_size; + + for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) { + T x_glu = row_input[2 * i]; + T x_linear = row_input[2 * i + 1]; + + float sigmoid_arg = swiglu_alpha * static_cast(x_glu); + float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg)); + + float swish_out = static_cast(x_glu) * sigmoid_out; + row_output[i] = static_cast(swish_out * (static_cast(x_linear) + 1.f)); + } +} + +// Non interleaved version of SwiGLU kernel, which splits each row into two chunks of same size. +template +__global__ void swiglu_kernel_chunked(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha) { + int const row = blockIdx.x; + if (row >= num_rows) { + return; + } + + T const* row_input = input + row * 2 * intermediate_size; + T* row_output = output + row * intermediate_size; + + for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) { + T x_glu = row_input[i]; + T x_linear = row_input[i + intermediate_size]; + + float sigmoid_arg = swiglu_alpha * static_cast(x_glu); + float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg)); + + float swish_out = static_cast(x_glu) * sigmoid_out; + row_output[i] = static_cast(swish_out * (static_cast(x_linear) + 1.f)); + } +} + +template +void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha, cudaStream_t stream) { + if (num_rows == 0) { + return; + } + dim3 block(std::min(intermediate_size, 1024)); + dim3 grid(num_rows); + + if constexpr (interleaved) { + swiglu_kernel_interleaved<<>>(output, input, intermediate_size, num_rows, swiglu_alpha); + } else { + swiglu_kernel_chunked<<>>(output, input, intermediate_size, num_rows, swiglu_alpha); + } +} + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. @@ -666,9 +732,14 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i } template -CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer) - : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) { + : activation_type_(activation_type), + has_fc3_(has_fc3), + total_past_rows_(0), + total_covered_rows_(0), + normalize_routing_weights_(normalize_routing_weights), + use_sparse_mixer_(use_sparse_mixer) { moe_gemm_runner_.initialize(sm_version); } @@ -695,8 +766,16 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro total_ws_bytes += buf_size * sizeof(T); // permuted_data total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ total_ws_bytes += num_softmax_outs * sizeof(T); - const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); - const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); + + size_t bytes_for_fc1_result; + if (activation_type_ == ActivationType::SwiGLU) { + // Space for both fc1_result_ and act_result_. + bytes_for_fc1_result = (2 * interbuf_size + interbuf_size) * sizeof(T); + } else { + bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); + } + + const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)); sorter_.update_num_experts(static_cast(num_experts)); size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; @@ -705,7 +784,7 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro bytes_for_intermediate_and_sorting += remaining_bytes; } - total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + total_ws_bytes += bytes_for_intermediate_and_sorting; return total_ws_bytes; } @@ -725,16 +804,34 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); + char* current_ptr = reinterpret_cast(total_rows_before_expert_ + padded_experts); + + if (activation_type_ == ActivationType::SwiGLU) { + // fc1_result_ is used for GEMM1 output (2 * inter_size) + fc1_result_ = reinterpret_cast(current_ptr); + current_ptr += 2 * interbuf_size * sizeof(T); + + // act_result_ is used for SwiGLU output (inter_size) + act_result_ = reinterpret_cast(current_ptr); + current_ptr += interbuf_size * sizeof(T); + + ORT_ENFORCE(!has_fc3_, "SwiGLU activation is not supported with fc3"); + } else { + fc1_result_ = reinterpret_cast(current_ptr); + act_result_ = nullptr; // No extra buffer for activation since it is done inplace. + current_ptr += interbuf_size * sizeof(T); + } + if (has_fc3_) { - fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); - fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); + fc3_result_ = reinterpret_cast(current_ptr); + current_ptr += interbuf_size * sizeof(T); } else { - fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + fc3_result_ = nullptr; } const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { - softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + softmax_out_ = reinterpret_cast(current_ptr); } else { softmax_out_ = nullptr; } @@ -880,8 +977,51 @@ void CutlassMoeFCRunner::run_moe_fc( stream); } - // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, - // expanded_active_expert_rows); + if (fc1_activation_type == ActivationType::SwiGLU) { + T* gemm1_output_buffer = fc1_result_; + T* swiglu_output_buffer = act_result_; + + moe_gemm_runner_.moe_gemm_bias_act( + permuted_data_ + total_past_rows_ * hidden_size, + fc1_expert_weights, + fc1_scales, + fc1_expert_biases, + gemm1_output_buffer + total_past_rows_ * 2 * inter_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, + 2 * inter_size, + hidden_size, + local_num_experts, + ActivationType::Identity, + stream); + + constexpr bool swiglu_interleaved = true; + constexpr float swiglu_alpha = 1.702f; + invokeSwiGLU( + swiglu_output_buffer + total_past_rows_ * inter_size, + gemm1_output_buffer + total_past_rows_ * 2 * inter_size, + inter_size, + static_cast(total_covered_rows_), + swiglu_alpha, + stream); + + moe_gemm_runner_.moe_gemm( + swiglu_output_buffer + total_past_rows_ * inter_size, + fc2_expert_weights, + fc2_scales, + nullptr, + fc2_result + total_past_rows_ * hidden_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, + hidden_size, + inter_size, + local_num_experts, + stream); + + // No fc3 for SwiGLU + return; + } + moe_gemm_runner_.moe_gemm_bias_act( permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index, @@ -1178,4 +1318,7 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const half*, const int*, const int*, int, int, int, cudaStream_t); +template void invokeSwiGLU(float*, float const*, int, int, float, cudaStream_t); +template void invokeSwiGLU(half*, half const*, int, int, float, cudaStream_t); + } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index c457b608decbf..3ac4862e101c3 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -54,7 +54,10 @@ static inline size_t pad_to_multiple_of_16(size_t input) { template void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out, int* indices, int* source_row, int num_rows, int num_experts, int k, - cudaStream_t stream); + bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream); + +template +void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha, cudaStream_t stream); class CubKeyValueSorter { public: @@ -109,7 +112,7 @@ template class CutlassMoeFCRunner { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); + CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k); @@ -157,8 +160,10 @@ class CutlassMoeFCRunner { int64_t* total_rows_before_expert_; T* fc1_result_; + T* act_result_; T* fc3_result_; + ActivationType activation_type_; bool has_fc3_; bool normalize_routing_weights_; bool use_sparse_mixer_; @@ -176,7 +181,7 @@ class CutlassMoeFCRunner { template class CutlassMoeFCRunner::value>> { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); + CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) { return 0; diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index c5352d931ce2c..cc6fe871a3bc1 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -48,8 +48,11 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); const int sm = device_prop.major * 10 + device_prop.minor; - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_, use_sparse_mixer_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 6b65557444a66..194f33acbeb59 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -76,15 +76,16 @@ class MoEBase { } const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1; - if (fc1_experts_weights_dims[2] != inter_size / coe) { + const int64_t act = activation_type_ == ort_fastertransformer::ActivationType::SwiGLU ? 2 : 1; + if (fc1_experts_weights_dims[2] != act * inter_size / coe) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[2] must be equal to inter_size, got ", - fc1_experts_weights_dims[2], " and ", inter_size); + "fc1_experts_weights_dims[2] is ", + fc1_experts_weights_dims[2], " expected ", act * inter_size / coe); } if (fc2_experts_weights_dims[2] != hidden_size / coe) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", - fc2_experts_weights_dims[2], " and ", hidden_size); + "fc2_experts_weights_dims[2] is ", + fc2_experts_weights_dims[2], " expected ", hidden_size / coe); } if (router_probs_dims.size() != 2) { @@ -116,10 +117,10 @@ class MoEBase { "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], " and ", num_experts); } - if (fc1_experts_bias_dims[1] != inter_size) { + if (fc1_experts_bias_dims[1] != act * inter_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], - " and ", inter_size); + "fc1_experts_bias_dims[1] is ", fc1_experts_bias_dims[1], + ", expected ", act * inter_size); } if (fc2_experts_bias_dims[1] != hidden_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -182,10 +183,14 @@ class MoEBase { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ", fc1_experts_scales_dims[0], " and ", num_experts); } - if (fc1_experts_scales_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to inter_size, got ", - fc1_experts_scales_dims[1], " and ", inter_size); + + // The activation type affects the output dimension of the first FC layer. + const int64_t act = activation_type_ == ort_fastertransformer::ActivationType::SwiGLU ? 2 : 1; + if (fc1_experts_scales_dims[1] != act * inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to act * inter_size, got ", + fc1_experts_scales_dims[1], " and ", act * inter_size); } + if (fc2_experts_scales_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ", fc2_experts_scales->Shape().GetDims().size()); @@ -219,6 +224,8 @@ class MoEBase { activation_type_ = ort_fastertransformer::ActivationType::Gelu; } else if (activation_type_str == "silu") { activation_type_ = ort_fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "swiglu") { + activation_type_ = ort_fastertransformer::ActivationType::SwiGLU; } else if (activation_type_str == "identity") { activation_type_ = ort_fastertransformer::ActivationType::Identity; } else { diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 4dd5a079d1a29..db6d99674cf5a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -72,6 +72,7 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, using CudaT = typename ToCudaType::MappedType; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, fc3_experts_weights_optional != nullptr, normalize_routing_weights_, use_sparse_mixer_); @@ -185,4 +186,4 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5511275239e45..39bf2bf855976 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1392,14 +1392,14 @@ constexpr const char* MoE_ver1_doc = R"DOC( ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, OpSchema() .SetDoc(MoE_ver1_doc) - .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) + .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") - .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") - .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu", "T") + .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) .Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional) @@ -1413,7 +1413,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema() .SetDoc("Quantized MoE") .Attr("activation_type", - "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", + "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", @@ -1438,12 +1438,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size) " - "or (num_experts, hidden_size, inter_size / 2)", + "or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).", "T1") - .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T") + .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T") .Input(4, "fc1_experts_bias", - "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) .Input(5, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size) " diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py index 252d89a2257fc..d805c8f9cae3c 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -9,6 +9,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import itertools import unittest from collections import OrderedDict @@ -24,11 +25,6 @@ torch.manual_seed(42) numpy.random.seed(42) -USE_QUANT = False -ORT_DTYPE = TensorProto.FLOAT16 if USE_QUANT else TensorProto.FLOAT -NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 -THRESHOLD = 5e-1 if USE_QUANT else 1e-2 - def value_string_of(numpy_array): arr = numpy_array.flatten() @@ -40,26 +36,69 @@ def print_tensor(name, numpy_array): print(f"const std::vector {name} = {value_string_of(numpy_array)};") -def quant_dequant(weights, quant_mode: bool = True): - # use the test version `_symmetric_...` to get the non-interleaved weights - type = torch.quint4x2 if quant_mode else torch.int8 - # This import is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix() - # Comment out this line for passing the lintrunner check in the CI. - # import tensorrt_llm +def quant_dequant(weights: torch.Tensor, is_4_bit_quantization: bool): + """ + Performs symmetric per-column quantization and dequantization on a weight tensor. + + This implementation is a pure PyTorch replacement for the original function that + relied on a custom tensorrt_llm operator. It supports both 8-bit (int8) and + 4-bit (quint4x2 style) quantization. + + Args: + weights (torch.Tensor): The input weight tensor to be quantized. + is_4_bit_quantization (bool): If True, performs 4-bit quantization. If False, + performs 8-bit quantization. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - scales (torch.float16): The quantization scales for each column. + - processed_q_weight (torch.int8): The packed quantized weights. For + 4-bit mode, two 4-bit values are packed into a single int8. For + 8-bit mode, this is the standard int8 quantized tensor. It is + transposed relative to the input weights' shape. + - dequantized_weights (torch.Tensor): The weights after being dequantized, + restored to the original dtype and device. + """ + # Determine quantization bits and range based on the mode + if is_4_bit_quantization: + # 4-bit symmetric quantization path + q_bits = 4 + q_max = 2 ** (q_bits - 1) - 1 # 7 + q_min = -(2 ** (q_bits - 1)) # -8 - quant_weights, processed_q_weight, torch_weight_scales = ( - torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type) - ) + max_abs_val = torch.max(torch.abs(weights), dim=0, keepdim=True).values + max_abs_val[max_abs_val == 0] = 1.0 + scales = max_abs_val / q_max + + quant_weights = torch.round(weights / scales).clamp(q_min, q_max).to(torch.int8) + + # Pack two 4-bit integers into a single int8 + q_weights_t = quant_weights.T.contiguous() + shape = q_weights_t.shape + q_weights_t_reshaped = q_weights_t.view(shape[0], shape[1] // 2, 2) + lower_nibble = q_weights_t_reshaped[..., 0] + upper_nibble = q_weights_t_reshaped[..., 1] + processed_q_weight = (lower_nibble & 0x0F) | (upper_nibble << 4) + + else: + # 8-bit symmetric quantization path + q_bits = 8 + q_max = 2 ** (q_bits - 1) - 1 # 127 + q_min = -(2 ** (q_bits - 1)) # -128 - # Unpack the int4s int int8s - if quant_mode: - upper = quant_weights >> 4 - lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends - quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape) + max_abs_val = torch.max(torch.abs(weights), dim=0, keepdim=True).values + max_abs_val[max_abs_val == 0] = 1.0 + scales = max_abs_val / q_max - quant_weights = quant_weights.to(dtype=weights.dtype) - result = torch.multiply(quant_weights, torch_weight_scales.unsqueeze(0)).T.contiguous() - return torch_weight_scales.to(torch.float16), processed_q_weight, result.to(device=weights.device) + quant_weights = torch.round(weights / scales).clamp(q_min, q_max).to(torch.int8) + + # For 8-bit, the processed weights are just the transposed quantized weights (no packing) + processed_q_weight = quant_weights.T.contiguous() + + # Dequantize the weights to verify and return for PyTorch-side parity check + dequantized_weights = quant_weights.to(weights.dtype) * scales.to(weights.dtype) + + return (scales.squeeze(0).to(torch.float16), processed_q_weight, dequantized_weights.T.to(device=weights.device)) def create_moe_onnx_graph( @@ -71,6 +110,7 @@ def create_moe_onnx_graph( fc1_experts_bias, fc2_experts_weights, fc2_experts_bias, + ort_dtype, ): nodes = [ helper.make_node( @@ -94,19 +134,19 @@ def create_moe_onnx_graph( fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE, + ort_dtype, fc1_shape, fc1_experts_weights.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE, + ort_dtype, fc2_shape, fc2_experts_weights.to(torch_type).flatten().tolist(), raw=False, @@ -119,14 +159,14 @@ def create_moe_onnx_graph( [ helper.make_tensor( "fc1_experts_bias", - ORT_DTYPE, + ort_dtype, fc1_bias_shape, fc1_experts_bias.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_bias", - ORT_DTYPE, + ort_dtype, fc2_bias_shape, fc2_experts_bias.to(torch_type).flatten().tolist(), raw=False, @@ -135,19 +175,19 @@ def create_moe_onnx_graph( ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + ort_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -171,6 +211,7 @@ def create_mixtral_moe_onnx_graph( fc2_experts_weights, fc3_experts_weights, topk, + ort_dtype, ): nodes = [ helper.make_node( @@ -197,26 +238,26 @@ def create_mixtral_moe_onnx_graph( fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE, + ort_dtype, fc1_shape, fc1_experts_weights.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE, + ort_dtype, fc2_shape, fc2_experts_weights.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ORT_DTYPE, + ort_dtype, fc3_shape, fc3_experts_weights.to(torch_type).flatten().tolist(), raw=False, @@ -224,19 +265,19 @@ def create_mixtral_moe_onnx_graph( ] graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + ort_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -259,12 +300,14 @@ def create_phi_moe_onnx_graph( fc1_experts_weights, fc2_experts_weights, fc3_experts_weights, - fc1_scales, - fc2_scales, - fc3_scales, topk, + ort_dtype, + quant_bits=0, + fc1_scales=None, + fc2_scales=None, + fc3_scales=None, ): - use_quant = USE_QUANT + use_quant = quant_bits > 0 if use_quant: assert fc1_experts_weights.dtype == torch.int8 assert fc2_experts_weights.dtype == torch.int8 @@ -276,34 +319,37 @@ def create_phi_moe_onnx_graph( assert fc2_scales.dtype == torch.float16 assert fc3_scales.dtype == torch.float16 + op_name = "QMoE" if use_quant else "MoE" + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + "fc3_experts_weights", + "fc3_scales", + "", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "", + "fc2_experts_weights", + "", + "fc3_experts_weights", + ] + ) + nodes = [ helper.make_node( - "MoE" if not use_quant else "QMoE", - ( - [ - "input", - "router_probs", - "fc1_experts_weights", - "", - "fc2_experts_weights", - "", - "fc3_experts_weights", - ] - if not use_quant - else [ - "input", - "router_probs", - "fc1_experts_weights", - "fc1_scales", - "", - "fc2_experts_weights", - "fc2_scales", - "", - "fc3_experts_weights", - "fc3_scales", - "", - ] - ), + op_name, + inputs, ["output"], "MoE_0", k=topk, @@ -315,37 +361,38 @@ def create_phi_moe_onnx_graph( ] if use_quant: - nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", 8)]) + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) - fc1_shape = [num_experts, hidden_size, inter_size] - fc2_shape = [num_experts, inter_size, hidden_size] - fc3_shape = [num_experts, hidden_size, inter_size] + components = 2 if quant_bits == 4 else 1 + fc1_shape = [num_experts, hidden_size, inter_size // components] + fc2_shape = [num_experts, inter_size, hidden_size // components] + fc3_shape = [num_experts, hidden_size, inter_size // components] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 - numpy_type = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 - if use_quant: - numpy_type = numpy.uint8 + torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 + numpy_type = numpy.float16 if ort_dtype == TensorProto.FLOAT16 else numpy.float32 + weight_numpy_type = numpy.uint8 if use_quant else numpy_type + weight_onnx_type = TensorProto.UINT8 if use_quant else ort_dtype initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc1_shape, - fc1_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc1_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc2_shape, - fc2_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc2_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc3_shape, - fc3_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc3_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(), raw=False, ), ] @@ -358,21 +405,21 @@ def create_phi_moe_onnx_graph( [ helper.make_tensor( "fc1_scales", - ORT_DTYPE, + ort_dtype, fc1_scale_shape, fc1_scales.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_scales", - ORT_DTYPE, + ort_dtype, fc2_scale_shape, fc2_scales.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_scales", - ORT_DTYPE, + ort_dtype, fc3_scale_shape, fc3_scales.to(torch_type).flatten().tolist(), raw=False, @@ -381,19 +428,19 @@ def create_phi_moe_onnx_graph( ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + ort_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -546,8 +593,11 @@ def __init__(self, config: PhiMoEConfig): class SparseMoeBlockORTHelper(nn.Module): - def __init__(self): + def __init__(self, quant_bits=0): super().__init__() + self.quant_bits = quant_bits + self.ort_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + self.np_type = numpy.float16 if self.ort_dtype == TensorProto.FLOAT16 else numpy.float32 def create_ort_session(self, moe_onnx_graph): from onnxruntime import InferenceSession, SessionOptions # noqa: PLC0415 @@ -573,8 +623,8 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten router_logits = self.gate(hidden_states) ort_inputs = { - "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), - "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), + "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(self.np_type)), + "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(self.np_type)), } ort_output = None @@ -586,13 +636,6 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten self.ort_run_with_iobinding(ort_inputs) return None - # print_tensor("input", ort_inputs["input"]) - # print_tensor("router_probs", ort_inputs["router_probs"]) - # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) - # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) - # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) - # print_tensor("output", ort_output[0]) - return None def ort_run_with_iobinding(self, ort_inputs, repeat=1000): @@ -603,7 +646,7 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): name="input", device_type="cuda", device_id=device_id, - element_type=NP_TYPE, + element_type=self.np_type, shape=ort_inputs["input"].shape, buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), ) @@ -612,7 +655,7 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): name="router_probs", device_type="cuda", device_id=device_id, - element_type=NP_TYPE, + element_type=self.np_type, shape=ort_inputs["router_probs"].shape, buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( ort_inputs["router_probs"], "cuda", device_id @@ -623,7 +666,7 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): name="output", device_type="cuda", device_id=device_id, - element_type=NP_TYPE, + element_type=self.np_type, shape=ort_inputs["input"].shape, buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( numpy.zeros(ort_inputs["input"].shape), "cuda", device_id @@ -646,22 +689,27 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): e = time.time() print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") - def parity_check(self): + def parity_check(self, atol=None, rtol=None): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) torch_output = self.forward(hidden_state) ort_output = self.ort_forward(hidden_state) + + if atol is None: + atol = 5e-2 if self.quant_bits == 0 else (2.0 if self.quant_bits == 8 else 3.0) + + if rtol is None: + rtol = 1e-3 if self.quant_bits == 0 else 1e-2 + if ort_output is not None: + dtype_str = "FP32" if self.quant_bits == 0 else "FP16" print( - "name:", - self.__class__.__name__, - " batch_size:", - self.batch_size, - " sequence_length:", - self.sequence_length, - " max_diff:", - (torch_output - ort_output).abs().max(), + f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," + f" batch: {self.batch_size}, seq_len: {self.sequence_length}," + f" max_diff: {(torch_output - ort_output).abs().max()}" + ) + torch.testing.assert_close( + ort_output.to(torch.float32), torch_output.to(torch.float32), rtol=rtol, atol=atol ) - torch.testing.assert_close(ort_output.to(torch.float32), torch_output, rtol=THRESHOLD, atol=THRESHOLD) def benchmark_ort(self): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) @@ -680,7 +728,7 @@ def __init__( eval_capacity=-1, activation="gelu", ): - super().__init__() + super().__init__(quant_bits=0) # SwitchMoE is not quantized self.batch_size = batch_size self.sequence_length = sequence_length self.num_experts = num_experts @@ -709,6 +757,7 @@ def __init__( self.moe_experts.bias1, self.moe_experts.weight2.transpose(1, 2), self.moe_experts.bias2, + self.ort_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -744,7 +793,7 @@ class MixtralSparseMoeBlock(SparseMoeBlockORTHelper): """ def __init__(self, config, batch_size, sequence_length): - super().__init__() + super().__init__(quant_bits=0) # Mixtral test is not quantized self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts @@ -778,6 +827,7 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2, self.moe_experts_weight3, self.top_k, + self.ort_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -874,43 +924,44 @@ class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): and memory on padding. """ - def __init__(self, config, batch_size, sequence_length): - super().__init__() + def __init__(self, config, batch_size, sequence_length, quant_bits=0): + super().__init__(quant_bits) self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok self.router_jitter_noise = config.router_jitter_noise + use_quant = self.quant_bits > 0 # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([PhiMoEBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) - w1_list = [] - w2_list = [] - w3_list = [] - w1_scale_list = [] - w2_scale_list = [] - w3_scale_list = [] - if not USE_QUANT: + w1_list, w2_list, w3_list = [], [], [] + w1_scale_list, w2_scale_list, w3_scale_list = [], [], [] + + if not use_quant: for i in range(self.num_experts): w1_list.append(self.experts[i].w1.weight) w2_list.append(self.experts[i].w2.weight) w3_list.append(self.experts[i].w3.weight) else: + is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, False) - w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, False) - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, False) + # Corrected quantization logic for per-output-channel quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight.T, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight.T, is_4_bit) + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight.T, is_4_bit) self.experts[i].w1.weight.data = w1_qdq self.experts[i].w2.weight.data = w2_qdq self.experts[i].w3.weight.data = w3_qdq - w1_list.append(pre_qweight1) - w2_list.append(pre_qweight2) - w3_list.append(pre_qweight3) + # Transpose quantized weights to match the expected ONNX layout + w1_list.append(pre_qweight1.T) + w2_list.append(pre_qweight2.T) + w3_list.append(pre_qweight3.T) w1_scale_list.append(w1_scale) w2_scale_list.append(w2_scale) w3_scale_list.append(w3_scale) @@ -919,9 +970,9 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2 = torch.stack(w2_list, dim=0) self.moe_experts_weight3 = torch.stack(w3_list, dim=0) - moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if USE_QUANT else None - moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if USE_QUANT else None - moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if USE_QUANT else None + moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if use_quant else None + moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if use_quant else None self.batch_size = batch_size self.sequence_length = sequence_length @@ -933,10 +984,12 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight1, self.moe_experts_weight2, self.moe_experts_weight3, + self.top_k, + self.ort_dtype, + self.quant_bits, moe_experts_weight_scale1, moe_experts_weight_scale2, moe_experts_weight_scale3, - self.top_k, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -992,19 +1045,23 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def small_test_cases(): for batch_size in [1, 4, 16]: for sequence_length in [128, 512, 1024]: - yield batch_size, sequence_length + yield batch_size, sequence_length, 0 -def phi3_test_cases(): - # TODO: phi3 moe failed in long sequence lengths (max diff 0.22 > threshold 0.01), need investigation. - for batch_size in [1, 4, 16]: - for sequence_length in [128]: - yield batch_size, sequence_length +# Test cases for Phi-3 MoE. +# We test three modes: no quantization, 8-bit, and 4-bit. +phi3_test_params = list( + itertools.product( + [1, 4], # batch_size + [1, 32], # sequence_length + [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) class TestSwitchMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) - def test_switch_moe_parity(self, batch_size, sequence_length): + def test_switch_moe_parity(self, batch_size, sequence_length, quant_bits): # if platform.system() == "Windows": # pytest.skip("Skip on Windows") switch_moe = SwitchMoE( @@ -1020,8 +1077,8 @@ def test_switch_moe_parity(self, batch_size, sequence_length): class TestMixtralMoE(unittest.TestCase): - @parameterized.expand(small_test_cases()) - def test_mixtral_moe_parity(self, batch_size, sequence_length): + @parameterized.expand([(b, s, q) for b, s, q in small_test_cases() if q == 0]) # only run non-quantized + def test_mixtral_moe_parity(self, batch_size, sequence_length, quant_bits): config = MixtralConfig(hidden_size=256, intermediate_size=1024) mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) mixtral_moe.parity_check() @@ -1029,13 +1086,329 @@ def test_mixtral_moe_parity(self, batch_size, sequence_length): class TestPhiMoE(unittest.TestCase): - @parameterized.expand(phi3_test_cases()) - def test_phi3_moe_parity(self, batch_size, sequence_length): + @parameterized.expand(phi3_test_params) + def test_phi3_moe_parity(self, batch_size, sequence_length, quant_bits): config = PhiMoEConfig(hidden_size=256, intermediate_size=1024) - phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length) + phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits) phi3_moe.parity_check() # phi3_moe.benchmark_ort() +# --------------------------------------------- +# The following test are for swiglu activation +# --------------------------------------------- +class SwigluMoeConfig: + def __init__( + self, + hidden_size=2048, + intermediate_size=2048, + num_experts_per_token=2, + num_local_experts=8, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts_per_token = num_experts_per_token + self.num_local_experts = num_local_experts + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def swiglu(self, x: torch.Tensor): + dim = x.shape[-1] + x = x.view(-1, dim // 2, 2) + x_glu, x_linear = x[..., 0], x[..., 1] + y = x_glu * torch.sigmoid(1.702 * x_glu) * (x_linear + 1) + return y + + def forward(self, x): + y = self.swiglu(self.w1(x)) + y = self.w2(y) + return y + + +def create_swiglu_moe_onnx_graph( + num_tokens: int, + num_experts: int, + hidden_size: int, + inter_size: int, + topk: int, + ort_dtype: int, + quant_bits: int, + fc1_experts_weights: torch.Tensor, + fc1_experts_bias: torch.Tensor, + fc2_experts_weights: torch.Tensor, + fc2_experts_bias: torch.Tensor, + fc1_experts_weight_scale: torch.Tensor = None, + fc2_experts_weight_scale: torch.Tensor = None, +): + use_quant = quant_bits > 0 + op_name = "QMoE" if use_quant else "MoE" + + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_weight_scale", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_weight_scale", + "fc2_experts_bias", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_bias", + ] + ) + + nodes = [ + helper.make_node( + op_name, + inputs, + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="swiglu", + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + + components = 2 if quant_bits == 4 else 1 + fc1_weight_shape = [num_experts, hidden_size, 2 * inter_size // components] + fc1_bias_shape = [num_experts, 2 * inter_size] + fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size] + + fc2_weight_shape = [num_experts, inter_size, hidden_size // components] + fc2_bias_shape = [num_experts, hidden_size] + fc2_experts_weight_scale_shape = [num_experts, hidden_size] + + torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 + numpy_type = numpy.float16 if ort_dtype == TensorProto.FLOAT16 else numpy.float32 + weight_numpy_type = numpy.uint8 if use_quant else numpy_type + weight_onnx_type = TensorProto.UINT8 if use_quant else ort_dtype + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + weight_onnx_type, + fc1_weight_shape, + fc1_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist() + if use_quant + else fc1_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc1_experts_bias", + ort_dtype, + fc1_bias_shape, + fc1_experts_bias.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weights", + weight_onnx_type, + fc2_weight_shape, + fc2_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist() + if use_quant + else fc2_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_bias", + ort_dtype, + fc2_bias_shape, + fc2_experts_bias.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + + if use_quant: + initializers.extend( + [ + helper.make_tensor( + "fc1_experts_weight_scale", + ort_dtype, + fc1_experts_weight_scale_shape, + fc1_experts_weight_scale.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weight_scale", + ort_dtype, + fc2_experts_weight_scale_shape, + fc2_experts_weight_scale.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", ort_dtype, [num_tokens, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ort_dtype, + [num_tokens, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", ort_dtype, [num_tokens, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__(self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0): + super().__init__(quant_bits) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + weight_1_list, weight_2_list = [], [] + bias_1_list, bias_2_list = [], [] + scale_1_list, scale_2_list = [], [] + + for i in range(self.num_experts): + bias_1_list.append(self.experts[i].w1.bias) + bias_2_list.append(self.experts[i].w2.bias) + if not use_quant: + weight_1_list.append(self.experts[i].w1.weight) + weight_2_list.append(self.experts[i].w2.weight) + else: + is_4_bit = self.quant_bits == 4 + # Pass the transposed weight to quant_dequant to get correct scales, + # then transpose the resulting quantized weight back to the expected layout. + scale1, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight.T, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight.T, is_4_bit) + + self.experts[i].w1.weight.data = w1_qdq + self.experts[i].w2.weight.data = w2_qdq + + weight_1_list.append(pre_qweight1.T) + weight_2_list.append(pre_qweight2.T) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + self.moe_experts_weight1 = torch.stack(weight_1_list, dim=0) + self.moe_experts_weight2 = torch.stack(weight_2_list, dim=0) + + self.moe_experts_bias1 = torch.stack(bias_1_list, dim=0) + self.moe_experts_bias2 = torch.stack(bias_2_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + self.moe_onnx_graph = create_swiglu_moe_onnx_graph( + num_tokens=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + hidden_size=self.hidden_dim, + inter_size=self.ffn_dim, + topk=self.top_k, + ort_dtype=self.ort_dtype, + quant_bits=self.quant_bits, + fc1_experts_weights=self.moe_experts_weight1, + fc1_experts_bias=self.moe_experts_bias1, + fc2_experts_weights=self.moe_experts_weight2, + fc2_experts_bias=self.moe_experts_bias2, + fc1_experts_weight_scale=moe_experts_weight_scale1, + fc2_experts_weight_scale=moe_experts_weight_scale2, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) # router_logits shape is (batch * sequence_length, num_experts) + routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + + routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float) + + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +swiglu_test_params = list( + itertools.product( + [1, 4], # batch_size + [1, 32], # sequence_length + [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) + + +class TestSwigluMoE(unittest.TestCase): + @parameterized.expand(swiglu_test_params) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=128, intermediate_size=512, num_experts_per_token=1, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.parity_check() + + if __name__ == "__main__": unittest.main() From 6ee4ea3b05423aaa3ecd3698a56b83eb45f4b2ad Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 28 Jul 2025 22:36:31 -0700 Subject: [PATCH 33/33] Fix C/C++ documentation generation (#25569) ### Description Fixes documentation error in onnxruntime_c_api.h: parameter name mismatch for `Graph_GetGraphView` ### Motivation and Context Fix errors in the GitHub action for generating the C/C++ documentation from public header files. --- include/onnxruntime/core/session/onnxruntime_c_api.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2f0e4aa7ce108..d87e9e083185b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5846,14 +5846,13 @@ struct OrtApi { /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. * - * Note: - * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * \note The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference * the same underlying graph. * * \param[in] src_graph The source OrtGraph instance. * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. * \param[in] num_nodes Number of nodes. - * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * \param[out] dst_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. * * \snippet{doc} snippets.dox OrtStatus Return Value *