Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8de7231
First version
gramalingam Dec 18, 2024
a20b903
Add rotary embedding
gramalingam Dec 18, 2024
b8f7a08
Remove SDPA
gramalingam Dec 18, 2024
315c94e
Add comment
gramalingam Dec 18, 2024
2219fd3
Remove MHA
gramalingam Dec 18, 2024
f77f0e7
Merge branch 'main' into rama/fuse-attn
gramalingam Dec 18, 2024
5ec9d1e
Add rewrite for cos-sin computation
gramalingam Dec 20, 2024
90f0b7b
Merge branch 'rama/fuse-attn' of https://github.com/microsoft/onnx-sc…
gramalingam Dec 20, 2024
1fdc19b
Run lint
gramalingam Dec 20, 2024
eb916b8
Add cos sin test
gramalingam Dec 20, 2024
d874dbc
Extend rewriter to support node reuse
gramalingam Dec 20, 2024
a745039
Minor fixes
gramalingam Dec 21, 2024
17c06c3
Fix concat bug in rotary embedding
gramalingam Dec 22, 2024
c7c7c79
Minor cleanup
gramalingam Dec 23, 2024
834815b
Merge branch 'main' into rama/fuse-attn
gramalingam Dec 23, 2024
9a4a58e
Use callable to test callable
gramalingam Dec 23, 2024
766791d
Fix lint issues
gramalingam Dec 23, 2024
c7384af
Attention fusion
gramalingam Dec 24, 2024
d0254d1
Add support for cached state in rewrite
gramalingam Dec 24, 2024
b91166b
Cleanup MHA pattern
gramalingam Dec 24, 2024
205805c
Complete MHA pattern
gramalingam Dec 26, 2024
e907f3e
Add MHA fusion test
gramalingam Dec 26, 2024
82f1919
Add validation condition
gramalingam Dec 26, 2024
fa3b94d
Run lint
gramalingam Dec 26, 2024
9310b67
Merge with main
gramalingam Jan 7, 2025
e0f29e2
Fix merge conflict
gramalingam Jan 7, 2025
41aa177
Fix merge conflict
gramalingam Jan 7, 2025
2688d6e
Merge conflict fix
gramalingam Jan 7, 2025
c080f4a
Merge with main
gramalingam Jan 9, 2025
2e6de3d
Merge branch 'main' into rama/fuse-attn-2
gramalingam Jan 15, 2025
889f7d2
Address lint issues
gramalingam Jan 15, 2025
a9947dc
Add smollm models to mypy exclusion
gramalingam Jan 15, 2025
c5c6588
Rename unused variable in test onnx model
gramalingam Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ exclude_patterns = [
'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
'onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py', # onnxscript code
'onnxscript/rewriter/onnxruntime/xformers/_smollm_*.py', # onnxscript code
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
'onnxscript/tools/function_unittest_producer.py', # FIXME
Expand Down
6 changes: 6 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
"fuse_normalization",
"fuse_rotary_embedding",
"fuse_cos_sin_cache",
"fuse_sdpa",
"fuse_mha",
"fuse_xformers",
]

from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.fuse_xformers import fuse_xformers
from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.sdpa import fuse_sdpa
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.

"""
A one-layer SmolLM model test case.
A one-layer SmolLM model test case, with inputs: input_ids, attention_mask, and position_ids.
This is an onnxscript version of the model.
"""

Expand Down Expand Up @@ -234,7 +234,7 @@ def make_model_with_random_weights():
return model


class _SmollmTestData:
class TestData:
def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
model_proto = make_model_with_random_weights()
Expand Down
467 changes: 467 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/_smollm_2.py

Large diffs are not rendered by default.

23 changes: 16 additions & 7 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def __init__(self, name: str, max_pos_id: int):
# pass to remove unused nodes.
super().__init__(name, remove_nodes=False)
self._max_pos_id = max_pos_id
# map from inv_freq to (cos, sin) values for transformed graph
self._inv_freq_cos_sin_cache: dict[ir.Value, tuple[ir.Value, ir.Value]] = {}

def cleanup(self):
self._inv_freq_cos_sin_cache.clear()

def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads):
position_ids_expanded = op.Unsqueeze(position_ids, 1)
Expand Down Expand Up @@ -72,13 +77,17 @@ def check(self, context, inv_freq, position_ids, **_) -> bool:
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1

def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_):
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
cos_value = np.cos(angles)
sin_value = np.sin(angles)
cos_2d = op.Constant(value=ir.tensor(cos_value))
sin_2d = op.Constant(value=ir.tensor(sin_value))
if inv_freq in self._inv_freq_cos_sin_cache:
cos_2d, sin_2d = self._inv_freq_cos_sin_cache[inv_freq]
else:
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
cos_value = np.cos(angles)
sin_value = np.sin(angles)
cos_2d = op.Constant(value=ir.tensor(cos_value))
sin_2d = op.Constant(value=ir.tensor(sin_value))
self._inv_freq_cos_sin_cache[inv_freq] = (cos_2d, sin_2d)
return op.RotaryEmbedding(
x,
position_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers import fuse_cos_sin_cache, fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._smollm_1 import TestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run


class TestCosSinCacheTransform(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
smollm_test = TestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
inputs = smollm_test.get_ort_inputs()
Expand Down
19 changes: 19 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/fuse_xformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.sdpa import fuse_sdpa
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization


def fuse_xformers(model):
fuse_rms_normalization(model)
fuse_normalization(model)
fuse_rotary_embedding(model)
fuse_cos_sin_cache(model)
fuse_sdpa(model)
fuse_mha(model)
178 changes: 178 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/mha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from typing import Sequence

import onnxscript.ir as ir
from onnxscript.rewriter import pattern

"""
The MultiHeadAttention pattern:

B: Batch size
S: Sequence length
D: input embedding dimension
H: number of heads
d_h: head size (usually, D = H * d_h)

thus, weights are usually of shape (D, D) and (D, D) and (D, D)

for each of Q, K, and V, we have the following pattern:
MatMul (Input, W), producing output of shape (B, S, D)
Reshape to produce a matrix of shape (B, S, H, d_h)
Transpose middle two axes to produce a matrix of shape (B, H, S, d_h)

This is followed by a RotaryEmbedding pattern for Q and K

The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence)

The dot-product attention is then computed using SDPA.
Finally, the output is transposed and reshaped back to (B, S, D) shape
"""


def _project_transpose_head(op, input, weight, reshape_var: str):
"""Applied to each of Q, K, and V."""
projected = op.MatMul(input, weight)
# Reshape from (B, S, D) to (B, S, H, D/H)
reshaped = op.Reshape(
projected,
_allow_other_inputs=True,
_allow_other_attributes=True,
_outputs=[reshape_var],
)
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3])
return transposed


def _multi_head_attention_pattern(
op,
input,
query_weight,
key_weight,
value_weight,
mask,
cos,
sin,
past_key,
past_value,
position_ids,
):
query = _project_transpose_head(op, input, query_weight, "query_mm_reshaped")
query_rope = op.RotaryEmbedding(query, position_ids, cos, sin, _domain="com.microsoft")
key = _project_transpose_head(op, input, key_weight, "key_mm_reshaped")
key_rope = op.RotaryEmbedding(key, position_ids, cos, sin, _domain="com.microsoft")
key_rope = op.Concat(past_key, key_rope, axis=-2)
# Transpose last two axes of key_rope to compute dot-product via matmul.
key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True, _outputs=["key_reshaped"])
key_reshaped_transposed = op.Transpose(key_reshaped, perm=[0, 2, 1])
key_transposed = op.Reshape(
key_reshaped_transposed, _allow_other_inputs=True, _outputs=["key_transposed"]
)
value = _project_transpose_head(op, input, value_weight, "value_mm_reshaped")
value = op.Concat(past_value, value, axis=-2)
attention = op.SDPA(
query_rope, key_transposed, value, mask, _domain="ai.onnxruntime.fusion"
)
# Transpose back to (B, S, H, D/H)
attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3])
# Reshape back to (B, S, D)
attention_reshaped = op.Reshape(
attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"]
)
return attention_reshaped, key_rope, value


def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Sequence[str]) -> bool:
if val.shape is None:
return False
if val.shape.rank() != len(shape):
return False
for actual, expected in zip(val.shape, shape):
if expected not in bindings:
bindings[expected] = actual # type: ignore[assignment]
elif actual != bindings[expected]:
return False
return True


def _mha_validation(
op,
query_mm_reshaped,
key_mm_reshaped,
value_mm_reshaped,
key_reshaped,
key_transposed,
attention_reshaped,
**_,
):
bindings: dict[str, int] = {}
check = (
_check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"])
and _check_shape(bindings, key_mm_reshaped, ["B", "KVS", "H", "d_h"])
and _check_shape(bindings, value_mm_reshaped, ["B", "KVS", "H", "d_h"])
and _check_shape(bindings, key_reshaped, ["B*H", "TS", "d_h"])
and _check_shape(bindings, key_transposed, ["B", "H", "d_h", "TS"])
and _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"])
)
if not check:
return False
if bindings["B"] * bindings["H"] != bindings["B*H"]:
return False
if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]:
return False
return True


def _multi_head_attention(
op,
input,
query_weight,
key_weight,
value_weight,
mask,
cos,
sin,
past_key,
past_value,
position_ids,
query_mm_reshaped,
**_,
):
num_heads = query_mm_reshaped.shape[2]
query = op.MatMul(input, query_weight)
query_rope = op.RotaryEmbedding(query, position_ids, cos, sin, _domain="com.microsoft")
key = op.MatMul(input, key_weight)
key_rope = op.RotaryEmbedding(key, position_ids, cos, sin, _domain="com.microsoft")
value = op.MatMul(input, value_weight)
tiling_factor = op.Constant(value_ints=[1, num_heads, 1, 1])
expanded_mask = op.Tile(mask, tiling_factor)
return op.MultiHeadAttention(
query_rope,
key_rope,
value,
None, # bias
None, # key padding mask
expanded_mask, # attention mask/bias
past_key,
past_value,
num_heads=num_heads,
_domain="com.microsoft",
_outputs=3,
)


_rule1 = pattern.RewriteRule(
_multi_head_attention_pattern, _multi_head_attention, _mha_validation
)


mha_rules = pattern.RewriteRuleSet([_rule1])


def fuse_mha(model: ir.Model) -> int:
count = mha_rules.apply_to_model(model)
print(f"MHA count: {count}")
return count
40 changes: 40 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/mha_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnxscript.optimizer
import onnxscript.rewriter.onnxruntime.xformers as xformers
from onnxscript.rewriter.onnxruntime.xformers._smollm_2 import TestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run


class TestMultiHeadAttention(unittest.TestCase):
def test_smollm(self):
# Generate model
smollm_test = TestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
xformers.fuse_rms_normalization(model)
xformers.fuse_normalization(model)
xformers.fuse_rotary_embedding(model)
xformers.fuse_cos_sin_cache(model)

# Run model
inputs = smollm_test.get_ort_inputs()
original_outputs = ort_run("original", model, inputs)

# Fuse SDPA and MHA
sdpa_count = xformers.fuse_sdpa(model)
self.assertGreater(sdpa_count, 0)
mha_count = xformers.fuse_mha(model)
self.assertGreater(mha_count, 0)

# Run model again
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import unittest

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._smollm_1 import TestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization


class TestRmsNormalization(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
smollm_test = TestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
inputs = smollm_test.get_ort_inputs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import unittest

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._smollm_1 import TestData
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding


class TestRotaryEmbedding(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
smollm_test = TestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
fuse_rotary_embedding(model)
Expand Down
Loading
Loading