From 1135a8e07b115948bf5b3c0a77e1211c4269b962 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 30 Oct 2025 20:26:18 -0700 Subject: [PATCH 1/3] Add test case for recent update to GQA fusion Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/gqa_test.py | 49 +++++++++++++++------ 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index c7ed888142..bc5e545d16 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -5,6 +5,7 @@ import math import unittest +import parameterized import numpy as np import onnx import onnx_ir as ir @@ -424,7 +425,7 @@ def __init__(self, *args, **kwargs): "key_scale": np.random.rand(Dh).astype(np.float32), } - def source_model_script(self): + def source_model_script(self, with_past: bool, transpose_first: bool): scale_factor = math.sqrt(math.sqrt(self.head_size)) minval = torch.finfo(torch.float32).min minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) @@ -458,16 +459,26 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only # one sequence length (S) for all Q, K, and V (with no cache). query_BSHDh = op.Reshape(query, shape_BSHDh) - query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) - query_BHSDh_normalized = op.SimplifiedLayerNormalization( - query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 - ) - key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) - key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) - key_BHkvSDh_normalized = op.SimplifiedLayerNormalization( - key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 - ) + + if transpose_first: + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + query_BHSDh_normalized = op.SimplifiedLayerNormalization( + query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + key_BHkvSDh_normalized = op.SimplifiedLayerNormalization( + key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + else: + query_BSHDh_normalized = op.SimplifiedLayerNormalization( + query_BSHDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + query_BHSDh_normalized = op.Transpose(query_BSHDh_normalized, perm=[0, 2, 1, 3]) + key_BSHkvDh_normalized = op.SimplifiedLayerNormalization( + key_BSHkvDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + key_BHkvSDh_normalized = op.Transpose(key_BSHkvDh_normalized, perm=[0, 2, 1, 3]) value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) @@ -489,9 +500,13 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal cos, sin, ) - key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) - value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + if with_past: + key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + else: + key_seq_BHkvSkvDh = key_BHkvSDh_rope + value_seq_BHkvSkvDh = value_BHkvSDh # Now, expand from shared heads to all heads key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) @@ -552,11 +567,17 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal return gqa - def test_fusion(self): + @parameterized.parameterized.expand([ + (True, True), # with_past=True, transpose_first=True + (True, False), # with_past=True, transpose_first=False + (False, True), # with_past=False, transpose_first=True + (False, False), # with_past=False, transpose_first=False + ]) + def test_fusion(self, with_past, transpose_first): """Test that GQA fusion is successful on source model and produces an equivalent model.""" inputs = self.inputs - source_model = self.source_model_script().to_model_proto( + source_model = self.source_model_script(with_past, transpose_first).to_model_proto( input_types=self.input_types, output_types=self.output_types, ) From cfc0a47fef449a488123cbbecb73162936b23828 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 30 Oct 2025 21:18:10 -0700 Subject: [PATCH 2/3] Minor fixes Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/gqa_test.py | 27 ++++++++++++--------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index bc5e545d16..14449667f4 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -5,12 +5,12 @@ import math import unittest -import parameterized import numpy as np import onnx import onnx_ir as ir import onnx_ir.passes.common.shape_inference as shape_inference import onnxruntime as ort +import parameterized import torch import onnxscript @@ -362,14 +362,23 @@ def test_fusion(self): assert_allclose(outputs3, source_model_outputs) +@parameterized.parameterized_class([ + {"with_past": True, "transpose_first": True}, + {"with_past": True, "transpose_first": False}, + {"with_past": False, "transpose_first": True}, + {"with_past": False, "transpose_first": False}, +]) class GemmaGQAFusionTest(unittest.TestCase): + with_past = True + transpose_first = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Config parameters self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? self.seqlen = 8 self.kv_seqlen = self.seqlen - self.past_seqlen = 16 + self.past_seqlen = 16 if self.with_past else 0 self.head_size = 16 self.num_heads = 20 self.kv_num_heads = 10 @@ -425,7 +434,9 @@ def __init__(self, *args, **kwargs): "key_scale": np.random.rand(Dh).astype(np.float32), } - def source_model_script(self, with_past: bool, transpose_first: bool): + def source_model_script(self): + with_past = self.with_past + transpose_first = self.transpose_first scale_factor = math.sqrt(math.sqrt(self.head_size)) minval = torch.finfo(torch.float32).min minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) @@ -567,17 +578,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal return gqa - @parameterized.parameterized.expand([ - (True, True), # with_past=True, transpose_first=True - (True, False), # with_past=True, transpose_first=False - (False, True), # with_past=False, transpose_first=True - (False, False), # with_past=False, transpose_first=False - ]) - def test_fusion(self, with_past, transpose_first): + def test_fusion(self): """Test that GQA fusion is successful on source model and produces an equivalent model.""" inputs = self.inputs - source_model = self.source_model_script(with_past, transpose_first).to_model_proto( + source_model = self.source_model_script().to_model_proto( input_types=self.input_types, output_types=self.output_types, ) From 70b5298ad45d341415a477c5dbaccb777e0d04af Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 30 Oct 2025 21:56:14 -0700 Subject: [PATCH 3/3] Run lint Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/gqa_test.py | 23 ++++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 14449667f4..038749c017 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -362,15 +362,18 @@ def test_fusion(self): assert_allclose(outputs3, source_model_outputs) -@parameterized.parameterized_class([ - {"with_past": True, "transpose_first": True}, - {"with_past": True, "transpose_first": False}, - {"with_past": False, "transpose_first": True}, - {"with_past": False, "transpose_first": False}, -]) +@parameterized.parameterized_class( + [ + {"with_past": True, "transpose_first": True}, + {"with_past": True, "transpose_first": False}, + {"with_past": False, "transpose_first": True}, + {"with_past": False, "transpose_first": False}, + ] +) class GemmaGQAFusionTest(unittest.TestCase): with_past = True transpose_first = True + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -485,11 +488,15 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal query_BSHDh_normalized = op.SimplifiedLayerNormalization( query_BSHDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 ) - query_BHSDh_normalized = op.Transpose(query_BSHDh_normalized, perm=[0, 2, 1, 3]) + query_BHSDh_normalized = op.Transpose( + query_BSHDh_normalized, perm=[0, 2, 1, 3] + ) key_BSHkvDh_normalized = op.SimplifiedLayerNormalization( key_BSHkvDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 ) - key_BHkvSDh_normalized = op.Transpose(key_BSHkvDh_normalized, perm=[0, 2, 1, 3]) + key_BHkvSDh_normalized = op.Transpose( + key_BSHkvDh_normalized, perm=[0, 2, 1, 3] + ) value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])