Skip to content

Migrate graph surgeries to onnx_ir + onnxscript rewriter (batch 1: infra + first surgeries)#2550

Merged
jambayk merged 13 commits into
mainfrom
justinchu/graph-surgeries-ir-rewriter
Jul 2, 2026
Merged

Migrate graph surgeries to onnx_ir + onnxscript rewriter (batch 1: infra + first surgeries)#2550
jambayk merged 13 commits into
mainfrom
justinchu/graph-surgeries-ir-rewriter

Conversation

@justinchuby

Copy link
Copy Markdown
Contributor

Describe your changes

First batch of an incremental migration of graph_surgeries.py off the
protobuf / OnnxDAG approach onto the ONNX IR (onnx_ir) + onnxscript
rewriter.

Infrastructure

  • Add a RewriteRuleSurgeon(Surgeon) base class. Subclasses implement rules()
    returning an onnxscript.rewriter.pattern.RewriteRuleSet; the base applies it
    to the IR model via call_ir. This lets local subgraph pattern replacements be
    expressed declaratively, and the rewriter handles operand commutativity,
    use-count bookkeeping, and dead-node cleanup for us.

First surgeries ported

  • ReciprocalMulToDiv: a * Reciprocal(x)Div(a, x) (commute=True covers
    both Mul operand orders; a shared Reciprocal is preserved automatically).
  • ReplaceErfWithTanh: Erf(x)Tanh(x * 605/503), emitting the scale as an
    initializer of the input's floating-point dtype; non-float inputs are skipped.

This trims ~120 lines of manual proto walking. Subsequent batches will port the
remaining pattern-based surgeries (GemmMatMul+Add, QDQ passes, RMSNorm
variants, decompositions, ...) and move the whole-graph surgeries (rename/expose
I/O, Non4D*, dedup, TieWordEmbeddings, ...) to plain onnx_ir.

Test change

  • test_replace_erf_with_tanh now reads the scale initializer via
    numpy_helper.to_array instead of .float_data, so it is agnostic to whether
    the tensor is stored as raw_data or float_data (IR emits raw_data).

No behavior change for users; the two ported surgeries produce equivalent graphs.

Checklist before requesting a review

  • Add unit tests for this change. (existing surgery tests cover both; all pass)
  • Make sure all tests can pass. (test_graph_surgeries.py: 82 passed, 2 skipped)
  • Update documents if necessary.
  • Lint and apply fixes to your code by running lintrunner -a
  • Is this a user-facing change? No — internal refactor, equivalent output.

(Optional) Issue link

Introduce a RewriteRuleSurgeon base class that lets graph surgeries be expressed
as onnxscript rewrite rules over the ONNX IR model, instead of manual protobuf /
OnnxDAG manipulation. Subclasses implement rules() returning a RewriteRuleSet;
the base applies them via call_ir, so the rewriter handles operand commutativity,
use-count bookkeeping, and dead-node cleanup.

Port the first two pattern-based surgeries to this base:
- ReciprocalMulToDiv: a * Reciprocal(x) -> Div(a, x) (commute=True covers both
  operand orders).
- ReplaceErfWithTanh: Erf(x) -> Tanh(x * 605/503), emitting the scale as an
  initializer of the input's floating-point dtype.

This is the first batch of an incremental migration of graph_surgeries.py off the
protobuf/OnnxDAG approach; subsequent batches will port the remaining pattern-based
surgeries (Gemm<->MatMul+Add, QDQ, RMSNorm variants, decompositions, ...) and move
the whole-graph surgeries to plain onnx_ir.

Update the ReplaceErfWithTanh test to read the scale via numpy_helper.to_array so
it is agnostic to raw_data vs float_data tensor storage.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Copilot AI review requested due to automatic review settings July 1, 2026 18:01

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR is the first batch of an incremental refactor of GraphSurgeries patterns from manual ONNX proto/DAG manipulation to ONNX IR (onnx_ir) plus onnxscript’s pattern rewriter, with a small test adjustment to accommodate IR-emitted initializers.

Changes:

  • Introduces a RewriteRuleSurgeon base class to implement local graph surgeries as onnxscript.rewriter.pattern rewrite rule sets applied on the IR model.
  • Ports two surgeries to the rewrite-rule approach: ReplaceErfWithTanh and ReciprocalMulToDiv.
  • Updates the ReplaceErfWithTanh unit test to read initializer values via numpy_helper.to_array (works for both raw_data and float_data storage).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
olive/passes/onnx/graph_surgeries.py Adds rewrite-rule surgeon infrastructure and migrates two surgeries to ONNX IR + onnxscript rewriter.
test/passes/onnx/test_graph_surgeries.py Makes the scale-constant assertion robust to initializer serialization format (raw_data vs float_data).

Comment thread olive/passes/onnx/graph_surgeries.py
justinchuby and others added 9 commits July 1, 2026 19:14
Convert RemoveGidxFromMatMulNBits from protobuf iteration to an onnx_ir call_ir
implementation: drop a sorted (identity-permutation) g_idx input via
resize_inputs and prune the now-unused g_idx initializer. Behavior unchanged.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
- InferShapes: delegate to onnx_ir ShapeInferencePass.
- RemoveShapes: clear type/shape on intermediate values (empties value_info).
- RemoveInputs: drop named graph inputs and their node references via onnx_ir,
  removing nodes left with no inputs.

Behavior unchanged; verified by existing surgery tests.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Rebuild ZeroOutInput on the onnx_ir API: read the target input's shape/dtype from
the IR value, emit a zero Constant, and rewire the node input. Update the test to
read the constant via numpy_helper.to_array (IR stores tensors as raw_data).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Reimplement RemoveMemcpy on onnx_ir: bypass 1-in/1-out MemcpyToHost/MemcpyFromHost
nodes via Value.replace_all_uses_with (which follows consumers into subgraphs),
recurse into Loop/If/Scan subgraphs, preserve public output names on the output
boundary, and re-order with TopologicalSortPass. Replaces ~185 lines of manual
proto bypass/rename/topo-sort logic with ~40. Behavior verified by the 4 existing
RemoveMemcpy tests.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Rewrite ReplaceAttentionMaskValue on onnx_ir: clamp below-threshold entries in
float Constant/ConstantOfShape node values and initializers whose consumers are
all mask-compatible ops. Behavior unchanged; verified by the existing test.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert RMSNorm, SimplifiedLayerNorm, and Pow/ReduceSum norm graph surgeries to mutate onnx_ir directly while preserving weight scaling, all-ones weights, and ReduceMean opset handling. Full graph surgery tests pass and lint is clean.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert the shape-dependent MatMul/Add and Gemm rewrites to the onnx_ir Surgeon path while preserving reshape, Relu, and transB handling. Targeted and full graph surgery tests pass.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert the GQA RoPE cache, attention-mask sequence length,
and quantized-output exposure surgeries to operate through onnx_ir.
Behavior is unchanged; graph surgery tests and lint pass.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert both graph surgeons to implement call_ir using onnx_ir while preserving their quantized initializer creation, shared-weight rewiring, output-name handling, and cleanup behavior. Verified with ad-hoc tiny ONNX models for both surgeries plus the existing graph_surgeries pytest module.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Comment thread olive/passes/onnx/graph_surgeries.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 1 out of 2 changed files in this pull request and generated 5 comments.

Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment thread olive/passes/onnx/graph_surgeries.py Outdated
justinchuby and others added 3 commits July 1, 2026 21:18
Add regression coverage for two previously untested migrated surgeries:
- PowReduceSumPowDiv2LpNorm: Pow(2)->ReduceSum->Pow(0.5)->Div collapses to LpNormalization.
- QuantizeEmbeddingInt8: an embed_tokens Gather over an FP16 weight becomes an INT8
  GatherBlockQuantized with a uint8 quantized table.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
- RemoveShapes: iterate model.graph.all_nodes() so value_info is cleared in
  subgraphs too, preserving every graph's declared outputs.
- ReplaceErfWithTanh: restore BFLOAT16 support (via ml_dtypes) so the scale is
  emitted in the input's floating-point dtype as documented.
- ExposeQuantizedOutput: guard against missing/None scale/zero-point inputs
  instead of dereferencing .name on None (and assuming >=3 inputs).
- RemoveRopeMultiCache: only remove the If-condition producer when it is a
  Greater node with no remaining consumers (avoid removing an unrelated node
  or raising StopIteration).
- QuantizeEmbeddingInt8 / ShareEmbeddingLmHead: do not downgrade an existing
  com.microsoft opset version; bump up to at least 1.
- AttentionMaskToSequenceLengths: default batch dim to 1 when input_ids is
  missing or its shape is unknown (dynamic models).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
@jambayk

jambayk commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

/azp run

@azure-pipelines

Copy link
Copy Markdown
Azure Pipelines successfully started running 1 pipeline(s).

@jambayk jambayk enabled auto-merge (squash) July 2, 2026 21:17
@jambayk jambayk merged commit 5957997 into main Jul 2, 2026
12 of 13 checks passed
@jambayk jambayk deleted the justinchu/graph-surgeries-ir-rewriter branch July 2, 2026 22:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants