From d07d20080e385c7545cf6175ad0bc946fe1cfe57 Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Tue, 21 May 2024 09:26:34 +0400 Subject: [PATCH 1/4] [GPU] Add SDPA impl; SDPA input transpose fusion support; GQA optimization --- .../intel_gpu/include/intel_gpu/op/sdpa.hpp | 94 ++ .../intel_gpu/plugin/primitives_list.hpp | 2 + .../scaled_dot_product_attention.hpp | 95 ++ src/plugins/intel_gpu/src/graph/gemm.cpp | 3 + .../src/graph/impls/ocl/register.cpp | 1 + .../src/graph/impls/ocl/register.hpp | 2 + .../ocl/scaled_dot_product_attention.cpp | 135 ++ .../scaled_dot_product_attention_inst.h | 40 + .../intel_gpu/src/graph/primitive_inst.cpp | 5 + .../graph/scaled_dot_product_attention.cpp | 87 ++ .../kernel_selector/cl_kernels/sdpa_opt.cl | 1167 +++++++++++++++++ .../kernel_selector/cl_kernels/sdpa_ref.cl | 212 +++ .../src/kernel_selector/common_types.h | 1 + .../intel_gpu/src/kernel_selector/jitter.cpp | 4 +- .../kernel_selector/kernel_selector_utils.h | 23 +- .../arg_max_min/arg_max_min_kernel_axis.cpp | 2 +- .../fully_connected_kernel_base.cpp | 2 +- .../kernels/gemm/gemm_kernel_tiled_opt.cpp | 8 +- .../kernels/mvn/mvn_kernel_bfyx_opt.cpp | 2 +- .../non_zero/count_nonzero_kernel_ref.cpp | 2 +- .../non_zero/gather_nonzero_kernel_ref.cpp | 2 +- .../permute/permute_kernel_tile_8x8_4x4.cpp | 2 +- .../kernels/reduce/reduce_kernel_base.cpp | 2 +- .../kernels/rms/rms_kernel_bfyx_opt.cpp | 2 +- .../kernels/sdpa/sdpa_kernel_base.cpp | 126 ++ .../kernels/sdpa/sdpa_kernel_base.h | 111 ++ .../kernels/sdpa/sdpa_kernel_opt.cpp | 258 ++++ .../kernels/sdpa/sdpa_kernel_opt.h | 29 + .../kernels/sdpa/sdpa_kernel_ref.cpp | 107 ++ .../kernels/sdpa/sdpa_kernel_ref.h | 28 + .../kernels/sdpa/sdpa_kernel_selector.cpp | 19 + .../kernels/sdpa/sdpa_kernel_selector.h | 23 + .../kernels/slice/slice_kernel_ref.cpp | 2 +- .../kernels/softmax/softmax_kernel_bf.cpp | 2 +- .../kernels/unique/unique_kernel_ref.cpp | 4 +- .../ops/scaled_dot_product_attention.cpp | 59 + .../src/plugin/transformations/op/sdpa.cpp | 171 +++ ...matmul_fusion.cpp => transpose_fusion.cpp} | 140 +- .../transformations/transpose_fusion.hpp | 37 + .../transpose_matmul_fusion.hpp | 19 - ...nsqueeze_broadcast_reshape_sdpa_fusion.cpp | 134 ++ ...nsqueeze_broadcast_reshape_sdpa_fusion.hpp | 19 + .../src/plugin/transformations_pipeline.cpp | 72 +- .../skip_tests_config.cpp | 2 + .../dynamic/scaled_dot_product_attention.cpp | 215 +++ .../transpose_matmul_fusion_test.cpp | 10 +- .../transpose_sdpa_fusion_test.cpp | 178 +++ 47 files changed, 3599 insertions(+), 61 deletions(-) create mode 100644 src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp create mode 100644 src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp create mode 100644 src/plugins/intel_gpu/src/graph/include/scaled_dot_product_attention_inst.h create mode 100644 src/plugins/intel_gpu/src/graph/scaled_dot_product_attention.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl create mode 100644 src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_ref.cl create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.h create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.h create mode 100644 src/plugins/intel_gpu/src/plugin/ops/scaled_dot_product_attention.cpp create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp rename src/plugins/intel_gpu/src/plugin/transformations/{transpose_matmul_fusion.cpp => transpose_fusion.cpp} (56%) create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.hpp delete mode 100644 src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.hpp create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.hpp create mode 100644 src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp create mode 100644 src/plugins/intel_gpu/tests/unit/transformations/transpose_sdpa_fusion_test.cpp diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp new file mode 100644 index 00000000000000..45416b4e53810b --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp @@ -0,0 +1,94 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/core/partial_shape.hpp" +#include "openvino/op/op.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" + +namespace ov { +namespace intel_gpu { +namespace op { + +class SDPA : public ov::op::v13::ScaledDotProductAttention { +public: + OPENVINO_OP("SDPA", "gpu_opset"); + + SDPA() = default; + + SDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const bool is_causal, + const ov::element::Type output_type = ov::element::undefined); + + SDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const ov::Output& attn_mask, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const bool is_causal, + const ov::element::Type output_type = ov::element::undefined); + + SDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const ov::Output& attn_mask, + const ov::Output& scale, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const bool is_causal, + const ov::element::Type output_type = ov::element::undefined); + + bool visit_attributes(ov::AttributeVisitor &visitor) override; + + void validate_and_infer_types() override; + + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + + bool get_causal() const { return m_is_causal; } + + std::vector get_input0_transpose_order() const { return m_order_q; } + std::vector get_input1_transpose_order() const { return m_order_k; } + std::vector get_input2_transpose_order() const { return m_order_v; } + std::vector get_output_transpose_order() const { return m_order_out; } + ov::element::Type get_output_type() const { return m_output_type; } + + static std::vector default_order(size_t rank) { + std::vector order(rank); + std::iota(order.begin(), order.end(), 0); + return order; + } + +protected: + std::vector m_order_q; + std::vector m_order_k; + std::vector m_order_v; + std::vector m_order_out; + bool m_is_causal; + ov::element::Type m_output_type; +}; + +std::vector shape_infer(const SDPA* op, + std::vector input_shapes, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out); + + +} // namespace op +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index 68cb607b116f24..3af4b61474d522 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -263,6 +263,7 @@ REGISTER_FACTORY(v12, ScatterElementsUpdate); // ------------------------------ Supported v13 ops ----------------------------- // REGISTER_FACTORY(v13, Multinomial); +REGISTER_FACTORY(v13, ScaledDotProductAttention); // ------------------------------ Supported v14 ops ----------------------------- // REGISTER_FACTORY(v14, ROIAlignRotated); @@ -283,3 +284,4 @@ REGISTER_FACTORY(internal, SwiGLU); REGISTER_FACTORY(internal, IndirectGemm); REGISTER_FACTORY(internal, Convolution); REGISTER_FACTORY(internal, Placeholder); +REGISTER_FACTORY(internal, SDPA); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp new file mode 100644 index 00000000000000..f4f32a6af37d87 --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp @@ -0,0 +1,95 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "primitive.hpp" + +namespace cldnn { + +struct scaled_dot_product_attention : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(scaled_dot_product_attention) + + scaled_dot_product_attention() : primitive_base("", {}) {} + + /// @brief Constructs scaled_dot_product_attention primitive. + /// @param id This primitive id. + /// @param inputs Input data primitives id (query, keys, values, [attention_mask], [scale]). + /// @param is_causal If true, assumes causal attention masking. In this case attention_mask input is ignored. + scaled_dot_product_attention(const primitive_id& id, + const std::vector inputs, + bool is_causal, + const std::vector& input_q_transpose_order = {}, + const std::vector& input_k_transpose_order = {}, + const std::vector& input_v_transpose_order = {}, + const std::vector& output_transpose_order = {}, + const padding& output_padding = padding()) + : primitive_base(id, inputs, {output_padding}) + , is_causal(is_causal) + , has_attn_mask_input(inputs.size() > 3) + , has_scale_input(inputs.size() > 4) + , input_q_transpose_order(input_q_transpose_order) + , input_k_transpose_order(input_k_transpose_order) + , input_v_transpose_order(input_v_transpose_order) + , output_transpose_order(output_transpose_order) {} + + + bool is_causal = false; + bool has_attn_mask_input = false; + bool has_scale_input = false; + + std::vector input_q_transpose_order; + std::vector input_k_transpose_order; + std::vector input_v_transpose_order; + std::vector output_transpose_order; + + size_t hash() const override { + size_t seed = primitive::hash(); + seed = hash_combine(seed, is_causal); + seed = hash_combine(seed, has_attn_mask_input); + seed = hash_combine(seed, has_scale_input); + seed = hash_range(seed, input_q_transpose_order.begin(), input_q_transpose_order.end()); + seed = hash_range(seed, input_k_transpose_order.begin(), input_k_transpose_order.end()); + seed = hash_range(seed, input_v_transpose_order.begin(), input_v_transpose_order.end()); + seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end()); + return seed; + } + + bool operator==(const primitive& rhs) const override { + if (!compare_common_params(rhs)) + return false; + + auto rhs_casted = downcast(rhs); + + return is_causal == rhs_casted.is_causal && + has_attn_mask_input == rhs_casted.has_attn_mask_input && + has_scale_input == rhs_casted.has_scale_input && + input_q_transpose_order == rhs_casted.input_q_transpose_order && + input_k_transpose_order == rhs_casted.input_k_transpose_order && + input_v_transpose_order == rhs_casted.input_v_transpose_order && + output_transpose_order == rhs_casted.output_transpose_order; + } + + void save(BinaryOutputBuffer& ob) const override { + primitive_base::save(ob); + ob << is_causal; + ob << has_attn_mask_input; + ob << has_scale_input; + ob << input_q_transpose_order; + ob << input_k_transpose_order; + ob << input_v_transpose_order; + ob << output_transpose_order; + } + + void load(BinaryInputBuffer& ib) override { + primitive_base::load(ib); + ib >> is_causal; + ib >> has_attn_mask_input; + ib >> has_scale_input; + ib >> input_q_transpose_order; + ib >> input_k_transpose_order; + ib >> input_v_transpose_order; + ib >> output_transpose_order; + } +}; +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/gemm.cpp b/src/plugins/intel_gpu/src/graph/gemm.cpp index 4af921d566bffc..a8b196bd45885f 100644 --- a/src/plugins/intel_gpu/src/graph/gemm.cpp +++ b/src/plugins/intel_gpu/src/graph/gemm.cpp @@ -272,6 +272,9 @@ std::string gemm_inst::to_string(gemm_node const& node) { gemm_info.add("transpose_input1", transpose_input1); gemm_info.add("indirect_input0", indirect_input0); gemm_info.add("indirect_input1", indirect_input1); + gemm_info.add("trasnpose_order_input0", desc->input0_transpose_order); + gemm_info.add("trasnpose_order_input1", desc->input1_transpose_order); + gemm_info.add("trasnpose_order_output", desc->output_transpose_order); node_info->add("gemm info", gemm_info); node_info->dump(primitive_description); diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp index 40264d856035e2..855ae9c421b235 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp @@ -93,6 +93,7 @@ void register_implementations() { REGISTER_OCL(eye); REGISTER_OCL(unique_count); REGISTER_OCL(unique_gather); + REGISTER_OCL(scaled_dot_product_attention); } } // namespace ocl diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp index a2f3202f816671..f0d2a72e51d848 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp @@ -74,6 +74,7 @@ #include "intel_gpu/primitives/eye.hpp" #include "intel_gpu/primitives/unique.hpp" #include "intel_gpu/primitives/kv_cache.hpp" +#include "intel_gpu/primitives/scaled_dot_product_attention.hpp" namespace cldnn { namespace ocl { @@ -172,6 +173,7 @@ REGISTER_OCL(gather_nonzero); REGISTER_OCL(eye); REGISTER_OCL(unique_count); REGISTER_OCL(unique_gather); +REGISTER_OCL(scaled_dot_product_attention); #undef REGISTER_OCL diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp new file mode 100644 index 00000000000000..d9303f058814a2 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp @@ -0,0 +1,135 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "primitive_base.hpp" +#include "scaled_dot_product_attention_inst.h" +#include "sdpa/sdpa_kernel_selector.h" +#include "sdpa/sdpa_kernel_base.h" + +namespace cldnn { +namespace ocl { +struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl { + using parent = typed_primitive_impl_ocl; + using parent::parent; + using kernel_selector_t = kernel_selector::sdpa_kernel_selector; + using kernel_params_t = kernel_selector::sdpa_params; + + DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::scaled_dot_product_attention_impl) + + std::unique_ptr clone() const override { + return make_unique(*this); + } + + void load(BinaryInputBuffer& ib) override { + parent::load(ib); + if (is_dynamic()) { + auto& kernel_selector = kernel_selector_t::Instance(); + auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName); + kernel_impl->GetUpdateDispatchDataFunc(_kernel_data); + } + } + + static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) { + kernel_selector::sdpa_configuration config; + + auto transpose_pshape = [](const ov::PartialShape& pshape, const std::vector& order) { + auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank()); + for (size_t i = 0; i < order.size(); i++) { + transposed_pshape[i] = pshape[order[i]]; + } + return transposed_pshape; + }; + + const auto& prim = impl_param.typed_desc(); + const auto query_shape = transpose_pshape(impl_param.get_input_layout(0).get_partial_shape(), prim->input_q_transpose_order); + const auto key_shape = transpose_pshape(impl_param.get_input_layout(1).get_partial_shape(), prim->input_k_transpose_order); + const auto value_shape = transpose_pshape(impl_param.get_input_layout(2).get_partial_shape(), prim->input_v_transpose_order); + + OPENVINO_ASSERT(key_shape == value_shape, "[GPU] The shapes of key and value inputs are expected to be equal"); + for (size_t i = 0; i < query_shape.size(); ++i) { + if (query_shape[i].is_static() && key_shape[i].is_static() && value_shape[i].is_static()) { + if (query_shape[i].get_length() > key_shape[i].get_length()) { + config.broadcast_axis = prim->input_k_transpose_order[i]; + config.group_size = query_shape[i].get_length() / key_shape[i].get_length(); + } + } + } + + if (query_shape[query_shape.size() - 1].is_static()) + config.head_size = query_shape[query_shape.size() - 1].get_length(); + + config.is_causal = prim->is_causal; + + return config; + } + + static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic) { + auto params = get_default_params(impl_param, is_dynamic); + + const auto inputs_num = impl_param.input_layouts.size(); + params.inputs.resize(inputs_num); + for (size_t i = 0; i < inputs_num; i++) { + params.inputs[i] = convert_data_tensor(impl_param.get_input_layout(i)); + } + + params.conf = get_sdpa_configuration(impl_param); + + const auto& prim = impl_param.typed_desc(); + params.input0_order = prim->input_q_transpose_order; + params.input1_order = prim->input_k_transpose_order; + params.input2_order = prim->input_v_transpose_order; + params.output_order = prim->output_transpose_order; + + params.set_dynamic_shape_offsets(); + + return params; + } + + static std::unique_ptr create(const typed_program_node& arg, const kernel_impl_params& impl_param) { + auto sdpa_kernel_params = get_kernel_params(impl_param, impl_param.is_dynamic()); + auto& sdpa_kernel_selector = kernel_selector_t::Instance(); + auto kd = sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params); + + return cldnn::make_unique(kd); + } + + void update_dispatch_data(const kernel_impl_params& impl_param) override { + auto kernel_params = get_kernel_params(impl_param, true); + (_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data); + } +}; + +namespace detail { + +attach_scaled_dot_product_attention_impl::attach_scaled_dot_product_attention_impl() { + using sdpa_prim = scaled_dot_product_attention; + + auto types = { + data_types::f32, + data_types::f16, + }; + + auto formats = { + format::bfyx, + }; + + implementation_map::add(impl_types::ocl, + shape_types::static_shape, + scaled_dot_product_attention_impl::create, + types, + formats); + + implementation_map::add(impl_types::ocl, + shape_types::dynamic_shape, + scaled_dot_product_attention_impl::create, + types, + formats); +} + +} // namespace detail +} // namespace ocl +} // namespace cldnn + +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::scaled_dot_product_attention_impl) +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::scaled_dot_product_attention) diff --git a/src/plugins/intel_gpu/src/graph/include/scaled_dot_product_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/scaled_dot_product_attention_inst.h new file mode 100644 index 00000000000000..cecb2a0f609550 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/include/scaled_dot_product_attention_inst.h @@ -0,0 +1,40 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "intel_gpu/primitives/scaled_dot_product_attention.hpp" +#include "primitive_inst.h" + +#include + +namespace cldnn { + +template <> +struct typed_program_node : public typed_program_node_base { + using parent = typed_program_node_base; + +public: + using parent::parent; + + program_node& input(size_t index = 0) const { return get_dependency(index); } + std::vector get_shape_infer_dependencies() const override { return {}; } +}; +using scaled_dot_product_attention_node = typed_program_node; + +template <> +class typed_primitive_inst : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + using parent::parent; + +public: + template + static std::vector calc_output_layouts(scaled_dot_product_attention_node const& /*node*/, const kernel_impl_params& impl_param); + static layout calc_output_layout(scaled_dot_product_attention_node const& node, kernel_impl_params const& impl_param); + static std::string to_string(scaled_dot_product_attention_node const& node); + + typed_primitive_inst(network& network, scaled_dot_product_attention_node const& desc); +}; + +using scaled_dot_product_attention_inst = typed_primitive_inst; +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index 6a71cbc8981587..5772470bad54f8 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -394,6 +394,7 @@ void primitive_inst::update_shape() { } if (has_runtime_deps) { + std::cout << "Runtime deps\n"; OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("update_shape_sync: " + id())); if (!dependencies_events.empty() && queue_type == QueueTypes::out_of_order) { _network.get_stream().wait_for_events(dependencies_events); @@ -1455,7 +1456,11 @@ event::ptr primitive_inst::execute(const std::vector& events) { { GPU_DEBUG_PROFILED_STAGE(instrumentation::pipeline_stage::inference); + auto time0 = std::chrono::high_resolution_clock::now(); auto ev = _impl->execute(dependencies, *this); + auto time1 = std::chrono::high_resolution_clock::now(); + auto time_res0 = std::chrono::duration_cast(time1 - time0).count(); + GPU_DEBUG_TRACE_DETAIL << "Enqueu time = " << time_res0 << "\n"; GPU_DEBUG_IF(!debug_config->dump_profiling_data.empty()) { get_network().get_stream().wait_for_events({ev}); diff --git a/src/plugins/intel_gpu/src/graph/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/src/graph/scaled_dot_product_attention.cpp new file mode 100644 index 00000000000000..42e5aeb9f1302e --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/scaled_dot_product_attention.cpp @@ -0,0 +1,87 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "scaled_dot_product_attention_inst.h" + +#include "primitive_type_base.h" +#include "intel_gpu/runtime/error_handler.hpp" +#include "json_object.h" +#include +#include + +#include "scaled_dot_product_attention_shape_inference.hpp" +#include "intel_gpu/op/sdpa.hpp" + +namespace cldnn { +GPU_DEFINE_PRIMITIVE_TYPE_ID(scaled_dot_product_attention) + +layout scaled_dot_product_attention_inst::calc_output_layout(scaled_dot_product_attention_node const& /* node */, + kernel_impl_params const& impl_param) { + auto desc = impl_param.typed_desc(); + + return impl_param.get_input_layout(0); +} + +template +std::vector scaled_dot_product_attention_inst::calc_output_layouts(scaled_dot_product_attention_node const& /*node*/, + const kernel_impl_params& impl_param) { + auto prim = impl_param.typed_desc(); + auto input0_layout = impl_param.get_input_layout(0); + + auto default_out_dt = data_type_traits::is_floating_point(input0_layout.data_type) ? input0_layout.data_type : data_types::f32; + auto output_type = prim->output_data_types[0].value_or(default_out_dt); + + if (impl_param.has_fused_primitives()) { + output_type = impl_param.get_output_element_type(); + } + + ov::intel_gpu::op::SDPA op; + + std::vector input_shapes; + for (size_t i = 0; i < impl_param.input_layouts.size(); i++) { + input_shapes.push_back(impl_param.get_input_layout(0).get()); + } + + std::vector output_shapes = ov::intel_gpu::op::shape_infer(&op, + input_shapes, + prim->input_q_transpose_order, + prim->input_k_transpose_order, + prim->input_v_transpose_order, + prim->output_transpose_order); + + cldnn::format output_format = input0_layout.format; + + return { layout{output_shapes[0], output_type, output_format, prim->output_paddings[0]} }; +} + +template std::vector scaled_dot_product_attention_inst::calc_output_layouts(scaled_dot_product_attention_node const& node, + const kernel_impl_params& impl_param); + +std::string scaled_dot_product_attention_inst::to_string(scaled_dot_product_attention_node const& node) { + auto desc = node.get_primitive(); + auto node_info = node.desc_to_json(); + auto& input = node.input(); + + std::stringstream primitive_description; + + json_composite scaled_dot_product_attention_info; + scaled_dot_product_attention_info.add("input id", input.id()); + scaled_dot_product_attention_info.add("is_causal", desc->is_causal); + scaled_dot_product_attention_info.add("has_attn_mask_input", desc->has_attn_mask_input); + scaled_dot_product_attention_info.add("has_scale_input", desc->has_scale_input); + scaled_dot_product_attention_info.add("input_q_transpose_order", desc->input_q_transpose_order); + scaled_dot_product_attention_info.add("input_k_transpose_order", desc->input_k_transpose_order); + scaled_dot_product_attention_info.add("input_v_transpose_order", desc->input_v_transpose_order); + scaled_dot_product_attention_info.add("output_transpose_order", desc->output_transpose_order); + + node_info->add("scaled_dot_product_attention_info", scaled_dot_product_attention_info); + node_info->dump(primitive_description); + + return primitive_description.str(); +} + +scaled_dot_product_attention_inst::typed_primitive_inst(network& network, scaled_dot_product_attention_node const& node) + : parent(network, node) {} + +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl new file mode 100644 index 00000000000000..f9a1d31bc434ee --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -0,0 +1,1167 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "include/batch_headers/fetch_data.cl" +#include "include/batch_headers/common.cl" +#include "include/batch_headers/sub_group_block_read.cl" +#include "include/batch_headers/sub_group_block_write.cl" +#include "include/batch_headers/sub_group_shuffle.cl" + +// query_input [batch, heads_num, q_len, head_size] +// key_input [batch, kv_heads_num, kv_len, head_size] +// value_input [batch, kv_heads_num, kv_len, head_size] +// attn_mask [1, 1, q_len, kv_len] +// output [batch, heads_num, q_len, head_size] +// exp_sums [batch, heads_num, q_len, partition_idx] +// max_logits [batch, heads_num, q_len, partition_idx] +// tmp_out [batch, heads_num, q_len, partition_idx, head_size] + + +inline uint FUNC(get_input0_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#if INPUT0_SIMPLE + return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, y, x); +#else +#if INPUT0_DIMS == 4 + return INPUT0_GET_INDEX_SAFE(b, f, y, x); +#elif INPUT0_DIMS == 5 + return INPUT0_GET_INDEX_SAFE(b, f, z, y, x); +#elif INPUT0_DIMS == 6 + return INPUT0_GET_INDEX_SAFE(b, f, w, z, y, x); +#else +# error sdpa_ref.cl : Unsupported input 0 format +#endif +#endif +} + +inline uint FUNC(get_input0_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#ifdef INPUT0_DIMS_ORDER + return FUNC_CALL(get_input0_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR INPUT0_DIMS_ORDER); +#else + return FUNC_CALL(get_input0_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x); +#endif +} + +inline uint FUNC(get_input1_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#ifdef DO_BROADCAST_KEY_VALUE + DO_BROADCAST_KEY_VALUE; +#endif +#if INPUT1_SIMPLE + return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x); +#else +#if INPUT1_DIMS == 4 + return INPUT1_GET_INDEX_SAFE(b, f, y, x); +#elif INPUT1_DIMS == 5 + return INPUT1_GET_INDEX_SAFE(b, f, z, y, x); +#elif INPUT1_DIMS == 6 + return INPUT1_GET_INDEX_SAFE(b, f, w, z, y, x); +#else +# error sdpa_ref.cl : Unsupported input 1 format +#endif +#endif +} + +inline uint FUNC(get_input1_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#ifdef INPUT1_DIMS_ORDER + return FUNC_CALL(get_input1_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR INPUT1_DIMS_ORDER); +#else + return FUNC_CALL(get_input1_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x); +#endif +} + +inline uint FUNC(get_input2_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#ifdef DO_BROADCAST_KEY_VALUE + DO_BROADCAST_KEY_VALUE; +#endif +#if INPUT2_SIMPLE + return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, y, x); +#else +#if INPUT2_DIMS == 4 + return INPUT2_GET_INDEX_SAFE(b, f, y, x); +#elif INPUT2_DIMS == 5 + return INPUT2_GET_INDEX_SAFE(b, f, z, y, x); +#elif INPUT2_DIMS == 6 + return INPUT2_GET_INDEX_SAFE(b, f, w, z, y, x); +#else +# error sdpa_ref.cl : Unsupported input 1 format +#endif +#endif +} + +inline uint FUNC(get_input2_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#ifdef INPUT2_DIMS_ORDER + return FUNC_CALL(get_input2_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR INPUT2_DIMS_ORDER); +#else + return FUNC_CALL(get_input2_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x); +#endif +} + +#define VALUE_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT2_TYPE, 1, ptr, offset) +#define SUBGROUPS_PER_WG (HEAD_SIZE / SUBGROUP_SIZE) + +#ifdef SDPA_STAGE_0 + +#if TARGET_SEQ_LEN_BLOCK_SIZE == 1 + +REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) +KERNEL(sdpa_opt)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* query_input, + const __global INPUT1_TYPE* key_input, + const __global INPUT2_TYPE* value_input, +#if HAS_ATTN_MASK_INPUT + const __global INPUT3_TYPE* attn_mask, +#endif +#if HAS_SCALE_INPUT + const __global INPUT4_TYPE* scale, +#endif + __global OUTPUT_TYPE* output, + __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, + __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, + __global OUTPUT_TYPE* tmp_out +) +{ + const uint batch_idx = get_global_id(0); + const uint b0_idx = batch_idx / NUM_HEADS; /* BATCH dim */ + const uint b1_idx = batch_idx % NUM_HEADS; /* HEADS_NUM dim */ + +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint target_seq_idx = (uint)get_global_id(1) * TARGET_SEQ_LEN_BLOCK_SIZE; +#else + const uint target_seq_idx = get_global_id(1); +#endif + const uint lid = get_local_id(2); + const uint head_size_idx = lid; + + const uint sgid = get_sub_group_id(); + const uint sglid = get_sub_group_local_id(); + + const uint partition_idx = get_group_id(2); + const uint num_of_partitions = get_num_groups(2); + const uint wi_num_per_partition = get_local_size(2); + + const uint start_partition_idx = partition_idx * SEQ_LEN_PARTITION_SIZE; + const uint partition_seq_len = + ((partition_idx + 1) < num_of_partitions) ? (SEQ_LEN_PARTITION_SIZE) + : (SOURCE_SEQ_LEN - partition_idx * SEQ_LEN_PARTITION_SIZE); + + // SLM for query inputs + __local INPUT0_TYPE query_local[HEAD_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE]; + // SLM for intermediate QK results + __local OUTPUT_TYPE qk_local[SEQ_LEN_PARTITION_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE]; + // SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WG + __local SOFTMAX_ACCUMULATOR_TYPE qk_max_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE]; + __local SOFTMAX_ACCUMULATOR_TYPE qk_sum_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE]; + + { + // Gemm1 and SoftMax calculation + + SOFTMAX_ACCUMULATOR_TYPE qk_max[TARGET_SEQ_LEN_BLOCK_SIZE] = {SOFTMAX_ACCUMULATOR_VAL_MIN}; + for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) { + qk_max[i] = SOFTMAX_ACCUMULATOR_VAL_MIN; + } + + { + // Gemm1 calculation +#if HAS_SCALE_INPUT + const OUTPUT_TYPE scale_val = *scale; +#else + const OUTPUT_TYPE scale_val = OUTPUT_VAL_ONE / sqrt(TO_OUTPUT_TYPE(HEAD_SIZE)); +#endif + { + // Query input loading to SLM + #define QUERY_STEP_LOCAL SUBGROUP_SIZE * SUBGROUPS_PER_WG + uint query_local_offset = sgid * SUBGROUP_SIZE + sglid; + +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#else + const uint seq_idx_end = 1; +#endif +#ifdef INPUT0_DIMS_ORDER + uint query_offset = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, target_seq_idx, (sgid * SUBGROUP_SIZE)); + uint query_offset_next_seq = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, target_seq_idx + 1, (sgid * SUBGROUP_SIZE)); + const uint query_pitch = query_offset_next_seq - query_offset; +#else + uint query_offset = INPUT0_GET_INDEX(b0_idx, b1_idx, target_seq_idx, (sgid * SUBGROUP_SIZE)); + const uint query_pitch = QUERY_STEP_LOCAL; +#endif + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + #define QUERY_BLOCK_SIZE 1 + + INPUT0_TYPE val = BLOCK_READN(INPUT0_TYPE, QUERY_BLOCK_SIZE, query_input, query_offset); + + query_local[query_local_offset] = val; + query_local_offset += QUERY_STEP_LOCAL; + query_offset += query_pitch; + } + #undef QUERY_BLOCK_SIZE + #undef QUERY_STEP + + barrier(CLK_LOCAL_MEM_FENCE); + } + + // Main Gemm1 calculation loop + // Each SG performs element-wise multiplications of Q[HEAD_SIZE]xK[HEAD_SIZE] values + // HEAD_SIZE / SUBGROUPS_PER_WG times in the loop and saves the result to the qk_local SLM buffer + for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += (HEAD_SIZE / SUBGROUP_SIZE)) { +#ifdef INPUT1_DIMS_ORDER + uint key_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len, 0); +#else + uint key_offset = INPUT1_GET_INDEX(b0_idx, b1_idx, start_partition_idx + seq_len, 0); +#endif + + INPUT0_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {INPUT0_VAL_ZERO}; + + uint head_idx_index = 0; + #define KEY_BLOCK_SIZE 8 + for (; head_idx_index + (KEY_BLOCK_SIZE * SUBGROUP_SIZE) <= HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * KEY_BLOCK_SIZE) { + #define KEY_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT1_TYPE, KEY_BLOCK_SIZE, ptr, offset); + #define KEY_BLOCK MAKE_VECTOR_TYPE(INPUT1_TYPE, KEY_BLOCK_SIZE) + #define QUERY_BLOCK MAKE_VECTOR_TYPE(INPUT0_TYPE, KEY_BLOCK_SIZE) + + KEY_BLOCK key_vals = KEY_BLOCK_READ(key_input, key_offset + head_idx_index); + + uint query_offset = head_idx_index + sglid; + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + QUERY_BLOCK query_vals_reg; + unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) { + query_vals_reg[i] = query_local[query_offset + i * SUBGROUP_SIZE]; + } + + unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) { + acc[seq_idx] = mad(query_vals_reg[i], key_vals[i], acc[seq_idx]); + } + + query_offset += HEAD_SIZE; + } + } + + #define KEY_BLOCK_SIZE 4 + for (; head_idx_index + (KEY_BLOCK_SIZE * SUBGROUP_SIZE) <= HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * KEY_BLOCK_SIZE) { + #define KEY_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT1_TYPE, KEY_BLOCK_SIZE, ptr, offset); + #define KEY_BLOCK MAKE_VECTOR_TYPE(INPUT1_TYPE, KEY_BLOCK_SIZE) + #define QUERY_BLOCK MAKE_VECTOR_TYPE(INPUT0_TYPE, KEY_BLOCK_SIZE) + + KEY_BLOCK key_vals = KEY_BLOCK_READ(key_input, key_offset + head_idx_index); + + uint query_offset = head_idx_index + sglid; + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + QUERY_BLOCK query_vals_reg; + unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) { + query_vals_reg[i] = query_local[query_offset + i * SUBGROUP_SIZE]; + } + + unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) { + acc[seq_idx] = mad(query_vals_reg[i], key_vals[i], acc[seq_idx]); + } + + query_offset += HEAD_SIZE; + } + } + + #define KEY_BLOCK_SIZE 2 + for (; head_idx_index + (KEY_BLOCK_SIZE * SUBGROUP_SIZE) <= HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * KEY_BLOCK_SIZE) { + #define KEY_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT1_TYPE, KEY_BLOCK_SIZE, ptr, offset); + #define KEY_BLOCK MAKE_VECTOR_TYPE(INPUT1_TYPE, KEY_BLOCK_SIZE) + #define QUERY_BLOCK MAKE_VECTOR_TYPE(INPUT0_TYPE, KEY_BLOCK_SIZE) + + KEY_BLOCK key_vals = KEY_BLOCK_READ(key_input, key_offset + head_idx_index); + + uint query_offset = head_idx_index + sglid; + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + QUERY_BLOCK query_vals_reg; + unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) { + query_vals_reg[i] = query_local[query_offset + i * SUBGROUP_SIZE]; + } + + unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) { + acc[seq_idx] = mad(query_vals_reg[i], key_vals[i], acc[seq_idx]); + } + + query_offset += HEAD_SIZE; + } + } + + #define KEY_BLOCK_SIZE 1 + for (; head_idx_index + (KEY_BLOCK_SIZE * SUBGROUP_SIZE) <= HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * KEY_BLOCK_SIZE) { + #define KEY_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT1_TYPE, KEY_BLOCK_SIZE, ptr, offset); + #define KEY_BLOCK MAKE_VECTOR_TYPE(INPUT1_TYPE, KEY_BLOCK_SIZE) + #define QUERY_BLOCK MAKE_VECTOR_TYPE(INPUT0_TYPE, KEY_BLOCK_SIZE) + + KEY_BLOCK key_vals = KEY_BLOCK_READ(key_input, key_offset + head_idx_index); + + uint query_offset = head_idx_index + sglid; + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + QUERY_BLOCK query_vals_reg; + unroll_for(uint i = 0; i < KEY_BLOCK_SIZE; i++) { + query_vals_reg = query_local[query_offset + i * SUBGROUP_SIZE]; + } + + acc[seq_idx] = mad(query_vals_reg, key_vals, acc[seq_idx]); + query_offset += HEAD_SIZE; + } + } + + // Sum up all accumulators accross single SG and save result to SLM + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + acc[seq_idx] = sub_group_reduce_add(acc[seq_idx]); + qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len] = acc[seq_idx]; + } + } + + { + // Wait until all SG finishes their calculations and apply scale and attention mask to the results + barrier(CLK_LOCAL_MEM_FENCE); + + INPUT0_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#else + const uint seq_idx_end = 1; +#endif + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + // Iterate over all values QK values in SLM and apply scale and attention mask + for (uint seq_len = sgid * SUBGROUP_SIZE + sglid; seq_len < partition_seq_len; seq_len += (HEAD_SIZE)) { + // Read value from SLM and apply scale + qk_val[seq_idx] = qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len]; + qk_val[seq_idx] *= scale_val; + + // Apply attention mask +#if IS_CAUSAL + if (start_partition_idx + seq_len > target_seq_idx + seq_idx) + qk_val[seq_idx] += INPUT0_VAL_MIN; +#elif !IS_CAUSAL && HAS_ATTN_MASK_INPUT + const uint attn_mask_offset = INPUT3_GET_INDEX_SAFE(b0_idx, b1_idx, target_seq_idx + seq_idx, start_partition_idx + seq_len); + qk_val[seq_idx] += attn_mask[attn_mask_offset]; +#endif + + // Update qk_max value + qk_max[seq_idx] = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max[seq_idx], TO_SOFTMAX_ACCUMULATOR_TYPE(qk_val[seq_idx])); + + // Save modified qk value back to SLM + qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len] = qk_val[seq_idx]; + } + } + } + } // Gemm1 calculation end + + { + // SoftMax calculation +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#else + const uint seq_idx_end = 1; +#endif + // Find the maximum value of qk in the subgroup + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + qk_max[seq_idx] = sub_group_reduce_max(qk_max[seq_idx]); + } + + // Find the maximum value of qk across all subgroups in the workgroup + if (sglid == 0) { + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + qk_max_vals[seq_idx * SUBGROUPS_PER_WG + sgid] = qk_max[seq_idx]; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + qk_max[seq_idx] = SOFTMAX_ACCUMULATOR_VAL_MIN; + + if (sglid < SUBGROUPS_PER_WG) + qk_max[seq_idx] = qk_max_vals[seq_idx * SUBGROUPS_PER_WG + sglid]; + + // Final maximum value of qk after reduction across all subgroups + qk_max[seq_idx] = sub_group_reduce_max(qk_max[seq_idx]); + } + + SOFTMAX_ACCUMULATOR_TYPE exp_sum[TARGET_SEQ_LEN_BLOCK_SIZE] = {SOFTMAX_ACCUMULATOR_VAL_ZERO}; + const uint qk_num_per_wi = CEIL_DIV(partition_seq_len, SUBGROUPS_PER_WG * SUBGROUP_SIZE); + for (uint qk_idx = 0; qk_idx < qk_num_per_wi; qk_idx++) { + const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + head_size_idx; + if (local_data_idx < partition_seq_len) { + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + local_data_idx]) - qk_max[seq_idx]); + qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + local_data_idx] = TO_OUTPUT_TYPE(qk_new); + + exp_sum[seq_idx] += qk_new; + } + } + } + + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + exp_sum[seq_idx] = sub_group_reduce_add(exp_sum[seq_idx]); + + if (sglid == 0) + qk_sum_vals[seq_idx * SUBGROUPS_PER_WG + sgid] = exp_sum[seq_idx]; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + exp_sum[seq_idx] = SOFTMAX_ACCUMULATOR_VAL_ZERO; + + if (sglid < SUBGROUPS_PER_WG) + exp_sum[seq_idx] = qk_sum_vals[seq_idx * SUBGROUPS_PER_WG + sglid]; + + // Find the final sum of all exp_sum[seq_idx] values in workgroup + exp_sum[seq_idx] = sub_group_reduce_add(exp_sum[seq_idx]); + } + + // const SOFTMAX_ACCUMULATOR_TYPE inv_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ONE / exp_sum[seq_idx]; + for (uint qk_idx = 0; qk_idx < qk_num_per_wi; qk_idx++) { + const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + sgid * SUBGROUP_SIZE + sglid; + if (local_data_idx < partition_seq_len) { + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + local_data_idx]) / exp_sum[seq_idx]; + qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + local_data_idx] = TO_OUTPUT_TYPE(qk_new); + } + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + { + // If the number of partitions is greater than 1, save exm_sums and max_logits to the temporary buffers + // Use single WI in the WG, since all the WIs have the same value + if (num_of_partitions > 1 && head_size_idx == 0) { + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + + (seq_idx + target_seq_idx) * (num_of_partitions) + + partition_idx; + exp_sums[exp_sums_offset] = exp_sum[seq_idx]; + + const uint max_logits_offset = exp_sums_offset; + max_logits[max_logits_offset] = qk_max[seq_idx]; + } + } + } + } // SoftMax calculation end + } // Gemm1 + SoftMax calculations end + + { + // Gemm2 calculation + OUTPUT_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {OUTPUT_VAL_ZERO}; + +#ifdef INPUT2_DIMS_ORDER + uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 0, 0); + uint value_offset_next_seq = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 1, 0); + const uint value_pitch = value_offset_next_seq - value_offset; +#else + const uint value_pitch = HEAD_SIZE; +#endif + + for (uint seq_len = 0; seq_len < partition_seq_len / SUBGROUP_SIZE; seq_len++) { +#ifdef INPUT2_DIMS_ORDER + uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); +#else + uint value_offset = INPUT2_GET_INDEX(b0_idx, b1_idx, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); +#endif + + OUTPUT_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + qk_val[seq_idx] = qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len * SUBGROUP_SIZE + sglid]; + } + + unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) { + INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, value_offset); + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + acc[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], i), value_val, acc[seq_idx]); + } + + value_offset += value_pitch; + } + } + + const uint seq_len_leftovers_start = (partition_seq_len / SUBGROUP_SIZE) * SUBGROUP_SIZE; + for (uint seq_len = seq_len_leftovers_start; seq_len < partition_seq_len; seq_len++) { +#ifdef INPUT2_DIMS_ORDER + const uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len, head_size_idx); +#else + const uint value_offset = INPUT2_GET_INDEX(b0_idx, b1_idx, start_partition_idx + seq_len, head_size_idx); +#endif + + OUTPUT_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + qk_val[seq_idx] = qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len]; + } + + INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, value_offset); + + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + acc[seq_idx] = mad(qk_val[seq_idx], value_val, acc[seq_idx]); + } + } + + // If the number of partitions is greater than 1, save results to the temporary buffer; + // otherwise, save results directly to the main output. + if (num_of_partitions > 1) { +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#else + const uint seq_idx_end = 1; +#endif + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + // Data layout of tmp_output buf: [batch, heads_num, q_len, partition_idx, head_size] + const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + (target_seq_idx + seq_idx) * (num_of_partitions * HEAD_SIZE) + + partition_idx * (HEAD_SIZE) + + head_size_idx; + tmp_out[tmp_out_offset] = acc[seq_idx]; + } + } else { +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#else + const uint seq_idx_end = 1; +#endif + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + const uint output_offset = OUTPUT_GET_INDEX(b0_idx, b1_idx, target_seq_idx + seq_idx, head_size_idx); + + output[output_offset] = acc[seq_idx]; + } + } + } // Gemm2 calculation end +} + +#else + +REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) +KERNEL(sdpa_opt)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* query_input, + const __global INPUT1_TYPE* key_input, + const __global INPUT2_TYPE* value_input, +#if HAS_ATTN_MASK_INPUT + const __global INPUT3_TYPE* attn_mask, +#endif +#if HAS_SCALE_INPUT + const __global INPUT4_TYPE* scale, +#endif + __global OUTPUT_TYPE* output, + __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, + __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, + __global OUTPUT_TYPE* tmp_out +) +{ + const uint batch_idx = get_global_id(0); + const uint b0_idx = batch_idx / NUM_HEADS; /* BATCH dim */ + const uint b1_idx = batch_idx % NUM_HEADS; /* HEADS_NUM dim */ + +#if TARGET_SEQ_LEN_BLOCK_SIZE != 1 && TARGET_SEQ_LEN_BLOCK_SIZE != 16 + #error TARGET_SEQ_LEN_BLOCK_SIZE unexpected size +#endif + +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint target_seq_idx = (uint)get_global_id(1) * TARGET_SEQ_LEN_BLOCK_SIZE; +#else + const uint target_seq_idx = get_global_id(1); +#endif + const uint lid = get_local_id(2); + const uint head_size_idx = lid; + + const uint sgid = get_sub_group_id(); + const uint sglid = get_sub_group_local_id(); + + const uint partition_idx = get_group_id(2); + const uint num_of_partitions = get_num_groups(2); + const uint wi_num_per_partition = get_local_size(2); + + const uint start_partition_idx = partition_idx * SEQ_LEN_PARTITION_SIZE; + const uint partition_seq_len = + ((partition_idx + 1) < num_of_partitions) ? (SEQ_LEN_PARTITION_SIZE) + : (SOURCE_SEQ_LEN - partition_idx * SEQ_LEN_PARTITION_SIZE); + + const uint target_seq_len_bs = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); + + // SLM for query inputs + __local INPUT0_TYPE query_local[HEAD_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE]; + // SLM for intermediate QK results + __local OUTPUT_TYPE qk_local[SEQ_LEN_PARTITION_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE]; + // SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WG + __local SOFTMAX_ACCUMULATOR_TYPE qk_max_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE]; + __local SOFTMAX_ACCUMULATOR_TYPE qk_sum_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE]; + + { + // Gemm1 and SoftMax calculation + + SOFTMAX_ACCUMULATOR_TYPE qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN; + + { + // Gemm1 calculation +#if HAS_SCALE_INPUT + const OUTPUT_TYPE scale_val = *scale; +#else + const OUTPUT_TYPE scale_val = OUTPUT_VAL_ONE / sqrt(TO_OUTPUT_TYPE(HEAD_SIZE)); +#endif + { + // Load Query input to SLM and transpose it +#ifdef INPUT0_DIMS_ORDER + uint query_offset = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, target_seq_idx, (sgid * SUBGROUP_SIZE)); + uint query_offset_next_seq = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, target_seq_idx + 1, (sgid * SUBGROUP_SIZE)); + const uint query_pitch = query_offset_next_seq - query_offset; +#else + uint query_offset = INPUT0_GET_INDEX(b0_idx, b1_idx, target_seq_idx, (sgid * SUBGROUP_SIZE)); + const uint query_pitch = SUBGROUP_SIZE * SUBGROUPS_PER_WG; +#endif + uint query_local_offset = (sgid * SUBGROUP_SIZE + sglid) * TARGET_SEQ_LEN_BLOCK_SIZE; + if (target_seq_len_bs != TARGET_SEQ_LEN_BLOCK_SIZE) { + for (uint seq_idx = 0; seq_idx < target_seq_len_bs; seq_idx++) { + INPUT0_TYPE val = BLOCK_READN(INPUT0_TYPE, 1, query_input, query_offset); + + query_local[query_local_offset] = val; + query_offset += query_pitch; + query_local_offset++; + } + } else { + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + INPUT0_TYPE val = BLOCK_READN(INPUT0_TYPE, 1, query_input, query_offset); + + query_local[query_local_offset] = val; + query_offset += query_pitch; + query_local_offset++; + } + } + } + + { + barrier(CLK_LOCAL_MEM_FENCE); + } + + // Main Gemm1 calculation loop + uint seq_len = sgid * TARGET_SEQ_LEN_BLOCK_SIZE; + for (; seq_len < partition_seq_len; seq_len += SUBGROUPS_PER_WG * SUBGROUP_SIZE) { +#ifdef INPUT1_DIMS_ORDER + uint key_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len, 0); + uint key_offset_next_seq = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len + 1, 0); + const uint key_pitch = key_offset_next_seq - key_offset; +#else + uint key_offset = INPUT1_GET_INDEX(b0_idx, b1_idx, start_partition_idx + seq_len, 0); + const uint key_pitch = HEAD_SIZE; +#endif + + INPUT0_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {INPUT0_VAL_ZERO}; + + for (uint head_idx_index = 0; head_idx_index < HEAD_SIZE; head_idx_index += SUBGROUP_SIZE) { + #define KEY_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT1_TYPE, 1, ptr, offset); + #define QUERY_VEC MAKE_VECTOR_TYPE(INPUT1_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) + + QUERY_VEC queries_vec; + uint query_local_offset = (head_idx_index * TARGET_SEQ_LEN_BLOCK_SIZE) + sglid; + unroll_for (uint q_row_idx = 0; q_row_idx < TARGET_SEQ_LEN_BLOCK_SIZE; q_row_idx++) { + queries_vec[q_row_idx] = query_local[query_local_offset]; + query_local_offset += TARGET_SEQ_LEN_BLOCK_SIZE; + } + + unroll_for (uint key_row_idx = 0; key_row_idx < TARGET_SEQ_LEN_BLOCK_SIZE; key_row_idx++) { + INPUT1_TYPE key_vals = KEY_BLOCK_READ(key_input, key_offset + key_row_idx * key_pitch + head_idx_index); + + unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) { + acc[key_row_idx] = mad(sub_group_broadcast(key_vals, i), queries_vec[i], acc[key_row_idx]); + } + } + } + + { +#if !IS_CAUSAL && HAS_ATTN_MASK_INPUT + const uint attn_mask_offset = INPUT3_GET_INDEX_SAFE(b0_idx, b1_idx, target_seq_idx + sglid, start_partition_idx + seq_len); + MAKE_VECTOR_TYPE(INPUT3_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) attn_mask_vec = INPUT3_VAL_MIN; + for (uint i = 0; i < min(partition_seq_len - seq_len, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); i++) { + attn_mask_vec[i] = attn_mask[attn_mask_offset + i]; + } +#endif + unroll_for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) { + acc[i] *= scale_val; +#if IS_CAUSAL + if (start_partition_idx + seq_len + i > target_seq_idx + sglid) + acc[i] += INPUT0_VAL_MIN; +#elif !IS_CAUSAL && HAS_ATTN_MASK_INPUT + acc[i] += attn_mask_vec[i]; +#endif +#if INPUT0_TYPE_SIZE == 2 + /* Adding this clamp improves performance for some reason */ + acc[i] = SOFTMAX_ACCUMULATOR_MIN_FUNC(SOFTMAX_ACCUMULATOR_MAX_FUNC(acc[i], INPUT0_VAL_MIN), INPUT0_VAL_MAX); +#endif + if (seq_len + i >= partition_seq_len) { + acc[i] = INPUT0_VAL_MIN; + } + + qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(acc[i])); + qk_local[sglid * SEQ_LEN_PARTITION_SIZE + seq_len + i] = acc[i]; + } + } + } + } // Gemm1 calculation end + + { + // Save QK max to SLM + qk_max_vals[sglid * SUBGROUPS_PER_WG + sgid] = qk_max; + } + + { + // SoftMax calculation +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint seq_idx_end = target_seq_len_bs; +#else + const uint seq_idx_end = 1; +#endif + #define QK_MAX_NUMS_PER_SG CEIL_DIV(TARGET_SEQ_LEN_BLOCK_SIZE, SUBGROUPS_PER_WG) + #if (TARGET_SEQ_LEN_BLOCK_SIZE % SUBGROUPS_PER_WG != 0) + /* /* If TARGET_SEQ_LEN_BLOCK_SIZE is not divisible by SUBGROUPS_PER_WG, then some subgroups will have to handle more QK rows than others */ + #define QK_ITERS_END \ + (TARGET_SEQ_LEN_BLOCK_SIZE / SUBGROUPS_PER_WG + (sgid < TARGET_SEQ_LEN_BLOCK_SIZE % SUBGROUPS_PER_WG ? 1 : 0)) + #else + #define QK_ITERS_END QK_MAX_NUMS_PER_SG + #endif + + OUTPUT_TYPE qk_max[QK_MAX_NUMS_PER_SG]; + for (uint i = 0; i < QK_MAX_NUMS_PER_SG; i++) + qk_max[i] = SOFTMAX_ACCUMULATOR_VAL_MIN; + + barrier(CLK_LOCAL_MEM_FENCE); + + if (sglid < SUBGROUPS_PER_WG) + for (uint i = 0; i < QK_ITERS_END; i++) + qk_max[i] = qk_max_vals[(i * SUBGROUPS_PER_WG * SUBGROUPS_PER_WG) + sgid * SUBGROUPS_PER_WG + sglid]; + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + for (uint i = 0; i < QK_ITERS_END; i++) { + qk_max[i] = sub_group_reduce_max(qk_max[i]); + } + + SOFTMAX_ACCUMULATOR_TYPE exp_sum[QK_MAX_NUMS_PER_SG]; + for (uint i = 0; i < QK_MAX_NUMS_PER_SG; i++) + exp_sum[i] = SOFTMAX_ACCUMULATOR_VAL_ZERO; + + for (uint i = 0; i < QK_ITERS_END; i++) { + // TODO: Try full loop, with ternary operator + for (uint qk_idx = sglid; qk_idx < partition_seq_len; qk_idx += SUBGROUP_SIZE) { + const uint qk_offset = i * SUBGROUPS_PER_WG * SEQ_LEN_PARTITION_SIZE + sgid * SEQ_LEN_PARTITION_SIZE + qk_idx; + SOFTMAX_ACCUMULATOR_TYPE qk_val = qk_local[qk_offset]; + SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(qk_val) - qk_max[i]); + qk_local[qk_offset] = qk_new; + exp_sum[i] += qk_new; + } + } + + for (uint i = 0; i < QK_ITERS_END; i++) { + exp_sum[i] = sub_group_reduce_add(exp_sum[i]); + } + + for (uint i = 0; i < QK_ITERS_END; i++) { + for (uint qk_idx = sglid; qk_idx < partition_seq_len; qk_idx += SUBGROUP_SIZE) { + const uint qk_offset = i * SUBGROUPS_PER_WG * SEQ_LEN_PARTITION_SIZE + sgid * SEQ_LEN_PARTITION_SIZE + qk_idx; + SOFTMAX_ACCUMULATOR_TYPE qk_val = TO_SOFTMAX_ACCUMULATOR_TYPE(qk_local[qk_offset]); + SOFTMAX_ACCUMULATOR_TYPE qk_new = qk_val / exp_sum[i]; + qk_local[qk_offset] = qk_new; + } + } + + { + // If the number of partitions is greater than 1, save exm_sums and max_logits to the temporary buffers + // Use single WI in the WG, since all the WIs have the same value + if (num_of_partitions > 1 && sglid == 0) { + for (uint i = 0; i < QK_MAX_NUMS_PER_SG; i++) { + if (target_seq_idx + sgid + (i * SUBGROUPS_PER_WG) < TARGET_SEQ_LEN) { + const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + + (target_seq_idx + sgid + (i * SUBGROUPS_PER_WG)) * (num_of_partitions) + + partition_idx; + exp_sums[exp_sums_offset] = exp_sum[i]; + + const uint max_logits_offset = exp_sums_offset; + max_logits[max_logits_offset] = qk_max[i]; + } + } + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + } // SoftMax calculation end + } // Gemm1 + SoftMax calculations end + + const uint seq_len_leftovers_start = (partition_seq_len / SUBGROUP_SIZE) * SUBGROUP_SIZE; + if (seq_len_leftovers_start != partition_seq_len) { + // Gemm2 calculation + OUTPUT_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {OUTPUT_VAL_ZERO}; + +#ifdef INPUT2_DIMS_ORDER + uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 0, 0); + uint value_offset_next_seq = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 1, 0); + const uint value_pitch = value_offset_next_seq - value_offset; +#else + const uint value_pitch = HEAD_SIZE; +#endif + + for (uint seq_len = 0; seq_len < partition_seq_len / SUBGROUP_SIZE; seq_len++) { +#ifdef INPUT2_DIMS_ORDER + uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); +#else + uint value_offset = INPUT2_GET_INDEX(b0_idx, b1_idx, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); +#endif + + OUTPUT_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + qk_val[seq_idx] = qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len * SUBGROUP_SIZE + sglid]; + } + + unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) { + INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, value_offset); + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + acc[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], i), value_val, acc[seq_idx]); + } + + value_offset += value_pitch; + } + } + + + /* The handling of leftovers causes significantly worse assembly code generation for the above main calculation loop. + Therefore, there are two independent branches for the calculation of QK*V matrices: + one with leftovers handling (when seq_len_leftovers_start != partition_seq_len) and one without. */ + { + OUTPUT_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; + uint qk_offset = min(seq_len_leftovers_start + sglid, partition_seq_len); + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + qk_val[seq_idx] = qk_local[qk_offset]; + qk_offset += SEQ_LEN_PARTITION_SIZE; + } + + uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len_leftovers_start, head_size_idx); + + for (uint seq_len_idx = 0; seq_len_idx < partition_seq_len - seq_len_leftovers_start; seq_len_idx++) { + INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, value_offset); + + for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + acc[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], seq_len_idx), value_val, acc[seq_idx]); + } + + value_offset += value_pitch; + } + } + + // If the number of partitions is greater than 1, save results to the temporary buffer; + // otherwise, save results directly to the main output. + if (num_of_partitions > 1) { +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#else + const uint seq_idx_end = 1; +#endif + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + // Data layout of tmp_output buf: [batch, heads_num, q_len, partition_idx, head_size] + const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + (target_seq_idx + seq_idx) * (num_of_partitions * HEAD_SIZE) + + partition_idx * (HEAD_SIZE) + + head_size_idx; + + tmp_out[tmp_out_offset] = acc[seq_idx]; + } + } else { +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#else + const uint seq_idx_end = 1; +#endif + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + const uint output_offset = OUTPUT_GET_INDEX(b0_idx, b1_idx, target_seq_idx + seq_idx, head_size_idx); + + output[output_offset] = acc[seq_idx]; + } + } + } else { + // Gemm2 calculation + OUTPUT_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {OUTPUT_VAL_ZERO}; + +#ifdef INPUT2_DIMS_ORDER + uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 0, 0); + uint value_offset_next_seq = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 1, 0); + const uint value_pitch = value_offset_next_seq - value_offset; +#else + const uint value_pitch = HEAD_SIZE; +#endif + + for (uint seq_len = 0; seq_len < partition_seq_len / SUBGROUP_SIZE; seq_len++) { +#ifdef INPUT2_DIMS_ORDER + uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); +#else + uint value_offset = INPUT2_GET_INDEX(b0_idx, b1_idx, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); +#endif + + OUTPUT_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + qk_val[seq_idx] = qk_local[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len * SUBGROUP_SIZE + sglid]; + } + + unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) { + INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, value_offset); + unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + acc[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], i), value_val, acc[seq_idx]); + } + + value_offset += value_pitch; + } + } + + // If the number of partitions is greater than 1, save results to the temporary buffer; + // otherwise, save results directly to the main output. + if (num_of_partitions > 1) { +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#else + const uint seq_idx_end = 1; +#endif + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + // Data layout of tmp_output buf: [batch, heads_num, q_len, partition_idx, head_size] + const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + (target_seq_idx + seq_idx) * (num_of_partitions * HEAD_SIZE) + + partition_idx * (HEAD_SIZE) + + head_size_idx; + tmp_out[tmp_out_offset] = acc[seq_idx]; + } + } else { +#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#else + const uint seq_idx_end = 1; +#endif + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + const uint output_offset = OUTPUT_GET_INDEX(b0_idx, b1_idx, target_seq_idx + seq_idx, head_size_idx); + + output[output_offset] = acc[seq_idx]; + } + } + } // Gemm2 calculation end +} + +#endif // TARGET_SEQ_LEN_BLOCK_SIZE != 1 + +#endif // SDPA_STAGE_0 + +#ifdef SDPA_STAGE_1 + +// MTL iGPU faces high register pressure issue with a higher number of REG_VERSION_MAX_VALUES_PER_WI. +// To mitigate this, add an additional level of SDPA results processing +// with lower register pressure (REG_VERSION_MAX_VALUES_PER_WI_LOWER). + +#if SOFTMAX_ACCUMULATOR_TYPE_SIZE == 4 +#define REG_VERSION_MAX_VALUES_PER_WI 24 +#define REG_VERSION_MAX_VALUES_PER_WI_LOWER 8 +#elif SOFTMAX_ACCUMULATOR_TYPE_SIZE == 2 +#define REG_VERSION_MAX_VALUES_PER_WI 48 +#define REG_VERSION_MAX_VALUES_PER_WI_LOWER 16 +#else +#error Unexpected SOFTMAX_ACCUMULATOR data type size +#endif + +// query_input [batch, heads_num, q_len, head_size] +// key_input [batch, kv_heads_num, kv_len, head_size] +// value_input [batch, kv_heads_num, kv_len, head_size] +// attn_mask [1, 1, q_len, kv_len] +// output [batch, heads_num, q_len, head_size] +// exp_sums [batch, heads_num, q_len, partition_idx] +// max_logits [batch, heads_num, q_len, partition_idx] +// tmp_out [batch, heads_num, q_len, partition_idx, head_size] + +REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) +KERNEL(sdpa_opt_finalization_stage)( + OPTIONAL_SHAPE_INFO_ARG + __global OUTPUT_TYPE* output, + const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, + const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, + const __global OUTPUT_TYPE* tmp_out, + const uint num_of_partitions) { + const uint batch_idx = get_global_id(0); + const uint b0_idx = batch_idx / NUM_HEADS; + const uint b1_idx = batch_idx % NUM_HEADS; + const uint target_seq_idx = get_global_id(1); + const uint sglid = get_sub_group_local_id(); + + if (num_of_partitions <= SUBGROUP_SIZE * REG_VERSION_MAX_VALUES_PER_WI_LOWER) { + /* Registers kernel version, can handle up to SEQ_LEN_PARTITION_SIZE(256) * SUBGROUP_SIZE(16) * REG_VERSION_MAX_VALUES_PER_WI_LOWER(8/16) = 32768/65536 tokens */ + SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO}; + SOFTMAX_ACCUMULATOR_TYPE max_logit[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_MIN}; + SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN; + + const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE); + for (uint i = 0; i < iters_num; i++) { + const uint partition_idx = i * SUBGROUP_SIZE + sglid; + const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + + target_seq_idx * (num_of_partitions) + + partition_idx; + const uint max_logit_offset = exp_sums_offset; + + if (partition_idx < num_of_partitions) { + exp_sum[i] = exp_sums[exp_sums_offset]; + max_logit[i] = max_logits[max_logit_offset]; + local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logit[i]); + } + } + + SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit); + + // Update exp_sum with respect to the global maximum + for (uint i = 0; i < iters_num; i++) { + const uint partition_idx = i * SUBGROUP_SIZE + sglid; + if (partition_idx < num_of_partitions) { + exp_sum[i] = exp_sum[i] * native_exp(max_logit[i] - global_max); + local_exp_sum += exp_sum[i]; + } + } + + SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum); + + for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) { + SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f; + for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) { + const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + target_seq_idx * (num_of_partitions * HEAD_SIZE) + + partition_idx * (HEAD_SIZE) + + (head_size_idx * SUBGROUP_SIZE + sglid); + OUTPUT_TYPE out_val = tmp_out[tmp_out_offset]; + acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * + TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(exp_sum[partition_idx / SUBGROUP_SIZE], partition_idx % SUBGROUP_SIZE)) / + TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum); + } + const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) + + target_seq_idx * (HEAD_SIZE) + + (head_size_idx * SUBGROUP_SIZE + sglid); + + output[out_offset] = TO_OUTPUT_TYPE(acc); + } + } else if (num_of_partitions <= SUBGROUP_SIZE * REG_VERSION_MAX_VALUES_PER_WI) { + /* Registers kernel version, can handle up to SEQ_LEN_PARTITION_SIZE(256) * SUBGROUP_SIZE(16) * REG_VERSION_MAX_VALUES_PER_WI(24/48) = 98304/196608 tokens */ + SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_ZERO}; + SOFTMAX_ACCUMULATOR_TYPE max_logit[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_MIN}; + SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN; + + const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE); + for (uint i = 0; i < iters_num; i++) { + const uint partition_idx = i * SUBGROUP_SIZE + sglid; + const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + + target_seq_idx * (num_of_partitions) + + partition_idx; + const uint max_logit_offset = exp_sums_offset; + + if (partition_idx < num_of_partitions) { + exp_sum[i] = exp_sums[exp_sums_offset]; + max_logit[i] = max_logits[max_logit_offset]; + local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logit[i]); + } + } + + SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit); + + // Update exp_sum with respect to the global maximum + for (uint i = 0; i < iters_num; i++) { + const uint partition_idx = i * SUBGROUP_SIZE + sglid; + if (partition_idx < num_of_partitions) { + exp_sum[i] = exp_sum[i] * native_exp(max_logit[i] - global_max); + local_exp_sum += exp_sum[i]; + } + } + + SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum); + + for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) { + SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f; + for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) { + const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + target_seq_idx * (num_of_partitions * HEAD_SIZE) + + partition_idx * (HEAD_SIZE) + + (head_size_idx * SUBGROUP_SIZE + sglid); + OUTPUT_TYPE out_val = tmp_out[tmp_out_offset]; + acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * + TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(exp_sum[partition_idx / SUBGROUP_SIZE], partition_idx % SUBGROUP_SIZE)) / + TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum); + } + const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) + + target_seq_idx * (HEAD_SIZE) + + (head_size_idx * SUBGROUP_SIZE + sglid); + + output[out_offset] = TO_OUTPUT_TYPE(acc); + } + } else { + /* Global memory kernel version, can handle any number of tokens, but could be very slow. */ + SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN; + + const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE); + for (uint i = 0; i < iters_num; i++) { + const uint partition_idx = i * SUBGROUP_SIZE + sglid; + const uint max_logit_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + + target_seq_idx * (num_of_partitions) + + partition_idx; + + + if (partition_idx < num_of_partitions) { + local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits[max_logit_offset]); + } + } + + SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit); + + // Calculate global sum + for (uint i = 0; i < iters_num; i++) { + const uint partition_idx = i * SUBGROUP_SIZE + sglid; + const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + + target_seq_idx * (num_of_partitions) + + partition_idx; + const uint max_logit_offset = exp_sums_offset; + + if (partition_idx < num_of_partitions) { + local_exp_sum += exp_sums[exp_sums_offset] * native_exp(max_logits[max_logit_offset] - global_max); + } + } + + SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum); + + for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) { + SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f; + for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) { + const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + + target_seq_idx * (num_of_partitions * HEAD_SIZE) + + partition_idx * (HEAD_SIZE) + + (head_size_idx * SUBGROUP_SIZE + sglid); + + const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + + target_seq_idx * (num_of_partitions) + + partition_idx; + const uint max_logit_offset = exp_sums_offset; + + SOFTMAX_ACCUMULATOR_TYPE new_exp_sum = exp_sums[exp_sums_offset] * native_exp(max_logits[max_logit_offset] - global_max); + + OUTPUT_TYPE out_val = tmp_out[tmp_out_offset]; + acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * new_exp_sum / TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum); + } + + const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) + + target_seq_idx * (HEAD_SIZE) + + (head_size_idx * SUBGROUP_SIZE + sglid); + + output[out_offset] = TO_OUTPUT_TYPE(acc); + } + } +} + +#endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_ref.cl new file mode 100644 index 00000000000000..cd289be026e7e3 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_ref.cl @@ -0,0 +1,212 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "include/batch_headers/fetch_data.cl" + +// query_input [batch, heads_num, q_len, head_size] +// key_input [batch, kv_heads_num, kv_len, head_size] +// value_input [batch, kv_heads_num, kv_len, head_size] +// attn_mask [1, 1, q_len, kv_len] +// output [batch, heads_num, q_len, head_size] +// tmp_buf [batch, heads_num, q_len, kv_len] + +// When handling long sequences and executing in FP16, accuracy can significantly vary based on two factors: +// 1) The order of scale application (which can be controlled using the APPLY_SCALE_TO_QUERY macro) +// 2) The type of SoftMax accumulator + +inline uint FUNC(get_input0_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#if INPUT0_SIMPLE + return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, y, x); +#else +#if INPUT0_DIMS == 4 + return INPUT0_GET_INDEX_SAFE(b, f, y, x); +#elif INPUT0_DIMS == 5 + return INPUT0_GET_INDEX_SAFE(b, f, z, y, x); +#elif INPUT0_DIMS == 6 + return INPUT0_GET_INDEX_SAFE(b, f, w, z, y, x); +#else +# error sdpa_ref.cl : Unsupported input 0 format +#endif +#endif +} + +inline uint FUNC(get_input0_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#ifdef INPUT0_DIMS_ORDER + return FUNC_CALL(get_input0_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR INPUT0_DIMS_ORDER); +#else + return FUNC_CALL(get_input0_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x); +#endif +} + +inline uint FUNC(get_input1_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#ifdef DO_BROADCAST_KEY_VALUE + DO_BROADCAST_KEY_VALUE; +#endif +#if INPUT1_SIMPLE + return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x); +#else +#if INPUT1_DIMS == 4 + return INPUT1_GET_INDEX_SAFE(b, f, y, x); +#elif INPUT1_DIMS == 5 + return INPUT1_GET_INDEX_SAFE(b, f, z, y, x); +#elif INPUT1_DIMS == 6 + return INPUT1_GET_INDEX_SAFE(b, f, w, z, y, x); +#else +# error sdpa_ref.cl : Unsupported input 1 format +#endif +#endif +} + +inline uint FUNC(get_input1_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#ifdef INPUT1_DIMS_ORDER + return FUNC_CALL(get_input1_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR INPUT1_DIMS_ORDER); +#else + return FUNC_CALL(get_input1_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x); +#endif +} + +inline uint FUNC(get_input2_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#ifdef DO_BROADCAST_KEY_VALUE + DO_BROADCAST_KEY_VALUE; +#endif +#if INPUT2_SIMPLE + return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, y, x); +#else +#if INPUT2_DIMS == 4 + return INPUT2_GET_INDEX_SAFE(b, f, y, x); +#elif INPUT2_DIMS == 5 + return INPUT2_GET_INDEX_SAFE(b, f, z, y, x); +#elif INPUT2_DIMS == 6 + return INPUT2_GET_INDEX_SAFE(b, f, w, z, y, x); +#else +# error sdpa_ref.cl : Unsupported input 1 format +#endif +#endif +} + +inline uint FUNC(get_input2_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#ifdef INPUT2_DIMS_ORDER + return FUNC_CALL(get_input2_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR INPUT2_DIMS_ORDER); +#else + return FUNC_CALL(get_input2_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x); +#endif +} + +#define APPLY_SCALE_TO_QUERY 1 + +KERNEL(sdpa_ref)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* query_input, + const __global INPUT1_TYPE* key_input, + const __global INPUT2_TYPE* value_input, +#if HAS_ATTN_MASK_INPUT + const __global INPUT3_TYPE* attn_mask, +#endif +#if HAS_SCALE_INPUT + const __global INPUT4_TYPE* scale, +#endif + __global OUTPUT_TYPE* output, + __global OUTPUT_TYPE* tmp_buf +) +{ + const uint batch_idx = get_global_id(0); + const uint b0 = batch_idx / NUM_HEADS; /* BATCH dim */ + const uint b1 = batch_idx % NUM_HEADS; /* HEADS_NUM dim */ + const uint target_seq_idx = get_global_id(1); + const uint head_size_idx = get_global_id(2); + +#if HAS_SCALE_INPUT + const OUTPUT_TYPE scale_val = *scale; +#else + const OUTPUT_TYPE scale_val = OUTPUT_VAL_ONE / sqrt(TO_OUTPUT_TYPE(INPUT1_SIZE_X)); +#endif + + // Process 1*seq_len elements (Gemm1 + SoftMax) using a single work item, saving results to tmp_buf and + // reusing them between all work items within a single workgroup for Gemm2 calculations. + if (get_local_id(2) == 0) { + for (uint s = 0; s < SOURCE_SEQ_LEN /* seq_len */; s++) { + OUTPUT_TYPE acc = 0; + for (uint h = 0; h < HEAD_SIZE /* head_size */; h++) { + uint query_offset = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0, b1, 0, 0, target_seq_idx, h); + uint key_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0, b1, 0, 0, s, h); + +#if APPLY_SCALE_TO_QUERY + INPUT0_TYPE q_val = query_input[query_offset] * scale_val; +#else + INPUT0_TYPE q_val = query_input[query_offset]; +#endif + INPUT1_TYPE k_val = key_input[key_offset]; + acc += q_val * k_val; + } + +#if !APPLY_SCALE_TO_QUERY + acc *= scale_val; +#endif + + uint tmp_buf_offset = b0 * (NUM_HEADS * TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + + b1 * (TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + + target_seq_idx * (SOURCE_SEQ_LEN) + s; + tmp_buf[tmp_buf_offset] = acc; + } + + ACCUMULATOR_TYPE qk_max = ACCUMULATOR_VAL_MIN; + for (uint s = 0; s < SOURCE_SEQ_LEN /* seq_len */; s++) { + uint tmp_buf_offset = b0 * (NUM_HEADS * TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + + b1 * (TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + + target_seq_idx * (SOURCE_SEQ_LEN) + s; +#if IS_CAUSAL + OUTPUT_TYPE attn_mask_val = s > target_seq_idx ? OUTPUT_VAL_MIN : 0; +#elif !IS_CAUSAL && HAS_ATTN_MASK_INPUT + uint attn_mask_offset = INPUT3_GET_INDEX_SAFE(b0, b1, target_seq_idx, s); + OUTPUT_TYPE attn_mask_val = attn_mask[attn_mask_offset]; +#else + OUTPUT_TYPE attn_mask_val = OUTPUT_VAL_ZERO; +#endif + + OUTPUT_TYPE qk_val = tmp_buf[tmp_buf_offset] + attn_mask_val; + tmp_buf[tmp_buf_offset] = qk_val; + + qk_max = ACCUMULATOR_MAX_FUNC(qk_max, TO_ACCUMULATOR_TYPE(qk_val)); + } + + ACCUMULATOR_TYPE exp_sum = ACCUMULATOR_VAL_ZERO; + for (uint s = 0; s < SOURCE_SEQ_LEN /* seq_len */; s++) { + uint tmp_buf_offset = b0 * (NUM_HEADS * TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + + b1 * (TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + + target_seq_idx * (SOURCE_SEQ_LEN) + s; + + OUTPUT_TYPE qk_val = tmp_buf[tmp_buf_offset]; + ACCUMULATOR_TYPE val = native_exp(TO_ACCUMULATOR_TYPE(qk_val) - qk_max); + exp_sum += val; + + tmp_buf[tmp_buf_offset] = TO_OUTPUT_TYPE(val); + } + + const ACCUMULATOR_TYPE inv_sum = ACCUMULATOR_VAL_ONE / exp_sum; + for (uint s = 0; s < SOURCE_SEQ_LEN /* seq_len */; s++) { + uint tmp_buf_offset = b0 * (NUM_HEADS * TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + + b1 * (TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + + target_seq_idx * (SOURCE_SEQ_LEN) + s; + + OUTPUT_TYPE qk_val = tmp_buf[tmp_buf_offset]; + ACCUMULATOR_TYPE val = TO_ACCUMULATOR_TYPE(qk_val) * inv_sum; + tmp_buf[tmp_buf_offset] = TO_OUTPUT_TYPE(val); + } + } + + barrier(CLK_GLOBAL_MEM_FENCE); + + OUTPUT_TYPE acc = 0; + for (uint s = 0; s < SOURCE_SEQ_LEN /* seq_len */; s++) { + uint tmp_buf_offset = b0 * (NUM_HEADS * TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + + b1 * (TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + + target_seq_idx * (SOURCE_SEQ_LEN) + s; + uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0, b1, 0, 0, s, head_size_idx); + + acc += tmp_buf[tmp_buf_offset] * value_input[value_offset]; + } + + uint output_offset = OUTPUT_GET_INDEX(b0, b1, target_seq_idx, head_size_idx); + output[output_offset] = acc; +} diff --git a/src/plugins/intel_gpu/src/kernel_selector/common_types.h b/src/plugins/intel_gpu/src/kernel_selector/common_types.h index c2a4ef1653472d..768a0fc3c4f854 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/common_types.h +++ b/src/plugins/intel_gpu/src/kernel_selector/common_types.h @@ -59,6 +59,7 @@ enum class KernelType { DEPTH_TO_SPACE, BATCH_TO_SPACE, SHAPE_OF, + SDPA, SHUFFLE_CHANNELS, SLICE, STRIDED_SLICE, diff --git a/src/plugins/intel_gpu/src/kernel_selector/jitter.cpp b/src/plugins/intel_gpu/src/kernel_selector/jitter.cpp index 084ae71e42732c..fcd35d13a3639b 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/jitter.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/jitter.cpp @@ -326,8 +326,8 @@ JitDefinitions DataTensorJitConstant::GetDefinitions() const { JitDefinitions baseDefinitions = TensorBaseTJitConstant::GetDefinitions(_tensor); JitDefinitions definitions{}; - DimensionAccessHelper dims(_tensor); - DimensionAccessHelper dims_padded(_tensor, true); + DimensionAccessHelperJit dims(_tensor); + DimensionAccessHelperJit dims_padded(_tensor, true); // shape_info layout // if only y has dynamic padding: // [dim_b, dim_f, dim_v, dim_u, dim_w, dim_z, dim_y, dim_x, pad_before_y, pad_after_y] diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernel_selector_utils.h b/src/plugins/intel_gpu/src/kernel_selector/kernel_selector_utils.h index e3e5f3dcc47a2d..2c8256b8551b89 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernel_selector_utils.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernel_selector_utils.h @@ -11,9 +11,9 @@ namespace kernel_selector { struct weight_bias_params; struct WeightsReorderParams; -struct DimensionAccessHelper { - explicit DimensionAccessHelper(const DataTensor& t, bool padded = false) { - std::vector dims = { +struct DimensionAccessHelperBase { + explicit DimensionAccessHelperBase(const DataTensor& t) { + dims = { t.Batch(), t.Feature(), t.U(), @@ -23,6 +23,23 @@ struct DimensionAccessHelper { t.Y(), t.X(), }; + } + + Tensor::Dim& x_dim() { return dims[7]; } + Tensor::Dim& y_dim() { return dims[6]; } + Tensor::Dim& z_dim() { return dims[5]; } + Tensor::Dim& w_dim() { return dims[4]; } + Tensor::Dim& v_dim() { return dims[3]; } + Tensor::Dim& u_dim() { return dims[2]; } + Tensor::Dim& f_dim() { return dims[1]; } + Tensor::Dim& b_dim() { return dims[0]; } + + std::vector dims; +}; + +struct DimensionAccessHelperJit : virtual DimensionAccessHelperBase { + explicit DimensionAccessHelperJit(const DataTensor& t, bool padded = false) + : DimensionAccessHelperBase(t) { size_t dyn_shape_offset = t.get_dynamic_shape_offset(); size_t dyn_pad_offset = dyn_shape_offset + DataTensor::max_rank(); for (auto d : dims) { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.cpp index 2d878e4a9f28e1..ecb6be6f17020d 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.cpp @@ -22,7 +22,7 @@ size_t getOperationNumber(const arg_max_min_params& params) { std::string getOperationNumberString(const arg_max_min_params& params) { const auto& output = params.outputs[0]; - DimensionAccessHelper dims(output); + DimensionAccessHelperJit dims(output); switch (params.argMaxMinAxis) { case ArgMaxMinAxis::BATCH: return toVectorMulString({dims.x(), dims.y(), dims.z(), dims.f()}); case ArgMaxMinAxis::FEATURE: return toVectorMulString({dims.x(), dims.y(), dims.z(), dims.b()}); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_base.cpp index 07734e85b9dd4a..cdd8d7fc56e39e 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_base.cpp @@ -15,7 +15,7 @@ JitConstants FullyConnectedKernelBase::GetJitConstants(const fully_connected_par JitConstants jit = WeightBiasKernelBase::GetJitConstants(params); const auto& input = params.inputs[0]; if (input.is_dynamic()) { - DimensionAccessHelper dims(input); + DimensionAccessHelperJit dims(input); jit.AddConstant(MakeJitConstant("INPUT0_ELEMENTS_COUNT", toVectorMulString({dims.x(), dims.y(), dims.z(), dims.w(), dims.f()}))); } else { const auto x_size = input.LogicalSize() / input.Batch().v; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp index b367e40308104d..0faec1c9d13696 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp @@ -135,10 +135,10 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons jit.Merge(MakeTypeJitConstants(params.inputs[0].GetDType(), "ACCUMULATOR")); if (params.has_dynamic_tensors()) { - DimensionAccessHelper dims0(params.inputs[0]); - DimensionAccessHelper dims1(params.inputs[1]); - DimensionAccessHelper dims0_padded(params.inputs[0], true); - DimensionAccessHelper dims1_padded(params.inputs[1], true); + DimensionAccessHelperJit dims0(params.inputs[0]); + DimensionAccessHelperJit dims1(params.inputs[1]); + DimensionAccessHelperJit dims0_padded(params.inputs[0], true); + DimensionAccessHelperJit dims1_padded(params.inputs[1], true); // Note: Actually currently this kernel is not being selected if it is shape agnostic impl && transposed inputs // Because we cannot get the original rank auto input0_dims = ConvTo8dims(params.input0_order); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/mvn/mvn_kernel_bfyx_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/mvn/mvn_kernel_bfyx_opt.cpp index 806bb90ba67b43..923bd98814a46f 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/mvn/mvn_kernel_bfyx_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/mvn/mvn_kernel_bfyx_opt.cpp @@ -80,7 +80,7 @@ JitConstants MVNKernelBfyxOpt::GetJitConstants(const mvn_params& params, MVNKern if (params.has_dynamic_tensors()) { const auto& input = params.inputs[0]; - DimensionAccessHelper dims(input); + DimensionAccessHelperJit dims(input); std::string data_set_size; std::string data_set_count; if (params.mvnMode == MVNMode::WITHIN_CHANNELS) { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/non_zero/count_nonzero_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/non_zero/count_nonzero_kernel_ref.cpp index d3132e4357fa07..7e6c1397b988e4 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/non_zero/count_nonzero_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/non_zero/count_nonzero_kernel_ref.cpp @@ -76,7 +76,7 @@ KernelsData CountNonzeroKernelRef::GetKernelsData(const Params& params) const { auto cldnn_jit = MakeBaseParamsJitConstants(newParams); if (newParams.has_dynamic_tensors()) { const auto& input = newParams.inputs[0]; - DimensionAccessHelper dims(input); + DimensionAccessHelperJit dims(input); const std::string total_data_size = toVectorMulString({dims.x(), dims.y(), dims.z(), dims.w(), dims.f(), dims.b()}); cldnn_jit.AddConstants({MakeJitConstant("DATA_SIZE", total_data_size)}); } else { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/non_zero/gather_nonzero_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/non_zero/gather_nonzero_kernel_ref.cpp index 0672566e0ed2ad..bac2237893bef3 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/non_zero/gather_nonzero_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/non_zero/gather_nonzero_kernel_ref.cpp @@ -46,7 +46,7 @@ JitConstants GatherNonzeroKernelRef::GetJitConstants(const gather_nonzero_params jit.AddConstant(MakeJitConstant("MAX_LOCAL_MEM_SIZE", max_local_mem_size)); if (input.is_dynamic()) { - DimensionAccessHelper dims(input); + DimensionAccessHelperJit dims(input); const std::string total_data_size = toVectorMulString({dims.x(), dims.y(), dims.z(), dims.w(), dims.f(), dims.b()}); jit.AddConstant(MakeJitConstant("TOTAL_DATA_SIZE", total_data_size)); } else { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/permute/permute_kernel_tile_8x8_4x4.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/permute/permute_kernel_tile_8x8_4x4.cpp index 13eb399ef8ef4d..06ee5a2bc4b6ef 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/permute/permute_kernel_tile_8x8_4x4.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/permute/permute_kernel_tile_8x8_4x4.cpp @@ -107,7 +107,7 @@ static inline std::string GetTiledOutputOrder(const permute_params& params) { std::string out_z_str = ""; const auto& output = params.outputs[0]; if (params.has_dynamic_outputs()) { - DimensionAccessHelper dims(output); + DimensionAccessHelperJit dims(output); out_y_str = dims.y(); out_z_str = dims.z(); } else { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/reduce/reduce_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/reduce/reduce_kernel_base.cpp index 80e16939bab248..318daac3b5b30e 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/reduce/reduce_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/reduce/reduce_kernel_base.cpp @@ -30,7 +30,7 @@ JitConstants ReduceKernelBase::GetJitConstants(const reduce_params& params) cons const auto& output = params.outputs[0]; if (output.is_dynamic()) { - DimensionAccessHelper dims(output); + DimensionAccessHelperJit dims(output); jit.AddConstant(MakeJitConstant("COMPUTATIONAL_OPERATIONS_NUMBER", toVectorMulString({dims.x(), dims.y(), dims.z(), diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.cpp index db5e8c6beb1588..15043ef2624053 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.cpp @@ -30,7 +30,7 @@ JitConstants RMSKernelBfyxOpt::GetJitConstants(const rms_params& params, Dispatc if (params.has_dynamic_tensors()) { const auto& input = params.inputs[0]; - DimensionAccessHelper dims(input); + DimensionAccessHelperJit dims(input); std::string data_size; switch (params.ov_input_rank) { case 1 : diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp new file mode 100644 index 00000000000000..a66618aa1f3f95 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp @@ -0,0 +1,126 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "sdpa_kernel_base.h" +#include "kernel_selector_utils.h" + +namespace kernel_selector { + +static std::string GetDimsOrder(const std::vector& order_idx) { + auto get_order_idx = [](std::vector order_idx, int64_t dim_idx) { + int loc = 0; + for (auto idx : order_idx) { + if (idx == dim_idx) + break; + loc += 1; + } + return loc; + }; + + std::string dims_order = ""; + if (order_idx.size() == 2) { + const std::vector dims2 = {"y", "x"}; + dims_order = "b,f,w,z," + + dims2[get_order_idx(order_idx, 0)] + "," + dims2[get_order_idx(order_idx, 1)]; + } else if (order_idx.size() == 3) { + const std::vector dims3 = {"f", "y", "x"}; + dims_order = "b," + dims3[get_order_idx(order_idx, 0)] + ",w,z," + + dims3[get_order_idx(order_idx, 1)] + "," + dims3[get_order_idx(order_idx, 2)]; + } else if (order_idx.size() == 4) { + const std::vector dims4 = {"b", "f", "y", "x"}; + dims_order = dims4[get_order_idx(order_idx, 0)] + "," + dims4[get_order_idx(order_idx, 1)] + ",w,z," + + dims4[get_order_idx(order_idx, 2)] + "," + dims4[get_order_idx(order_idx, 3)]; + } else if (order_idx.size() == 5) { + const std::vector dims5 = {"b", "f", "z", "y", "x"}; + dims_order = dims5[get_order_idx(order_idx, 0)] + "," + dims5[get_order_idx(order_idx, 1)] + ",w," + + dims5[get_order_idx(order_idx, 2)] + "," + dims5[get_order_idx(order_idx, 3)] + "," + + dims5[get_order_idx(order_idx, 4)]; + } else if (order_idx.size() == 6) { + const std::vector dims6 = {"b", "f", "w", "z", "y", "x"}; + dims_order = dims6[get_order_idx(order_idx, 0)] + "," + dims6[get_order_idx(order_idx, 1)] + "," + + dims6[get_order_idx(order_idx, 2)] + "," + dims6[get_order_idx(order_idx, 3)] + "," + + dims6[get_order_idx(order_idx, 4)] + "," + dims6[get_order_idx(order_idx, 5)]; + } else { + dims_order = "b,f,w,z,y,x"; + } + return dims_order; +} + +static std::string GetBroadcastInputStr(const size_t input_rank, const int64_t axes, const int64_t val) { + std::vector dims; + if (input_rank == 1) { + dims = {"x"}; + } else if (input_rank == 2) { + dims = {"y", "x"}; + } else if (input_rank == 3) { + dims = {"f", "y", "x"}; + } else if (input_rank == 4) { + dims = {"b", "f", "y", "x"}; + } else if (input_rank == 5) { + dims = {"b", "f", "z", "y", "x"}; + } else if (input_rank == 6) { + dims = {"b", "f", "w", "z", "y", "x"}; + } + return dims[axes] + " /= " + std::to_string(val) + ";"; +} + +JitConstants SDPAKernelBase::GetJitConstants(const sdpa_params& params) const { + auto jit = MakeBaseParamsJitConstants(params); + + if (params.conf.broadcast_axis != -1) { + jit.AddConstant(MakeJitConstant("DO_BROADCAST_KEY_VALUE", GetBroadcastInputStr(params.inputs[0].GetDims().size(), + params.conf.broadcast_axis, + params.conf.group_size))); + } + + jit.AddConstant(MakeJitConstant("IS_CAUSAL", params.conf.is_causal)); + jit.AddConstant(MakeJitConstant("HAS_ATTN_MASK_INPUT", params.inputs.size() > 3)); + jit.AddConstant(MakeJitConstant("HAS_SCALE_INPUT", params.inputs.size() > 4)); + + auto is_default_order = [](const std::vector& order) { + for (size_t i = 0; i < order.size(); i++) + if (order[i] != static_cast(i)) + return false; + return true; + }; + + if ((!params.input0_order.empty() && !is_default_order(params.input0_order)) || params.conf.broadcast_axis != -1) { + jit.AddConstant(MakeJitConstant("INPUT0_DIMS_ORDER", GetDimsOrder(params.input0_order))); + } + if ((!params.input1_order.empty() && !is_default_order(params.input1_order)) || params.conf.broadcast_axis != -1) { + jit.AddConstant(MakeJitConstant("INPUT1_DIMS_ORDER", GetDimsOrder(params.input1_order))); + } + if ((!params.input2_order.empty() && !is_default_order(params.input2_order)) || params.conf.broadcast_axis != -1) { + jit.AddConstant(MakeJitConstant("INPUT2_DIMS_ORDER", GetDimsOrder(params.input2_order))); + } + + TransposedDimensionAccessHelperJit dims_q(params.inputs[0], params.input0_order); + jit.AddConstant(MakeJitConstant("TARGET_SEQ_LEN", dims_q.y())); + jit.AddConstant(MakeJitConstant("HEAD_SIZE", dims_q.x())); + jit.AddConstant(MakeJitConstant("NUM_HEADS", dims_q.f())); + + TransposedDimensionAccessHelperJit dims_k(params.inputs[1], params.input1_order); + jit.AddConstant(MakeJitConstant("SOURCE_SEQ_LEN", dims_k.y())); + + return jit; +} + +bool SDPAKernelBase::Validate(const Params& p) const { + if (p.GetType() != KernelType::SDPA) { + return false; + } + + const sdpa_params& params = static_cast(p); + + for (size_t i = 0; i < params.inputs.size(); i++) { + if (params.inputs[i].Dimentions() != 4) + return false; + } + + if (params.outputs[0].Dimentions() != 4) + return false; + + return true; +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h new file mode 100644 index 00000000000000..319742860c6f73 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h @@ -0,0 +1,111 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_base_opencl.h" +#include "kernel_selector_params.h" +#include "kernel_selector_utils.h" +#include + +namespace kernel_selector { +struct TransposedDimensionAccessHelperBase : virtual DimensionAccessHelperBase { + explicit TransposedDimensionAccessHelperBase(const DataTensor& t, std::vector order) + : DimensionAccessHelperBase(t) { + size_t total_dims_count = dims.size(); + size_t new_axis_count = total_dims_count - order.size(); + + transposed_order.resize(total_dims_count); + std::iota(transposed_order.begin(), transposed_order.end(), 0); + for (size_t i = 0; i < order.size(); i++) { + size_t transposed_order_pos = i < 2 ? i : i + new_axis_count; + transposed_order[transposed_order_pos] = order[i] < 2 ? order[i] : order[i] + new_axis_count; + } + } + + Tensor::Dim& x_dim() { return dims[transposed_order[7]]; } + Tensor::Dim& y_dim() { return dims[transposed_order[6]]; } + Tensor::Dim& z_dim() { return dims[transposed_order[5]]; } + Tensor::Dim& w_dim() { return dims[transposed_order[4]]; } + Tensor::Dim& v_dim() { return dims[transposed_order[3]]; } + Tensor::Dim& u_dim() { return dims[transposed_order[2]]; } + Tensor::Dim& f_dim() { return dims[transposed_order[1]]; } + Tensor::Dim& b_dim() { return dims[transposed_order[0]]; } + + std::vector transposed_order; +}; + +struct TransposedDimensionAccessHelperJit : DimensionAccessHelperJit, TransposedDimensionAccessHelperBase { + explicit TransposedDimensionAccessHelperJit(const DataTensor& t, std::vector order, bool padded = false) + : DimensionAccessHelperBase(t) + , DimensionAccessHelperJit(t, padded) + , TransposedDimensionAccessHelperBase(t, order) {} + + std::string x() { return dims_sizes[transposed_order[7]]; } + std::string y() { return dims_sizes[transposed_order[6]]; } + std::string z() { return dims_sizes[transposed_order[5]]; } + std::string w() { return dims_sizes[transposed_order[4]]; } + std::string v() { return dims_sizes[transposed_order[3]]; } + std::string u() { return dims_sizes[transposed_order[2]]; } + std::string f() { return dims_sizes[transposed_order[1]]; } + std::string b() { return dims_sizes[transposed_order[0]]; } + + std::pair x_pad() { return {pad_before_after_sizes[(transposed_order[7] * 2) + 0], pad_before_after_sizes[(transposed_order[7] * 2) + 1]}; } + std::pair y_pad() { return {pad_before_after_sizes[(transposed_order[6] * 2) + 0], pad_before_after_sizes[(transposed_order[6] * 2) + 1]}; } + std::pair z_pad() { return {pad_before_after_sizes[(transposed_order[5] * 2) + 0], pad_before_after_sizes[(transposed_order[5] * 2) + 1]}; } + std::pair w_pad() { return {pad_before_after_sizes[(transposed_order[4] * 2) + 0], pad_before_after_sizes[(transposed_order[4] * 2) + 1]}; } + std::pair v_pad() { return {pad_before_after_sizes[(transposed_order[3] * 2) + 0], pad_before_after_sizes[(transposed_order[3] * 2) + 1]}; } + std::pair u_pad() { return {pad_before_after_sizes[(transposed_order[2] * 2) + 0], pad_before_after_sizes[(transposed_order[2] * 2) + 1]}; } + std::pair f_pad() { return {pad_before_after_sizes[(transposed_order[1] * 2) + 0], pad_before_after_sizes[(transposed_order[1] * 2) + 1]}; } + std::pair b_pad() { return {pad_before_after_sizes[(transposed_order[0] * 2) + 0], pad_before_after_sizes[(transposed_order[0] * 2) + 1]}; } +}; + +struct GQA_configuration { +}; + +struct sdpa_configuration { + int64_t head_size = -1; + int64_t heads_num = -1; + int64_t kv_heads_num = -1; + + // GQA configuration + int64_t group_size = -1; + int64_t broadcast_axis = -1; + + bool is_causal = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// sdpa_params +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +struct sdpa_params : public base_params { + sdpa_params() : base_params(KernelType::SDPA) {} + + std::vector input0_order; + std::vector input1_order; + std::vector input2_order; + std::vector output_order; + + sdpa_configuration conf; +}; + +struct sdpa_fuse_params : fuse_params { + sdpa_fuse_params() : fuse_params(KernelType::SDPA) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// SDPAKernelBase +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class SDPAKernelBase : public KernelBaseOpenCL { +public: + using KernelBaseOpenCL::KernelBaseOpenCL; + virtual ~SDPAKernelBase() {} + + struct DispatchData : public CommonDispatchData {}; + +protected: + bool Validate(const Params& p) const override; + JitConstants GetJitConstants(const sdpa_params& params) const; +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp new file mode 100644 index 00000000000000..b64c9ffa16ce5d --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp @@ -0,0 +1,258 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "sdpa_kernel_opt.h" +#include "kernel_selector_utils.h" +#include +#include + +namespace kernel_selector { + +constexpr size_t subgroup_size = 16; + +enum KernelsTypes { + SINGLE_TOKEN = 0, + MULTI_TOKENS, + FINALIZATION, + TOTAL_KERNELS_NUM +}; + +static size_t get_target_seq_len_block_size() { + const size_t block_size = 16; + return block_size; +} + + +static size_t get_seq_len_partition_size() { + const size_t seq_len = 256; + return seq_len; +} + +ParamsKey SDPAKernelOpt::GetSupportedKey() const { + ParamsKey k; + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + + k.EnableDifferentTypes(); + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDynamicShapesSupport(); + + return k; +} + +bool SDPAKernelOpt::Validate(const Params& p) const { + if (!Parent::Validate(p)) + return false; + + const sdpa_params& params = static_cast(p); + + if (params.conf.head_size < 1 || params.conf.head_size % subgroup_size != 0) + return false; + + return true; +} + +JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t kernel_idx) const { + auto jit = SDPAKernelBase::GetJitConstants(params); + + const auto softmax_acc_dt = params.inputs[0].GetDType(); + jit.Merge(MakeTypeJitConstants(softmax_acc_dt, "SOFTMAX_ACCUMULATOR")); + + const auto& config = params.conf; + jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size)); + jit.AddConstant(MakeJitConstant("HEAD_SIZE", config.head_size)); + jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", get_seq_len_partition_size())); + + auto target_seq_len_block_size = kernel_idx == KernelsTypes::SINGLE_TOKEN ? 1 : get_target_seq_len_block_size(); + jit.AddConstant(MakeJitConstant("TARGET_SEQ_LEN_BLOCK_SIZE", target_seq_len_block_size)); + + auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION ? 1 : 0; + jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1)); + + return jit; +} + +CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t kernel_idx) const { + CommonDispatchData dispatch_data; + + const auto& query_input = params.inputs[0]; + + if (!query_input.is_dynamic()) { + TransposedDimensionAccessHelperBase dims_q(params.inputs[0], params.input0_order); + TransposedDimensionAccessHelperBase dims_k(params.inputs[1], params.input1_order); + TransposedDimensionAccessHelperBase output(params.outputs[0], params.output_order); + + const size_t batch_size = output.b_dim().v; + const size_t heads_num = output.f_dim().v; + const size_t source_seq_len = dims_k.y_dim().v; + const size_t target_seq_len = dims_q.y_dim().v; + const size_t head_size = static_cast(params.conf.head_size); + const size_t num_of_partitions = CeilDiv(source_seq_len, get_seq_len_partition_size()); + const size_t target_seq_len_block_size = kernel_idx == 1 ? get_target_seq_len_block_size() : 1; + + if (kernel_idx == KernelsTypes::SINGLE_TOKEN || kernel_idx == KernelsTypes::MULTI_TOKENS) { + dispatch_data.gws = { batch_size * heads_num, + CeilDiv(target_seq_len, target_seq_len_block_size), + head_size * num_of_partitions }; + dispatch_data.lws = { 1, 1, head_size }; + } else if (kernel_idx == 2) { + dispatch_data.gws = { batch_size * heads_num, + target_seq_len, + 16 }; + dispatch_data.lws = { 1, 1, 16 }; + } + } + + return dispatch_data; +} + +KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const { + if (!Validate(params)) { + return {}; + } + + // Implementation contains multiple kernels: + // kernel[0] - single token generation stage (2nd token) + // kernel[1] - multi tokens processing stage (1st token) + // kernel[2] - results aggregation + + const size_t kernels_num = KernelsTypes::TOTAL_KERNELS_NUM; + KernelData kd = KernelData::Default(params, kernels_num); + kd.needs_sub_kernels_sync = true; + + GetUpdateDispatchDataFunc(kd); + + const auto& prim_params = dynamic_cast(params); + for (size_t kernel_idx = 0; kernel_idx < kernels_num; kernel_idx++) { + auto dispatch_data = SetDefault(prim_params, kernel_idx); + auto kernel_name = kernel_idx == 0 ? kernelName + "_single_token" : + kernel_idx == 1 ? kernelName + "_multi_tokens" : kernelName + "_finalization"; + auto entry_point = GetEntryPoint(kernel_name, prim_params.layerID, params); + auto jit_constants = GetJitConstants(prim_params, kernel_idx); + auto jit = CreateJit(kernel_name, jit_constants, entry_point); + + auto& kernel = kd.kernels[kernel_idx]; + + auto inputs_num = + kernel_idx == KernelsTypes::FINALIZATION ? 0 : static_cast(prim_params.inputs.size()); + + FillCLKernelData(kernel, + dispatch_data, + params.engineInfo, + kernelName, + jit, + entry_point, + {}, + false, + false, + inputs_num, + GetFusedPrimitiveInputsCount(params), + static_cast(prim_params.outputs.size()), + prim_params.is_shape_agnostic); + + const auto num_of_partitions = 1; + auto& output = prim_params.outputs[0]; + auto head_size = output.X().v; + + auto buf_dt_size = 4; + auto buf_elements_count = (num_of_partitions == 1) ? 1 : output.LogicalSize() / head_size * num_of_partitions; + auto buf_size = buf_elements_count * buf_dt_size; + + auto tmp_out_dt_size = 4; + auto tmp_out_elements_count = (num_of_partitions == 1) ? 1 : output.LogicalSize() * num_of_partitions; + auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size; + + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); + + kd.internalBufferSizes.clear(); + kd.internalBufferSizes.push_back(buf_size); + kd.internalBufferSizes.push_back(buf_size); + kd.internalBufferSizes.push_back(tmp_out_size); + kd.internalBufferDataType = prim_params.inputs[0].GetDType(); + + if (kernel_idx == KernelsTypes::FINALIZATION) { + kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0}); + + ScalarDescriptor num_of_partitions_scalar; + num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32; + num_of_partitions_scalar.v.u32 = num_of_partitions; + + kernel.params.scalars.clear(); + kernel.params.scalars.push_back(num_of_partitions_scalar); + } + } + + return { kd }; +} + +void SDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) const { + kd.update_dispatch_data_func = [this](const Params& params, KernelData& kernel_data) { + const auto& prim_params = static_cast(params); + + const size_t expected_kernels_num = KernelsTypes::TOTAL_KERNELS_NUM; + OPENVINO_ASSERT(kernel_data.kernels.size() == expected_kernels_num, + "[GPU] Invalid kernels size for update dispatch data func of SDPA kernel"); + + TransposedDimensionAccessHelperBase dims_q(prim_params.inputs[0], prim_params.input0_order); + TransposedDimensionAccessHelperBase dims_k(prim_params.inputs[1], prim_params.input1_order); + auto& output = prim_params.outputs[0]; + + auto target_seq_len = dims_q.y_dim().v; + auto head_size = dims_q.x_dim().v; + auto source_seq_len = dims_k.y_dim().v; + + auto num_of_partitions = CeilDiv(source_seq_len, get_seq_len_partition_size()); + + auto buf_dt_size = output.ElementSize(); + auto buf_elements_count = (num_of_partitions == 1) ? 1 : output.LogicalSize() / head_size * num_of_partitions; + auto buf_size = buf_elements_count * buf_dt_size; + + auto tmp_out_dt_size = output.ElementSize(); + auto tmp_out_elements_count = (num_of_partitions == 1) ? 1 : output.LogicalSize() * num_of_partitions; + auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size; + + auto dispatch_data1 = SetDefault(prim_params, 0); + kernel_data.kernels[0].params.workGroups.global = dispatch_data1.gws; + kernel_data.kernels[0].params.workGroups.local = dispatch_data1.lws; + kernel_data.kernels[0].skip_execution = target_seq_len > 1; + + auto dispatch_data2 = SetDefault(prim_params, 1); + kernel_data.kernels[1].params.workGroups.global = dispatch_data2.gws; + kernel_data.kernels[1].params.workGroups.local = dispatch_data2.lws; + kernel_data.kernels[1].skip_execution = target_seq_len == 1; + + ScalarDescriptor num_of_partitions_scalar; + num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32; + num_of_partitions_scalar.v.u32 = num_of_partitions; + + auto dispatch_data3 = SetDefault(prim_params, 2); + kernel_data.kernels[2].params.workGroups.global = dispatch_data3.gws; + kernel_data.kernels[2].params.workGroups.local = dispatch_data3.lws; + kernel_data.kernels[2].skip_execution = num_of_partitions == 1; + + kernel_data.kernels[2].params.scalars.clear(); + kernel_data.kernels[2].params.scalars.push_back(num_of_partitions_scalar); + + kernel_data.internalBufferSizes.clear(); + kernel_data.internalBufferSizes.push_back(buf_size); + kernel_data.internalBufferSizes.push_back(buf_size); + kernel_data.internalBufferSizes.push_back(tmp_out_size); + kernel_data.internalBufferDataType = prim_params.inputs[0].GetDType(); + }; +} + +KernelsPriority SDPAKernelOpt::GetKernelsPriority(const Params& /*params*/) const { + return FORCE_PRIORITY_1; +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h new file mode 100644 index 00000000000000..8d7279f5546112 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h @@ -0,0 +1,29 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "sdpa_kernel_base.h" + +namespace kernel_selector { +class SDPAKernelOpt : public SDPAKernelBase { +public: + using Parent = SDPAKernelBase; + SDPAKernelOpt() : SDPAKernelBase("sdpa_opt") {} + virtual ~SDPAKernelOpt() {} + + KernelsData GetKernelsData(const Params& params) const override; + KernelsPriority GetKernelsPriority(const Params& params) const override; + ParamsKey GetSupportedKey() const override; + +protected: + bool Validate(const Params& p) const override; + void GetUpdateDispatchDataFunc(KernelData& kd) const override; + CommonDispatchData SetDefault(const sdpa_params& params, size_t kernel_idx) const; + JitConstants GetJitConstants(const sdpa_params& params, size_t kernel_idx) const; + std::vector GetSupportedFusedOps() const override { + return {}; + } +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp new file mode 100644 index 00000000000000..5ea3ccd4224c7c --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp @@ -0,0 +1,107 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "sdpa_kernel_ref.h" +#include "kernel_selector_utils.h" +#include +#include + +namespace kernel_selector { + +ParamsKey SDPAKernelRef::GetSupportedKey() const { + ParamsKey k; + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + + k.EnableDifferentTypes(); + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDynamicShapesSupport(); + + return k; +} + +JitConstants SDPAKernelRef::GetJitConstants(const sdpa_params& params) const { + auto jit = SDPAKernelBase::GetJitConstants(params); + + auto acc_dt = params.inputs[0].GetDType(); + jit.Merge(MakeTypeJitConstants(acc_dt, "ACCUMULATOR")); + + return jit; +} + +CommonDispatchData SDPAKernelRef::SetDefault(const sdpa_params& params) const { + CommonDispatchData dispatchData; + + const auto& output = params.outputs[0]; + dispatchData.gws = { output.Batch().v * output.Feature().v, output.Y().v, output.X().v }; + dispatchData.lws = { 1, 1, output.X().v }; + + return dispatchData; +} + +KernelsData SDPAKernelRef::GetKernelsData(const Params& params) const { + KernelData kd = KernelData::Default(params); + const auto& prim_params = dynamic_cast(params); + + if (!Validate(params)) { + return {}; + } + + auto dispatchData = SetDefault(prim_params); + auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, params); + auto cldnn_jit = GetJitConstants(prim_params); + auto jit = CreateJit(kernelName, cldnn_jit, entry_point); + + auto& kernel = kd.kernels[0]; + + GetUpdateDispatchDataFunc(kd); + + FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, + "", false, false, static_cast(prim_params.inputs.size()), + GetFusedPrimitiveInputsCount(params), 1, prim_params.is_shape_agnostic); + + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); + + kd.internalBufferSizes.clear(); + kd.internalBufferSizes.push_back(prim_params.inputs[0].ElementSize()); + kd.internalBufferDataType = prim_params.inputs[0].GetDType(); + + return { kd }; +} + +void SDPAKernelRef::GetUpdateDispatchDataFunc(KernelData& kd) const { + kd.update_dispatch_data_func = [this](const Params& params, KernelData& kernel_data) { + const auto& prim_params = static_cast(params); + auto dispatchData = SetDefault(prim_params); + OPENVINO_ASSERT(kernel_data.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func"); + kernel_data.kernels[0].params.workGroups.global = dispatchData.gws; + kernel_data.kernels[0].params.workGroups.local = dispatchData.lws; + kernel_data.kernels[0].skip_execution = KernelData::SkipKernelExecution(prim_params); + + auto& in_q = prim_params.inputs[0]; + auto& in_k = prim_params.inputs[1]; + TransposedDimensionAccessHelperBase dims_q(in_q, prim_params.input0_order); + TransposedDimensionAccessHelperBase dims_k(in_k, prim_params.input1_order); + + auto elem_size = in_q.ElementSize(); + auto batch_size = in_q.LogicalSize() / dims_q.x_dim().v / dims_q.y_dim().v; + kernel_data.internalBufferSizes.clear(); + kernel_data.internalBufferSizes.push_back(batch_size * dims_q.y_dim().v * dims_k.y_dim().v * elem_size); + + kernel_data.internalBufferDataType = in_q.GetDType(); + }; +} + +KernelsPriority SDPAKernelRef::GetKernelsPriority(const Params& /*params*/) const { + return DONT_USE_IF_HAVE_SOMETHING_ELSE; +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.h new file mode 100644 index 00000000000000..c570f32cc1e94e --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.h @@ -0,0 +1,28 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "sdpa_kernel_base.h" + +namespace kernel_selector { +class SDPAKernelRef : public SDPAKernelBase { +public: + using Parent = SDPAKernelBase; + SDPAKernelRef() : SDPAKernelBase("sdpa_ref") {} + virtual ~SDPAKernelRef() {} + + KernelsData GetKernelsData(const Params& params) const override; + KernelsPriority GetKernelsPriority(const Params& params) const override; + ParamsKey GetSupportedKey() const override; + +protected: + void GetUpdateDispatchDataFunc(KernelData& kd) const override; + CommonDispatchData SetDefault(const sdpa_params& params) const; + JitConstants GetJitConstants(const sdpa_params& params) const; + std::vector GetSupportedFusedOps() const override { + return {}; + } +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp new file mode 100644 index 00000000000000..b58f04f23e2643 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "sdpa_kernel_selector.h" +#include "sdpa_kernel_ref.h" +#include "sdpa_kernel_opt.h" + +namespace kernel_selector { + +sdpa_kernel_selector::sdpa_kernel_selector() { + Attach(); + Attach(); +} + +KernelsData sdpa_kernel_selector::GetBestKernels(const Params& params) const { + return GetNaiveBestKernel(params, KernelType::SDPA); +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.h new file mode 100644 index 00000000000000..e4a5f245bfe18b --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.h @@ -0,0 +1,23 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_selector.h" + +namespace kernel_selector { +class sdpa_kernel_selector : public kernel_selector_base { +public: + static sdpa_kernel_selector& Instance() { + static sdpa_kernel_selector instance_; + return instance_; + } + + sdpa_kernel_selector(); + + virtual ~sdpa_kernel_selector() {} + + KernelsData GetBestKernels(const Params& params) const override; +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/slice/slice_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/slice/slice_kernel_ref.cpp index ee6f39c3c3c71e..34279dd7de148c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/slice/slice_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/slice/slice_kernel_ref.cpp @@ -122,7 +122,7 @@ JitConstants SliceKernelRef::GetJitConstants(const slice_params& params) const { // Define axes size as constant: if (params.compile_time_axes.empty()) { - kernel_selector::DimensionAccessHelper dims(params.inputs.back()); + kernel_selector::DimensionAccessHelperJit dims(params.inputs.back()); jit.AddConstant(MakeJitConstant(JIT_AXES_BUFF_SIZE_NAME, toVectorMulString({dims.b(), dims.f(), dims.x(), dims.y(), dims.z()}))); } else { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/softmax/softmax_kernel_bf.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/softmax/softmax_kernel_bf.cpp index 335c2bc1017303..338ed8d3fb1077 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/softmax/softmax_kernel_bf.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/softmax/softmax_kernel_bf.cpp @@ -115,7 +115,7 @@ JitConstants SoftmaxKernel_bf::GetJitConstants(const softmax_params& params, Dis if (params.has_dynamic_tensors()) { const auto& input = params.inputs[0]; - DimensionAccessHelper dims(input); + DimensionAccessHelperJit dims(input); auto softmax_dim_y_bfyx = (params.dim == SoftmaxDim::Y && input.GetLayout() == DataLayout::bfyx); auto softmax_dim_x_bfyx = (params.dim == SoftmaxDim::X && input.GetLayout() == DataLayout::bfyx); const std::string lws_0 = "get_local_size(0)"; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/unique/unique_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/unique/unique_kernel_ref.cpp index 5aafdd309ae6d0..5d20503919241b 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/unique/unique_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/unique/unique_kernel_ref.cpp @@ -216,7 +216,7 @@ JitConstants UniqueCountKernelRef::GetJitConstants(const unique_count_params& ke } if (input.is_dynamic()) { - DimensionAccessHelper dims(input); + DimensionAccessHelperJit dims(input); const std::string total_data_size = toVectorMulString({dims.x(), dims.y(), dims.z(), dims.w(), dims.f(), dims.b()}); jit_constants.AddConstant(MakeJitConstant("TOTAL_DATA_SIZE", total_data_size)); @@ -326,7 +326,7 @@ JitConstants UniqueGatherKernelRef::GetJitConstants(const unique_gather_params& } if (input.is_dynamic()) { - DimensionAccessHelper dims(input); + DimensionAccessHelperJit dims(input); const std::string total_data_size = toVectorMulString({dims.x(), dims.y(), dims.z(), dims.w(), dims.f(), dims.b()}); jit_constants.AddConstant(MakeJitConstant("TOTAL_DATA_SIZE", total_data_size)); diff --git a/src/plugins/intel_gpu/src/plugin/ops/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/scaled_dot_product_attention.cpp new file mode 100644 index 00000000000000..c07c501a1f970b --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/ops/scaled_dot_product_attention.cpp @@ -0,0 +1,59 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "intel_gpu/plugin/program_builder.hpp" +#include "intel_gpu/plugin/common_utils.hpp" + +#include "intel_gpu/op/sdpa.hpp" + +#include "openvino/op/scaled_dot_product_attention.hpp" + +#include "intel_gpu/primitives/scaled_dot_product_attention.hpp" + +namespace ov { +namespace op { +namespace internal { +using SDPA = ov::intel_gpu::op::SDPA; +} // namespace internal +} // namespace op +} // namespace ov + +namespace ov { +namespace intel_gpu { + +static void CreateScaledDotProductAttentionOp(ProgramBuilder& p, const std::shared_ptr& op) { + validate_inputs_count(op, {3, 4, 5}); + auto inputs = p.GetInputInfo(op); + auto layerName = layer_type_name_ID(op); + + bool is_causal = op->get_causal(); + auto sdpa_prim = cldnn::scaled_dot_product_attention(layerName, + inputs, + is_causal); + + p.add_primitive(*op, sdpa_prim); +} + +static void CreateSDPAOp(ProgramBuilder& p, const std::shared_ptr& op) { + validate_inputs_count(op, {3, 4, 5}); + auto inputs = p.GetInputInfo(op); + auto layerName = layer_type_name_ID(op); + + bool is_causal = op->get_causal(); + auto sdpa_prim = cldnn::scaled_dot_product_attention(layerName, + inputs, + is_causal, + op->get_input0_transpose_order(), + op->get_input1_transpose_order(), + op->get_input2_transpose_order(), + op->get_output_transpose_order()); + + p.add_primitive(*op, sdpa_prim); +} + +REGISTER_FACTORY_IMPL(internal, SDPA); +REGISTER_FACTORY_IMPL(v13, ScaledDotProductAttention); + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp new file mode 100644 index 00000000000000..67e927abb43f97 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp @@ -0,0 +1,171 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "intel_gpu/op/sdpa.hpp" +#include "intel_gpu/plugin/common_utils.hpp" +#include "scaled_dot_product_attention_shape_inference.hpp" +#include "openvino/core/partial_shape.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/reshape.hpp" + +namespace ov { +namespace intel_gpu { +namespace op { + +SDPA::SDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const bool is_causal, + const ov::element::Type output_type) + : m_order_q(order_q) + , m_order_k(order_k) + , m_order_v(order_v) + , m_order_out(order_out) + , m_is_causal(is_causal) + , m_output_type(output_type) { + set_arguments({Q, K, V}); + validate_and_infer_types(); +} + +SDPA::SDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const ov::Output& attn_mask, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const bool is_causal, + const ov::element::Type output_type) + : m_order_q(order_q) + , m_order_k(order_k) + , m_order_v(order_v) + , m_order_out(order_out) + , m_is_causal(is_causal) + , m_output_type(output_type) { + set_arguments({Q, K, V, attn_mask}); + validate_and_infer_types(); +} + +SDPA::SDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const ov::Output& attn_mask, + const ov::Output& scale, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const bool is_causal, + const ov::element::Type output_type) + : m_order_q(order_q) + , m_order_k(order_k) + , m_order_v(order_v) + , m_order_out(order_out) + , m_is_causal(is_causal) + , m_output_type(output_type) { + set_arguments({Q, K, V, attn_mask, scale}); + validate_and_infer_types(); +} + +std::shared_ptr SDPA::clone_with_new_inputs(const ov::OutputVector& new_args) const { + check_new_args_count(this, new_args); + + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_order_q, m_order_k, m_order_v, m_order_out, m_is_causal, m_output_type); +} + +void SDPA::validate_and_infer_types() { + const auto input_size = get_input_size(); + NODE_VALIDATION_CHECK(this, + input_size == 3 || input_size == 4 || input_size == 5, + "Number of inputs is incorrect. Current value is: ", + input_size, + ", expected 3, 4 or 5."); + + std::vector input_shapes; + for (size_t i = 0; i < input_size; i++) { + input_shapes.push_back(get_input_partial_shape(i)); + } + + auto out_shapes = shape_infer(this, + input_shapes, + m_order_q, + m_order_k, + m_order_v, + m_order_out); + + auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; + set_output_type(0, output_type, out_shapes[0]); +} + +bool SDPA::visit_attributes(ov::AttributeVisitor &visitor) { + visitor.on_attribute("order_q", m_order_q); + visitor.on_attribute("order_k", m_order_k); + visitor.on_attribute("order_v", m_order_v); + visitor.on_attribute("order_out", m_order_out); + visitor.on_attribute("output_type", m_output_type); + return true; +} + +std::vector shape_infer(const SDPA* op, + std::vector input_shapes, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out) { + auto shape_q = input_shapes[0]; + auto shape_k = input_shapes[1]; + auto shape_v = input_shapes[2]; + + // transposed shape + auto transpose_pshape = [](const ov::PartialShape pshape, const std::vector& order) { + auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank()); + for (size_t i = 0; i < order.size(); i++) { + transposed_pshape[i] = pshape[order[i]]; + } + + return transposed_pshape; + }; + + auto shape_q_t = (order_q.size() > 1) ? transpose_pshape(shape_q, order_q) : shape_q; + auto shape_k_t = (order_k.size() > 1) ? transpose_pshape(shape_k, order_k) : shape_k; + auto shape_v_t = (order_v.size() > 1) ? transpose_pshape(shape_v, order_v) : shape_v; + + const auto is_broadcastable = shape_k_t.rank().is_static() && + shape_v_t.rank().is_static() && + ((shape_q_t.size() == shape_k_t.size()) && (shape_q_t.size() == shape_v_t.size())); + if (is_broadcastable) { + size_t max_rank = shape_q_t.size(); + for (size_t i = 0; i < max_rank; ++i) { + if (shape_q_t[i].is_static() && shape_k_t[i].is_static() && shape_v_t[i].is_static()) { + auto broadcasted_dim = shape_q_t[i].get_length(); + shape_k_t[i] = broadcasted_dim; + shape_v_t[i] = broadcasted_dim; + } + } + } + + std::vector transposed_input_shapes{ shape_q_t, shape_k_t, shape_v_t }; + for (size_t i = 3; i < transposed_input_shapes.size(); i++) { + transposed_input_shapes.push_back(input_shapes[i]); + } + + OPENVINO_ASSERT(op != nullptr, "op should not be nullptr for shape_infer."); + auto out_shapes = ov::op::v13::shape_infer(dynamic_cast(op), transposed_input_shapes); + + if (order_out.size() > 0) { + return { transpose_pshape(out_shapes[0], order_out) }; + } else { + return { out_shapes[0] }; + } +} + +} // namespace op +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.cpp similarity index 56% rename from src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.cpp rename to src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.cpp index e57a7978a5e7bf..614a42845ec521 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.cpp @@ -3,14 +3,16 @@ // #include "intel_gpu/op/gemm.hpp" +#include "intel_gpu/op/sdpa.hpp" #include "openvino/core/node_vector.hpp" #include "openvino/core/partial_shape.hpp" #include "openvino/core/type/element_type.hpp" #include "openvino/op/constant.hpp" #include "openvino/pass/pattern/op/label.hpp" #include "openvino/pass/pattern/op/pattern.hpp" -#include "transpose_matmul_fusion.hpp" +#include "transpose_fusion.hpp" #include "openvino/op/matmul.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/transpose.hpp" #include "openvino/core/rt_info.hpp" @@ -25,23 +27,133 @@ using ov::pass::pattern::op::Or; namespace ov { namespace intel_gpu { -class TransposeMatMulMatcher : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("TransposeMatMulMatcher", "0"); - TransposeMatMulMatcher(); -}; - -class TransposeMatMulTransposeMatcher : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("TransposeMatMulTransposeMatcher", "0"); - TransposeMatMulTransposeMatcher(); -}; - -TransposeMatMulFusion::TransposeMatMulFusion() { +TransposeFusion::TransposeFusion() { add_matcher(); add_matcher(); + add_matcher(); } +TransposeSDPAMatcher::TransposeSDPAMatcher() { + auto is_fp_type = [](const ov::Output& output) -> bool { + switch (output.get_element_type()) { + case ov::element::f16: + case ov::element::f32: return true; + default: return false; + } + }; + auto not_transpose = [is_fp_type](const ov::Output& output) -> bool { + return std::dynamic_pointer_cast(output.get_node_shared_ptr()) == nullptr + && is_fp_type(output); + }; + auto is_dynamic = [](const ov::Output& output) -> bool { + bool is_dynamic = output.get_node_shared_ptr()->get_output_partial_shape(0).is_dynamic(); + size_t num_inputs = output.get_node_shared_ptr()->get_input_size(); + for (size_t idx = 0; idx < num_inputs; idx++) { + is_dynamic |= output.get_node_shared_ptr()->get_input_partial_shape(idx).is_dynamic(); + } + return is_dynamic; + }; + + auto input_q_m = any_input(not_transpose); + auto input_k_m = any_input(not_transpose); + auto input_v_m = any_input(not_transpose); + auto input_attn_mask = any_input(not_transpose); + auto input_scale = any_input(not_transpose); + auto transpose_q_order_m = wrap_type(consumers_count(1)); + auto transpose_k_order_m = wrap_type(consumers_count(1)); + auto transpose_v_order_m = wrap_type(consumers_count(1)); + auto transpose_q_m = wrap_type({input_q_m, transpose_q_order_m}, is_fp_type); + auto transpose_k_m = wrap_type({input_k_m, transpose_k_order_m}, is_fp_type); + auto transpose_v_m = wrap_type({input_v_m, transpose_v_order_m}, is_fp_type); + + auto sdpa_in_q = std::make_shared(OutputVector{input_q_m, transpose_q_m}); + auto sdpa_in_k = std::make_shared(OutputVector{input_k_m, transpose_k_m}); + auto sdpa_in_v = std::make_shared(OutputVector{input_v_m, transpose_v_m}); + + auto sdpa_without_attn_mask_m = wrap_type({ sdpa_in_q, sdpa_in_k, sdpa_in_v }, is_dynamic); + auto sdpa_with_attn_mask_m = wrap_type({ sdpa_in_q, sdpa_in_k, sdpa_in_v, input_attn_mask }, is_dynamic); + auto sdpa_with_attn_mask_and_scale_m = + wrap_type({ sdpa_in_q, sdpa_in_k, sdpa_in_v, input_attn_mask, input_scale }, is_dynamic); + + auto sdpa_m = std::make_shared(OutputVector{sdpa_without_attn_mask_m, sdpa_with_attn_mask_m, sdpa_with_attn_mask_and_scale_m}); + + ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + + std::shared_ptr sdpa; + if (pattern_map.find(sdpa_without_attn_mask_m) != pattern_map.end()) { + sdpa = std::dynamic_pointer_cast(pattern_map.at(sdpa_without_attn_mask_m).get_node_shared_ptr()); + } else if (pattern_map.find(sdpa_with_attn_mask_m) != pattern_map.end()) { + sdpa = std::dynamic_pointer_cast(pattern_map.at(sdpa_with_attn_mask_m).get_node_shared_ptr()); + } else if (pattern_map.find(sdpa_with_attn_mask_and_scale_m) != pattern_map.end()) { + sdpa = std::dynamic_pointer_cast(pattern_map.at(sdpa_with_attn_mask_and_scale_m).get_node_shared_ptr()); + } + + if (!sdpa || transformation_callback(sdpa)) { + return false; + } + + auto order_q = op::SDPA::default_order(sdpa->get_input_partial_shape(0).size()); + auto order_k = op::SDPA::default_order(sdpa->get_input_partial_shape(1).size()); + auto order_v = op::SDPA::default_order(sdpa->get_input_partial_shape(2).size()); + auto order_output = op::SDPA::default_order(sdpa->get_output_partial_shape(0).size()); + size_t input_q_output_idx = sdpa->get_input_source_output(0).get_index(); + size_t input_k_output_idx = sdpa->get_input_source_output(1).get_index(); + size_t input_v_output_idx = sdpa->get_input_source_output(2).get_index(); + + if (pattern_map.count(transpose_q_m) > 0) { + auto tranpose_a_order = std::dynamic_pointer_cast(pattern_map.at(transpose_q_order_m).get_node_shared_ptr()); + order_q = tranpose_a_order->cast_vector(); + if (order_q.back() != static_cast(order_q.size() - 1)) // Allow any transposes without head_size dim position change + return false; + + auto tranpose_a = std::dynamic_pointer_cast(pattern_map.at(transpose_q_m).get_node_shared_ptr()); + input_q_output_idx = tranpose_a->get_input_source_output(0).get_index(); + } + if (pattern_map.count(transpose_k_m) > 0) { + auto tranpose_b_order = std::dynamic_pointer_cast(pattern_map.at(transpose_k_order_m).get_node_shared_ptr()); + order_k = tranpose_b_order->cast_vector(); + if (order_k.back() != static_cast(order_k.size() - 1)) // Allow any transposes without head_size dim position change + return false; + + auto tranpose_b = std::dynamic_pointer_cast(pattern_map.at(transpose_k_m).get_node_shared_ptr()); + input_k_output_idx = tranpose_b->get_input_source_output(0).get_index(); + } + if (pattern_map.count(transpose_v_m) > 0) { + auto tranpose_c_order = std::dynamic_pointer_cast(pattern_map.at(transpose_v_order_m).get_node_shared_ptr()); + order_v = tranpose_c_order->cast_vector(); + if (order_v.back() != static_cast(order_v.size() - 1)) // Allow any transposes without head_size dim position change + return false; + + auto tranpose_c = std::dynamic_pointer_cast(pattern_map.at(transpose_k_m).get_node_shared_ptr()); + input_v_output_idx = tranpose_c->get_input_source_output(0).get_index(); + } + + auto input_q = ov::Output(pattern_map.at(input_q_m).get_node_shared_ptr(), input_q_output_idx); + auto input_k = ov::Output(pattern_map.at(input_k_m).get_node_shared_ptr(), input_k_output_idx); + auto input_v = ov::Output(pattern_map.at(input_v_m).get_node_shared_ptr(), input_v_output_idx); + + std::shared_ptr sdpa_new; + if (pattern_map.find(sdpa_without_attn_mask_m) != pattern_map.end()) { + sdpa_new = std::make_shared(input_q, input_k, input_v, order_q, order_k, order_v, order_output, sdpa->get_causal()); + } else if (pattern_map.find(sdpa_with_attn_mask_m) != pattern_map.end()) { + auto attn_mask = sdpa->get_input_source_output(3); + sdpa_new = std::make_shared(input_q, input_k, input_v, attn_mask, order_q, order_k, order_v, order_output, sdpa->get_causal()); + } else if (pattern_map.find(sdpa_with_attn_mask_and_scale_m) != pattern_map.end()) { + auto attn_mask = sdpa->get_input_source_output(3); + auto scale = sdpa->get_input_source_output(4); + sdpa_new = std::make_shared(input_q, input_k, input_v, attn_mask, scale, order_q, order_k, order_v, order_output, sdpa->get_causal()); + } + + sdpa_new->set_friendly_name(sdpa->get_friendly_name()); + ov::copy_runtime_info(m.get_matched_nodes(), sdpa_new); + ov::replace_node(sdpa, sdpa_new); + return true; + }; + + auto m = std::make_shared(sdpa_m, "TransposeSDPAMatcher"); + this->register_matcher(m, callback); +} TransposeMatMulMatcher::TransposeMatMulMatcher() { auto is_fp_type = [](const ov::Output& output) -> bool { diff --git a/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.hpp b/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.hpp new file mode 100644 index 00000000000000..b43b74adf396d5 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.hpp @@ -0,0 +1,37 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" + +namespace ov { +namespace intel_gpu { + +class TransposeFusion: public ov::pass::GraphRewrite { +public: + OPENVINO_RTTI("TransposeMatMulFusion", "0"); + TransposeFusion(); +}; + +class TransposeMatMulMatcher : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TransposeMatMulMatcher", "0"); + TransposeMatMulMatcher(); +}; + +class TransposeMatMulTransposeMatcher : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TransposeMatMulTransposeMatcher", "0"); + TransposeMatMulTransposeMatcher(); +}; + +class TransposeSDPAMatcher : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TransposeSDPAMatcher", "0"); + TransposeSDPAMatcher(); +}; + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.hpp b/src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.hpp deleted file mode 100644 index b24d76059ada11..00000000000000 --- a/src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.hpp +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (C) 2018-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" - -namespace ov { -namespace intel_gpu { - -class TransposeMatMulFusion: public ov::pass::GraphRewrite { -public: - OPENVINO_RTTI("TransposeMatMulFusion", "0"); - TransposeMatMulFusion(); -}; - -} // namespace intel_gpu -} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp new file mode 100644 index 00000000000000..3fdb3794585106 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp @@ -0,0 +1,134 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "unsqueeze_broadcast_reshape_sdpa_fusion.hpp" + +#include "intel_gpu/op/sdpa.hpp" +#include "intel_gpu/op/kv_cache.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "transformations/utils/utils.hpp" + +namespace ov { +namespace intel_gpu { +using ov::pass::pattern::op::Or; + +UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() { + using namespace ov::pass::pattern; + + auto not_reshape = [](const ov::Output& output) -> bool { + return std::dynamic_pointer_cast(output.get_node_shared_ptr()) == nullptr; + }; + + auto unsqueeze_predicate = [](const ov::Output& output) -> bool { + return rank_equals(5)(output) && consumers_count(1); + }; + + auto broadcast_predicate = [](const ov::Output& output) -> bool { + const auto broadcast = ov::as_type_ptr(output.get_node_shared_ptr()); + if (!broadcast || broadcast->get_broadcast_spec().m_type != ov::op::BroadcastType::BIDIRECTIONAL) + return false; + return rank_equals(5)(output) && consumers_count(1); + }; + + auto reshape_predicate = [](const ov::Output& output) -> bool { + return rank_equals(4)(output) && consumers_count(1); + }; + + auto input_a_m = any_input(not_reshape); + auto input_attn_mask = any_input(); + auto input_scale = any_input(); + auto input_b_m = wrap_type({any_input(), any_input()}); + auto input_c_m = wrap_type({any_input(), any_input()}); + auto axes_const_b_m = wrap_type(); + auto axes_const_c_m = wrap_type(); + auto unsqueeze_b_m = wrap_type({input_b_m, axes_const_b_m}, unsqueeze_predicate); + auto unsqueeze_c_m = wrap_type({input_c_m, axes_const_c_m}, unsqueeze_predicate); + auto broadcast_b_m = wrap_type({unsqueeze_b_m, any_input()}, broadcast_predicate); + auto broadcast_c_m = wrap_type({unsqueeze_c_m, any_input()}, broadcast_predicate); + auto reshape_b_m = wrap_type({broadcast_b_m, any_input()}, reshape_predicate); + auto reshape_c_m = wrap_type({broadcast_c_m, any_input()}, reshape_predicate); + + auto sdpa_without_attn_mask_m = wrap_type({ input_a_m, reshape_b_m, reshape_c_m }); + auto sdpa_with_attn_mask_m = wrap_type({ input_a_m, reshape_b_m, reshape_c_m, input_attn_mask }); + auto sdpa_with_attn_mask_and_scale_m = wrap_type({ input_a_m, reshape_b_m, reshape_c_m, input_attn_mask, input_scale }); + + auto sdpa_m = std::make_shared(OutputVector{sdpa_without_attn_mask_m, sdpa_with_attn_mask_m, sdpa_with_attn_mask_and_scale_m}); + + ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { + if (transformation_callback(m.get_match_root())) { + return false; + } + const auto& pattern_map = m.get_pattern_value_map(); + + auto valid_broadcast_target_shape = [](const std::vector& target_shape) { + return std::count_if(target_shape.begin(), target_shape.end(), [](int32_t s) { return s != 1; }) == 1; + }; + auto broadcast_b = std::dynamic_pointer_cast(pattern_map.at(broadcast_b_m).get_node_shared_ptr()); + auto broadcast_c = std::dynamic_pointer_cast(pattern_map.at(broadcast_c_m).get_node_shared_ptr()); + + std::vector target_shape_val_b; + auto target_shape_constant_b = std::dynamic_pointer_cast(broadcast_c->get_input_node_shared_ptr(1)); + if (target_shape_constant_b) { + target_shape_val_b = target_shape_constant_b->cast_vector(); + if (!valid_broadcast_target_shape(target_shape_val_b)) { + return false; + } + } + + std::vector target_shape_val_c; + auto target_shape_constant_c = std::dynamic_pointer_cast(broadcast_b->get_input_node_shared_ptr(1)); + if (target_shape_constant_c) { + target_shape_val_c = target_shape_constant_c->cast_vector(); + if (!valid_broadcast_target_shape(target_shape_val_c)) { + return false; + } + } + + // Expect the same broadcast rules for key and value inputs + if (target_shape_val_b != target_shape_val_c) { + return false; + } + + auto input_a = pattern_map.at(input_a_m).get_node_shared_ptr(); + auto input_b = pattern_map.at(input_b_m).get_node_shared_ptr(); + auto input_c = pattern_map.at(input_c_m).get_node_shared_ptr(); + + auto sdpa = std::dynamic_pointer_cast(m.get_match_root()); + auto order_a = sdpa->get_input0_transpose_order(); + auto order_b = sdpa->get_input1_transpose_order(); + auto order_c = sdpa->get_input2_transpose_order(); + auto order_d = sdpa->get_output_transpose_order(); + + std::shared_ptr sdpa_new; + if (pattern_map.find(sdpa_without_attn_mask_m) != pattern_map.end()) { + sdpa_new = std::make_shared(input_a, input_b, input_c, order_a, order_b, order_c, order_d, sdpa->get_causal()); + } else if (pattern_map.find(sdpa_with_attn_mask_m) != pattern_map.end()) { + auto attn_mask = sdpa->get_input_source_output(3); + sdpa_new = std::make_shared(input_a, input_b, input_c, attn_mask, order_a, order_b, order_c, order_d, sdpa->get_causal()); + } else if (pattern_map.find(sdpa_with_attn_mask_and_scale_m) != pattern_map.end()) { + auto attn_mask = sdpa->get_input_source_output(3); + auto scale = sdpa->get_input_source_output(4); + sdpa_new = std::make_shared(input_a, input_b, input_c, attn_mask, scale, order_a, order_b, order_c, order_d, sdpa->get_causal()); + } + + sdpa_new->set_friendly_name(sdpa->get_friendly_name()); + ov::copy_runtime_info(m.get_matched_nodes(), sdpa_new); + ov::replace_node(sdpa, sdpa_new); + + return true; + }; + + auto m = std::make_shared(sdpa_m, "UnsqueezeBroadcastReshapeSDPAFusion"); + this->register_matcher(m, callback); +} + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.hpp b/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.hpp new file mode 100644 index 00000000000000..ede3ac16fb51b5 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" + +namespace ov { +namespace intel_gpu { + +class UnsqueezeBroadcastReshapeSDPAFusion : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("UnsqueezeBroadcastReshapeSDPAFusion", "0"); + UnsqueezeBroadcastReshapeSDPAFusion(); +}; + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 0c690dfe7d6df1..1dad68a1ecd997 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -61,11 +61,12 @@ #include "plugin/transformations/bcast_and_pad_zp_buffers.hpp" #include "transformations/common_optimizations/rms_fusion.hpp" #include "plugin/transformations/swiglu_fusion.hpp" -#include "plugin/transformations/transpose_matmul_fusion.hpp" +#include "plugin/transformations/transpose_fusion.hpp" #include "plugin/transformations/indirect_kv_cache.hpp" #include "plugin/transformations/convert_convolution.hpp" #include "plugin/transformations/unsqueeze_broadcast_reshape_matmul_fusion.hpp" #include "transformations/common_optimizations/rms_fusion.hpp" +#include "plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.hpp" #include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp" #include "transformations/common_optimizations/broadcast_transition.hpp" #include "transformations/common_optimizations/common_optimizations.hpp" @@ -134,6 +135,7 @@ #include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp" #include "transformations/op_conversions/softmax_decomposition.hpp" #include "transformations/op_conversions/softplus_decomposition.hpp" +#include "transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp" #include "transformations/opset_conversions/convert_opset2_to_opset1.hpp" #include "transformations/opset_conversions/convert_opset3_to_opset2.hpp" #include "transformations/resolve_names_collisions.hpp" @@ -141,6 +143,19 @@ #include "transformations/rt_info/keep_const_precision.hpp" #include "transformations/smart_reshape/matmul_sr.hpp" +template +T convert_to(const std::string &str) { + std::istringstream ss(str); + T res; + ss >> res; + return res; +} + +template <> +std::string convert_to(const std::string &str) { + return str; +} + namespace { template static bool disable_reduce_decomposition(const std::shared_ptr node) { @@ -303,6 +318,52 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); + // Disable SDPA decomposition once additional transformations are added: + // 1) Input/Output Transpose fusion + // 2) Indirect inputs support + // 3) GQA related optimization (Broadcast fusion) + pass_config->set_callback([&](const std::shared_ptr node){ + // Known limitations: + // - The head size of all Q, K, and V inputs should be the same static value + // - The head size should be divisible by 16 + // - All inputs and outputs must have the same data type + // - The number of dimensions for each input is expected to be 4 + // - SDPA impl could be slower on GPUs with IMMAD support in non-LLM scenarios, + // because oneDNN can be used for those cases - SDPA requires DPAS support + auto sdpa = std::dynamic_pointer_cast(node); + const auto& query_ps = sdpa->get_input_partial_shape(0); + const auto& key_ps = sdpa->get_input_partial_shape(1); + const auto& value_ps = sdpa->get_input_partial_shape(2); + + if (const auto env_var = std::getenv("USE_SDPA")) { + bool use_sdpa = convert_to(env_var); + std::cout << "Use SDPA forced to " << (use_sdpa ? "TRUE" : "FALSE") << "\n"; + return use_sdpa; + } + + if (query_ps.size() != 4 || key_ps.size() != 4 || value_ps.size() != 4) { + return false; + } + + if (query_ps[query_ps.size() - 1].is_dynamic() || key_ps[key_ps.size() - 1].is_dynamic() || value_ps[query_ps.size() - 1].is_dynamic()) { + return false; + } + + if (query_ps[query_ps.size() - 1].get_length() != key_ps[key_ps.size() - 1].get_length() || + query_ps[query_ps.size() - 1].get_length() != value_ps[query_ps.size() - 1].get_length()) { + return false; + } + + const auto optimal_subgroup_size = 16; + if (query_ps[query_ps.size() - 1].is_dynamic() || + query_ps[query_ps.size() - 1].get_length() > 256 || + query_ps[query_ps.size() - 1].get_length() % optimal_subgroup_size != 0) { + return false; + } + + return true; + }); + manager.register_pass(); manager.register_pass(); @@ -749,10 +810,17 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); + if (device_info.supports_immad) { + manager.get_pass_config()->disable(); + manager.get_pass_config()->disable(); + } + if (!device_info.supports_immad) { - manager.register_pass(); manager.register_pass(); } + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 11360f1fe80faa..1d1b9e8dd9e9a5 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -205,6 +205,8 @@ std::vector disabledTestPatterns() { R"(.*smoke_RDFT_5d_last_axis/RDFTLayerTest.Inference/IS=\(10.4.8.2.5\)_modelType=f32_Axes=\(0.1.2.3.4\)_SignalSize=\(\).*)", // Issue: 136862 R"(.*smoke_ConditionGPUTest_static/StaticConditionLayerGPUTest.CompareWithRefs/IS=\(3.6\)_netPRC=i8_ifCond=PARAM_targetDevice=GPU_.*)", + // Uncomment once SDPA decomposition is disabled + R"(.*smoke_ScaledAttn_GPU.*)", #if defined(_WIN32) R"(.*smoke_RemoteTensor/OVRemoteTensorBatched_Test.NV12toBGR_buffer/(num_batch_4|num_batch_2).*)", R"(.*smoke_Check/ConstantResultSubgraphTest.Inference/SubgraphType=SINGLE_COMPONENT_IS=\[1,3,10,10\]_IT=i16_Device=GPU.*)", diff --git a/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp new file mode 100644 index 00000000000000..29498e65965d37 --- /dev/null +++ b/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp @@ -0,0 +1,215 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_test_utils/ov_tensor_utils.hpp" +#include "common_test_utils/test_enums.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" + + +#include "openvino/opsets/opset13.hpp" +#include "transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp" +#include "openvino/pass/manager.hpp" + +#include "openvino/op/parameter.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/result.hpp" +#include "openvino/op/matmul.hpp" + +namespace { +using ov::test::InputShape; + +typedef std::tuple, // shape + bool, // is_causal + bool, // has_attn + bool, // has_scale + std::string // targetDevice + > ScaledAttnGPUTestParams; + +class ScaledAttnLayerGPUTest : public testing::WithParamInterface, + virtual public ov::test::SubgraphBaseTest { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj); + +protected: + void SetUp() override; + void generate_inputs(const std::vector& targetInputStaticShapes) override; + bool is_causal; + bool has_attn; + bool has_scale; +}; + +std::string ScaledAttnLayerGPUTest::getTestCaseName(const testing::TestParamInfo& obj) { + ov::element::Type inType; + std::vector inputShapes; + bool is_causal; + bool has_attn; + bool has_scale; + std::string targetDevice; + std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice) = obj.param; + + std::ostringstream result; + result << "netPRC=" << inType << "_"; + result << "IS="; + for (const auto& inputShape : inputShapes) { + result << ov::test::utils::partialShape2str({inputShape.first}) << "_"; + } + result << "TS="; + for (const auto& shapes : inputShapes) { + for (const auto& shape : shapes.second) { + result << ov::test::utils::vec2str(shape); + result << "_"; + } + } + result << "is_causal=" << is_causal << "_"; + result << "has_attn=" << has_attn << "_"; + result << "has_scale=" << has_scale << "_"; + result << "trgDev=" << targetDevice; + + return result.str(); +} + +void ScaledAttnLayerGPUTest::SetUp() { + ov::element::Type inType; + std::vector inputShapes; + std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice) = this->GetParam(); + + init_input_shapes(inputShapes); + ov::ParameterVector inputParams; + // q, k, v + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); + inputParams[0]->set_friendly_name("q"); + inputParams[1]->set_friendly_name("k"); + inputParams[2]->set_friendly_name("v"); + // special case: only scale but no attn + if (!has_attn && has_scale) { + // attention_mask:[1] + inputParams.push_back(std::make_shared(inType, ov::PartialShape{})); + inputParams.back()->set_friendly_name("attention_mask"); + // scale:[1] + inputParams.push_back(std::make_shared(inType, ov::PartialShape{1})); + inputParams.back()->set_friendly_name("scale"); + } else { + if (has_attn) { + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[2])); + inputParams.back()->set_friendly_name("attention_mask"); + } + if (has_scale) { + // scale:[1] + inputParams.push_back(std::make_shared(inType, ov::PartialShape{1})); + inputParams.back()->set_friendly_name("scale"); + } + } + ov::OutputVector inputs; + for (auto& input : inputParams) { + inputs.push_back(input); + } + auto sdp = std::make_shared(inputs, is_causal); + sdp->set_friendly_name("sdpa"); + + auto output = std::make_shared(sdp->output(0)); + + function = std::make_shared(ov::OutputVector{output}, inputParams, "sdpa_model"); + + functionRefs = function->clone(); + ov::pass::Manager manager; + + // decompose ScaledDotProductAttention + manager.register_pass(); + manager.run_passes(functionRefs); +} + +void ScaledAttnLayerGPUTest::generate_inputs(const std::vector& targetInputStaticShapes) { + std::vector shapes(3); + shapes[0] = targetInputStaticShapes[0]; + shapes[1] = targetInputStaticShapes[1]; + shapes[2] = targetInputStaticShapes[1]; + if (!has_attn && has_scale) { + shapes.push_back(ov::Shape{}); + shapes.push_back(ov::Shape{1}); + } else { + if (has_attn) { + shapes.push_back(targetInputStaticShapes[2]); + } + if (has_scale) { + shapes.push_back(ov::Shape{1}); + } + } + SubgraphBaseTest::generate_inputs(shapes); +} + +TEST_P(ScaledAttnLayerGPUTest, CompareWithRefs) { + ov::element::Type inType; + std::vector inputShapes; + bool is_causal; + bool has_attn; + bool has_scale; + std::string targetDevice; + std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice) = this->GetParam(); + run(); +} + +const std::vector> shapes{ + // normal case, shapes of q,k,v are same + { + // q shape + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, + {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + }, + // kv shape + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, + {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + }, + // attn shape: [B, 1, -1, L0+L1] + {ov::test::InputShape{ov::PartialShape{-1, 1, -1, -1}, + {ov::Shape{1, 1, 100, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 10, 10}}} + }, + }, + { + // q shape + {ov::test::InputShape{ov::PartialShape{-1, 5, -1, 64}, + {ov::Shape{2, 5, 100, 64}, ov::Shape{2, 5, 1, 64}, ov::Shape{2, 5, 512, 64}}} + }, + // kv shape + {ov::test::InputShape{ov::PartialShape{-1, 5, -1, 64}, + {ov::Shape{2, 5, 100, 64}, ov::Shape{2, 5, 1, 64}, ov::Shape{2, 5, 512, 64}}} + }, + // attn shape: [B, 1, -1, L0+L1] + {ov::test::InputShape{ov::PartialShape{-1, 1, -1, -1}, + {ov::Shape{1, 1, 100, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 512, 512}}} + }, + }, + // Currently unsupported + // heads number of kv is 1, attn mask: [B, H, L1, L0+L1] + // { + // // q shape + // {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, + // {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + // }, + // // kv shape + // {ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64}, + // {ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}} + // }, + // // attn shape + // {ov::test::InputShape{ov::PartialShape{-1, 8, -1, -1}, + // {ov::Shape{1, 8, 100, 100}, ov::Shape{1, 8, 1, 1}, ov::Shape{2, 8, 10, 10}}} + // }, + // }, +}; + +const auto params = testing::Combine(testing::Values(/* ov::element::f16, */ov::element::f32), + testing::ValuesIn(shapes), + testing::Values(true, false), + testing::Values(true, false), + testing::Values(true, false), + testing::Values(ov::test::utils::DEVICE_GPU)); + +INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_GPU, + ScaledAttnLayerGPUTest, + params, + ScaledAttnLayerGPUTest::getTestCaseName); + +} // namespace diff --git a/src/plugins/intel_gpu/tests/unit/transformations/transpose_matmul_fusion_test.cpp b/src/plugins/intel_gpu/tests/unit/transformations/transpose_matmul_fusion_test.cpp index 61638930c3b63f..f97ac8f9c433a1 100644 --- a/src/plugins/intel_gpu/tests/unit/transformations/transpose_matmul_fusion_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/transformations/transpose_matmul_fusion_test.cpp @@ -13,7 +13,7 @@ #include "openvino/op/result.hpp" #include "intel_gpu/op/gemm.hpp" -#include "plugin/transformations/transpose_matmul_fusion.hpp" +#include "plugin/transformations/transpose_fusion.hpp" #include @@ -31,7 +31,7 @@ TEST_F(TransformationTestsF, TranposeMatmulFusion1) { auto matmul = std::make_shared(input_a, input_b); model = std::make_shared(ov::NodeVector{ matmul }, ov::ParameterVector{ input_a, input_b }); - manager.register_pass(); + manager.register_pass(); } { std::vector order_a = {0, 1, 2, 3}; @@ -55,7 +55,7 @@ TEST_F(TransformationTestsF, TranposeMatmulFusion2) { auto matmul = std::make_shared(tranpose_a, input_b); model = std::make_shared(ov::NodeVector{ matmul }, ov::ParameterVector{ input_a, input_b }); - manager.register_pass(); + manager.register_pass(); } { std::vector order_a = {0, 2, 1, 3}; @@ -81,7 +81,7 @@ TEST_F(TransformationTestsF, TranposeMatmulFusion3) { auto matmul = std::make_shared(tranpose_a, tranpose_b); model = std::make_shared(ov::NodeVector{ matmul }, ov::ParameterVector{ input_a, input_b }); - manager.register_pass(); + manager.register_pass(); } { std::vector order_a = {0, 2, 1, 3}; @@ -109,7 +109,7 @@ TEST_F(TransformationTestsF, TranposeMatmulFusion4) { auto tranpose_c = std::make_shared(matmul, tranpose_c_const); model = std::make_shared(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b }); - manager.register_pass(); + manager.register_pass(); } { std::vector order_a = {0, 2, 1, 3}; diff --git a/src/plugins/intel_gpu/tests/unit/transformations/transpose_sdpa_fusion_test.cpp b/src/plugins/intel_gpu/tests/unit/transformations/transpose_sdpa_fusion_test.cpp new file mode 100644 index 00000000000000..ebe15f4d806b31 --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/transformations/transpose_sdpa_fusion_test.cpp @@ -0,0 +1,178 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_test_utils/ov_test_utils.hpp" + +#include "openvino/core/model.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/pass/manager.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/parameter.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/result.hpp" +#include "intel_gpu/op/sdpa.hpp" + +#include "plugin/transformations/transpose_fusion.hpp" + +#include + +using namespace testing; +using namespace ov::intel_gpu; + +namespace ov { +namespace test { +namespace intel_gpu { + +TEST_F(TransformationTestsF, TranposeSDPAFusion1) { + { + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_c = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto sdpa = std::make_shared(input_a, input_b, input_c, true); + + model = std::make_shared(ov::NodeVector{ sdpa }, ov::ParameterVector{ input_a, input_b, input_c }); + manager.register_pass(); + } + { + std::vector order_a = {0, 1, 2, 3}; + std::vector order_b = {0, 1, 2, 3}; + std::vector order_c = {0, 1, 2, 3}; + std::vector order_output = {0, 1, 2, 3}; + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_c = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto sdpa = std::make_shared(input_a, input_b, input_c, order_a, order_b, order_c, order_output, true, ov::element::undefined ); + + model_ref = std::make_shared(ov::NodeVector{ sdpa }, ov::ParameterVector{ input_a, input_b, input_c }); + comparator.enable(FunctionsComparator::ATTRIBUTES); + } +} + +TEST_F(TransformationTestsF, TranposeSDPAFusion2) { + { + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_a_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + auto tranpose_a = std::make_shared(input_a, tranpose_a_const); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_c = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto sdpa = std::make_shared(tranpose_a, input_b, input_c, true); + + model = std::make_shared(ov::NodeVector{ sdpa }, ov::ParameterVector{ input_a, input_b, input_c }); + manager.register_pass(); + } + { + std::vector order_a = {0, 2, 1, 3}; + std::vector order_b = {0, 1, 2, 3}; + std::vector order_c = {0, 1, 2, 3}; + std::vector order_output = {0, 1, 2, 3}; + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_c = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto sdpa = std::make_shared(input_a, input_b, input_c, order_a, order_b, order_c, order_output, true, ov::element::undefined); + + model_ref = std::make_shared(ov::NodeVector{ sdpa }, ov::ParameterVector{ input_a, input_b, input_c }); + comparator.enable(FunctionsComparator::ATTRIBUTES); + } +} + +TEST_F(TransformationTestsF, TranposeSDPAFusion3) { + { + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_a_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + auto tranpose_a = std::make_shared(input_a, tranpose_a_const); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_b_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {1, 2, 0, 3}); + auto tranpose_b = std::make_shared(input_b, tranpose_b_const); + auto input_c = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + + auto sdpa = std::make_shared(tranpose_a, tranpose_b, input_c, false); + + model = std::make_shared(ov::NodeVector{ sdpa }, ov::ParameterVector{ input_a, input_b, input_c }); + manager.register_pass(); + } + { + std::vector order_a = {0, 2, 1, 3}; + std::vector order_b = {1, 2, 0, 3}; + std::vector order_c = {0, 1, 2, 3}; + std::vector order_output = {0, 1, 2, 3}; + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_c = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto sdpa = std::make_shared(input_a, input_b, input_c, order_a, order_b, order_c, order_output, false, ov::element::undefined); + + model_ref = std::make_shared(ov::NodeVector{ sdpa }, ov::ParameterVector{ input_a, input_b, input_c }); + comparator.enable(FunctionsComparator::ATTRIBUTES); + } +} + +TEST_F(TransformationTestsF, TranposeSDPAFusion4) { + { + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_a_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + auto tranpose_a = std::make_shared(input_a, tranpose_a_const); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_b_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + auto tranpose_b = std::make_shared(input_b, tranpose_b_const); + auto input_c = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_c_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + auto tranpose_c = std::make_shared(input_c, tranpose_c_const); + + auto sdpa = std::make_shared(tranpose_a, tranpose_b, tranpose_c, false); + + model = std::make_shared(ov::NodeVector{ sdpa }, ov::ParameterVector{ input_a, input_b, input_c }); + manager.register_pass(); + } + { + std::vector order_a = {0, 2, 1, 3}; + std::vector order_b = {0, 2, 1, 3}; + std::vector order_c = {0, 2, 1, 3}; + std::vector order_output = {0, 1, 2, 3}; + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_c = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto sdpa = std::make_shared(input_a, input_b, input_c, order_a, order_b, order_c, order_output, false, ov::element::undefined); + + model_ref = std::make_shared(ov::NodeVector{ sdpa }, ov::ParameterVector{ input_a, input_b, input_c }); + comparator.enable(FunctionsComparator::ATTRIBUTES); + } +} + +TEST_F(TransformationTestsF, TranposeSDPAFusion5) { + { + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_a_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + auto tranpose_a = std::make_shared(input_a, tranpose_a_const); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_b_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + auto tranpose_b = std::make_shared(input_b, tranpose_b_const); + auto input_c = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_c_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {3, 2, 1, 0}); + auto tranpose_c = std::make_shared(input_c, tranpose_c_const); + + auto sdpa = std::make_shared(tranpose_a, tranpose_b, tranpose_c, false); + + model = std::make_shared(ov::NodeVector{ sdpa }, ov::ParameterVector{ input_a, input_b, input_c }); + manager.register_pass(); + } + { + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_a_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + auto tranpose_a = std::make_shared(input_a, tranpose_a_const); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_b_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + auto tranpose_b = std::make_shared(input_b, tranpose_b_const); + auto input_c = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto tranpose_c_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {3, 2, 1, 0}); + auto tranpose_c = std::make_shared(input_c, tranpose_c_const); + + auto sdpa = std::make_shared(tranpose_a, tranpose_b, tranpose_c, false); + + model_ref = std::make_shared(ov::NodeVector{ sdpa }, ov::ParameterVector{ input_a, input_b, input_c }); + comparator.enable(FunctionsComparator::ATTRIBUTES); + } +} + +} // namespace intel_gpu +} // namespace test +} // namespace ov From 013e0cfb144e5bccd838b09ade9fdb5f78821d13 Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Tue, 21 May 2024 13:17:00 +0400 Subject: [PATCH 2/4] Fix code style --- .../ocl/scaled_dot_product_attention.cpp | 3 ++ .../kernel_selector/cl_kernels/sdpa_opt.cl | 2 ++ .../kernels/sdpa/sdpa_kernel_base.h | 32 ++++++++++++++----- .../kernels/sdpa/sdpa_kernel_opt.cpp | 2 +- .../src/plugin/transformations_pipeline.cpp | 19 +++++++---- 5 files changed, 42 insertions(+), 16 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp index d9303f058814a2..d60098aca74588 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp @@ -34,6 +34,9 @@ struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl& order) { + if (order.empty()) + return pshape; + auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank()); for (size_t i = 0; i < order.size(); i++) { transposed_pshape[i] = pshape[order[i]]; diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index f9a1d31bc434ee..14cef4010c6bea 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -102,6 +102,7 @@ inline uint FUNC(get_input2_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint #ifdef SDPA_STAGE_0 #if TARGET_SEQ_LEN_BLOCK_SIZE == 1 +/* This version is used for 2nd token */ REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) KERNEL(sdpa_opt)( @@ -529,6 +530,7 @@ KERNEL(sdpa_opt)( } #else +/* This version is used for 1st token */ REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) KERNEL(sdpa_opt)( diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h index 319742860c6f73..644d9930f69c1f 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h @@ -51,14 +51,30 @@ struct TransposedDimensionAccessHelperJit : DimensionAccessHelperJit, Transposed std::string f() { return dims_sizes[transposed_order[1]]; } std::string b() { return dims_sizes[transposed_order[0]]; } - std::pair x_pad() { return {pad_before_after_sizes[(transposed_order[7] * 2) + 0], pad_before_after_sizes[(transposed_order[7] * 2) + 1]}; } - std::pair y_pad() { return {pad_before_after_sizes[(transposed_order[6] * 2) + 0], pad_before_after_sizes[(transposed_order[6] * 2) + 1]}; } - std::pair z_pad() { return {pad_before_after_sizes[(transposed_order[5] * 2) + 0], pad_before_after_sizes[(transposed_order[5] * 2) + 1]}; } - std::pair w_pad() { return {pad_before_after_sizes[(transposed_order[4] * 2) + 0], pad_before_after_sizes[(transposed_order[4] * 2) + 1]}; } - std::pair v_pad() { return {pad_before_after_sizes[(transposed_order[3] * 2) + 0], pad_before_after_sizes[(transposed_order[3] * 2) + 1]}; } - std::pair u_pad() { return {pad_before_after_sizes[(transposed_order[2] * 2) + 0], pad_before_after_sizes[(transposed_order[2] * 2) + 1]}; } - std::pair f_pad() { return {pad_before_after_sizes[(transposed_order[1] * 2) + 0], pad_before_after_sizes[(transposed_order[1] * 2) + 1]}; } - std::pair b_pad() { return {pad_before_after_sizes[(transposed_order[0] * 2) + 0], pad_before_after_sizes[(transposed_order[0] * 2) + 1]}; } + std::pair x_pad() { + return {pad_before_after_sizes[(transposed_order[7] * 2) + 0], pad_before_after_sizes[(transposed_order[7] * 2) + 1]}; + } + std::pair y_pad() { + return {pad_before_after_sizes[(transposed_order[6] * 2) + 0], pad_before_after_sizes[(transposed_order[6] * 2) + 1]}; + } + std::pair z_pad() { + return {pad_before_after_sizes[(transposed_order[5] * 2) + 0], pad_before_after_sizes[(transposed_order[5] * 2) + 1]}; + } + std::pair w_pad() { + return {pad_before_after_sizes[(transposed_order[4] * 2) + 0], pad_before_after_sizes[(transposed_order[4] * 2) + 1]}; + } + std::pair v_pad() { + return {pad_before_after_sizes[(transposed_order[3] * 2) + 0], pad_before_after_sizes[(transposed_order[3] * 2) + 1]}; + } + std::pair u_pad() { + return {pad_before_after_sizes[(transposed_order[2] * 2) + 0], pad_before_after_sizes[(transposed_order[2] * 2) + 1]}; + } + std::pair f_pad() { + return {pad_before_after_sizes[(transposed_order[1] * 2) + 0], pad_before_after_sizes[(transposed_order[1] * 2) + 1]}; + } + std::pair b_pad() { + return {pad_before_after_sizes[(transposed_order[0] * 2) + 0], pad_before_after_sizes[(transposed_order[0] * 2) + 1]}; + } }; struct GQA_configuration { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp index b64c9ffa16ce5d..581565874f7fbb 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp @@ -234,7 +234,7 @@ void SDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) const { ScalarDescriptor num_of_partitions_scalar; num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32; - num_of_partitions_scalar.v.u32 = num_of_partitions; + num_of_partitions_scalar.v.u32 = static_cast(num_of_partitions); auto dispatch_data3 = SetDefault(prim_params, 2); kernel_data.kernels[2].params.workGroups.global = dispatch_data3.gws; diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 1dad68a1ecd997..68b5b02780fd0f 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -323,13 +323,6 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // 2) Indirect inputs support // 3) GQA related optimization (Broadcast fusion) pass_config->set_callback([&](const std::shared_ptr node){ - // Known limitations: - // - The head size of all Q, K, and V inputs should be the same static value - // - The head size should be divisible by 16 - // - All inputs and outputs must have the same data type - // - The number of dimensions for each input is expected to be 4 - // - SDPA impl could be slower on GPUs with IMMAD support in non-LLM scenarios, - // because oneDNN can be used for those cases - SDPA requires DPAS support auto sdpa = std::dynamic_pointer_cast(node); const auto& query_ps = sdpa->get_input_partial_shape(0); const auto& key_ps = sdpa->get_input_partial_shape(1); @@ -341,10 +334,21 @@ void TransformationsPipeline::apply(std::shared_ptr func) { return use_sdpa; } + // Known limitations: + // - SDPA impl could be slower in non-LLM scenarios than decomposed version + if (func->get_variables().size() == 0) + return false; + + // - The data type of SDPA should be fp16 + if (sdpa->get_output_element_type(0) != ov::element::f16) + return false; + + // - The number of dimensions for each input is expected to be 4 if (query_ps.size() != 4 || key_ps.size() != 4 || value_ps.size() != 4) { return false; } + // - The head size of all Q, K, and V inputs should be the same static value if (query_ps[query_ps.size() - 1].is_dynamic() || key_ps[key_ps.size() - 1].is_dynamic() || value_ps[query_ps.size() - 1].is_dynamic()) { return false; } @@ -354,6 +358,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { return false; } + // - The head size should be divisible by 16 const auto optimal_subgroup_size = 16; if (query_ps[query_ps.size() - 1].is_dynamic() || query_ps[query_ps.size() - 1].get_length() > 256 || From a7002c548b404ce5a72c8f8452b534692005a150 Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Wed, 22 May 2024 11:35:05 +0400 Subject: [PATCH 3/4] Add custom property and fix debug build issue --- .../openvino/runtime/intel_gpu/properties.hpp | 8 ++++++ .../intel_gpu/src/graph/primitive_inst.cpp | 5 ---- .../kernels/sdpa/sdpa_kernel_base.cpp | 1 - .../kernels/sdpa/sdpa_kernel_base.h | 3 --- .../kernels/sdpa/sdpa_kernel_ref.cpp | 3 +++ .../transformations/transpose_fusion.hpp | 2 +- .../src/plugin/transformations_pipeline.cpp | 26 +++---------------- .../src/runtime/execution_config.cpp | 1 + 8 files changed, 16 insertions(+), 33 deletions(-) diff --git a/src/inference/include/openvino/runtime/intel_gpu/properties.hpp b/src/inference/include/openvino/runtime/intel_gpu/properties.hpp index 7f661d5b67a74a..185195e288805c 100644 --- a/src/inference/include/openvino/runtime/intel_gpu/properties.hpp +++ b/src/inference/include/openvino/runtime/intel_gpu/properties.hpp @@ -115,6 +115,14 @@ static constexpr Property host_task_priority{"GPU_HOST_TASK_ * @ingroup ov_runtime_ocl_gpu_prop_cpp_api */ static constexpr Property available_device_mem{"AVAILABLE_DEVICE_MEM_SIZE"}; + +/** + * @brief Turning on this key disables SDPA operation decomposition and keeps SDPA operation in the graph. + * Enabling SDPA optimization may provide performance improvements and memory usage reduction. + * This key serves as a recommendation and may be ignored in known sub-optimal cases. + * @ingroup ov_runtime_ocl_gpu_prop_cpp_api + */ +static constexpr Property enable_sdpa_optimization{"GPU_ENABLE_SDPA_OPTIMIZATION"}; } // namespace hint /** diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index 5772470bad54f8..6a71cbc8981587 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -394,7 +394,6 @@ void primitive_inst::update_shape() { } if (has_runtime_deps) { - std::cout << "Runtime deps\n"; OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("update_shape_sync: " + id())); if (!dependencies_events.empty() && queue_type == QueueTypes::out_of_order) { _network.get_stream().wait_for_events(dependencies_events); @@ -1456,11 +1455,7 @@ event::ptr primitive_inst::execute(const std::vector& events) { { GPU_DEBUG_PROFILED_STAGE(instrumentation::pipeline_stage::inference); - auto time0 = std::chrono::high_resolution_clock::now(); auto ev = _impl->execute(dependencies, *this); - auto time1 = std::chrono::high_resolution_clock::now(); - auto time_res0 = std::chrono::duration_cast(time1 - time0).count(); - GPU_DEBUG_TRACE_DETAIL << "Enqueu time = " << time_res0 << "\n"; GPU_DEBUG_IF(!debug_config->dump_profiling_data.empty()) { get_network().get_stream().wait_for_events({ev}); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp index a66618aa1f3f95..61028ef5348a1a 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp @@ -97,7 +97,6 @@ JitConstants SDPAKernelBase::GetJitConstants(const sdpa_params& params) const { TransposedDimensionAccessHelperJit dims_q(params.inputs[0], params.input0_order); jit.AddConstant(MakeJitConstant("TARGET_SEQ_LEN", dims_q.y())); - jit.AddConstant(MakeJitConstant("HEAD_SIZE", dims_q.x())); jit.AddConstant(MakeJitConstant("NUM_HEADS", dims_q.f())); TransposedDimensionAccessHelperJit dims_k(params.inputs[1], params.input1_order); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h index 644d9930f69c1f..1d4f30512df06b 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h @@ -77,9 +77,6 @@ struct TransposedDimensionAccessHelperJit : DimensionAccessHelperJit, Transposed } }; -struct GQA_configuration { -}; - struct sdpa_configuration { int64_t head_size = -1; int64_t heads_num = -1; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp index 5ea3ccd4224c7c..a80f3c31dfc8f3 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp @@ -35,6 +35,9 @@ JitConstants SDPAKernelRef::GetJitConstants(const sdpa_params& params) const { auto acc_dt = params.inputs[0].GetDType(); jit.Merge(MakeTypeJitConstants(acc_dt, "ACCUMULATOR")); + TransposedDimensionAccessHelperJit dims_q(params.inputs[0], params.input0_order); + jit.AddConstant(MakeJitConstant("HEAD_SIZE", dims_q.x())); + return jit; } diff --git a/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.hpp b/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.hpp index b43b74adf396d5..a9b3ebe05317f3 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.hpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.hpp @@ -11,7 +11,7 @@ namespace intel_gpu { class TransposeFusion: public ov::pass::GraphRewrite { public: - OPENVINO_RTTI("TransposeMatMulFusion", "0"); + OPENVINO_RTTI("TransposeFusion", "0"); TransposeFusion(); }; diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 68b5b02780fd0f..5d8db18151cd4e 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -143,19 +143,6 @@ #include "transformations/rt_info/keep_const_precision.hpp" #include "transformations/smart_reshape/matmul_sr.hpp" -template -T convert_to(const std::string &str) { - std::istringstream ss(str); - T res; - ss >> res; - return res; -} - -template <> -std::string convert_to(const std::string &str) { - return str; -} - namespace { template static bool disable_reduce_decomposition(const std::shared_ptr node) { @@ -318,22 +305,15 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); - // Disable SDPA decomposition once additional transformations are added: - // 1) Input/Output Transpose fusion - // 2) Indirect inputs support - // 3) GQA related optimization (Broadcast fusion) pass_config->set_callback([&](const std::shared_ptr node){ + if (!config.get_property(ov::intel_gpu::hint::enable_sdpa_optimization)) + return false; + auto sdpa = std::dynamic_pointer_cast(node); const auto& query_ps = sdpa->get_input_partial_shape(0); const auto& key_ps = sdpa->get_input_partial_shape(1); const auto& value_ps = sdpa->get_input_partial_shape(2); - if (const auto env_var = std::getenv("USE_SDPA")) { - bool use_sdpa = convert_to(env_var); - std::cout << "Use SDPA forced to " << (use_sdpa ? "TRUE" : "FALSE") << "\n"; - return use_sdpa; - } - // Known limitations: // - SDPA impl could be slower in non-LLM scenarios than decomposed version if (func->get_variables().size() == 0) diff --git a/src/plugins/intel_gpu/src/runtime/execution_config.cpp b/src/plugins/intel_gpu/src/runtime/execution_config.cpp index 8a57759bff9413..66b8d3e70cab1f 100644 --- a/src/plugins/intel_gpu/src/runtime/execution_config.cpp +++ b/src/plugins/intel_gpu/src/runtime/execution_config.cpp @@ -50,6 +50,7 @@ void ExecutionConfig::set_default() { std::make_tuple(ov::intel_gpu::hint::host_task_priority, ov::hint::Priority::MEDIUM), std::make_tuple(ov::intel_gpu::hint::queue_throttle, ov::intel_gpu::hint::ThrottleLevel::MEDIUM), std::make_tuple(ov::intel_gpu::hint::queue_priority, ov::hint::Priority::MEDIUM), + std::make_tuple(ov::intel_gpu::hint::enable_sdpa_optimization, false), std::make_tuple(ov::intel_gpu::enable_loop_unrolling, true), std::make_tuple(ov::intel_gpu::disable_winograd_convolution, false), std::make_tuple(ov::internal::exclusive_async_requests, false), From b51bc3bd3ac97c8dd9ac1c5160d72c7bf744cb29 Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Wed, 22 May 2024 13:22:41 +0400 Subject: [PATCH 4/4] Remove SDPA tests from skip_config --- .../skip_tests_config.cpp | 2 - .../dynamic/scaled_dot_product_attention.cpp | 79 +++++++++++++------ 2 files changed, 56 insertions(+), 25 deletions(-) diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 1d1b9e8dd9e9a5..11360f1fe80faa 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -205,8 +205,6 @@ std::vector disabledTestPatterns() { R"(.*smoke_RDFT_5d_last_axis/RDFTLayerTest.Inference/IS=\(10.4.8.2.5\)_modelType=f32_Axes=\(0.1.2.3.4\)_SignalSize=\(\).*)", // Issue: 136862 R"(.*smoke_ConditionGPUTest_static/StaticConditionLayerGPUTest.CompareWithRefs/IS=\(3.6\)_netPRC=i8_ifCond=PARAM_targetDevice=GPU_.*)", - // Uncomment once SDPA decomposition is disabled - R"(.*smoke_ScaledAttn_GPU.*)", #if defined(_WIN32) R"(.*smoke_RemoteTensor/OVRemoteTensorBatched_Test.NV12toBGR_buffer/(num_batch_4|num_batch_2).*)", R"(.*smoke_Check/ConstantResultSubgraphTest.Inference/SubgraphType=SINGLE_COMPONENT_IS=\[1,3,10,10\]_IT=i16_Device=GPU.*)", diff --git a/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp index 29498e65965d37..3b97cde5cfe636 100644 --- a/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp +++ b/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp @@ -16,6 +16,8 @@ #include "openvino/op/result.hpp" #include "openvino/op/matmul.hpp" +#include "intel_gpu/runtime/execution_config.hpp" + namespace { using ov::test::InputShape; @@ -103,23 +105,55 @@ void ScaledAttnLayerGPUTest::SetUp() { inputParams.back()->set_friendly_name("scale"); } } + + // Add artificial read/value operations to the model to trigger the enabling of the SDPA operation + auto read_key = std::make_shared(inputParams.at(1), "v0"); + auto assign_key = std::make_shared(read_key, "v0"); + + auto read_value = std::make_shared(inputParams.at(2), "v0"); + auto assign_value = std::make_shared(read_value, "v0"); + ov::OutputVector inputs; - for (auto& input : inputParams) { - inputs.push_back(input); + for (size_t i = 0; i < inputParams.size(); i++) { + if (i == 1) + inputs.push_back(read_key); + else if (i == 2) + inputs.push_back(read_value); + else + inputs.push_back(inputParams[i]); } + auto sdp = std::make_shared(inputs, is_causal); sdp->set_friendly_name("sdpa"); auto output = std::make_shared(sdp->output(0)); - function = std::make_shared(ov::OutputVector{output}, inputParams, "sdpa_model"); + function = std::make_shared(ov::OutputVector{output}, ov::SinkVector{assign_key, assign_value}, inputParams, "sdpa_model"); functionRefs = function->clone(); ov::pass::Manager manager; - // decompose ScaledDotProductAttention + // Decompose ScaledDotProductAttention manager.register_pass(); manager.run_passes(functionRefs); + + // Enable SDPA + configuration.insert(ov::intel_gpu::hint::enable_sdpa_optimization(true)); + + auto it = std::find_if(inputShapes[1].second.begin(), inputShapes[1].second.end(), [&](const ov::Shape& shape){ + return shape[2] >= 384; + }); + + bool has_long_seq = it != inputShapes[1].second.end(); + if (inType == ov::element::f16) { + if (has_long_seq) { + abs_threshold = 0.025; + rel_threshold = 0.025; + } else { + abs_threshold = 0.005; + rel_threshold = 0.005; + } + } } void ScaledAttnLayerGPUTest::generate_inputs(const std::vector& targetInputStaticShapes) { @@ -171,36 +205,35 @@ const std::vector> shapes{ { // q shape {ov::test::InputShape{ov::PartialShape{-1, 5, -1, 64}, - {ov::Shape{2, 5, 100, 64}, ov::Shape{2, 5, 1, 64}, ov::Shape{2, 5, 512, 64}}} + {ov::Shape{2, 5, 100, 64}, ov::Shape{2, 5, 1, 64}, ov::Shape{2, 5, 384, 64}}} }, // kv shape {ov::test::InputShape{ov::PartialShape{-1, 5, -1, 64}, - {ov::Shape{2, 5, 100, 64}, ov::Shape{2, 5, 1, 64}, ov::Shape{2, 5, 512, 64}}} + {ov::Shape{2, 5, 100, 64}, ov::Shape{2, 5, 1, 64}, ov::Shape{2, 5, 384, 64}}} }, // attn shape: [B, 1, -1, L0+L1] {ov::test::InputShape{ov::PartialShape{-1, 1, -1, -1}, - {ov::Shape{1, 1, 100, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 512, 512}}} + {ov::Shape{1, 1, 100, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 384, 384}}} }, }, - // Currently unsupported // heads number of kv is 1, attn mask: [B, H, L1, L0+L1] - // { - // // q shape - // {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, - // {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} - // }, - // // kv shape - // {ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64}, - // {ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}} - // }, - // // attn shape - // {ov::test::InputShape{ov::PartialShape{-1, 8, -1, -1}, - // {ov::Shape{1, 8, 100, 100}, ov::Shape{1, 8, 1, 1}, ov::Shape{2, 8, 10, 10}}} - // }, - // }, + { + // q shape + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, + {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + }, + // kv shape + {ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64}, + {ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}} + }, + // attn shape + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, -1}, + {ov::Shape{1, 8, 100, 100}, ov::Shape{1, 8, 1, 1}, ov::Shape{2, 8, 10, 10}}} + }, + }, }; -const auto params = testing::Combine(testing::Values(/* ov::element::f16, */ov::element::f32), +const auto params = testing::Combine(testing::Values(ov::element::f16 /*, ov::element::f32 */), testing::ValuesIn(shapes), testing::Values(true, false), testing::Values(true, false),