From 3fb7223f50cf5ab56d358b1e4c5a65f34d0e6941 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 17:18:20 -0700 Subject: [PATCH] Improve GQA fusion Signed-off-by: Justin Chu --- onnxscript/rewriter/rules/fusion/_gqa.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/rules/fusion/_gqa.py b/onnxscript/rewriter/rules/fusion/_gqa.py index 8d6f156ed5..c12dcc7140 100644 --- a/onnxscript/rewriter/rules/fusion/_gqa.py +++ b/onnxscript/rewriter/rules/fusion/_gqa.py @@ -52,7 +52,7 @@ def pattern( _outputs=["attention_BHSDh"], ) - return attention_BHSDh + return attention_BHSDh, present_key_BHkvStD, present_value_BHkvStD def check( self, @@ -103,6 +103,7 @@ def rewrite( past_key_BHkvSpD, past_value_BHkvSpD, **original_attrs, + _outputs=3, )