Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 87 additions & 112 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "contrib_ops/webgpu/bert/multihead_attention.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/math/matmul.h"
using namespace onnxruntime::webgpu;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
Expand Down Expand Up @@ -70,6 +72,50 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
return context.RunProgram(program);
};

Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
// Inputs: packed_qkv [B, S, D], outputs: Q, K, V [B, S, D]
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform);
const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
sh.MainFunctionBody()
<< sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")
<< " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n"
<< " let batch = packed_qkv_indices[0];\n"
<< " let seq = packed_qkv_indices[1];\n"
<< " let d = packed_qkv_indices[2];\n"
<< " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n"
<< " if (d < uniforms.hidden_size) {\n"
<< " " << query.SetByIndices("vec3<u32>(batch, seq, d)", "input_data") << ";\n"
<< " } else if (d < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n"
<< " let kd = d - uniforms.hidden_size;\n"
<< " " << key.SetByIndices("vec3<u32>(batch, seq, kd)", "input_data") << ";\n"
<< " } else {\n"
<< " let vd = d - uniforms.hidden_size - uniforms.kv_hidden_size;\n"
<< " " << value.SetByIndices("vec3<u32>(batch, seq, vd)", "input_data") << ";\n"
<< " }\n";
return Status::OK();
}

Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params,
const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val, int kv_hidden_size) {
// Output Q, K, V in BSD format
const int components = std::min({GetMaxComponents(params.hidden_size_), GetMaxComponents(kv_hidden_size), GetMaxComponents(params.v_hidden_size_)});
SplitPackedQKVProgram program;
auto input_size = packedQKV->Shape().Size();
const uint32_t vectorized_input_size = static_cast<uint32_t>(input_size / components);
program
.AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
.AddOutputs({{query, ProgramTensorMetadataDependency::TypeAndRank, components}, {key, ProgramTensorMetadataDependency::TypeAndRank, components}, {val, ProgramTensorMetadataDependency::TypeAndRank, components}})
.AddUniformVariables({
{vectorized_input_size},
{static_cast<uint32_t>(params.hidden_size_ / components)},
{static_cast<uint32_t>(kv_hidden_size / components)},
})
.SetDispatchGroupSize((vectorized_input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) {
if (has_seqlen_k) {
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";
Expand Down Expand Up @@ -566,113 +612,26 @@ Attention::Attention(const OpKernelInfo& info)
onnxruntime::contrib::AttentionBase(info, false) {
}

// QKV preparation program - computes Q, K, V from input, weights, and bias
class AttentionPrepareProgram final : public Program<AttentionPrepareProgram> {
public:
AttentionPrepareProgram() : Program{"AttentionPrepare"} {}

Status GenerateShaderCode(ShaderHelper& shader) const override {
shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddInput("weight", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddOutput("output_q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddOutput("output_k", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddOutput("output_v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);

constexpr int TILE_SIZE = 12;

shader.AdditionalImplementation() << "const TILE_SIZE = " << TILE_SIZE << "u;\n"
<< "var<workgroup> tileInput: array<input_value_t, " << TILE_SIZE * TILE_SIZE << ">;\n"
<< "var<workgroup> tileWeightQ: array<input_value_t, " << TILE_SIZE * TILE_SIZE << ">;\n"
<< "var<workgroup> tileWeightK: array<input_value_t, " << TILE_SIZE * TILE_SIZE << ">;\n"
<< "var<workgroup> tileWeightV: array<input_value_t, " << TILE_SIZE * TILE_SIZE << ">;\n";

shader.MainFunctionBody() //<< shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.M * uniforms.N")
<< "let batchIndex = workgroup_id.z / uniforms.num_heads;\n"
<< "let headNumber = workgroup_id.z % uniforms.num_heads;\n"
<< "let m = global_id.y;\n"
<< "let n = global_id.x;\n"
<< "let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K;\n"
<< "let biasOffsetQ = headNumber * uniforms.head_size;\n"
<< "let biasOffsetK = uniforms.hidden_size + biasOffsetQ;\n"
<< "let biasOffsetV = uniforms.hidden_size + biasOffsetK;\n"
<< "var valueQ = input_value_t(0);\n"
<< "var valueK = input_value_t(0);\n"
<< "var valueV = input_value_t(0);\n"
<< "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
<< " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n"
<< " tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];\n"
<< " }\n"
<< " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n"
<< " let offset = n + (w + local_id.y) * uniforms.ldb;\n"
<< " tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset];\n"
<< " tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset];\n"
<< " tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset];\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< " for (var k: u32 = 0u; k<TILE_SIZE && w+k < uniforms.K; k++) {\n"
<< " let inputTileOffset = TILE_SIZE * local_id.y + k;\n"
<< " let weightTileOffset = TILE_SIZE * k + local_id.x;\n"
<< " valueQ += tileInput[inputTileOffset] * tileWeightQ[weightTileOffset];\n"
<< " valueK += tileInput[inputTileOffset] * tileWeightK[weightTileOffset];\n"
<< " valueV += tileInput[inputTileOffset] * tileWeightV[weightTileOffset];\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< "}\n"
<< "let headOffset = (m * uniforms.N + n) % uniforms.head_size;\n"
<< "valueQ += bias[headOffset + biasOffsetQ];\n"
<< "valueK += bias[headOffset + biasOffsetK];\n"
<< "valueV += bias[headOffset + biasOffsetV];\n"
<< "let offset = workgroup_id.z * uniforms.M * uniforms.N;\n"
<< "if (m < uniforms.M && n < uniforms.N) {\n"
<< " let outputIdx = offset + m * uniforms.N + n;\n"
<< " output_q[outputIdx] = valueQ;\n"
<< " output_k[outputIdx] = valueK;\n"
<< " output_v[outputIdx] = valueV;\n"
<< "}\n";

return Status::OK();
}

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"head_size", ProgramUniformVariableDataType::Uint32},
{"hidden_size", ProgramUniformVariableDataType::Uint32},
{"ldb", ProgramUniformVariableDataType::Uint32});
};

Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters,
const Tensor* input, const Tensor* weights, const Tensor* bias,
Tensor* q, Tensor* k, Tensor* v) {
constexpr int TILE_SIZE = 12;
const int M = parameters.sequence_length_;
const int K = parameters.input_hidden_size_;
const int N = parameters.head_size_;

const uint32_t dispatch_x = (parameters.head_size_ + TILE_SIZE - 1) / TILE_SIZE;
const uint32_t dispatch_y = (parameters.sequence_length_ + TILE_SIZE - 1) / TILE_SIZE;
const uint32_t dispatch_z = parameters.batch_size_ * parameters.num_heads_;

AttentionPrepareProgram program{};
program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
{weights, ProgramTensorMetadataDependency::TypeAndRank},
{bias, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutputs({{q, ProgramTensorMetadataDependency::TypeAndRank},
{k, ProgramTensorMetadataDependency::TypeAndRank},
{v, ProgramTensorMetadataDependency::TypeAndRank}})
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
.SetWorkgroupSize(TILE_SIZE, TILE_SIZE)
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(K)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(parameters.num_heads_)},
{static_cast<uint32_t>(parameters.head_size_)},
{static_cast<uint32_t>(parameters.hidden_size_)},
{static_cast<uint32_t>(parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_)}});
// Use MatMul to compute packed QKV output: input * weights + bias
// Then use SplitPackedQKV to split into Q, K, V in BSD format
// Returns Q, K, V in BSD format

return context.RunProgram(program);
// Create packed QKV tensor with shape [batch_size, sequence_length, hidden_size + hidden_size + v_hidden_size]
const int64_t packed_qkv_size = parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_;
TensorShapeVector packed_qkv_shape({parameters.batch_size_, parameters.sequence_length_, packed_qkv_size});
Tensor packed_qkv = context.CreateGPUTensor(input->DataType(), TensorShape(packed_qkv_shape));

// Prepare inputs for MatMul
std::vector<const Tensor*> matmul_inputs = {input, weights, bias};

// Call MatMul: packed_qkv = input * weights + bias
ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, true));

// Output Q, K, V in BSD format
return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v, parameters.hidden_size_);
}

Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
Expand Down Expand Up @@ -727,15 +686,16 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
ORT_NOT_IMPLEMENTED("present tensor not implemented for webgpu Attention");
}

// Create Q, K, V tensors by computing input * weights + bias
TensorShapeVector qkv_shape({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, parameters.head_size_});
Tensor Q = context.CreateGPUTensor(input->DataType(), qkv_shape);
Tensor K = context.CreateGPUTensor(input->DataType(), qkv_shape);
Tensor V = context.CreateGPUTensor(input->DataType(), qkv_shape);
// Create Q, K, V tensors in BSD format from input * weights + bias
TensorShapeVector qkv_bsd_shape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_});
TensorShapeVector v_bsd_shape({parameters.batch_size_, parameters.sequence_length_, parameters.v_hidden_size_});
Tensor Q_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bsd_shape));
Tensor K_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bsd_shape));
Tensor V_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(v_bsd_shape));

// Compute Q, K, V from input, weights, and bias
ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V));
// Compute Q, K, V from input, weights, and bias (returns BSD format)
ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q_bsd, &K_bsd, &V_bsd));
parameters.qkv_format_ = Q_K_V_BSNH;

// Check if we can use flash attention
// For Attention operator, we need to create present_key and present_value tensors for flash attention
Expand All @@ -746,10 +706,25 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
Tensor present_value = context.CreateGPUTensor(input->DataType(), present_kv_shape);

if (CanApplyFlashAttention(nullptr, &present_key, &present_value, parameters, context)) {
return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, nullptr, &present_key, nullptr, &present_value,
// FlashAttention supports Q_K_V_BSNH format directly
return ApplyFlashAttention(&Q_bsd, &K_bsd, &V_bsd, attention_bias, output, nullptr, &present_key, nullptr, &present_value,
parameters, context, nullptr);
}

// For non-flash attention path, convert BSD to BNSH format
TensorShapeVector qkv_bnsh_shape({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_});
TensorShapeVector v_bnsh_shape({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.v_head_size_});
Tensor Q = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape));
Tensor K = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape));
Tensor V = context.CreateGPUTensor(input->DataType(), TensorShape(v_bnsh_shape));

ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_,
parameters.head_size_, &Q_bsd, nullptr, 0, &Q));
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_,
parameters.head_size_, &K_bsd, nullptr, 0, &K));
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_,
parameters.v_head_size_, &V_bsd, nullptr, 0, &V));

// Apply the actual attention computation
return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr,
/* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1);
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
bool has_bias_;
};

class SplitPackedQKVProgram final : public Program<SplitPackedQKVProgram> {
public:
SplitPackedQKVProgram() : Program{"SplitPackedQKV"} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32},
{"hidden_size", ProgramUniformVariableDataType::Uint32},
{"kv_hidden_size", ProgramUniformVariableDataType::Uint32});
};

class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ struct WebgpuAttentionParameters {
Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length,
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor);

Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params,
const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val, int kv_hidden_size);

Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
Expand Down
Loading
Loading