Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions onnxscript/rewriter/ort_fusions/gqa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import onnx_ir as ir
import onnx_ir.passes.common.shape_inference as shape_inference
import onnxruntime as ort
import parameterized
import torch

import onnxscript
Expand Down Expand Up @@ -361,14 +362,26 @@ def test_fusion(self):
assert_allclose(outputs3, source_model_outputs)


@parameterized.parameterized_class(
[
{"with_past": True, "transpose_first": True},
{"with_past": True, "transpose_first": False},
{"with_past": False, "transpose_first": True},
{"with_past": False, "transpose_first": False},
]
)
class GemmaGQAFusionTest(unittest.TestCase):
with_past = True
transpose_first = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Config parameters
self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1?
self.seqlen = 8
self.kv_seqlen = self.seqlen
self.past_seqlen = 16
self.past_seqlen = 16 if self.with_past else 0
self.head_size = 16
self.num_heads = 20
self.kv_num_heads = 10
Expand Down Expand Up @@ -425,6 +438,8 @@ def __init__(self, *args, **kwargs):
}

def source_model_script(self):
with_past = self.with_past
transpose_first = self.transpose_first
scale_factor = math.sqrt(math.sqrt(self.head_size))
minval = torch.finfo(torch.float32).min
minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval])
Expand Down Expand Up @@ -458,16 +473,30 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
# We convert them into BHSDh (i.e., BHSd) format. In this version, we have only
# one sequence length (S) for all Q, K, and V (with no cache).
query_BSHDh = op.Reshape(query, shape_BSHDh)
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
query_BHSDh_normalized = op.SimplifiedLayerNormalization(
query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
)

key_BSHkvDh = op.Reshape(key, shape_BSHkvDh)
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
key_BHkvSDh_normalized = op.SimplifiedLayerNormalization(
key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
)

if transpose_first:
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
query_BHSDh_normalized = op.SimplifiedLayerNormalization(
query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
)
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
key_BHkvSDh_normalized = op.SimplifiedLayerNormalization(
key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
)
else:
query_BSHDh_normalized = op.SimplifiedLayerNormalization(
query_BSHDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
)
query_BHSDh_normalized = op.Transpose(
query_BSHDh_normalized, perm=[0, 2, 1, 3]
)
key_BSHkvDh_normalized = op.SimplifiedLayerNormalization(
key_BSHkvDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
)
key_BHkvSDh_normalized = op.Transpose(
key_BSHkvDh_normalized, perm=[0, 2, 1, 3]
)

value_BSHkvDh = op.Reshape(value, shape_BSHkvDh)
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])
Expand All @@ -489,9 +518,13 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
cos,
sin,
)
key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)

value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
if with_past:
key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
else:
key_seq_BHkvSkvDh = key_BHkvSDh_rope
value_seq_BHkvSkvDh = value_BHkvSDh

# Now, expand from shared heads to all heads
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2)
Expand Down
Loading