-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GPU] Add SDPA impl; SDPA input transpose fusion support; GQA optimiz…
…ation
- Loading branch information
Showing
47 changed files
with
3,598 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/core/partial_shape.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
|
||
namespace ov { | ||
namespace intel_gpu { | ||
namespace op { | ||
|
||
class SDPA : public ov::op::v13::ScaledDotProductAttention { | ||
public: | ||
OPENVINO_OP("SDPA", "gpu_opset"); | ||
|
||
SDPA() = default; | ||
|
||
SDPA(const ov::Output<Node>& Q, | ||
const ov::Output<Node>& K, | ||
const ov::Output<Node>& V, | ||
const std::vector<int64_t>& order_q, | ||
const std::vector<int64_t>& order_k, | ||
const std::vector<int64_t>& order_v, | ||
const std::vector<int64_t>& order_out, | ||
const bool is_causal, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
SDPA(const ov::Output<Node>& Q, | ||
const ov::Output<Node>& K, | ||
const ov::Output<Node>& V, | ||
const ov::Output<Node>& attn_mask, | ||
const std::vector<int64_t>& order_q, | ||
const std::vector<int64_t>& order_k, | ||
const std::vector<int64_t>& order_v, | ||
const std::vector<int64_t>& order_out, | ||
const bool is_causal, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
SDPA(const ov::Output<Node>& Q, | ||
const ov::Output<Node>& K, | ||
const ov::Output<Node>& V, | ||
const ov::Output<Node>& attn_mask, | ||
const ov::Output<Node>& scale, | ||
const std::vector<int64_t>& order_q, | ||
const std::vector<int64_t>& order_k, | ||
const std::vector<int64_t>& order_v, | ||
const std::vector<int64_t>& order_out, | ||
const bool is_causal, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
bool visit_attributes(ov::AttributeVisitor &visitor) override; | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
bool get_causal() const { return m_is_causal; } | ||
|
||
std::vector<int64_t> get_input0_transpose_order() const { return m_order_q; } | ||
std::vector<int64_t> get_input1_transpose_order() const { return m_order_k; } | ||
std::vector<int64_t> get_input2_transpose_order() const { return m_order_v; } | ||
std::vector<int64_t> get_output_transpose_order() const { return m_order_out; } | ||
ov::element::Type get_output_type() const { return m_output_type; } | ||
|
||
static std::vector<int64_t> default_order(size_t rank) { | ||
std::vector<int64_t> order(rank); | ||
std::iota(order.begin(), order.end(), 0); | ||
return order; | ||
} | ||
|
||
protected: | ||
std::vector<int64_t> m_order_q; | ||
std::vector<int64_t> m_order_k; | ||
std::vector<int64_t> m_order_v; | ||
std::vector<int64_t> m_order_out; | ||
bool m_is_causal; | ||
ov::element::Type m_output_type; | ||
}; | ||
|
||
std::vector<ov::PartialShape> shape_infer(const SDPA* op, | ||
std::vector<ov::PartialShape> input_shapes, | ||
const std::vector<int64_t>& order_q, | ||
const std::vector<int64_t>& order_k, | ||
const std::vector<int64_t>& order_v, | ||
const std::vector<int64_t>& order_out); | ||
|
||
|
||
} // namespace op | ||
} // namespace intel_gpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
#include "primitive.hpp" | ||
|
||
namespace cldnn { | ||
|
||
struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_attention> { | ||
CLDNN_DECLARE_PRIMITIVE(scaled_dot_product_attention) | ||
|
||
scaled_dot_product_attention() : primitive_base("", {}) {} | ||
|
||
/// @brief Constructs scaled_dot_product_attention primitive. | ||
/// @param id This primitive id. | ||
/// @param inputs Input data primitives id (query, keys, values, [attention_mask], [scale]). | ||
/// @param is_causal If true, assumes causal attention masking. In this case attention_mask input is ignored. | ||
scaled_dot_product_attention(const primitive_id& id, | ||
const std::vector<cldnn::input_info> inputs, | ||
bool is_causal, | ||
const std::vector<int64_t>& input_q_transpose_order = {}, | ||
const std::vector<int64_t>& input_k_transpose_order = {}, | ||
const std::vector<int64_t>& input_v_transpose_order = {}, | ||
const std::vector<int64_t>& output_transpose_order = {}, | ||
const padding& output_padding = padding()) | ||
: primitive_base(id, inputs, {output_padding}) | ||
, is_causal(is_causal) | ||
, has_attn_mask_input(inputs.size() > 3) | ||
, has_scale_input(inputs.size() > 4) | ||
, input_q_transpose_order(input_q_transpose_order) | ||
, input_k_transpose_order(input_k_transpose_order) | ||
, input_v_transpose_order(input_v_transpose_order) | ||
, output_transpose_order(output_transpose_order) {} | ||
|
||
|
||
bool is_causal = false; | ||
bool has_attn_mask_input = false; | ||
bool has_scale_input = false; | ||
|
||
std::vector<int64_t> input_q_transpose_order; | ||
std::vector<int64_t> input_k_transpose_order; | ||
std::vector<int64_t> input_v_transpose_order; | ||
std::vector<int64_t> output_transpose_order; | ||
|
||
size_t hash() const override { | ||
size_t seed = primitive::hash(); | ||
seed = hash_combine(seed, is_causal); | ||
seed = hash_combine(seed, has_attn_mask_input); | ||
seed = hash_combine(seed, has_scale_input); | ||
seed = hash_range(seed, input_q_transpose_order.begin(), input_q_transpose_order.end()); | ||
seed = hash_range(seed, input_k_transpose_order.begin(), input_k_transpose_order.end()); | ||
seed = hash_range(seed, input_v_transpose_order.begin(), input_v_transpose_order.end()); | ||
seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end()); | ||
return seed; | ||
} | ||
|
||
bool operator==(const primitive& rhs) const override { | ||
if (!compare_common_params(rhs)) | ||
return false; | ||
|
||
auto rhs_casted = downcast<const scaled_dot_product_attention>(rhs); | ||
|
||
return is_causal == rhs_casted.is_causal && | ||
has_attn_mask_input == rhs_casted.has_attn_mask_input && | ||
has_scale_input == rhs_casted.has_scale_input && | ||
input_q_transpose_order == rhs_casted.input_q_transpose_order && | ||
input_k_transpose_order == rhs_casted.input_k_transpose_order && | ||
input_v_transpose_order == rhs_casted.input_v_transpose_order && | ||
output_transpose_order == rhs_casted.output_transpose_order; | ||
} | ||
|
||
void save(BinaryOutputBuffer& ob) const override { | ||
primitive_base<scaled_dot_product_attention>::save(ob); | ||
ob << is_causal; | ||
ob << has_attn_mask_input; | ||
ob << has_scale_input; | ||
ob << input_q_transpose_order; | ||
ob << input_k_transpose_order; | ||
ob << input_v_transpose_order; | ||
ob << output_transpose_order; | ||
} | ||
|
||
void load(BinaryInputBuffer& ib) override { | ||
primitive_base<scaled_dot_product_attention>::load(ib); | ||
ib >> is_causal; | ||
ib >> has_attn_mask_input; | ||
ib >> has_scale_input; | ||
ib >> input_q_transpose_order; | ||
ib >> input_k_transpose_order; | ||
ib >> input_v_transpose_order; | ||
ib >> output_transpose_order; | ||
} | ||
}; | ||
} // namespace cldnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
135 changes: 135 additions & 0 deletions
135
src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "primitive_base.hpp" | ||
#include "scaled_dot_product_attention_inst.h" | ||
#include "sdpa/sdpa_kernel_selector.h" | ||
#include "sdpa/sdpa_kernel_base.h" | ||
|
||
namespace cldnn { | ||
namespace ocl { | ||
struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl<scaled_dot_product_attention> { | ||
using parent = typed_primitive_impl_ocl<scaled_dot_product_attention>; | ||
using parent::parent; | ||
using kernel_selector_t = kernel_selector::sdpa_kernel_selector; | ||
using kernel_params_t = kernel_selector::sdpa_params; | ||
|
||
DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::scaled_dot_product_attention_impl) | ||
|
||
std::unique_ptr<primitive_impl> clone() const override { | ||
return make_unique<scaled_dot_product_attention_impl>(*this); | ||
} | ||
|
||
void load(BinaryInputBuffer& ib) override { | ||
parent::load(ib); | ||
if (is_dynamic()) { | ||
auto& kernel_selector = kernel_selector_t::Instance(); | ||
auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName); | ||
kernel_impl->GetUpdateDispatchDataFunc(_kernel_data); | ||
} | ||
} | ||
|
||
static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) { | ||
kernel_selector::sdpa_configuration config; | ||
|
||
auto transpose_pshape = [](const ov::PartialShape& pshape, const std::vector<int64_t>& order) { | ||
auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank()); | ||
for (size_t i = 0; i < order.size(); i++) { | ||
transposed_pshape[i] = pshape[order[i]]; | ||
} | ||
return transposed_pshape; | ||
}; | ||
|
||
const auto& prim = impl_param.typed_desc<scaled_dot_product_attention>(); | ||
const auto query_shape = transpose_pshape(impl_param.get_input_layout(0).get_partial_shape(), prim->input_q_transpose_order); | ||
const auto key_shape = transpose_pshape(impl_param.get_input_layout(1).get_partial_shape(), prim->input_k_transpose_order); | ||
const auto value_shape = transpose_pshape(impl_param.get_input_layout(2).get_partial_shape(), prim->input_v_transpose_order); | ||
|
||
OPENVINO_ASSERT(key_shape == value_shape, "[GPU] The shapes of key and value inputs are expected to be equal"); | ||
for (size_t i = 0; i < query_shape.size(); ++i) { | ||
if (query_shape[i].is_static() && key_shape[i].is_static() && value_shape[i].is_static()) { | ||
if (query_shape[i].get_length() > key_shape[i].get_length()) { | ||
config.broadcast_axis = prim->input_k_transpose_order[i]; | ||
config.group_size = query_shape[i].get_length() / key_shape[i].get_length(); | ||
} | ||
} | ||
} | ||
|
||
if (query_shape[query_shape.size() - 1].is_static()) | ||
config.head_size = query_shape[query_shape.size() - 1].get_length(); | ||
|
||
config.is_causal = prim->is_causal; | ||
|
||
return config; | ||
} | ||
|
||
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic) { | ||
auto params = get_default_params<kernel_selector::sdpa_params>(impl_param, is_dynamic); | ||
|
||
const auto inputs_num = impl_param.input_layouts.size(); | ||
params.inputs.resize(inputs_num); | ||
for (size_t i = 0; i < inputs_num; i++) { | ||
params.inputs[i] = convert_data_tensor(impl_param.get_input_layout(i)); | ||
} | ||
|
||
params.conf = get_sdpa_configuration(impl_param); | ||
|
||
const auto& prim = impl_param.typed_desc<scaled_dot_product_attention>(); | ||
params.input0_order = prim->input_q_transpose_order; | ||
params.input1_order = prim->input_k_transpose_order; | ||
params.input2_order = prim->input_v_transpose_order; | ||
params.output_order = prim->output_transpose_order; | ||
|
||
params.set_dynamic_shape_offsets(); | ||
|
||
return params; | ||
} | ||
|
||
static std::unique_ptr<primitive_impl> create(const typed_program_node<scaled_dot_product_attention>& arg, const kernel_impl_params& impl_param) { | ||
auto sdpa_kernel_params = get_kernel_params(impl_param, impl_param.is_dynamic()); | ||
auto& sdpa_kernel_selector = kernel_selector_t::Instance(); | ||
auto kd = sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params); | ||
|
||
return cldnn::make_unique<scaled_dot_product_attention_impl>(kd); | ||
} | ||
|
||
void update_dispatch_data(const kernel_impl_params& impl_param) override { | ||
auto kernel_params = get_kernel_params(impl_param, true); | ||
(_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data); | ||
} | ||
}; | ||
|
||
namespace detail { | ||
|
||
attach_scaled_dot_product_attention_impl::attach_scaled_dot_product_attention_impl() { | ||
using sdpa_prim = scaled_dot_product_attention; | ||
|
||
auto types = { | ||
data_types::f32, | ||
data_types::f16, | ||
}; | ||
|
||
auto formats = { | ||
format::bfyx, | ||
}; | ||
|
||
implementation_map<sdpa_prim>::add(impl_types::ocl, | ||
shape_types::static_shape, | ||
scaled_dot_product_attention_impl::create, | ||
types, | ||
formats); | ||
|
||
implementation_map<sdpa_prim>::add(impl_types::ocl, | ||
shape_types::dynamic_shape, | ||
scaled_dot_product_attention_impl::create, | ||
types, | ||
formats); | ||
} | ||
|
||
} // namespace detail | ||
} // namespace ocl | ||
} // namespace cldnn | ||
|
||
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::scaled_dot_product_attention_impl) | ||
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::scaled_dot_product_attention) |
Oops, something went wrong.