diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index c7ed888142..038749c017 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -10,6 +10,7 @@ import onnx_ir as ir import onnx_ir.passes.common.shape_inference as shape_inference import onnxruntime as ort +import parameterized import torch import onnxscript @@ -361,14 +362,26 @@ def test_fusion(self): assert_allclose(outputs3, source_model_outputs) +@parameterized.parameterized_class( + [ + {"with_past": True, "transpose_first": True}, + {"with_past": True, "transpose_first": False}, + {"with_past": False, "transpose_first": True}, + {"with_past": False, "transpose_first": False}, + ] +) class GemmaGQAFusionTest(unittest.TestCase): + with_past = True + transpose_first = True + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Config parameters self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? self.seqlen = 8 self.kv_seqlen = self.seqlen - self.past_seqlen = 16 + self.past_seqlen = 16 if self.with_past else 0 self.head_size = 16 self.num_heads = 20 self.kv_num_heads = 10 @@ -425,6 +438,8 @@ def __init__(self, *args, **kwargs): } def source_model_script(self): + with_past = self.with_past + transpose_first = self.transpose_first scale_factor = math.sqrt(math.sqrt(self.head_size)) minval = torch.finfo(torch.float32).min minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) @@ -458,16 +473,30 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only # one sequence length (S) for all Q, K, and V (with no cache). query_BSHDh = op.Reshape(query, shape_BSHDh) - query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) - query_BHSDh_normalized = op.SimplifiedLayerNormalization( - query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 - ) - key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) - key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) - key_BHkvSDh_normalized = op.SimplifiedLayerNormalization( - key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 - ) + + if transpose_first: + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + query_BHSDh_normalized = op.SimplifiedLayerNormalization( + query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + key_BHkvSDh_normalized = op.SimplifiedLayerNormalization( + key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + else: + query_BSHDh_normalized = op.SimplifiedLayerNormalization( + query_BSHDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + query_BHSDh_normalized = op.Transpose( + query_BSHDh_normalized, perm=[0, 2, 1, 3] + ) + key_BSHkvDh_normalized = op.SimplifiedLayerNormalization( + key_BSHkvDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + key_BHkvSDh_normalized = op.Transpose( + key_BSHkvDh_normalized, perm=[0, 2, 1, 3] + ) value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) @@ -489,9 +518,13 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal cos, sin, ) - key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) - value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + if with_past: + key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + else: + key_seq_BHkvSkvDh = key_BHkvSDh_rope + value_seq_BHkvSkvDh = value_BHkvSDh # Now, expand from shared heads to all heads key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2)