Skip to content

Bug for reshape->reshape pattern which include -1 #2945

Description

@OYCN

repro script:

"""Minimal reproduction of onnxscript ReshapeReshape rewrite bug.

When two Reshape->Reshape chains share the same shape constant containing -1,
the ReshapeReshape rule resolves -1 to different concrete values per match,
but creates new initializers reusing the original name (name=shape.name).

Both rewrites write to graph.initializers["shared_shape"], so only the last
one survives. NameFixPass then renames the orphaned Value to "shared_shape_1"
to ensure uniqueness, but does NOT register it as an initializer. The result:
a Reshape node references "shared_shape_1" which is neither a graph input nor
an initializer nor a node output -> toposort validation failure.

Bug location: onnxscript/rewriter/rules/common/_basic_rules.py
    class ReshapeReshape.rewrite():
        new_shape = op.initializer(ir.Tensor(self._new_shape, name=shape.name))
                                                              ^^^^^^^^^^^^^^
    Reusing the shared input name causes both rewrites to clobber the same
    initializer dict entry. The loser becomes an unregistered dangling Value.

Fix: generate a unique name instead of reusing the shared input name, e.g.:
        new_shape = op.initializer(ir.Tensor(self._new_shape))
"""

import logging

import numpy as np
import onnx
import onnx_ir as ir
from onnxscript import rewriter
from onnxscript.optimizer import _constant_folding, common_passes

# Suppress onnx_ir deserialization warnings (known from_onnx_text limitation:
# initializers are attached after parsing, so their names aren't in scope yet).
logging.getLogger("onnx_ir").setLevel(logging.ERROR)

MODEL_TEXT = """\
<ir_version: 9, opset_import: ["" : 21]>
test (float[2, 3] input1, float[2, 6] input2) => (float[2, 3] out1, float[2, 6] out2) {
    mid1 = Reshape (input1, shape_mid_a)
    mid2 = Reshape (input2, shape_mid_b)
    out1 = Reshape (mid1, shared_shape)
    out2 = Reshape (mid2, shared_shape)
}
"""


def build_model() -> ir.Model:
    """Two Reshape->Reshape chains sharing shape=[2, -1].

    input1 [2,3] -> Reshape([6]) -> Reshape([2,-1]) -> out1 [2,3]  (-1 resolves to 3)
    input2 [2,6] -> Reshape([12]) -> Reshape([2,-1]) -> out2 [2,6] (-1 resolves to 6)
    """
    return ir.from_onnx_text(
        MODEL_TEXT,
        initializers=[
            ir.Tensor(np.array([6], dtype=np.int64), name="shape_mid_a"),
            ir.Tensor(np.array([12], dtype=np.int64), name="shape_mid_b"),
            ir.Tensor(np.array([2, -1], dtype=np.int64), name="shared_shape"),
        ],
    )


def reproduce():
    model = build_model()

    # Sanity check: original model is valid
    onnx.checker.check_model(ir.to_proto(model), full_check=True)

    # Standard onnxscript optimize_ir sub-pipeline
    _constant_folding.FoldConstantsPass(
        shape_inference=True, input_size_limit=1024, output_size_limit=1024
    )(model)
    rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES)(model)
    common_passes.RemoveUnusedNodesPass()(model)
    common_passes.LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=0)(model)
    common_passes.DeduplicateInitializersPass()(model)

    # Diagnosis: show each Reshape's shape input and whether it's registered
    print("After optimization (IR):")
    for node in model.graph:
        if node.op_type == "Reshape":
            s = node.inputs[1]
            registered = s.name in model.graph.initializers
            print(
                f"  {node.name}: shape={s.const_value.numpy()}, "
                f"name='{s.name}', registered_initializer={registered}"
            )

    # Serialize and show what survives
    out_proto = ir.to_proto(model)
    print("\nSerialized initializers:")
    for init in out_proto.graph.initializer:
        print(f"  {init.name}: {onnx.numpy_helper.to_array(init)}")

    # onnx checker catches the dangling reference
    print("\nonnx.checker:")
    try:
        onnx.checker.check_model(out_proto, full_check=True)
        print("  PASS (bug not triggered)")
    except Exception as e:
        print(f"  FAIL: {e}")


if __name__ == "__main__":
    reproduce()

The output on main branch

After optimization (IR):
  node_Reshape_0: shape=[2 3], name='shared_shape_1', registered_initializer=False
  node_Reshape_1: shape=[2 6], name='shared_shape', registered_initializer=True

Serialized initializers:
  shared_shape: [2 6]

onnx.checker:
  FAIL: Nodes in a graph must be topologically sorted, however input 'shared_shape_1' of node: 
name: node_Reshape_0 OpType: Reshape
 is not output of any previous nodes.

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions