Skip to content

Commit

Permalink
Fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed May 21, 2024
1 parent d07d200 commit 4a3931c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl<scaled_dot_p
kernel_selector::sdpa_configuration config;

auto transpose_pshape = [](const ov::PartialShape& pshape, const std::vector<int64_t>& order) {
if (order.empty())
return pshape;

auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank());
for (size_t i = 0; i < order.size(); i++) {
transposed_pshape[i] = pshape[order[i]];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ inline uint FUNC(get_input2_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint
#ifdef SDPA_STAGE_0

#if TARGET_SEQ_LEN_BLOCK_SIZE == 1
/* This version is used for 2nd token */

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
KERNEL(sdpa_opt)(
Expand Down Expand Up @@ -529,6 +530,7 @@ KERNEL(sdpa_opt)(
}

#else
/* This version is used for 1st token */

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
KERNEL(sdpa_opt)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,30 @@ struct TransposedDimensionAccessHelperJit : DimensionAccessHelperJit, Transposed
std::string f() { return dims_sizes[transposed_order[1]]; }
std::string b() { return dims_sizes[transposed_order[0]]; }

std::pair<std::string, std::string> x_pad() { return {pad_before_after_sizes[(transposed_order[7] * 2) + 0], pad_before_after_sizes[(transposed_order[7] * 2) + 1]}; }
std::pair<std::string, std::string> y_pad() { return {pad_before_after_sizes[(transposed_order[6] * 2) + 0], pad_before_after_sizes[(transposed_order[6] * 2) + 1]}; }
std::pair<std::string, std::string> z_pad() { return {pad_before_after_sizes[(transposed_order[5] * 2) + 0], pad_before_after_sizes[(transposed_order[5] * 2) + 1]}; }
std::pair<std::string, std::string> w_pad() { return {pad_before_after_sizes[(transposed_order[4] * 2) + 0], pad_before_after_sizes[(transposed_order[4] * 2) + 1]}; }
std::pair<std::string, std::string> v_pad() { return {pad_before_after_sizes[(transposed_order[3] * 2) + 0], pad_before_after_sizes[(transposed_order[3] * 2) + 1]}; }
std::pair<std::string, std::string> u_pad() { return {pad_before_after_sizes[(transposed_order[2] * 2) + 0], pad_before_after_sizes[(transposed_order[2] * 2) + 1]}; }
std::pair<std::string, std::string> f_pad() { return {pad_before_after_sizes[(transposed_order[1] * 2) + 0], pad_before_after_sizes[(transposed_order[1] * 2) + 1]}; }
std::pair<std::string, std::string> b_pad() { return {pad_before_after_sizes[(transposed_order[0] * 2) + 0], pad_before_after_sizes[(transposed_order[0] * 2) + 1]}; }
std::pair<std::string, std::string> x_pad() {
return {pad_before_after_sizes[(transposed_order[7] * 2) + 0], pad_before_after_sizes[(transposed_order[7] * 2) + 1]};
}
std::pair<std::string, std::string> y_pad() {
return {pad_before_after_sizes[(transposed_order[6] * 2) + 0], pad_before_after_sizes[(transposed_order[6] * 2) + 1]};
}
std::pair<std::string, std::string> z_pad() {
return {pad_before_after_sizes[(transposed_order[5] * 2) + 0], pad_before_after_sizes[(transposed_order[5] * 2) + 1]};
}
std::pair<std::string, std::string> w_pad() {
return {pad_before_after_sizes[(transposed_order[4] * 2) + 0], pad_before_after_sizes[(transposed_order[4] * 2) + 1]};
}
std::pair<std::string, std::string> v_pad() {
return {pad_before_after_sizes[(transposed_order[3] * 2) + 0], pad_before_after_sizes[(transposed_order[3] * 2) + 1]};
}
std::pair<std::string, std::string> u_pad() {
return {pad_before_after_sizes[(transposed_order[2] * 2) + 0], pad_before_after_sizes[(transposed_order[2] * 2) + 1]};
}
std::pair<std::string, std::string> f_pad() {
return {pad_before_after_sizes[(transposed_order[1] * 2) + 0], pad_before_after_sizes[(transposed_order[1] * 2) + 1]};
}
std::pair<std::string, std::string> b_pad() {
return {pad_before_after_sizes[(transposed_order[0] * 2) + 0], pad_before_after_sizes[(transposed_order[0] * 2) + 1]};
}
};

struct GQA_configuration {
Expand Down
19 changes: 12 additions & 7 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,6 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
// 2) Indirect inputs support
// 3) GQA related optimization (Broadcast fusion)
pass_config->set_callback<ov::pass::ScaledDotProductAttentionDecomposition>([&](const std::shared_ptr<const ov::Node> node){
// Known limitations:
// - The head size of all Q, K, and V inputs should be the same static value
// - The head size should be divisible by 16
// - All inputs and outputs must have the same data type
// - The number of dimensions for each input is expected to be 4
// - SDPA impl could be slower on GPUs with IMMAD support in non-LLM scenarios,
// because oneDNN can be used for those cases - SDPA requires DPAS support
auto sdpa = std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(node);
const auto& query_ps = sdpa->get_input_partial_shape(0);
const auto& key_ps = sdpa->get_input_partial_shape(1);
Expand All @@ -341,10 +334,21 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
return use_sdpa;
}

// Known limitations:
// - SDPA impl could be slower in non-LLM scenarios than decomposed version
if (func->get_variables().size() == 0)
return false;

// - The data type of SDPA should be fp16
if (sdpa->get_output_element_type(0) != ov::element::f16)
return false;

// - The number of dimensions for each input is expected to be 4
if (query_ps.size() != 4 || key_ps.size() != 4 || value_ps.size() != 4) {
return false;
}

// - The head size of all Q, K, and V inputs should be the same static value
if (query_ps[query_ps.size() - 1].is_dynamic() || key_ps[key_ps.size() - 1].is_dynamic() || value_ps[query_ps.size() - 1].is_dynamic()) {
return false;
}
Expand All @@ -354,6 +358,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
return false;
}

// - The head size should be divisible by 16
const auto optimal_subgroup_size = 16;
if (query_ps[query_ps.size() - 1].is_dynamic() ||
query_ps[query_ps.size() - 1].get_length() > 256 ||
Expand Down

0 comments on commit 4a3931c

Please sign in to comment.