Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "contrib_ops/webgpu/bert/attention.h"

#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/webgpu/bert/flash_attention.h"
#include "contrib_ops/webgpu/bert/multihead_attention.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
Expand Down Expand Up @@ -736,6 +737,19 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
// Compute Q, K, V from input, weights, and bias
ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V));

// Check if we can use flash attention
// For Attention operator, we need to create present_key and present_value tensors for flash attention
// even though they are not exposed as outputs
TensorShapeVector present_kv_shape({parameters.batch_size_, parameters.num_heads_,
parameters.total_sequence_length_, parameters.head_size_});
Tensor present_key = context.CreateGPUTensor(input->DataType(), present_kv_shape);
Tensor present_value = context.CreateGPUTensor(input->DataType(), present_kv_shape);

if (CanApplyFlashAttention(nullptr, &present_key, &present_value, parameters, context)) {
return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, nullptr, &present_key, nullptr, &present_value,
parameters, context, nullptr);
}

// Apply the actual attention computation
return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr,
/* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1);
Expand Down
71 changes: 52 additions & 19 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
} else {
shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n";
}
shader.MainFunctionBody() << "let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n";
shader.MainFunctionBody() << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n";
if (past_present_share_buffer_) {
shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n";
} else {
shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n";
}

// Add indirect dispatch logic for thread 0
if (prepare_indirect_dispatch_) {
Expand All @@ -93,8 +98,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (has_past_) {
const auto& past_key = shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
shader.AddInput("past_value", ShaderUsage::UseUniform);
shader.MainFunctionBody() << "let present_offset = global_idx;"
<< "if (sequence_id < past_sequence_length) {\n"
shader.MainFunctionBody() << "if (sequence_id < past_sequence_length) {\n"
<< " let pastOffset = " << past_key.IndicesToOffset("past_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n"
<< " " << present_key.SetByOffset("present_offset", "past_key[pastOffset]") << ";\n"
<< " " << present_value.SetByOffset("present_offset", "past_value[pastOffset]") << ";\n"
Expand All @@ -104,8 +108,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n"
<< "}";
} else {
shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n"
<< " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n"
shader.MainFunctionBody() << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n"
<< " " << present_key.SetByOffset("present_offset", "key[offset]") << ";\n"
<< " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n";
}
Expand Down Expand Up @@ -134,10 +137,10 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
// Determine if we need to prepare indirect dispatch
bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
bool use_seqlen_k = (seqlen_k != nullptr);

CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH,
bool kv_BNSH = parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH || parameters.qkv_format_ == Q_K_V_BNSH;
CopyKVCacheProgram program{"CopyKVCache", has_past, kv_BNSH, parameters.past_present_share_buffer_,
prepare_indirect_dispatch, use_seqlen_k};
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) {
if (kv_BNSH) {
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
} else {
Expand Down Expand Up @@ -207,6 +210,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
WGSL_TEMPLATE_PARAMETER(prefer_subgroupshuffle, !is_nvidia_),
WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_),
WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_),
WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_),
WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_));
Expand Down Expand Up @@ -256,10 +260,20 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
{metadata, ProgramTensorMetadataDependency::Rank, 2}});

const uint32_t vectorized_head_size = parameters.head_size_ / components;

// Get attention bias dimensions for broadcasting
uint32_t attn_bias_dim0 = 1;
uint32_t attn_bias_dim1 = 1;
if (has_attention_bias) {
const auto& bias_shape = attention_bias->Shape();
attn_bias_dim0 = static_cast<uint32_t>(bias_shape[0]);
attn_bias_dim1 = static_cast<uint32_t>(bias_shape[1]);
}

if (use_indirect_dispatch) {
program.SetIndirectDispatchTensor(indirect_buffer);
} else {
program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile);
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile);
}
program.SetWorkgroupSize(64)
.CacheHint(tile_size, has_attention_bias, use_indirect_dispatch)
Expand All @@ -269,7 +283,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
present_sequence_length,
{static_cast<uint32_t>(parameters.n_reps)},
{num_present_sequence_length_tile},
{static_cast<uint32_t>(parameters.num_heads_)}});
{static_cast<uint32_t>(parameters.num_heads_)},
{static_cast<uint32_t>(parameters.batch_size_)},
{attn_bias_dim0},
{attn_bias_dim1}});

return context.RunProgram(program);
}
Expand Down Expand Up @@ -313,11 +330,12 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
{qk, ProgramTensorMetadataDependency::TypeAndRank},
{present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size]
const uint32_t batch_heads = static_cast<uint32_t>(parameters.batch_size_ * parameters.num_heads_);
if (use_indirect_dispatch) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None})
.SetIndirectDispatchTensor(indirect_buffer);
} else {
program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile);
program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile);
}
program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch)
.SetWorkgroupSize(64)
Expand All @@ -326,7 +344,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
present_sequence_length,
{static_cast<uint32_t>(parameters.n_reps)},
num_present_sequence_length_tile,
{static_cast<uint32_t>(parameters.num_heads_)}});
{batch_heads}});

return context.RunProgram(program);
}
Expand Down Expand Up @@ -363,14 +381,15 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&
}
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}});
const uint32_t num_head_size_tile = static_cast<uint32_t>((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size);
program.SetDispatchGroupSize(parameters.num_heads_ * num_head_size_tile)
const uint32_t batch_heads = static_cast<uint32_t>(parameters.batch_size_ * parameters.num_heads_);
program.SetDispatchGroupSize(batch_heads * num_head_size_tile)
.CacheHint(tile_size, seq_tile_size, use_indirect_dispatch)
.SetWorkgroupSize(tile_size * tile_size)
.AddUniformVariables({{static_cast<uint32_t>(parameters.v_head_size_ / components)},
num_total_seq_length_tile,
num_present_sequence_length_tile,
{num_head_size_tile},
{static_cast<uint32_t>(parameters.num_heads_)}});
{batch_heads}});

return context.RunProgram(program);
}
Expand Down Expand Up @@ -429,6 +448,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH;
FlashAttentionProgram program{"FlashAttention",
has_attention_bias,
is_qualcomm,
Expand All @@ -437,6 +457,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
parameters.num_heads_,
parameters.is_unidirectional_,
is_nvidia,
q_BNSH,
use_seqlen_k};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4},
{present_key, ProgramTensorMetadataDependency::TypeAndRank, 4},
Expand All @@ -451,15 +472,28 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
: parameters.scale_;
const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile)

// Get attention bias dimensions for broadcasting
uint32_t attn_bias_dim0 = 1;
uint32_t attn_bias_dim1 = 1;
if (has_attention_bias) {
const auto& bias_shape = attention_bias->Shape();
attn_bias_dim0 = static_cast<uint32_t>(bias_shape[0]);
attn_bias_dim1 = static_cast<uint32_t>(bias_shape[1]);
}

program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile)
.SetWorkgroupSize(tile_size)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, use_seqlen_k)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, q_BNSH, use_seqlen_k)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(parameters.total_sequence_length_)},
{static_cast<uint32_t>(present_sequence_length)},
{static_cast<uint32_t>(parameters.batch_size_)},
{static_cast<uint32_t>(parameters.n_reps)},
{alpha},
{num_seq_tile}});
{num_seq_tile},
{attn_bias_dim0},
{attn_bias_dim1}});

return context.RunProgram(program);
}
Expand Down Expand Up @@ -500,8 +534,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co

bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
return parameters.batch_size_ == 1 &&
!parameters.is_packed_qkv_ &&
return !parameters.is_packed_qkv_ &&
parameters.head_size_ == parameters.v_head_size_ &&
bias == nullptr &&
context.HasFeature(wgpu::FeatureName::Subgroups) &&
Expand Down
22 changes: 16 additions & 6 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program<S

class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
public:
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH,
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer,
bool prepare_indirect_dispatch = false, bool use_seqlen_k = false)
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) {
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -59,6 +59,7 @@ class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
private:
bool has_past_;
bool kv_BNSH_;
bool past_present_share_buffer_;
bool prepare_indirect_dispatch_;
bool use_seqlen_k_;
};
Expand All @@ -73,6 +74,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
int qkv_num_heads,
bool is_unidirectional,
bool is_nvidia,
bool q_BNSH,
bool use_seqlen_k = false)
: Program{kernel_name},
has_attention_bias_(has_attention_bias),
Expand All @@ -82,6 +84,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
qkv_num_heads_(qkv_num_heads),
is_unidirectional_(is_unidirectional),
is_nvidia_(is_nvidia),
q_BNSH_(q_BNSH),
use_seqlen_k_(use_seqlen_k) {
}

Expand All @@ -90,9 +93,12 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32},
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"batch_size", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"alpha", ProgramUniformVariableDataType::Float32},
{"num_seq_tile", ProgramUniformVariableDataType::Uint32});
{"num_seq_tile", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim0", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim1", ProgramUniformVariableDataType::Uint32});

private:
bool has_attention_bias_;
Expand All @@ -102,6 +108,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
int qkv_num_heads_;
bool is_unidirectional_;
bool is_nvidia_;
bool q_BNSH_;
bool use_seqlen_k_;
};

Expand All @@ -120,7 +127,10 @@ class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecode
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32});
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"batch_size", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim0", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim1", ProgramUniformVariableDataType::Uint32});

private:
bool has_attention_bias_;
Expand All @@ -141,7 +151,7 @@ class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDe
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32});
{"batch_heads", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_size_;
Expand All @@ -161,7 +171,7 @@ class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionD
{"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_head_size_tile", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32});
{"batch_heads", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_size_;
Expand Down
Loading
Loading