diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index dfb072417a..2c6d9b46ff 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -602,7 +602,16 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue: output = node.outputs[0] if input is not None and output is not None: # NOTE: backward shape inference - input.shape = _merge_shapes(input.shape, output.shape) + try: + input.shape = _merge_shapes(input.shape, output.shape) + except Exception as e: + logger.warning( + "[Constant folder] Cannot merge shapes on Identity node '%s' " + "(folded from: %s) because of error: %s", + node.name, + input.meta.get(FOLDED_FROM_KEY, set()), + e, + ) if input.type is None: input.type = output.type state.set_sym_value(output, input) @@ -919,7 +928,9 @@ def merge_dims(dim1, dim2): if other_shape is None: return preferred_shape if len(preferred_shape) != len(other_shape): - raise ValueError("Shapes must have the same rank.") + raise ValueError( + f"Shapes must have the same rank, got preferred_shape={preferred_shape}, other_shape={other_shape}" + ) return ir.Shape( [merge_dims(dim1, dim2) for dim1, dim2 in zip(preferred_shape, other_shape)] ) @@ -1035,7 +1046,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: except Exception as e: logger.debug( "Skipping shape inference for node %r due to exception: %s", - node.name, + node, e, ) @@ -1124,7 +1135,12 @@ def process_node(self, node: ir.Node) -> Replacement | None: for optimizer in op_optimizers: assert optimizer context = RewriterContext() - output = optimizer(node, context, self._state) + try: + output = optimizer(node, context, self._state) + except Exception as e: + raise RuntimeError( + f"Error during constant folding for node {node.name!r} ({node.domain}::{node.op_type})" + ) from e if output is not None: if isinstance(output, Replacement): return output