Skip to content

Add an option to not inline a function when building the graph#2851

Merged
justinchuby merged 34 commits intomainfrom
justinchu/function-inline
Apr 17, 2026
Merged

Add an option to not inline a function when building the graph#2851
justinchuby merged 34 commits intomainfrom
justinchu/function-inline

Conversation

@justinchuby
Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby commented Mar 13, 2026

  • Introduced distinct call and call_inline methods in GraphBuilder and OpBuilder to differentiate between creating a single function call node (call) and inlining a function's body directly into the graph (call_inline). The call method now registers the function in the builder, while call_inline does not.

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support in the internal GraphBuilder to optionally not inline an ONNXScript function call, emitting a single function-call node and tracking the referenced function definitions for later export/attachment.

Changes:

  • Add Op.domain convenience property to expose an op’s opset domain.
  • Extend GraphBuilder.call/OpBuilder.call with _inline flag; when _inline=False, emit a single call node and register the callee in GraphBuilder.functions.
  • Add unit tests covering _inline=False behavior (single node emission, output renaming, prefix handling, function registration).

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
onnxscript/_internal/values.py Adds Op.domain property forwarding to the bound opset domain.
onnxscript/_internal/builder.py Introduces function registry and _inline option to either inline or emit a function-call node.
onnxscript/_internal/builder_test.py Adds tests verifying non-inlined call behavior and registration.

Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder_test.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 13, 2026

Codecov Report

❌ Patch coverage is 83.44371% with 25 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.52%. Comparing base (90f754a) to head (461a10c).
⚠️ Report is 1 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/_internal/builder_test.py 86.17% 13 Missing ⚠️
onnxscript/_internal/builder.py 79.62% 5 Missing and 6 partials ⚠️
onnxscript/_internal/values.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2851      +/-   ##
==========================================
+ Coverage   72.48%   72.52%   +0.03%     
==========================================
  Files         241      241              
  Lines       29915    30032     +117     
  Branches     2935     2940       +5     
==========================================
+ Hits        21684    21780      +96     
- Misses       7233     7247      +14     
- Partials      998     1005       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Comment thread onnxscript/_internal/builder.py Outdated
@gramalingam
Copy link
Copy Markdown
Collaborator

Not sure if it is better to have a separate function or an option (as in this PR). For now, this seems fine.

Comment thread onnxscript/_internal/builder.py Fixed
Comment thread onnxscript/_internal/builder.py Fixed
…args

OpBuilder._call_op was inserting _domain, _version into the kwargs dict,
but GraphBuilder.call_op expects domain, version, outputs as separate
keyword arguments. This caused them to be treated as node attributes,
breaking custom domain handling, schema lookup, type inference, shape
inference, and output naming.

Changes:
- OpBuilder._call_op: pop _domain, _version, _outputs from kwargs and
  pass as separate keyword args to call_op
- Remove _prefix from GraphBuilder.call and OpBuilder.call (only
  call_inline needs it)
- Update test to use push_module/pop_module instead of _prefix on call
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enhances the ONNXScript IR builder’s function-call behavior by supporting both (a) inlining a function body into the current graph and (b) emitting a single function-call node while tracking the called function for downstream export/inspection.

Changes:

  • Added function tracking/registration in GraphBuilder and a new non-inlining function-call path.
  • Introduced a separate call_inline(...) path and updated/added tests to cover both behaviors.
  • Added a domain accessor on Op (covering OnnxFunction as a subclass) for easier domain access.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
onnxscript/_internal/builder.py Adds function registry, introduces non-inlined function call node creation, and refactors op-call argument handling.
onnxscript/_internal/builder_test.py Updates existing call tests to call_inline and adds new tests for single-node function calls + registration.
onnxscript/_internal/values.py Adds domain property (via Op) to expose opset domain.

Comment thread onnxscript/_internal/builder.py
Comment thread onnxscript/_internal/builder.py
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py
Comment thread onnxscript/_internal/builder.py
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby added this to the 0.7.0 milestone Apr 17, 2026
justinchuby and others added 2 commits April 16, 2026 17:24
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
The constant folding pass was eliminating `DequantizeLinear` nodes that
operated on constant weight tensors during `optimize()`, collapsing the
quantization structure into a plain `Conv` and losing quantization
semantics in QAT-exported models.

### Changes

- **`optimizer/_constant_folding.py`**: Add `DynamicQuantizeLinear` to
`DEFAULT_CONSTANT_FOLD_BLACKLIST` alongside the existing
`QuantizeLinear` and `DequantizeLinear` entries; reorder alphabetically
for consistency
- **`optimizer/_constant_folding_test.py`**: Add tests verifying
`QuantizeLinear` and `DequantizeLinear` are not folded when all inputs
are constant initializers

<!-- START COPILOT ORIGINAL PROMPT -->



<details>

<summary>Original prompt</summary>


----

*This section details on the original issue you should resolve*

<issue_title>[ONNX] Optimize should not fold
DequantizeLinear</issue_title>
<issue_description>### 🐛 Describe the bug

After the QAT model undergoes the onnx_program.optimize() process, there
is a loss of quantization nodes. As shown in the figure on the left is
the normal export, and on the right is the abnormal export graph.

<img width="898" height="884" alt="Image"
src="https://github.com/user-attachments/assets/481bc3c0-38fe-45f6-9fde-bc1a287617a3"
/>


This bug occurred in `torch/onnx/_internal/exporter/_onnx_program.py`: 
```
def optimize(self) -> None:
    self.model = onnxscript_apis.optimize(self.model)
```
and it internally called the optimize_ir function in
`onnxscript/optimizer/_optimizer.py`.
The default value of `input_size_limit` is 512. Nodes with an input size
less than this value will be collapsed.
```
def optimize_ir(
    model: ir.Model,
    num_iterations: int = 2,
    *,
    onnx_shape_inference: bool = True,
    stop_if_no_change: bool = True,
    input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
    output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
    inline: bool = True,
) -> None:
    passes = [
        ir.passes.PassManager(
            [
                _constant_folding.FoldConstantsPass(
                    shape_inference=onnx_shape_inference,
                    input_size_limit=input_size_limit,
                    output_size_limit=output_size_limit,
                ),
                rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
                common_passes.RemoveUnusedNodesPass(),
                common_passes.RemoveUnusedFunctionsPass(),
                common_passes.RemoveUnusedOpsetsPass(),
            ],
            steps=num_iterations,
            early_stop=stop_if_no_change,
        ),
    ......
```
⭐ Please enable the parameter `optimization` function in
`torch/onnx/_internal/exporter/_onnx_program.py`. Otherwise, I will be
able to install onnxscript only by referring to the source code.

The smallest reproducible example:
```
import copy
import torch
import torch.nn as nn
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from onnxslim import slim
import onnx


class ConvBnReluModel(nn.Module):
    def __init__(self, eps=1e-3, momentum=0.03):
        super().__init__()
        self.conv = nn.Conv2d(4, 4, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(4, eps=eps, momentum=momentum)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


def get_batch_norm_node_args(gm):
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.batch_norm.default:
            return tuple(node.args)
    raise RuntimeError("No aten.batch_norm.default node found")


torch.manual_seed(0)
device = 'cuda' 

model = ConvBnReluModel().train().to(device)
inputs = (torch.randn(2, 4, 8, 8).to(device),)
exported = torch.export.export_for_training(copy.deepcopy(model), inputs).module()
print("before prepare:", get_batch_norm_node_args(exported))

quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))
prepared = prepare_qat_pt2e(exported, quantizer)
prepared.to(device)
torch.ao.quantization.move_exported_model_to_eval(prepared)
torch.ao.quantization.allow_exported_model_train_eval(prepared)
prepared.eval()

#---export quantized model to onnx---
qat_onnx_sp = './quant.onnx'
quantized_model = convert_pt2e(prepared)
print('convert_pt2e done!')
onnx_program = torch.onnx.export(quantized_model, inputs, dynamo=True, opset_version=21)
""" bug """
onnx_program.optimize()
onnx_program.save(qat_onnx_sp)
print(f'export qat model to [{qat_onnx_sp}] done!')

model_simp = slim(onnx_program.model_proto)
sim_path = qat_onnx_sp.replace('.onnx', '_slim.onnx')
onnx.save(model_simp, sim_path)
print(f"save onnx model to [{sim_path}] Successfully!")
```

### Versions

Versions of relevant libraries:
[pip3] executorch==0.5.0
[pip3] numpy==1.23.5
[pip3] nvidia-cublas-cu11==11.11.3.6
[pip3] nvidia-cuda-cupti-cu11==11.8.87
[pip3] nvidia-cuda-nvrtc-cu11==11.8.89
[pip3] nvidia-cuda-runtime-cu11==11.8.89
[pip3] nvidia-cudnn-cu11==9.1.0.70
[pip3] nvidia-cufft-cu11==10.9.0.58
[pip3] nvidia-curand-cu11==10.3.0.86
[pip3] nvidia-cusolver-cu11==11.4.1.48
[pip3] nvidia-cusparse-cu11==11.7.5.86
[pip3] nvidia-nccl-cu11==2.21.5
[pip3] nvidia-nvtx-cu11==11.8.86
[pip3] onnx==1.17.0
[pip3] onnx_graphsurgeon==0.5.8
[pip3] onnx-ir==0.1.12
[pip3] onnx-simplifier==0.4.36
[pip3] onnxruntime==1.21.0
[pip3] onnxruntime-gpu==1.21.0
[pip3] onnxscript==0.4.0
[pip3] onnxslim==0.1.48
[pip3] torch==2.6.0+cu118
[pip3] torchao==0.14.1
[pip3] torchaudio==2.6.0+cu118
[pip3] torchvision==0.21.0+cu118
[pip3] ...

</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

- Fixes pytorch/pytorch#177611

<!-- START COPILOT CODING AGENT TIPS -->
---

📍 Connect Copilot coding agent with [Jira](https://gh.io/cca-jira-docs),
[Azure Boards](https://gh.io/cca-azure-boards-docs) or
[Linear](https://gh.io/cca-linear-docs) to delegate work to Copilot in
one click without leaving your project management tool.

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
AyoubMDL and others added 5 commits April 16, 2026 17:27
…#2888)

When a Python int literal (e.g. `1`) is used in both untyped positions
(like a Gather index, where the ONNX schema type variable is unbound)
and typed positions (like Add with an INT64 tensor), the constant cache
created two entries: (1, None) and (1, INT64). Both generated the same
initializer name 'const_1_i64' but as different ir.Value objects,
causing register_initializer to raise ValueError.

Fix: before cache lookup, normalize dtype=None to the default ONNX dtype
for the Python type (_PYTHON_TYPE_TO_DTYPE: int->INT64, float->FLOAT).
This merges both entries into a single cache key and reuses the same
ir.Value. Applied to both scalar and sequence (list/tuple) branches.

---------

Signed-off-by: G Ramalingam <grama@microsoft.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
(a) Update the logic for creating initializers to ensure that they are
always added to the root (main) Graph.
(b) Add a utility to convert initializers into Constant nodes (which is
necessary for a Function in ONNX).

Note: An alternative considered was adding an option to the GraphBuilder
so that we automatically construct Constant nodes instead of
initializers when the options says so. While that avoids the extra pass
in the end, it has some minor implications to what the graph would look
like (in terms of whether we want all Constant nodes upfront in one
place, and what kind of node-names (based on node numbering) we generate
etc. Hence, leaving it in the current form. But it can be changed as
above if desirable.

---------

Signed-off-by: G Ramalingam <grama@microsoft.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby force-pushed the justinchu/function-inline branch from 5dad7c6 to 1322a73 Compare April 17, 2026 00:45
Comment thread onnxscript/_internal/builder_test.py Fixed
Comment thread onnxscript/_internal/builder_test.py Fixed
Comment thread onnxscript/_internal/builder_test.py Fixed
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py Outdated
Comment thread onnxscript/_internal/builder.py
Comment thread onnxscript/_internal/builder.py
Comment thread onnxscript/_internal/builder.py
Comment thread onnxscript/_internal/builder.py
Comment thread onnxscript/_internal/builder_test.py Outdated
Comment thread onnxscript/_internal/builder_test.py
Comment thread onnxscript/_internal/builder.py
…ns, handle edge cases

- Delegate _functions registration to root builder so subgraph
  function calls are not lost when the child builder is discarded
- Fix call_op kwargs type annotation from dict[str, ir.Value | ir.TensorProtocol] to dict[str, Any]
- Handle 0-output case in call() and call_inline() to avoid IndexError
- Add test for calling the same function twice (2 nodes, 1 registration)
- Clarify docstring for call vs call_inline comparison test

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator Author

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed the review comments in 0c240a6:

  • Thread 23 (functions in subgraphs): _functions now delegates to the root builder (same pattern as _constant_cache), so function registrations in subgraphs are preserved.
  • Thread 22 (kwargs annotation): Changed call_op kwargs type to dict[str, Any].
  • Thread 25 (_outputs validation): Intentionally kept flexible — users should be able to customize the number of outputs.
  • Thread 26+27 (0-output IndexError): Both call() and call_inline() now return () for 0 outputs.
  • Thread 28 (comparison test): Updated docstring to clarify its purpose.
  • Thread 29 (same function twice): Added test_call_same_function_twice verifying 2 nodes created but only 1 function registered.
  • Thread 30 (design question): Function registration delegates to root; duplicate calls overwrite with the same object. Will address copy semantics in follow-up if needed.

justinchuby and others added 3 commits April 17, 2026 10:06
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@tadani3 tadani3 self-requested a review April 17, 2026 23:10
@justinchuby justinchuby enabled auto-merge (squash) April 17, 2026 23:12
@justinchuby justinchuby merged commit df97c94 into main Apr 17, 2026
30 of 33 checks passed
@justinchuby justinchuby deleted the justinchu/function-inline branch April 17, 2026 23:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

10 participants