Fix two breaking CI failures: reshape-reshape rule test and logsumexp float16 tolerance#2907
Open
Fix two breaking CI failures: reshape-reshape rule test and logsumexp float16 tolerance#2907
Conversation
…erance Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/cb0312cb-87e4-4b34-845d-32155f2f2eb0 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copilot
AI
changed the title
[WIP] Fix breaking CIs in ort-nightly and torch-nightly
Fix two breaking CI failures: reshape-reshape rule test and logsumexp float16 tolerance
May 2, 2026
Collaborator
|
@copilot create a PR on onnxruntime to fix the reshape bug |
Contributor
Author
I'm unable to create PRs on external repositories — I can only push to the current
The repro is: input shape |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2907 +/- ##
=======================================
Coverage 72.52% 72.52%
=======================================
Files 258 258
Lines 31480 31480
Branches 2972 2972
=======================================
Hits 22831 22831
Misses 7640 7640
Partials 1009 1009 ☔ View full report in Codecov by Sentry. |
justinchuby
approved these changes
May 5, 2026
justinchuby
pushed a commit
to microsoft/onnxruntime
that referenced
this pull request
May 5, 2026
… dims (#28349) ### Description `ReshapeFusion::FuseContiguousReshapes` collapses a chain of `Reshape` / `Squeeze` / `Unsqueeze` nodes into a single `Reshape` whose shape data is taken verbatim from the fully-inferred output shape of the last node in the chain. The new node is created without an `allowzero` attribute, so it defaults to `allowzero = 0`. When that inferred shape contains a literal `0` dim (legitimate when the original chain used `allowzero=1`, or when intermediate tensors had zero-sized dimensions), the fused `Reshape` misinterprets the `0` as "copy the corresponding dim from the input tensor" — but the input here is the original input of the *first* reshape in the chain, with unrelated dims. The result is a silently wrong output shape (and a benign-looking `MergeShapeInfo` warning at graph load). ### Repro (before the fix) ```python import numpy as np, onnx, onnxruntime as ort, onnx.reference from onnx import helper, TensorProto X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [0, 6, 2]) Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None]) s1 = helper.make_tensor("s1", TensorProto.INT64, [3], [3, 2, -1]) s2 = helper.make_tensor("s2", TensorProto.INT64, [3], [0, 0, 3]) n1 = helper.make_node("Reshape", ["X", "s1"], ["mid"]) n2 = helper.make_node("Reshape", ["mid", "s2"], ["Y"], allowzero=1) m = helper.make_model(helper.make_graph([n1, n2], "g", [X], [Y], initializer=[s1, s2]), opset_imports=[helper.make_opsetid("", 18)]) inp = np.random.default_rng(7).random((0, 6, 2), dtype=np.float32) print("REF:", onnx.reference.ReferenceEvaluator(m).run(None, {"X": inp})[0].shape) print("ORT:", ort.InferenceSession(m.SerializeToString(), providers=["CPUExecutionProvider"]).run(None, {"X": inp})[0].shape) ``` Output on `main` (`40c9f85f69`): ``` REF: (0, 0, 3) [W ... graph.cc:122 MergeShapeInfo] Error merging shape info for output. 'Y' source:{0,6,3} target:{0,0,3}. Falling back to lenient merge. ORT: (0, 6, 3) ❌ ``` ### Fix Setting `allowzero=1` on the fused node would also work but requires opset >= 14, which this transformer cannot assume (it accepts `Reshape` opset 5+). Bail out of fusion conservatively when `shape_value` contains any literal `0` dim. ### Test Adds `ReshapeFusionContiguousReshapesWithZeroDim` that builds the bug repro programmatically and asserts: - the two reshapes are NOT collapsed - the inferred output shape stays `(0, 0, 3)` The existing happy-path test `ReshapeFusion_Contiguous_Reshape` (added in #22494) is unaffected — its inferred output shape `(2, 1, 64, 32)` contains no zero dims, so the new guard does not trigger. ### Provenance `FuseContiguousReshapes` was introduced in #22494 (Feb 2025). The bug has been latent in `main` since then. ### Motivation and Context Found while reviewing microsoft/onnxscript#2907 — the rewriter rule under test there is semantically correct, but its numerical-equivalence check using ORT as the oracle fails because of this fusion bug. Fixes #28348. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Two unrelated CI failures: ORT-nightly incorrectly folds consecutive
Reshapeops withallowzero=1, causing a shape mismatch between the original and rewritten models; torch-nightly exposes float16 precision gaps inlogsumexpthat exceed default tolerances.Changes
_basic_rules_test.py—test_reshape_reshape_dynamic_rule: switch numerical comparison touse_reference=True(ONNX reference evaluator) instead of ORT. ORT's graph optimizer incorrectly foldsReshape(Reshape(x, s1), s2, allowzero=1)when the intermediate has zero-sized dimensions, producing wrong output shapes ((1,0,6,3)vs the correct(1,0,0,3)). The rewrite rule itself is semantically correct; ORT is the broken oracle.ops_test_data.py— addtolerance={torch.float16: (2e-2, 1e-4)}to thelogsumexpentry. Observed float16 drift in torch-nightly: absolute Δ ≈ 2.3e-5 (limit 1e-5), relative Δ ≈ 2.3e-3 (limit 1e-3).