Skip to content

ReshapeFusion drops allowzero, producing wrong shape when inferred intermediate has 0-sized dim #28348

@titaiwangms

Description

@titaiwangms

Describe the issue

ReshapeFusion::FuseContiguousReshapes fuses a chain of contiguous Reshape (and Squeeze / Unsqueeze) nodes into a single Reshape whose shape is taken verbatim from the inferred output shape of the last node in the chain. The new Reshape is created without an allowzero attribute, so it defaults to allowzero = 0.

When the inferred shape contains a literal 0 dimension, the fused Reshape then interprets that 0 as "copy the corresponding dim from the input tensor" — but the input tensor here is the original input of the first reshape in the chain, which generally has unrelated dims. The result is a wrong output shape (and silently wrong outputs), accompanied by a benign-looking MergeShapeInfo warning.

Repro

import numpy as np, onnx, onnxruntime as ort, onnx.reference
from onnx import helper, TensorProto

X  = helper.make_tensor_value_info("X", TensorProto.FLOAT, [0, 6, 2])
Y  = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None])
s1 = helper.make_tensor("s1", TensorProto.INT64, [3], [3, 2, -1])
s2 = helper.make_tensor("s2", TensorProto.INT64, [3], [0, 0, 3])

n1 = helper.make_node("Reshape", ["X",   "s1"], ["mid"])
n2 = helper.make_node("Reshape", ["mid", "s2"], ["Y"], allowzero=1)
m  = helper.make_model(helper.make_graph([n1, n2], "g", [X], [Y], initializer=[s1, s2]),
                       opset_imports=[helper.make_opsetid("", 18)])

inp = np.random.default_rng(7).random((0, 6, 2), dtype=np.float32)
print("REF:", onnx.reference.ReferenceEvaluator(m).run(None, {"X": inp})[0].shape)
print("ORT:", ort.InferenceSession(m.SerializeToString(),
                                   providers=["CPUExecutionProvider"]).run(None, {"X": inp})[0].shape)

Output on main (55f8234) and on the 1.26.0 release build:

REF: (0, 0, 3)
[W ... graph.cc:122 MergeShapeInfo] Error merging shape info for output. 'Y' source:{0,6,3} target:{0,0,3}. Falling back to lenient merge.
ORT: (0, 6, 3)

Root cause

onnxruntime/core/optimizer/reshape_fusion.cc::FuseContiguousReshapes (around line 496):

Node& reshape_node = graph.AddNode(
    graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name,
    {contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg},
    {contiguous_reshapes.back().get().MutableOutputDefs()[0]},
    reshape);   // 5th arg is the layering-annotation source, NOT attribute source

The 6-arg AddNode overload only copies the layering annotation from reshape; it does not copy attributes. So the new Reshape is created with no allowzero, defaulting to 0. shape_value is the fully resolved output shape from shape inference (the tensor_shape.Size() == -1 early-out guarantees no symbolic dims) and may legitimately contain literal 0 dims — those are then misinterpreted.

Suggested fix

Since shape_value is the fully-resolved target output shape, the fused Reshape should be invariant to allowzero semantics — set allowzero = 1 on the fused node so any 0 in the inferred shape is interpreted literally. PR follows.

Found via

While reviewing microsoft/onnxscript#2907 — the rewriter rule reshape_reshape_rule in onnxscript is semantically correct, but its numerical-equivalence test using ORT as an oracle fails because ORT's own fusion is wrong on inputs with zero-sized intermediates.

To reproduce

See repro above.

Urgency

Low–medium. Wrong silent output for any model whose graph contains contiguous reshapes where shape inference yields a literal-0 intermediate dim.

Platform

Linux

OS Version

Ubuntu

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

main @ 55f8234 (also reproduces on 1.26.0)

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions