diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index ea7af31b3e..67e8f7839a 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -29,6 +29,7 @@ fuse_rotary_embedding, ) from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa +from onnxscript.rewriter.ort_fusions.sdpa_via_mha import replace_sdpa_by_mha from onnxscript.rewriter.ort_fusions.skip_normalization import ( fuse_skip_layer_normalization, fuse_skip_rms_normalization, @@ -104,6 +105,7 @@ def fuse(func, **kwargs): fusion_count["attention"] = fuse(fuse_attention) fusion_count["gelu"] = fuse(fuse_gelu) fusion_count["bias_gelu"] = fuse(fuse_bias_gelu) + fusion_count["sdpa_via_mha"] = fuse(replace_sdpa_by_mha) # Finally: inline any intermediate fusion functions introduced that were not # consumed by other fusions, and eliminate any remaining unused nodes. optimize(model) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 55b38e9ad4..821537afe5 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -12,6 +12,18 @@ Dim = Union[int, ir.SymbolicDim] +# This file contains a fusion rule that recognizes various patterns of scaled dot-product attention +# (SDPA) implementations and replaces them with a single SDPA op. The SDPA op is a temporary fusion +# op defined in the ai.onnxruntime._fusion domain. Subsequent fusion rules will map it into one +# of the various ops defined in ORT: MHA, GQA, or Attention depending on the input patterns. +# The SDPA is a standard scalar dot-product attention with an optional mask input and scaling factor. +# Currently, it is restricted to query, key, and values of rank 4 with shapes: +# Query: [batch_size, num_heads, seq_len, head_size_qk] +# Key: [batch_size, num_heads, seq_len_kv, head_size_qk] +# or [batch_size, seq_len_kv, num_heads, head_size_qk]) +# Value: [batch_size, num_heads, seq_len_kv, head_size_v] +# The key_format attribute indicates which of the two formats the key uses and can be either "BHSd" or "BSHd". + class SDPA(pattern.RewriteRuleClassBase): _scale: float | None diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index c5326a77b9..3b29418cc6 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -292,20 +292,41 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask): return attn_output +# This tests a scenario where the key is in BSHd format instead of BHSd, which +# happens due to an optimization that fuses two transposes together, the one +# to convert from BSHd to BHSd and then to BHdS before MatMul. Hence, the first +# transpose down below is different from other test cases. +@script() +def _unmasked_pre_div_sdpa_BSHd_key_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 2, 3, 1]) # BSHd to BHdS + divisor = op.Constant(value_float=SQRT_SCALE_FACTOR) + scaled_query = op.Div(query, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + class SDPATestCase: - def __init__(self, script_func, *, with_mask): + def __init__(self, script_func, *, with_mask, BSHd_key=False): self.script_func = script_func self.with_mask = with_mask + self.BSHd_key = BSHd_key def get_onnx_model(self): if not hasattr(self, "_onnx_model"): - qkv_type = FLOAT[B, N, S, H] + qv_type = FLOAT[B, N, S, H] mask_type = FLOAT[B, N, S, S] - input_types = [qkv_type, qkv_type, qkv_type] + k_type = FLOAT[B, S, N, H] if self.BSHd_key else FLOAT[B, N, S, H] + input_types = [qv_type, k_type, qv_type] if self.with_mask: input_types.append(mask_type) model_proto = self.script_func.to_model_proto( - input_types=input_types, output_types=[qkv_type] + input_types=input_types, output_types=[qv_type] ) self._onnx_model = ir.serde.deserialize_model(model_proto) return self._onnx_model @@ -314,7 +335,9 @@ def get_ort_inputs(self): if not hasattr(self, "_ort_inputs"): inputs = { "query": numpy.random.rand(B, N, S, H).astype(numpy.float32), - "key": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "key": numpy.random.rand(B, S, N, H).astype(numpy.float32) + if self.BSHd_key + else numpy.random.rand(B, N, S, H).astype(numpy.float32), "value": numpy.random.rand(B, N, S, H).astype(numpy.float32), } if self.with_mask: @@ -374,10 +397,13 @@ class TestSDPAFusion(unittest.TestCase): "_custom_multi_scale_pre_mul_sdpa_script", _custom_multi_scale_pre_mul_sdpa_script, ), + ("pre_div_sdpa_BSHd_key", _unmasked_pre_div_sdpa_BSHd_key_script), ] ) def test_sdpa_fusion(self, name, script_func): - test_case = SDPATestCase(script_func, with_mask="masked" in name) + test_case = SDPATestCase( + script_func, with_mask="masked" in name, BSHd_key="BSHd_key" in name + ) model = test_case.get_onnx_model() onnxscript.optimizer.optimize(model) diff --git a/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py b/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py index e6484406a9..acbc0705fa 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py @@ -7,43 +7,57 @@ import onnx_ir as ir from onnxscript.rewriter import _fusion_utils, pattern +from onnxscript.rewriter._basics import MatchFailureError Dim = Union[int, ir.SymbolicDim] class SDPAImplementation(pattern.RewriteRuleClassBase): - def pattern(self, op, query, key, value): + def pattern(self, op, query, key, value, key_format): + """Pattern matches any call to SDPA. See sdpa.py for documentation on the SDPA op.""" return op.SDPA( query, key, value, - key_format="BHSd", + key_format=key_format, _allow_other_inputs=True, # Mask is optional _outputs=["sdpa_output"], _domain="ai.onnxruntime._fusion", ) - def check(self, context, query, key, value, sdpa_output): + def check(self, context, query, key, value, key_format, sdpa_output): bindings: dict[str, Dim] = {} _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) - _fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"]) _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) + if key_format.value == "BHSd": + _fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"]) + elif key_format.value == "BSHd": + _fusion_utils.check_shape(bindings, key, ["B", "Skv", "H", "Dh"]) + else: + raise MatchFailureError( + f"Unexpected key_format value: {key_format.value}", key_format + ) + self._num_heads = bindings["H"] if not isinstance(self._num_heads, int): return False self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed return isinstance(self._num_heads, int) - def rewrite(self, op, query, key, value, sdpa_output): + def rewrite(self, op, query, key, value, key_format, sdpa_output): sdpa_node = sdpa_output.producer() scale = sdpa_node.attributes.get("scale", None) to_3d_shape = op.Constant(value_ints=[0, 0, -1]) to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1]) query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape) - key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape) value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape) + if key_format.value == "BHSd": + key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape) + else: # BSHd + key_3d = op.Reshape(key, to_3d_shape) + inputs = [query_3d, key_3d, value_3d] if len(sdpa_node.inputs) > 3: mask = sdpa_node.inputs[3]