Migrate graph surgeries to onnx_ir + onnxscript rewriter (batch 1: infra + first surgeries)#2550
Merged
Merged
Conversation
Introduce a RewriteRuleSurgeon base class that lets graph surgeries be expressed as onnxscript rewrite rules over the ONNX IR model, instead of manual protobuf / OnnxDAG manipulation. Subclasses implement rules() returning a RewriteRuleSet; the base applies them via call_ir, so the rewriter handles operand commutativity, use-count bookkeeping, and dead-node cleanup. Port the first two pattern-based surgeries to this base: - ReciprocalMulToDiv: a * Reciprocal(x) -> Div(a, x) (commute=True covers both operand orders). - ReplaceErfWithTanh: Erf(x) -> Tanh(x * 605/503), emitting the scale as an initializer of the input's floating-point dtype. This is the first batch of an incremental migration of graph_surgeries.py off the protobuf/OnnxDAG approach; subsequent batches will port the remaining pattern-based surgeries (Gemm<->MatMul+Add, QDQ, RMSNorm variants, decompositions, ...) and move the whole-graph surgeries to plain onnx_ir. Update the ReplaceErfWithTanh test to read the scale via numpy_helper.to_array so it is agnostic to raw_data vs float_data tensor storage. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR is the first batch of an incremental refactor of GraphSurgeries patterns from manual ONNX proto/DAG manipulation to ONNX IR (onnx_ir) plus onnxscript’s pattern rewriter, with a small test adjustment to accommodate IR-emitted initializers.
Changes:
- Introduces a
RewriteRuleSurgeonbase class to implement local graph surgeries asonnxscript.rewriter.patternrewrite rule sets applied on the IR model. - Ports two surgeries to the rewrite-rule approach:
ReplaceErfWithTanhandReciprocalMulToDiv. - Updates the
ReplaceErfWithTanhunit test to read initializer values vianumpy_helper.to_array(works for bothraw_dataandfloat_datastorage).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
olive/passes/onnx/graph_surgeries.py |
Adds rewrite-rule surgeon infrastructure and migrates two surgeries to ONNX IR + onnxscript rewriter. |
test/passes/onnx/test_graph_surgeries.py |
Makes the scale-constant assertion robust to initializer serialization format (raw_data vs float_data). |
Convert RemoveGidxFromMatMulNBits from protobuf iteration to an onnx_ir call_ir implementation: drop a sorted (identity-permutation) g_idx input via resize_inputs and prune the now-unused g_idx initializer. Behavior unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
- InferShapes: delegate to onnx_ir ShapeInferencePass. - RemoveShapes: clear type/shape on intermediate values (empties value_info). - RemoveInputs: drop named graph inputs and their node references via onnx_ir, removing nodes left with no inputs. Behavior unchanged; verified by existing surgery tests. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Rebuild ZeroOutInput on the onnx_ir API: read the target input's shape/dtype from the IR value, emit a zero Constant, and rewire the node input. Update the test to read the constant via numpy_helper.to_array (IR stores tensors as raw_data). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Reimplement RemoveMemcpy on onnx_ir: bypass 1-in/1-out MemcpyToHost/MemcpyFromHost nodes via Value.replace_all_uses_with (which follows consumers into subgraphs), recurse into Loop/If/Scan subgraphs, preserve public output names on the output boundary, and re-order with TopologicalSortPass. Replaces ~185 lines of manual proto bypass/rename/topo-sort logic with ~40. Behavior verified by the 4 existing RemoveMemcpy tests. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Rewrite ReplaceAttentionMaskValue on onnx_ir: clamp below-threshold entries in float Constant/ConstantOfShape node values and initializers whose consumers are all mask-compatible ops. Behavior unchanged; verified by the existing test. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert RMSNorm, SimplifiedLayerNorm, and Pow/ReduceSum norm graph surgeries to mutate onnx_ir directly while preserving weight scaling, all-ones weights, and ReduceMean opset handling. Full graph surgery tests pass and lint is clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert the shape-dependent MatMul/Add and Gemm rewrites to the onnx_ir Surgeon path while preserving reshape, Relu, and transB handling. Targeted and full graph surgery tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert the GQA RoPE cache, attention-mask sequence length, and quantized-output exposure surgeries to operate through onnx_ir. Behavior is unchanged; graph surgery tests and lint pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert both graph surgeons to implement call_ir using onnx_ir while preserving their quantized initializer creation, shared-weight rewiring, output-name handling, and cleanup behavior. Verified with ad-hoc tiny ONNX models for both surgeries plus the existing graph_surgeries pytest module. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
justinchuby
commented
Jul 1, 2026
Add regression coverage for two previously untested migrated surgeries: - PowReduceSumPowDiv2LpNorm: Pow(2)->ReduceSum->Pow(0.5)->Div collapses to LpNormalization. - QuantizeEmbeddingInt8: an embed_tokens Gather over an FP16 weight becomes an INT8 GatherBlockQuantized with a uint8 quantized table. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
- RemoveShapes: iterate model.graph.all_nodes() so value_info is cleared in subgraphs too, preserving every graph's declared outputs. - ReplaceErfWithTanh: restore BFLOAT16 support (via ml_dtypes) so the scale is emitted in the input's floating-point dtype as documented. - ExposeQuantizedOutput: guard against missing/None scale/zero-point inputs instead of dereferencing .name on None (and assuming >=3 inputs). - RemoveRopeMultiCache: only remove the If-condition producer when it is a Greater node with no remaining consumers (avoid removing an unrelated node or raising StopIteration). - QuantizeEmbeddingInt8 / ShareEmbeddingLmHead: do not downgrade an existing com.microsoft opset version; bump up to at least 1. - AttentionMaskToSequenceLengths: default batch dim to 1 when input_ids is missing or its shape is unknown (dynamic models). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
jambayk
approved these changes
Jul 2, 2026
Contributor
|
/azp run |
|
Azure Pipelines successfully started running 1 pipeline(s). |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Describe your changes
First batch of an incremental migration of
graph_surgeries.pyoff theprotobuf /
OnnxDAGapproach onto the ONNX IR (onnx_ir) +onnxscriptrewriter.
Infrastructure
RewriteRuleSurgeon(Surgeon)base class. Subclasses implementrules()returning an
onnxscript.rewriter.pattern.RewriteRuleSet; the base applies itto the IR model via
call_ir. This lets local subgraph pattern replacements beexpressed declaratively, and the rewriter handles operand commutativity,
use-count bookkeeping, and dead-node cleanup for us.
First surgeries ported
ReciprocalMulToDiv:a * Reciprocal(x)→Div(a, x)(commute=Truecoversboth
Muloperand orders; a sharedReciprocalis preserved automatically).ReplaceErfWithTanh:Erf(x)→Tanh(x * 605/503), emitting the scale as aninitializer of the input's floating-point dtype; non-float inputs are skipped.
This trims ~120 lines of manual proto walking. Subsequent batches will port the
remaining pattern-based surgeries (
Gemm↔MatMul+Add, QDQ passes, RMSNormvariants, decompositions, ...) and move the whole-graph surgeries (rename/expose
I/O,
Non4D*, dedup,TieWordEmbeddings, ...) to plainonnx_ir.Test change
test_replace_erf_with_tanhnow reads the scale initializer vianumpy_helper.to_arrayinstead of.float_data, so it is agnostic to whetherthe tensor is stored as
raw_dataorfloat_data(IR emitsraw_data).No behavior change for users; the two ported surgeries produce equivalent graphs.
Checklist before requesting a review
test_graph_surgeries.py: 82 passed, 2 skipped)lintrunner -a(Optional) Issue link