Skip to content

Commit

Permalink
[GPU] Add SDPA impl; SDPA input transpose fusion support; GQA optimiz…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
sshlyapn committed May 21, 2024
1 parent 01de846 commit 63fd7cc
Show file tree
Hide file tree
Showing 47 changed files with 3,598 additions and 61 deletions.
94 changes: 94 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/node.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/op/op.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"

namespace ov {
namespace intel_gpu {
namespace op {

class SDPA : public ov::op::v13::ScaledDotProductAttention {
public:
OPENVINO_OP("SDPA", "gpu_opset");

SDPA() = default;

SDPA(const ov::Output<Node>& Q,
const ov::Output<Node>& K,
const ov::Output<Node>& V,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const bool is_causal,
const ov::element::Type output_type = ov::element::undefined);

SDPA(const ov::Output<Node>& Q,
const ov::Output<Node>& K,
const ov::Output<Node>& V,
const ov::Output<Node>& attn_mask,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const bool is_causal,
const ov::element::Type output_type = ov::element::undefined);

SDPA(const ov::Output<Node>& Q,
const ov::Output<Node>& K,
const ov::Output<Node>& V,
const ov::Output<Node>& attn_mask,
const ov::Output<Node>& scale,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const bool is_causal,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor &visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

bool get_causal() const { return m_is_causal; }

std::vector<int64_t> get_input0_transpose_order() const { return m_order_q; }
std::vector<int64_t> get_input1_transpose_order() const { return m_order_k; }
std::vector<int64_t> get_input2_transpose_order() const { return m_order_v; }
std::vector<int64_t> get_output_transpose_order() const { return m_order_out; }
ov::element::Type get_output_type() const { return m_output_type; }

static std::vector<int64_t> default_order(size_t rank) {
std::vector<int64_t> order(rank);
std::iota(order.begin(), order.end(), 0);
return order;
}

protected:
std::vector<int64_t> m_order_q;
std::vector<int64_t> m_order_k;
std::vector<int64_t> m_order_v;
std::vector<int64_t> m_order_out;
bool m_is_causal;
ov::element::Type m_output_type;
};

std::vector<ov::PartialShape> shape_infer(const SDPA* op,
std::vector<ov::PartialShape> input_shapes,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out);


} // namespace op
} // namespace intel_gpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ REGISTER_FACTORY(v12, ScatterElementsUpdate);

// ------------------------------ Supported v13 ops ----------------------------- //
REGISTER_FACTORY(v13, Multinomial);
REGISTER_FACTORY(v13, ScaledDotProductAttention);

// ------------------------------ Supported v14 ops ----------------------------- //
REGISTER_FACTORY(v14, ROIAlignRotated);
Expand All @@ -283,3 +284,4 @@ REGISTER_FACTORY(internal, SwiGLU);
REGISTER_FACTORY(internal, IndirectGemm);
REGISTER_FACTORY(internal, Convolution);
REGISTER_FACTORY(internal, Placeholder);
REGISTER_FACTORY(internal, SDPA);
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include "primitive.hpp"

namespace cldnn {

struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_attention> {
CLDNN_DECLARE_PRIMITIVE(scaled_dot_product_attention)

scaled_dot_product_attention() : primitive_base("", {}) {}

/// @brief Constructs scaled_dot_product_attention primitive.
/// @param id This primitive id.
/// @param inputs Input data primitives id (query, keys, values, [attention_mask], [scale]).
/// @param is_causal If true, assumes causal attention masking. In this case attention_mask input is ignored.
scaled_dot_product_attention(const primitive_id& id,
const std::vector<cldnn::input_info> inputs,
bool is_causal,
const std::vector<int64_t>& input_q_transpose_order = {},
const std::vector<int64_t>& input_k_transpose_order = {},
const std::vector<int64_t>& input_v_transpose_order = {},
const std::vector<int64_t>& output_transpose_order = {},
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding})
, is_causal(is_causal)
, has_attn_mask_input(inputs.size() > 3)
, has_scale_input(inputs.size() > 4)
, input_q_transpose_order(input_q_transpose_order)
, input_k_transpose_order(input_k_transpose_order)
, input_v_transpose_order(input_v_transpose_order)
, output_transpose_order(output_transpose_order) {}


bool is_causal = false;
bool has_attn_mask_input = false;
bool has_scale_input = false;

std::vector<int64_t> input_q_transpose_order;
std::vector<int64_t> input_k_transpose_order;
std::vector<int64_t> input_v_transpose_order;
std::vector<int64_t> output_transpose_order;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, is_causal);
seed = hash_combine(seed, has_attn_mask_input);
seed = hash_combine(seed, has_scale_input);
seed = hash_range(seed, input_q_transpose_order.begin(), input_q_transpose_order.end());
seed = hash_range(seed, input_k_transpose_order.begin(), input_k_transpose_order.end());
seed = hash_range(seed, input_v_transpose_order.begin(), input_v_transpose_order.end());
seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end());
return seed;
}

bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;

auto rhs_casted = downcast<const scaled_dot_product_attention>(rhs);

return is_causal == rhs_casted.is_causal &&
has_attn_mask_input == rhs_casted.has_attn_mask_input &&
has_scale_input == rhs_casted.has_scale_input &&
input_q_transpose_order == rhs_casted.input_q_transpose_order &&
input_k_transpose_order == rhs_casted.input_k_transpose_order &&
input_v_transpose_order == rhs_casted.input_v_transpose_order &&
output_transpose_order == rhs_casted.output_transpose_order;
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<scaled_dot_product_attention>::save(ob);
ob << is_causal;
ob << has_attn_mask_input;
ob << has_scale_input;
ob << input_q_transpose_order;
ob << input_k_transpose_order;
ob << input_v_transpose_order;
ob << output_transpose_order;
}

void load(BinaryInputBuffer& ib) override {
primitive_base<scaled_dot_product_attention>::load(ib);
ib >> is_causal;
ib >> has_attn_mask_input;
ib >> has_scale_input;
ib >> input_q_transpose_order;
ib >> input_k_transpose_order;
ib >> input_v_transpose_order;
ib >> output_transpose_order;
}
};
} // namespace cldnn
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/src/graph/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ std::string gemm_inst::to_string(gemm_node const& node) {
gemm_info.add("transpose_input1", transpose_input1);
gemm_info.add("indirect_input0", indirect_input0);
gemm_info.add("indirect_input1", indirect_input1);
gemm_info.add("trasnpose_order_input0", desc->input0_transpose_order);
gemm_info.add("trasnpose_order_input1", desc->input1_transpose_order);
gemm_info.add("trasnpose_order_output", desc->output_transpose_order);
node_info->add("gemm info", gemm_info);
node_info->dump(primitive_description);

Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ void register_implementations() {
REGISTER_OCL(eye);
REGISTER_OCL(unique_count);
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);
}

} // namespace ocl
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
#include "intel_gpu/primitives/eye.hpp"
#include "intel_gpu/primitives/unique.hpp"
#include "intel_gpu/primitives/kv_cache.hpp"
#include "intel_gpu/primitives/scaled_dot_product_attention.hpp"

namespace cldnn {
namespace ocl {
Expand Down Expand Up @@ -172,6 +173,7 @@ REGISTER_OCL(gather_nonzero);
REGISTER_OCL(eye);
REGISTER_OCL(unique_count);
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);

#undef REGISTER_OCL

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "primitive_base.hpp"
#include "scaled_dot_product_attention_inst.h"
#include "sdpa/sdpa_kernel_selector.h"
#include "sdpa/sdpa_kernel_base.h"

namespace cldnn {
namespace ocl {
struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl<scaled_dot_product_attention> {
using parent = typed_primitive_impl_ocl<scaled_dot_product_attention>;
using parent::parent;
using kernel_selector_t = kernel_selector::sdpa_kernel_selector;
using kernel_params_t = kernel_selector::sdpa_params;

DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::scaled_dot_product_attention_impl)

std::unique_ptr<primitive_impl> clone() const override {
return make_unique<scaled_dot_product_attention_impl>(*this);
}

void load(BinaryInputBuffer& ib) override {
parent::load(ib);
if (is_dynamic()) {
auto& kernel_selector = kernel_selector_t::Instance();
auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName);
kernel_impl->GetUpdateDispatchDataFunc(_kernel_data);
}
}

static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) {
kernel_selector::sdpa_configuration config;

auto transpose_pshape = [](const ov::PartialShape& pshape, const std::vector<int64_t>& order) {
auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank());
for (size_t i = 0; i < order.size(); i++) {
transposed_pshape[i] = pshape[order[i]];
}
return transposed_pshape;
};

const auto& prim = impl_param.typed_desc<scaled_dot_product_attention>();
const auto query_shape = transpose_pshape(impl_param.get_input_layout(0).get_partial_shape(), prim->input_q_transpose_order);
const auto key_shape = transpose_pshape(impl_param.get_input_layout(1).get_partial_shape(), prim->input_k_transpose_order);
const auto value_shape = transpose_pshape(impl_param.get_input_layout(2).get_partial_shape(), prim->input_v_transpose_order);

OPENVINO_ASSERT(key_shape == value_shape, "[GPU] The shapes of key and value inputs are expected to be equal");
for (size_t i = 0; i < query_shape.size(); ++i) {
if (query_shape[i].is_static() && key_shape[i].is_static() && value_shape[i].is_static()) {
if (query_shape[i].get_length() > key_shape[i].get_length()) {
config.broadcast_axis = prim->input_k_transpose_order[i];
config.group_size = query_shape[i].get_length() / key_shape[i].get_length();
}
}
}

if (query_shape[query_shape.size() - 1].is_static())
config.head_size = query_shape[query_shape.size() - 1].get_length();

config.is_causal = prim->is_causal;

return config;
}

static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic) {
auto params = get_default_params<kernel_selector::sdpa_params>(impl_param, is_dynamic);

const auto inputs_num = impl_param.input_layouts.size();
params.inputs.resize(inputs_num);
for (size_t i = 0; i < inputs_num; i++) {
params.inputs[i] = convert_data_tensor(impl_param.get_input_layout(i));
}

params.conf = get_sdpa_configuration(impl_param);

const auto& prim = impl_param.typed_desc<scaled_dot_product_attention>();
params.input0_order = prim->input_q_transpose_order;
params.input1_order = prim->input_k_transpose_order;
params.input2_order = prim->input_v_transpose_order;
params.output_order = prim->output_transpose_order;

params.set_dynamic_shape_offsets();

return params;
}

static std::unique_ptr<primitive_impl> create(const typed_program_node<scaled_dot_product_attention>& arg, const kernel_impl_params& impl_param) {
auto sdpa_kernel_params = get_kernel_params(impl_param, impl_param.is_dynamic());
auto& sdpa_kernel_selector = kernel_selector_t::Instance();
auto kd = sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params);

return cldnn::make_unique<scaled_dot_product_attention_impl>(kd);
}

void update_dispatch_data(const kernel_impl_params& impl_param) override {
auto kernel_params = get_kernel_params(impl_param, true);
(_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data);
}
};

namespace detail {

attach_scaled_dot_product_attention_impl::attach_scaled_dot_product_attention_impl() {
using sdpa_prim = scaled_dot_product_attention;

auto types = {
data_types::f32,
data_types::f16,
};

auto formats = {
format::bfyx,
};

implementation_map<sdpa_prim>::add(impl_types::ocl,
shape_types::static_shape,
scaled_dot_product_attention_impl::create,
types,
formats);

implementation_map<sdpa_prim>::add(impl_types::ocl,
shape_types::dynamic_shape,
scaled_dot_product_attention_impl::create,
types,
formats);
}

} // namespace detail
} // namespace ocl
} // namespace cldnn

BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::scaled_dot_product_attention_impl)
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::scaled_dot_product_attention)

0 comments on commit 63fd7cc

Please sign in to comment.