"""Minimal reproduction of onnxscript ReshapeReshape rewrite bug.
When two Reshape->Reshape chains share the same shape constant containing -1,
the ReshapeReshape rule resolves -1 to different concrete values per match,
but creates new initializers reusing the original name (name=shape.name).
Both rewrites write to graph.initializers["shared_shape"], so only the last
one survives. NameFixPass then renames the orphaned Value to "shared_shape_1"
to ensure uniqueness, but does NOT register it as an initializer. The result:
a Reshape node references "shared_shape_1" which is neither a graph input nor
an initializer nor a node output -> toposort validation failure.
Bug location: onnxscript/rewriter/rules/common/_basic_rules.py
class ReshapeReshape.rewrite():
new_shape = op.initializer(ir.Tensor(self._new_shape, name=shape.name))
^^^^^^^^^^^^^^
Reusing the shared input name causes both rewrites to clobber the same
initializer dict entry. The loser becomes an unregistered dangling Value.
Fix: generate a unique name instead of reusing the shared input name, e.g.:
new_shape = op.initializer(ir.Tensor(self._new_shape))
"""
import logging
import numpy as np
import onnx
import onnx_ir as ir
from onnxscript import rewriter
from onnxscript.optimizer import _constant_folding, common_passes
# Suppress onnx_ir deserialization warnings (known from_onnx_text limitation:
# initializers are attached after parsing, so their names aren't in scope yet).
logging.getLogger("onnx_ir").setLevel(logging.ERROR)
MODEL_TEXT = """\
<ir_version: 9, opset_import: ["" : 21]>
test (float[2, 3] input1, float[2, 6] input2) => (float[2, 3] out1, float[2, 6] out2) {
mid1 = Reshape (input1, shape_mid_a)
mid2 = Reshape (input2, shape_mid_b)
out1 = Reshape (mid1, shared_shape)
out2 = Reshape (mid2, shared_shape)
}
"""
def build_model() -> ir.Model:
"""Two Reshape->Reshape chains sharing shape=[2, -1].
input1 [2,3] -> Reshape([6]) -> Reshape([2,-1]) -> out1 [2,3] (-1 resolves to 3)
input2 [2,6] -> Reshape([12]) -> Reshape([2,-1]) -> out2 [2,6] (-1 resolves to 6)
"""
return ir.from_onnx_text(
MODEL_TEXT,
initializers=[
ir.Tensor(np.array([6], dtype=np.int64), name="shape_mid_a"),
ir.Tensor(np.array([12], dtype=np.int64), name="shape_mid_b"),
ir.Tensor(np.array([2, -1], dtype=np.int64), name="shared_shape"),
],
)
def reproduce():
model = build_model()
# Sanity check: original model is valid
onnx.checker.check_model(ir.to_proto(model), full_check=True)
# Standard onnxscript optimize_ir sub-pipeline
_constant_folding.FoldConstantsPass(
shape_inference=True, input_size_limit=1024, output_size_limit=1024
)(model)
rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES)(model)
common_passes.RemoveUnusedNodesPass()(model)
common_passes.LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=0)(model)
common_passes.DeduplicateInitializersPass()(model)
# Diagnosis: show each Reshape's shape input and whether it's registered
print("After optimization (IR):")
for node in model.graph:
if node.op_type == "Reshape":
s = node.inputs[1]
registered = s.name in model.graph.initializers
print(
f" {node.name}: shape={s.const_value.numpy()}, "
f"name='{s.name}', registered_initializer={registered}"
)
# Serialize and show what survives
out_proto = ir.to_proto(model)
print("\nSerialized initializers:")
for init in out_proto.graph.initializer:
print(f" {init.name}: {onnx.numpy_helper.to_array(init)}")
# onnx checker catches the dangling reference
print("\nonnx.checker:")
try:
onnx.checker.check_model(out_proto, full_check=True)
print(" PASS (bug not triggered)")
except Exception as e:
print(f" FAIL: {e}")
if __name__ == "__main__":
reproduce()
After optimization (IR):
node_Reshape_0: shape=[2 3], name='shared_shape_1', registered_initializer=False
node_Reshape_1: shape=[2 6], name='shared_shape', registered_initializer=True
Serialized initializers:
shared_shape: [2 6]
onnx.checker:
FAIL: Nodes in a graph must be topologically sorted, however input 'shared_shape_1' of node:
name: node_Reshape_0 OpType: Reshape
is not output of any previous nodes.
repro script:
The output on main branch