-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPU] Add initial SDPA implementation #24466
Merged
Merged
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/core/partial_shape.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
|
||
namespace ov { | ||
namespace intel_gpu { | ||
namespace op { | ||
|
||
class SDPA : public ov::op::v13::ScaledDotProductAttention { | ||
public: | ||
OPENVINO_OP("SDPA", "gpu_opset"); | ||
|
||
SDPA() = default; | ||
|
||
SDPA(const ov::Output<Node>& Q, | ||
const ov::Output<Node>& K, | ||
const ov::Output<Node>& V, | ||
const std::vector<int64_t>& order_q, | ||
const std::vector<int64_t>& order_k, | ||
const std::vector<int64_t>& order_v, | ||
const std::vector<int64_t>& order_out, | ||
const bool is_causal, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
SDPA(const ov::Output<Node>& Q, | ||
const ov::Output<Node>& K, | ||
const ov::Output<Node>& V, | ||
const ov::Output<Node>& attn_mask, | ||
const std::vector<int64_t>& order_q, | ||
const std::vector<int64_t>& order_k, | ||
const std::vector<int64_t>& order_v, | ||
const std::vector<int64_t>& order_out, | ||
const bool is_causal, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
SDPA(const ov::Output<Node>& Q, | ||
const ov::Output<Node>& K, | ||
const ov::Output<Node>& V, | ||
const ov::Output<Node>& attn_mask, | ||
const ov::Output<Node>& scale, | ||
const std::vector<int64_t>& order_q, | ||
const std::vector<int64_t>& order_k, | ||
const std::vector<int64_t>& order_v, | ||
const std::vector<int64_t>& order_out, | ||
const bool is_causal, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
bool visit_attributes(ov::AttributeVisitor &visitor) override; | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
bool get_causal() const { return m_is_causal; } | ||
|
||
std::vector<int64_t> get_input0_transpose_order() const { return m_order_q; } | ||
std::vector<int64_t> get_input1_transpose_order() const { return m_order_k; } | ||
std::vector<int64_t> get_input2_transpose_order() const { return m_order_v; } | ||
std::vector<int64_t> get_output_transpose_order() const { return m_order_out; } | ||
ov::element::Type get_output_type() const { return m_output_type; } | ||
|
||
static std::vector<int64_t> default_order(size_t rank) { | ||
std::vector<int64_t> order(rank); | ||
std::iota(order.begin(), order.end(), 0); | ||
return order; | ||
} | ||
|
||
protected: | ||
std::vector<int64_t> m_order_q; | ||
std::vector<int64_t> m_order_k; | ||
std::vector<int64_t> m_order_v; | ||
std::vector<int64_t> m_order_out; | ||
bool m_is_causal; | ||
ov::element::Type m_output_type; | ||
}; | ||
|
||
std::vector<ov::PartialShape> shape_infer(const SDPA* op, | ||
std::vector<ov::PartialShape> input_shapes, | ||
const std::vector<int64_t>& order_q, | ||
const std::vector<int64_t>& order_k, | ||
const std::vector<int64_t>& order_v, | ||
const std::vector<int64_t>& order_out); | ||
|
||
|
||
} // namespace op | ||
} // namespace intel_gpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
#include "primitive.hpp" | ||
|
||
namespace cldnn { | ||
|
||
struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_attention> { | ||
CLDNN_DECLARE_PRIMITIVE(scaled_dot_product_attention) | ||
|
||
scaled_dot_product_attention() : primitive_base("", {}) {} | ||
|
||
/// @brief Constructs scaled_dot_product_attention primitive. | ||
/// @param id This primitive id. | ||
/// @param inputs Input data primitives id (query, keys, values, [attention_mask], [scale]). | ||
/// @param is_causal If true, assumes causal attention masking. In this case attention_mask input is ignored. | ||
scaled_dot_product_attention(const primitive_id& id, | ||
const std::vector<cldnn::input_info> inputs, | ||
bool is_causal, | ||
const std::vector<int64_t>& input_q_transpose_order = {}, | ||
const std::vector<int64_t>& input_k_transpose_order = {}, | ||
const std::vector<int64_t>& input_v_transpose_order = {}, | ||
const std::vector<int64_t>& output_transpose_order = {}, | ||
const padding& output_padding = padding()) | ||
: primitive_base(id, inputs, {output_padding}) | ||
, is_causal(is_causal) | ||
, has_attn_mask_input(inputs.size() > 3) | ||
, has_scale_input(inputs.size() > 4) | ||
, input_q_transpose_order(input_q_transpose_order) | ||
, input_k_transpose_order(input_k_transpose_order) | ||
, input_v_transpose_order(input_v_transpose_order) | ||
, output_transpose_order(output_transpose_order) {} | ||
|
||
|
||
bool is_causal = false; | ||
bool has_attn_mask_input = false; | ||
bool has_scale_input = false; | ||
|
||
std::vector<int64_t> input_q_transpose_order; | ||
std::vector<int64_t> input_k_transpose_order; | ||
std::vector<int64_t> input_v_transpose_order; | ||
std::vector<int64_t> output_transpose_order; | ||
|
||
size_t hash() const override { | ||
size_t seed = primitive::hash(); | ||
seed = hash_combine(seed, is_causal); | ||
seed = hash_combine(seed, has_attn_mask_input); | ||
seed = hash_combine(seed, has_scale_input); | ||
seed = hash_range(seed, input_q_transpose_order.begin(), input_q_transpose_order.end()); | ||
seed = hash_range(seed, input_k_transpose_order.begin(), input_k_transpose_order.end()); | ||
seed = hash_range(seed, input_v_transpose_order.begin(), input_v_transpose_order.end()); | ||
seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end()); | ||
return seed; | ||
} | ||
|
||
bool operator==(const primitive& rhs) const override { | ||
if (!compare_common_params(rhs)) | ||
return false; | ||
|
||
auto rhs_casted = downcast<const scaled_dot_product_attention>(rhs); | ||
|
||
return is_causal == rhs_casted.is_causal && | ||
has_attn_mask_input == rhs_casted.has_attn_mask_input && | ||
has_scale_input == rhs_casted.has_scale_input && | ||
input_q_transpose_order == rhs_casted.input_q_transpose_order && | ||
input_k_transpose_order == rhs_casted.input_k_transpose_order && | ||
input_v_transpose_order == rhs_casted.input_v_transpose_order && | ||
output_transpose_order == rhs_casted.output_transpose_order; | ||
} | ||
|
||
void save(BinaryOutputBuffer& ob) const override { | ||
primitive_base<scaled_dot_product_attention>::save(ob); | ||
ob << is_causal; | ||
ob << has_attn_mask_input; | ||
ob << has_scale_input; | ||
ob << input_q_transpose_order; | ||
ob << input_k_transpose_order; | ||
ob << input_v_transpose_order; | ||
ob << output_transpose_order; | ||
} | ||
|
||
void load(BinaryInputBuffer& ib) override { | ||
primitive_base<scaled_dot_product_attention>::load(ib); | ||
ib >> is_causal; | ||
ib >> has_attn_mask_input; | ||
ib >> has_scale_input; | ||
ib >> input_q_transpose_order; | ||
ib >> input_k_transpose_order; | ||
ib >> input_v_transpose_order; | ||
ib >> output_transpose_order; | ||
} | ||
}; | ||
} // namespace cldnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
138 changes: 138 additions & 0 deletions
138
src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "primitive_base.hpp" | ||
#include "scaled_dot_product_attention_inst.h" | ||
#include "sdpa/sdpa_kernel_selector.h" | ||
#include "sdpa/sdpa_kernel_base.h" | ||
|
||
namespace cldnn { | ||
namespace ocl { | ||
struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl<scaled_dot_product_attention> { | ||
using parent = typed_primitive_impl_ocl<scaled_dot_product_attention>; | ||
using parent::parent; | ||
using kernel_selector_t = kernel_selector::sdpa_kernel_selector; | ||
using kernel_params_t = kernel_selector::sdpa_params; | ||
|
||
DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::scaled_dot_product_attention_impl) | ||
|
||
std::unique_ptr<primitive_impl> clone() const override { | ||
return make_unique<scaled_dot_product_attention_impl>(*this); | ||
} | ||
|
||
void load(BinaryInputBuffer& ib) override { | ||
parent::load(ib); | ||
if (is_dynamic()) { | ||
auto& kernel_selector = kernel_selector_t::Instance(); | ||
auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName); | ||
kernel_impl->GetUpdateDispatchDataFunc(_kernel_data); | ||
} | ||
} | ||
|
||
static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) { | ||
kernel_selector::sdpa_configuration config; | ||
|
||
auto transpose_pshape = [](const ov::PartialShape& pshape, const std::vector<int64_t>& order) { | ||
if (order.empty()) | ||
return pshape; | ||
|
||
auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank()); | ||
for (size_t i = 0; i < order.size(); i++) { | ||
transposed_pshape[i] = pshape[order[i]]; | ||
} | ||
return transposed_pshape; | ||
}; | ||
|
||
const auto& prim = impl_param.typed_desc<scaled_dot_product_attention>(); | ||
const auto query_shape = transpose_pshape(impl_param.get_input_layout(0).get_partial_shape(), prim->input_q_transpose_order); | ||
const auto key_shape = transpose_pshape(impl_param.get_input_layout(1).get_partial_shape(), prim->input_k_transpose_order); | ||
const auto value_shape = transpose_pshape(impl_param.get_input_layout(2).get_partial_shape(), prim->input_v_transpose_order); | ||
|
||
OPENVINO_ASSERT(key_shape == value_shape, "[GPU] The shapes of key and value inputs are expected to be equal"); | ||
for (size_t i = 0; i < query_shape.size(); ++i) { | ||
if (query_shape[i].is_static() && key_shape[i].is_static() && value_shape[i].is_static()) { | ||
if (query_shape[i].get_length() > key_shape[i].get_length()) { | ||
config.broadcast_axis = prim->input_k_transpose_order[i]; | ||
config.group_size = query_shape[i].get_length() / key_shape[i].get_length(); | ||
} | ||
} | ||
} | ||
|
||
if (query_shape[query_shape.size() - 1].is_static()) | ||
config.head_size = query_shape[query_shape.size() - 1].get_length(); | ||
|
||
config.is_causal = prim->is_causal; | ||
|
||
return config; | ||
} | ||
|
||
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic) { | ||
auto params = get_default_params<kernel_selector::sdpa_params>(impl_param, is_dynamic); | ||
|
||
const auto inputs_num = impl_param.input_layouts.size(); | ||
params.inputs.resize(inputs_num); | ||
for (size_t i = 0; i < inputs_num; i++) { | ||
params.inputs[i] = convert_data_tensor(impl_param.get_input_layout(i)); | ||
} | ||
|
||
params.conf = get_sdpa_configuration(impl_param); | ||
|
||
const auto& prim = impl_param.typed_desc<scaled_dot_product_attention>(); | ||
params.input0_order = prim->input_q_transpose_order; | ||
params.input1_order = prim->input_k_transpose_order; | ||
params.input2_order = prim->input_v_transpose_order; | ||
params.output_order = prim->output_transpose_order; | ||
|
||
params.set_dynamic_shape_offsets(); | ||
|
||
return params; | ||
} | ||
|
||
static std::unique_ptr<primitive_impl> create(const typed_program_node<scaled_dot_product_attention>& arg, const kernel_impl_params& impl_param) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this impl have any difference with common version in base class? |
||
auto sdpa_kernel_params = get_kernel_params(impl_param, impl_param.is_dynamic()); | ||
auto& sdpa_kernel_selector = kernel_selector_t::Instance(); | ||
auto kd = sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params); | ||
|
||
return cldnn::make_unique<scaled_dot_product_attention_impl>(kd); | ||
} | ||
|
||
void update_dispatch_data(const kernel_impl_params& impl_param) override { | ||
auto kernel_params = get_kernel_params(impl_param, true); | ||
(_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data); | ||
} | ||
}; | ||
|
||
namespace detail { | ||
|
||
attach_scaled_dot_product_attention_impl::attach_scaled_dot_product_attention_impl() { | ||
using sdpa_prim = scaled_dot_product_attention; | ||
|
||
auto types = { | ||
data_types::f32, | ||
data_types::f16, | ||
}; | ||
|
||
auto formats = { | ||
format::bfyx, | ||
}; | ||
|
||
implementation_map<sdpa_prim>::add(impl_types::ocl, | ||
shape_types::static_shape, | ||
scaled_dot_product_attention_impl::create, | ||
types, | ||
formats); | ||
|
||
implementation_map<sdpa_prim>::add(impl_types::ocl, | ||
shape_types::dynamic_shape, | ||
scaled_dot_product_attention_impl::create, | ||
types, | ||
formats); | ||
} | ||
|
||
} // namespace detail | ||
} // namespace ocl | ||
} // namespace cldnn | ||
|
||
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::scaled_dot_product_attention_impl) | ||
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::scaled_dot_product_attention) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ilya-lavrenov currently it's disabled by default. However, in the final version, it will depend on whether support for indirect inputs is implemented for SDPA in time or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows to switch on for models where indirect inputs are not required