Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.ir.passes.common import shape_inference
from onnxscript.optimizer import optimize, remove_unused_nodes
from onnxscript.rewriter import rewrite
from onnxscript.rewriter.ort_fusions import (
Expand All @@ -12,9 +13,13 @@
softmax,
)
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
fuse_partial_rotary_embedding,
fuse_rotary_embedding,
)
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
from onnxscript.rewriter.ort_fusions.skip_normalization import fuse_normalization

Expand All @@ -27,14 +32,29 @@
]


def fuse_xformers(model: ir.Model) -> None:
# Preliminary optimizations before applying the transformer fusions.
# TODO: There are some potential redundancies below. Can be targeted for optimization
# once we have robust fusion.
def _pre_optimize(model: ir.Model) -> ir.Model:
optimize(model)

Check warning on line 39 in onnxscript/rewriter/ort_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L39

Added line #L39 was not covered by tests
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
# incorporated in our optimizer.
model = shape_inference.infer_shapes(model)

Check warning on line 43 in onnxscript/rewriter/ort_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L43

Added line #L43 was not covered by tests
optimize(model)
return model

Check warning on line 45 in onnxscript/rewriter/ort_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L45

Added line #L45 was not covered by tests


def fuse_xformers(model: ir.Model) -> None:
model = _pre_optimize(model)

Check warning on line 49 in onnxscript/rewriter/ort_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L49

Added line #L49 was not covered by tests
fuse_rms_normalization(model)
fuse_normalization(model)
fuse_rotary_embedding(model)
fuse_partial_rotary_embedding(model)

Check warning on line 53 in onnxscript/rewriter/ort_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L53

Added line #L53 was not covered by tests
fuse_cos_sin_cache(model)
fuse_sdpa(model)
fuse_mha(model)
fuse_gelu(model)

Check warning on line 57 in onnxscript/rewriter/ort_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L57

Added line #L57 was not covered by tests
remove_unused_nodes(model)


Expand Down
Loading