-
Notifications
You must be signed in to change notification settings - Fork 86
Add rotary embedding fusion rule (part 1) #1981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
8de7231
First version
gramalingam a20b903
Add rotary embedding
gramalingam b8f7a08
Remove SDPA
gramalingam 315c94e
Add comment
gramalingam 2219fd3
Remove MHA
gramalingam f77f0e7
Merge branch 'main' into rama/fuse-attn
gramalingam 5ec9d1e
Add rewrite for cos-sin computation
gramalingam 90f0b7b
Merge branch 'rama/fuse-attn' of https://github.com/microsoft/onnx-sc…
gramalingam 1fdc19b
Run lint
gramalingam eb916b8
Add cos sin test
gramalingam d874dbc
Extend rewriter to support node reuse
gramalingam a745039
Minor fixes
gramalingam 17c06c3
Fix concat bug in rotary embedding
gramalingam c7c7c79
Minor cleanup
gramalingam 834815b
Merge branch 'main' into rama/fuse-attn
gramalingam 9a4a58e
Use callable to test callable
gramalingam 766791d
Fix lint issues
gramalingam 2b5309a
Update generic matcher for new parameter
gramalingam 4c0e5f9
Merge branch 'main' into rama/fuse-attn
gramalingam 867cbd8
Merge branch 'main' into rama/fuse-attn
gramalingam ed781df
Update onnxscript/rewriter/onnxruntime/xformers/__init__.py
gramalingam f346d0b
Update onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
gramalingam 7399941
Merge branch 'main' into rama/fuse-attn
gramalingam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,15 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| from __future__ import annotations | ||
|
|
||
| __all__ = [ | ||
| "fuse_rms_normalization", | ||
| "fuse_normalization", | ||
| "fuse_rotary_embedding", | ||
| "fuse_cos_sin_cache", | ||
| ] | ||
|
|
||
| from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache | ||
| from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization | ||
| from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding | ||
| from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
102 changes: 102 additions & 0 deletions
102
onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
|
||
| # Licensed under the MIT License. | ||
| from __future__ import annotations | ||
|
|
||
| import numpy as np | ||
|
|
||
| import onnxscript.ir as ir | ||
| from onnxscript.optimizer import remove_unused_nodes | ||
| from onnxscript.rewriter import _ir_utils, pattern | ||
|
|
||
| # Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. | ||
|
|
||
| # We match against the following code pattern: | ||
| # Original code (from transformers) for computing cos/sin cache for RoPE: | ||
| # https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135 | ||
| # position_ids_expanded = position_ids[:, None, :].float() | ||
| # freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | ||
| # emb = torch.cat((freqs, freqs), dim=-1) | ||
| # cos = emb.cos() | ||
| # sin = emb.sin() | ||
| # | ||
| # We rewrite this pattern into the following form: | ||
| # inv_freq_values = inv_freq_expanded.reshape(1, -1) | ||
| # pos_id_range = np.arange(max_pos_id, dtype=np.float32).reshape(-1, 1) | ||
| # angles = np.matmul(pos_id_range, inv_freq_values) | ||
| # cos_value = np.cos(angles) | ||
| # sin_value = np.sin(angles) | ||
| # cos_2d = op.Constant(value=ir.tensor(cos_value)) | ||
| # sin_2d = op.Constant(value=ir.tensor(sin_value)) | ||
| # | ||
| # This produces cos/sin values in a form that can be used by ORT's custom ops. | ||
|
|
||
| # TODO: To apply the pattern-rewrite, we need to know the maximum position id. | ||
| # Need to find a way to get this information from the model or its config. | ||
|
|
||
|
|
||
| class CosSinCacheFusion(pattern.RewriteRuleClassBase): | ||
| def __init__(self, name: str, max_pos_id: int): | ||
| # This pattern makes use of shared Cos/Sin values. So, we can't remove the | ||
| # matched nodes as part of the rewrite-step. We apply a separate final | ||
| # pass to remove unused nodes. | ||
| super().__init__(name, remove_nodes=False) | ||
| self._max_pos_id = max_pos_id | ||
|
|
||
| def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads): | ||
| position_ids_expanded = op.Unsqueeze(position_ids, 1) | ||
| position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) | ||
| freqs = op.MatMul(inv_freq, position_ids_expanded) | ||
| freqs = op.Transpose(freqs, perm=[0, 2, 1]) | ||
| emb = op.Concat(freqs, freqs, axis=-1) | ||
| cos = op.Cos(emb) | ||
| sin = op.Sin(emb) | ||
| cos_4d = op.Unsqueeze(cos, 1) # convert | ||
| sin_4d = op.Unsqueeze(sin, 1) | ||
| return op.RotaryEmbedding( | ||
| x, | ||
| cos_4d, | ||
| sin_4d, | ||
| interleaved=interleaved, | ||
| num_heads=num_heads, | ||
| _domain="ai.onnxruntime.fusion", | ||
| ) | ||
|
|
||
| def check(self, context, inv_freq, position_ids, **_) -> bool: | ||
| if not _ir_utils.has_rank(position_ids, 2): | ||
| return False | ||
| if not _ir_utils.has_rank(inv_freq, 3): | ||
| return False | ||
| inv_freq_shape = inv_freq.shape | ||
| if inv_freq.const_value is None: | ||
| return False | ||
| return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 | ||
|
|
||
| def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_): | ||
| inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1) | ||
| pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) | ||
| angles = np.matmul(pos_id_range, inv_freq_values) | ||
| cos_value = np.cos(angles) | ||
| sin_value = np.sin(angles) | ||
| cos_2d = op.Constant(value=ir.tensor(cos_value)) | ||
| sin_2d = op.Constant(value=ir.tensor(sin_value)) | ||
| return op.RotaryEmbedding( | ||
| x, | ||
| position_ids, | ||
| cos_2d, | ||
| sin_2d, | ||
| interleaved=interleaved, | ||
| num_heads=num_heads, | ||
| _domain="com.microsoft", | ||
| ) | ||
|
|
||
|
|
||
| _rule = CosSinCacheFusion.rule("CosSinCache", 2048) | ||
|
|
||
| cos_sin_cache_rules = pattern.RewriteRuleSet([_rule]) | ||
|
|
||
|
|
||
| def fuse_cos_sin_cache(model: ir.Model) -> int: | ||
| count = cos_sin_cache_rules.apply_to_model(model) | ||
| print(f"CosSinCache count: {count}") | ||
| remove_unused_nodes(model) | ||
| return count | ||
29 changes: 29 additions & 0 deletions
29
onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
|
||
| # Licensed under the MIT License. | ||
| from __future__ import annotations | ||
|
|
||
| import unittest | ||
|
|
||
| import onnxscript.optimizer | ||
| from onnxscript.rewriter.onnxruntime.xformers import fuse_cos_sin_cache, fuse_rotary_embedding | ||
| from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData | ||
| from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run | ||
|
|
||
|
|
||
| class TestCosSinCacheTransform(unittest.TestCase): | ||
| def test_smollm(self): | ||
| smollm_test = _SmollmTestData() | ||
| model = smollm_test.get_onnx_model() | ||
| onnxscript.optimizer.optimize(model) | ||
| inputs = smollm_test.get_ort_inputs() | ||
| original_outputs = ort_run("original", model, inputs) | ||
| count = fuse_rotary_embedding(model) | ||
|
||
| self.assertGreater(count, 0) | ||
| count = fuse_cos_sin_cache(model) | ||
| self.assertGreater(count, 0) | ||
| new_outputs = ort_run("optimized", model, inputs) | ||
| assert_allclose(new_outputs, original_outputs) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| from __future__ import annotations | ||
|
|
||
| import onnxscript.ir as ir | ||
| from onnxscript.rewriter import _ir_utils, pattern | ||
|
|
||
| # Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern | ||
| # for full rotation without interleaving. | ||
| # TODO(rama): Add pattern variations to handle other cases (interleaved, as well as partial rotation). | ||
|
|
||
| # Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet. | ||
| # so it can't be tested by running against ORT. See cos_sin_cache.py for a transformation that | ||
| # rewrites the pattern into one that can be run against ORT. | ||
|
|
||
|
|
||
| def _rotate_half_pattern(op, x, start1, end1, start2, end2): | ||
| # Slice(input, starts, ends, axes, steps) | ||
| x1 = op.Slice(x, start1, end1, [3], [1]) | ||
| x2 = op.Slice(x, start2, end2, [3], [1]) | ||
| minus_x2 = op.Neg(x2) | ||
gramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| rotated_x = op.Concat(minus_x2, x1, axis=-1) | ||
| return rotated_x | ||
|
|
||
|
|
||
| class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase): | ||
| def pattern(self, op, x, cos, sin, start1, end1, start2, end2): | ||
| return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin | ||
|
|
||
| def check(self, op, x, start1, end1, start2, end2, **_): | ||
| # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) | ||
gramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if x is None or x.shape is None or len(x.shape) != 4: | ||
| return False | ||
| if not isinstance(x.shape[1], int): | ||
| return False | ||
| head_size = x.shape[3] | ||
| if not isinstance(head_size, int): | ||
| return False | ||
| half_head_size = head_size // 2 | ||
|
|
||
| # Check that x is being split into two equal halves of size half_head_size | ||
| return ( | ||
| _ir_utils.is_singleton_value(start1, 0) | ||
| and _ir_utils.is_singleton_value(end1, half_head_size) | ||
| and _ir_utils.is_singleton_value(start2, half_head_size) | ||
| and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) | ||
| ) | ||
|
|
||
| def rewrite(self, op, x, cos, sin, **_): | ||
| num_heads = x.shape[1] | ||
| return op.RotaryEmbedding( | ||
| x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion" | ||
gramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
|
|
||
| _rule = RotaryEmbeddingFusion.rule() | ||
|
|
||
| rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) | ||
|
|
||
|
|
||
| def fuse_rotary_embedding(model: ir.Model) -> int: | ||
| count = rotary_embedding_rules.apply_to_model(model) | ||
| print(f"Rotary Embedding count: {count}") | ||
| return count | ||
23 changes: 23 additions & 0 deletions
23
onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| from __future__ import annotations | ||
|
|
||
| import unittest | ||
|
|
||
| import onnxscript.optimizer | ||
| from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData | ||
| from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding | ||
|
|
||
|
|
||
| class TestRotaryEmbedding(unittest.TestCase): | ||
| def test_smollm(self): | ||
| smollm_test = _SmollmTestData() | ||
| model = smollm_test.get_onnx_model() | ||
| onnxscript.optimizer.optimize(model) | ||
| fuse_rotary_embedding(model) | ||
| op_types = [n.op_type for n in model.graph] | ||
| self.assertIn("RotaryEmbedding", op_types) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.