Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,16 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
output = node.outputs[0]
if input is not None and output is not None:
# NOTE: backward shape inference
input.shape = _merge_shapes(input.shape, output.shape)
try:
input.shape = _merge_shapes(input.shape, output.shape)
except Exception as e:
logger.warning(
"[Constant folder] Cannot merge shapes on Identity node '%s' "
"(folded from: %s) because of error: %s",
node.name,
input.meta.get(FOLDED_FROM_KEY, set()),
e,
)
if input.type is None:
input.type = output.type
state.set_sym_value(output, input)
Expand Down Expand Up @@ -919,7 +928,9 @@ def merge_dims(dim1, dim2):
if other_shape is None:
return preferred_shape
if len(preferred_shape) != len(other_shape):
raise ValueError("Shapes must have the same rank.")
raise ValueError(
f"Shapes must have the same rank, got preferred_shape={preferred_shape}, other_shape={other_shape}"
)
return ir.Shape(
[merge_dims(dim1, dim2) for dim1, dim2 in zip(preferred_shape, other_shape)]
)
Expand Down Expand Up @@ -1035,7 +1046,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
except Exception as e:
logger.debug(
"Skipping shape inference for node %r due to exception: %s",
node.name,
node,
e,
)

Expand Down Expand Up @@ -1124,7 +1135,12 @@ def process_node(self, node: ir.Node) -> Replacement | None:
for optimizer in op_optimizers:
assert optimizer
context = RewriterContext()
output = optimizer(node, context, self._state)
try:
output = optimizer(node, context, self._state)
except Exception as e:
raise RuntimeError(
f"Error during constant folding for node {node.name!r} ({node.domain}::{node.op_type})"
) from e
if output is not None:
if isinstance(output, Replacement):
return output
Expand Down
Loading