Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Add initial SDPA implementation #24466

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ static constexpr Property<ov::hint::Priority> host_task_priority{"GPU_HOST_TASK_
* @ingroup ov_runtime_ocl_gpu_prop_cpp_api
*/
static constexpr Property<int64_t> available_device_mem{"AVAILABLE_DEVICE_MEM_SIZE"};

/**
* @brief Turning on this key disables SDPA operation decomposition and keeps SDPA operation in the graph.
* Enabling SDPA optimization may provide performance improvements and memory usage reduction.
* This key serves as a recommendation and may be ignored in known sub-optimal cases.
* @ingroup ov_runtime_ocl_gpu_prop_cpp_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilya-lavrenov currently it's disabled by default. However, in the final version, it will depend on whether support for indirect inputs is implemented for SDPA in time or not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows to switch on for models where indirect inputs are not required

*/
static constexpr Property<bool> enable_sdpa_optimization{"GPU_ENABLE_SDPA_OPTIMIZATION"};
} // namespace hint

/**
Expand Down
94 changes: 94 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/node.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/op/op.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"

namespace ov {
namespace intel_gpu {
namespace op {

class SDPA : public ov::op::v13::ScaledDotProductAttention {
public:
OPENVINO_OP("SDPA", "gpu_opset");

SDPA() = default;

SDPA(const ov::Output<Node>& Q,
const ov::Output<Node>& K,
const ov::Output<Node>& V,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const bool is_causal,
const ov::element::Type output_type = ov::element::undefined);

SDPA(const ov::Output<Node>& Q,
const ov::Output<Node>& K,
const ov::Output<Node>& V,
const ov::Output<Node>& attn_mask,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const bool is_causal,
const ov::element::Type output_type = ov::element::undefined);

SDPA(const ov::Output<Node>& Q,
const ov::Output<Node>& K,
const ov::Output<Node>& V,
const ov::Output<Node>& attn_mask,
const ov::Output<Node>& scale,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const bool is_causal,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor &visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

bool get_causal() const { return m_is_causal; }

std::vector<int64_t> get_input0_transpose_order() const { return m_order_q; }
std::vector<int64_t> get_input1_transpose_order() const { return m_order_k; }
std::vector<int64_t> get_input2_transpose_order() const { return m_order_v; }
std::vector<int64_t> get_output_transpose_order() const { return m_order_out; }
ov::element::Type get_output_type() const { return m_output_type; }

static std::vector<int64_t> default_order(size_t rank) {
std::vector<int64_t> order(rank);
std::iota(order.begin(), order.end(), 0);
return order;
}

protected:
std::vector<int64_t> m_order_q;
std::vector<int64_t> m_order_k;
std::vector<int64_t> m_order_v;
std::vector<int64_t> m_order_out;
bool m_is_causal;
ov::element::Type m_output_type;
};

std::vector<ov::PartialShape> shape_infer(const SDPA* op,
std::vector<ov::PartialShape> input_shapes,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out);


} // namespace op
} // namespace intel_gpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ REGISTER_FACTORY(v12, ScatterElementsUpdate);

// ------------------------------ Supported v13 ops ----------------------------- //
REGISTER_FACTORY(v13, Multinomial);
REGISTER_FACTORY(v13, ScaledDotProductAttention);

// ------------------------------ Supported v14 ops ----------------------------- //
REGISTER_FACTORY(v14, ROIAlignRotated);
Expand All @@ -283,3 +284,4 @@ REGISTER_FACTORY(internal, SwiGLU);
REGISTER_FACTORY(internal, IndirectGemm);
REGISTER_FACTORY(internal, Convolution);
REGISTER_FACTORY(internal, Placeholder);
REGISTER_FACTORY(internal, SDPA);
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include "primitive.hpp"

namespace cldnn {

struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_attention> {
CLDNN_DECLARE_PRIMITIVE(scaled_dot_product_attention)

scaled_dot_product_attention() : primitive_base("", {}) {}

/// @brief Constructs scaled_dot_product_attention primitive.
/// @param id This primitive id.
/// @param inputs Input data primitives id (query, keys, values, [attention_mask], [scale]).
/// @param is_causal If true, assumes causal attention masking. In this case attention_mask input is ignored.
scaled_dot_product_attention(const primitive_id& id,
const std::vector<cldnn::input_info> inputs,
bool is_causal,
const std::vector<int64_t>& input_q_transpose_order = {},
const std::vector<int64_t>& input_k_transpose_order = {},
const std::vector<int64_t>& input_v_transpose_order = {},
const std::vector<int64_t>& output_transpose_order = {},
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding})
, is_causal(is_causal)
, has_attn_mask_input(inputs.size() > 3)
, has_scale_input(inputs.size() > 4)
, input_q_transpose_order(input_q_transpose_order)
, input_k_transpose_order(input_k_transpose_order)
, input_v_transpose_order(input_v_transpose_order)
, output_transpose_order(output_transpose_order) {}


bool is_causal = false;
bool has_attn_mask_input = false;
bool has_scale_input = false;

std::vector<int64_t> input_q_transpose_order;
std::vector<int64_t> input_k_transpose_order;
std::vector<int64_t> input_v_transpose_order;
std::vector<int64_t> output_transpose_order;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, is_causal);
seed = hash_combine(seed, has_attn_mask_input);
seed = hash_combine(seed, has_scale_input);
seed = hash_range(seed, input_q_transpose_order.begin(), input_q_transpose_order.end());
seed = hash_range(seed, input_k_transpose_order.begin(), input_k_transpose_order.end());
seed = hash_range(seed, input_v_transpose_order.begin(), input_v_transpose_order.end());
seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end());
return seed;
}

bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;

auto rhs_casted = downcast<const scaled_dot_product_attention>(rhs);

return is_causal == rhs_casted.is_causal &&
has_attn_mask_input == rhs_casted.has_attn_mask_input &&
has_scale_input == rhs_casted.has_scale_input &&
input_q_transpose_order == rhs_casted.input_q_transpose_order &&
input_k_transpose_order == rhs_casted.input_k_transpose_order &&
input_v_transpose_order == rhs_casted.input_v_transpose_order &&
output_transpose_order == rhs_casted.output_transpose_order;
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<scaled_dot_product_attention>::save(ob);
ob << is_causal;
ob << has_attn_mask_input;
ob << has_scale_input;
ob << input_q_transpose_order;
ob << input_k_transpose_order;
ob << input_v_transpose_order;
ob << output_transpose_order;
}

void load(BinaryInputBuffer& ib) override {
primitive_base<scaled_dot_product_attention>::load(ib);
ib >> is_causal;
ib >> has_attn_mask_input;
ib >> has_scale_input;
ib >> input_q_transpose_order;
ib >> input_k_transpose_order;
ib >> input_v_transpose_order;
ib >> output_transpose_order;
}
};
} // namespace cldnn
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/src/graph/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ std::string gemm_inst::to_string(gemm_node const& node) {
gemm_info.add("transpose_input1", transpose_input1);
gemm_info.add("indirect_input0", indirect_input0);
gemm_info.add("indirect_input1", indirect_input1);
gemm_info.add("trasnpose_order_input0", desc->input0_transpose_order);
gemm_info.add("trasnpose_order_input1", desc->input1_transpose_order);
gemm_info.add("trasnpose_order_output", desc->output_transpose_order);
node_info->add("gemm info", gemm_info);
node_info->dump(primitive_description);

Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ void register_implementations() {
REGISTER_OCL(eye);
REGISTER_OCL(unique_count);
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);
}

} // namespace ocl
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
#include "intel_gpu/primitives/eye.hpp"
#include "intel_gpu/primitives/unique.hpp"
#include "intel_gpu/primitives/kv_cache.hpp"
#include "intel_gpu/primitives/scaled_dot_product_attention.hpp"

namespace cldnn {
namespace ocl {
Expand Down Expand Up @@ -172,6 +173,7 @@ REGISTER_OCL(gather_nonzero);
REGISTER_OCL(eye);
REGISTER_OCL(unique_count);
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);

#undef REGISTER_OCL

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "primitive_base.hpp"
#include "scaled_dot_product_attention_inst.h"
#include "sdpa/sdpa_kernel_selector.h"
#include "sdpa/sdpa_kernel_base.h"

namespace cldnn {
namespace ocl {
struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl<scaled_dot_product_attention> {
using parent = typed_primitive_impl_ocl<scaled_dot_product_attention>;
using parent::parent;
using kernel_selector_t = kernel_selector::sdpa_kernel_selector;
using kernel_params_t = kernel_selector::sdpa_params;

DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::scaled_dot_product_attention_impl)

std::unique_ptr<primitive_impl> clone() const override {
return make_unique<scaled_dot_product_attention_impl>(*this);
}

void load(BinaryInputBuffer& ib) override {
parent::load(ib);
if (is_dynamic()) {
auto& kernel_selector = kernel_selector_t::Instance();
auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName);
kernel_impl->GetUpdateDispatchDataFunc(_kernel_data);
}
}

static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) {
kernel_selector::sdpa_configuration config;

auto transpose_pshape = [](const ov::PartialShape& pshape, const std::vector<int64_t>& order) {
if (order.empty())
return pshape;

auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank());
for (size_t i = 0; i < order.size(); i++) {
transposed_pshape[i] = pshape[order[i]];
}
return transposed_pshape;
};

const auto& prim = impl_param.typed_desc<scaled_dot_product_attention>();
const auto query_shape = transpose_pshape(impl_param.get_input_layout(0).get_partial_shape(), prim->input_q_transpose_order);
const auto key_shape = transpose_pshape(impl_param.get_input_layout(1).get_partial_shape(), prim->input_k_transpose_order);
const auto value_shape = transpose_pshape(impl_param.get_input_layout(2).get_partial_shape(), prim->input_v_transpose_order);

OPENVINO_ASSERT(key_shape == value_shape, "[GPU] The shapes of key and value inputs are expected to be equal");
for (size_t i = 0; i < query_shape.size(); ++i) {
if (query_shape[i].is_static() && key_shape[i].is_static() && value_shape[i].is_static()) {
if (query_shape[i].get_length() > key_shape[i].get_length()) {
config.broadcast_axis = prim->input_k_transpose_order[i];
config.group_size = query_shape[i].get_length() / key_shape[i].get_length();
}
}
}

if (query_shape[query_shape.size() - 1].is_static())
config.head_size = query_shape[query_shape.size() - 1].get_length();

config.is_causal = prim->is_causal;

return config;
}

static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic) {
auto params = get_default_params<kernel_selector::sdpa_params>(impl_param, is_dynamic);

const auto inputs_num = impl_param.input_layouts.size();
params.inputs.resize(inputs_num);
for (size_t i = 0; i < inputs_num; i++) {
params.inputs[i] = convert_data_tensor(impl_param.get_input_layout(i));
}

params.conf = get_sdpa_configuration(impl_param);

const auto& prim = impl_param.typed_desc<scaled_dot_product_attention>();
params.input0_order = prim->input_q_transpose_order;
params.input1_order = prim->input_k_transpose_order;
params.input2_order = prim->input_v_transpose_order;
params.output_order = prim->output_transpose_order;

params.set_dynamic_shape_offsets();

return params;
}

static std::unique_ptr<primitive_impl> create(const typed_program_node<scaled_dot_product_attention>& arg, const kernel_impl_params& impl_param) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this impl have any difference with common version in base class?

auto sdpa_kernel_params = get_kernel_params(impl_param, impl_param.is_dynamic());
auto& sdpa_kernel_selector = kernel_selector_t::Instance();
auto kd = sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params);

return cldnn::make_unique<scaled_dot_product_attention_impl>(kd);
}

void update_dispatch_data(const kernel_impl_params& impl_param) override {
auto kernel_params = get_kernel_params(impl_param, true);
(_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data);
}
};

namespace detail {

attach_scaled_dot_product_attention_impl::attach_scaled_dot_product_attention_impl() {
using sdpa_prim = scaled_dot_product_attention;

auto types = {
data_types::f32,
data_types::f16,
};

auto formats = {
format::bfyx,
};

implementation_map<sdpa_prim>::add(impl_types::ocl,
shape_types::static_shape,
scaled_dot_product_attention_impl::create,
types,
formats);

implementation_map<sdpa_prim>::add(impl_types::ocl,
shape_types::dynamic_shape,
scaled_dot_product_attention_impl::create,
types,
formats);
}

} // namespace detail
} // namespace ocl
} // namespace cldnn

BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::scaled_dot_product_attention_impl)
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::scaled_dot_product_attention)