Skip to content

[WebNN EP] Support MultiHeadAttention(MHA) #24079

Merged
fdwr merged 9 commits intomicrosoft:mainfrom
peishenyan:mha_attention
Apr 23, 2025
Merged

[WebNN EP] Support MultiHeadAttention(MHA) #24079
fdwr merged 9 commits intomicrosoft:mainfrom
peishenyan:mha_attention

Conversation

@peishenyan
Copy link
Contributor

Description

Adds support for MultiHeadAttention via WebNN matmul, transpose, reshape, and other operations that follow the logic in the MHA subgraph below

 Abbreviatios: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length
               N is number of attention heads, H is head size, and W=N*H, h=Sqrt(H)
    Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision.

                 query     key     value
                   |        |        |
           q_Reshape   k_Reshape   v_Reshape  (shape=B,S,H,N)
                   |        |        |
          q_Transpose  k_Transpose v_Transpose (perm=0,2,1,3)
             \           /           |
              \         /            |
present_key<---\----Concat <---------|----past_key
               |      |              |
               |  opt_k_transpose    |
               \  (0,1,3,2)          |
                \    /               |  past_value
                qk_MatMul            |     /
                     |  scale        |    /
                     |   /           |   /
                  qk_Div           Concat------> present_value
                      |              |
                      |              /
                     Add <----------/---------------attention_bias
                      |            /
                    Softmax       /
                       \         /
                        \       /
                      qkv_MatMul
                             |
                          Transpose (perm=0,2,1,3)
                             |
                          Reshape---(shape=B,P,W)
                             |
                           output

Motivation and Context

@peishenyan peishenyan marked this pull request as draft March 18, 2025 08:20
@peishenyan peishenyan marked this pull request as ready for review April 10, 2025 14:10
@peishenyan
Copy link
Contributor Author

Made some modifications according to #23416.
Now, this PR is ready for review. @Honry @fdwr @guschmue PTAL. Thanks.

Copy link
Contributor

@Honry Honry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @peishenyan, some comments and pls. add this new op to the webnn-operators.md file.

@fdwr
Copy link
Contributor

fdwr commented Apr 11, 2025

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline,Windows GPU WebGPU CI Pipeline,Windows OpenVINO CI Pipeline

@fdwr
Copy link
Contributor

fdwr commented Apr 11, 2025

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@fdwr
Copy link
Contributor

fdwr commented Apr 11, 2025

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI

@fdwr
Copy link
Contributor

fdwr commented Apr 11, 2025

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 3 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@fdwr
Copy link
Contributor

fdwr commented Apr 12, 2025

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline,Windows GPU WebGPU CI Pipeline,Windows OpenVINO CI Pipeline

@fdwr
Copy link
Contributor

fdwr commented Apr 12, 2025

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@fdwr
Copy link
Contributor

fdwr commented Apr 12, 2025

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI

@fdwr
Copy link
Contributor

fdwr commented Apr 12, 2025

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 3 pipeline(s).

1 similar comment
@azure-pipelines
Copy link

Azure Pipelines successfully started running 3 pipeline(s).

@peishenyan
Copy link
Contributor Author

peishenyan commented Apr 12, 2025

Oh my fault...I forgot to format op_builder_factory.cc file...Maybe I force-push a new commit? I apologize for the inconvenience.

@fdwr
Copy link
Contributor

fdwr commented Apr 14, 2025

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI

@fdwr
Copy link
Contributor

fdwr commented Apr 14, 2025

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

@azure-pipelines
Copy link

Azure Pipelines successfully started running 3 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 3 pipeline(s).

fdwr
fdwr previously approved these changes Apr 14, 2025
Copy link
Contributor

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@fdwr
Copy link
Contributor

fdwr commented Apr 14, 2025

Will await @Honry's re-review.

@Honry
Copy link
Contributor

Honry commented Apr 14, 2025

@peishenyan you forgot to add the op info to the webnn-operators.md , others LGTM.

@fdwr
Copy link
Contributor

fdwr commented Apr 14, 2025

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline,Windows GPU WebGPU CI Pipeline,Windows OpenVINO CI Pipeline

@fdwr
Copy link
Contributor

fdwr commented Apr 14, 2025

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@fdwr
Copy link
Contributor

fdwr commented Apr 14, 2025

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI

@fdwr
Copy link
Contributor

fdwr commented Apr 14, 2025

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 3 pipeline(s).

1 similar comment
@azure-pipelines
Copy link

Azure Pipelines successfully started running 3 pipeline(s).

Copy link
Contributor

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor

@Honry Honry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@peishenyan
Copy link
Contributor Author

That's so weird. ONNX Runtime CUDA Builds / Windows GPU CUDA CI Pipeline (pull_request) Failed...

This test has passed every time before this commit, but I only changed doc file in this commit.

@peishenyan
Copy link
Contributor Author

Hi @fdwr, is it possible to re-trigger test and achieve a passed result?

@Honry
Copy link
Contributor

Honry commented Apr 22, 2025

@peishenyan, you may need to rebase the code to latest main.

@fdwr
Copy link
Contributor

fdwr commented Apr 22, 2025

/azp run ONNX Runtime CUDA Builds / Windows GPU CUDA CI Pipeline (pull_request)

@azure-pipelines
Copy link

No pipelines are associated with this pull request.

@fdwr
Copy link
Contributor

fdwr commented Apr 22, 2025

I'll retry the 2 required ones again (Linux CI / Build Linux x64 Release / build_test_pipeline (pull_request) ...). If they don't pass today, you'll need to try remerging with main.

@peishenyan
Copy link
Contributor Author

Amazing... they finally passed😂

@fdwr
Copy link
Contributor

fdwr commented Apr 23, 2025

Amazing... they finally passed

Merging, as the 5 remaining failing tests are unrelated pervasive and persistent infrastructure issues.

@fdwr fdwr merged commit df581e1 into microsoft:main Apr 23, 2025
71 of 76 checks passed
intbf pushed a commit to intbf/onnxruntime that referenced this pull request Apr 25, 2025
### Description
<!-- Describe your changes. -->
Adds support for MultiHeadAttention via WebNN matmul, transpose,
reshape, and other operations that follow the logic in the MHA subgraph
below

```
 Abbreviatios: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length
               N is number of attention heads, H is head size, and W=N*H, h=Sqrt(H)
    Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision.

                 query     key     value
                   |        |        |
           q_Reshape   k_Reshape   v_Reshape  (shape=B,S,H,N)
                   |        |        |
          q_Transpose  k_Transpose v_Transpose (perm=0,2,1,3)
             \           /           |
              \         /            |
present_key<---\----Concat <---------|----past_key
               |      |              |
               |  opt_k_transpose    |
               \  (0,1,3,2)          |
                \    /               |  past_value
                qk_MatMul            |     /
                     |  scale        |    /
                     |   /           |   /
                  qk_Div           Concat------> present_value
                      |              |
                      |              /
                     Add <----------/---------------attention_bias
                      |            /
                    Softmax       /
                       \         /
                        \       /
                      qkv_MatMul
                             |
                          Transpose (perm=0,2,1,3)
                             |
                          Reshape---(shape=B,P,W)
                             |
                           output
```

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Signed-off-by: bfilipek <bartlomiej.filipek@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments