From 7aa2bb6dcf038d93b75af6edef1098b10906a9b1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 11 Jul 2025 21:58:14 +0000 Subject: [PATCH 01/25] Initial plan From 1f758586eeda5ed81a4a326d19e7d0ad65c2f856 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 11 Jul 2025 22:13:04 +0000 Subject: [PATCH 02/25] Implement NameFixPass for ensuring unique names Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- src/onnx_ir/passes/common/__init__.py | 2 + src/onnx_ir/passes/common/naming.py | 255 ++++++++++++++ src/onnx_ir/passes/common/naming_test.py | 410 +++++++++++++++++++++++ 3 files changed, 667 insertions(+) create mode 100644 src/onnx_ir/passes/common/naming.py create mode 100644 src/onnx_ir/passes/common/naming_test.py diff --git a/src/onnx_ir/passes/common/__init__.py b/src/onnx_ir/passes/common/__init__.py index 2aee4df7..a8d73175 100644 --- a/src/onnx_ir/passes/common/__init__.py +++ b/src/onnx_ir/passes/common/__init__.py @@ -11,6 +11,7 @@ "InlinePass", "LiftConstantsToInitializersPass", "LiftSubgraphInitializersToMainGraphPass", + "NameFixPass", "RemoveInitializersFromInputsPass", "RemoveUnusedFunctionsPass", "RemoveUnusedNodesPass", @@ -38,6 +39,7 @@ DeduplicateInitializersPass, ) from onnx_ir.passes.common.inliner import InlinePass +from onnx_ir.passes.common.naming import NameFixPass from onnx_ir.passes.common.onnx_checker import CheckerPass from onnx_ir.passes.common.shape_inference import ShapeInferencePass from onnx_ir.passes.common.topological_sort import TopologicalSortPass diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py new file mode 100644 index 00000000..09ccfe98 --- /dev/null +++ b/src/onnx_ir/passes/common/naming.py @@ -0,0 +1,255 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Name fix pass for ensuring unique names for all values and nodes.""" + +from __future__ import annotations + +__all__ = [ + "NameFixPass", +] + +import logging +from collections.abc import Set as AbstractSet + +import onnx_ir as ir + +logger = logging.getLogger(__name__) + + +class NameFixPass(ir.passes.InPlacePass): + """Pass for fixing names to ensure all values and nodes have unique names. + + This pass ensures that: + 1. Graph inputs and outputs have unique names (take precedence) + 2. All intermediate values have unique names (assign names to unnamed values) + 3. All values in subgraphs have unique names + 4. All nodes have unique names (assign names to unnamed nodes) + + The pass maintains global uniqueness across the entire model. + """ + + def call(self, model: ir.Model) -> ir.passes.PassResult: + """Main entry point for the name fix pass.""" + modified = False + + # Use sets to track seen names globally + seen_value_names: set[str] = set() + seen_node_names: set[str] = set() + + # Counters for generating unique names (using list to pass by reference) + value_counter = [0] + node_counter = [0] + + # Process the main graph + if self._fix_graph_names( + model.graph, seen_value_names, seen_node_names, value_counter, node_counter + ): + modified = True + + # Process functions + for function in model.functions.values(): + if self._fix_function_names( + function, seen_value_names, seen_node_names, value_counter, node_counter + ): + modified = True + + if modified: + logger.info("Name fix pass modified the model") + + return ir.passes.PassResult(model, modified=modified) + + + def _fix_graph_names( + self, + graph: ir.Graph, + seen_value_names: set[str], + seen_node_names: set[str], + value_counter: list[int], + node_counter: list[int], + ) -> bool: + """Fix names in a graph and return whether modifications were made.""" + modified = False + + # Keep track of values we've already processed to avoid double-processing + processed_values: set[ir.Value] = set() + + # Step 1: Fix graph input names first (they have precedence) + for input_value in graph.inputs: + if self._process_value(input_value, seen_value_names, value_counter, processed_values): + modified = True + + # Step 2: Fix graph output names (they have precedence) + for output_value in graph.outputs: + if self._process_value(output_value, seen_value_names, value_counter, processed_values): + modified = True + + # Step 3: Fix initializer names + for initializer in graph.initializers.values(): + if self._process_value(initializer, seen_value_names, value_counter, processed_values): + modified = True + + # Step 4: Process all nodes and their values + for node in ir.traversal.RecursiveGraphIterator(graph): + # Fix node name + if node.name is None or node.name == "": + if self._assign_node_name(node, seen_node_names, node_counter): + modified = True + else: + if self._fix_duplicate_node_name(node, seen_node_names): + modified = True + + # Fix input value names (only if not already processed) + for input_value in node.inputs: + if input_value is not None: + if self._process_value(input_value, seen_value_names, value_counter, processed_values): + modified = True + + # Fix output value names (only if not already processed) + for output_value in node.outputs: + if self._process_value(output_value, seen_value_names, value_counter, processed_values): + modified = True + + return modified + + def _fix_function_names( + self, + function: ir.Function, + seen_value_names: set[str], + seen_node_names: set[str], + value_counter: list[int], + node_counter: list[int], + ) -> bool: + """Fix names in a function and return whether modifications were made.""" + modified = False + + # Keep track of values we've already processed to avoid double-processing + processed_values: set[ir.Value] = set() + + # Process function inputs first (they have precedence) + for input_value in function.inputs: + if self._process_value(input_value, seen_value_names, value_counter, processed_values): + modified = True + + # Process function outputs (they have precedence) + for output_value in function.outputs: + if self._process_value(output_value, seen_value_names, value_counter, processed_values): + modified = True + + # Process all nodes and their values + for node in ir.traversal.RecursiveGraphIterator(function): + # Fix node name + if node.name is None or node.name == "": + if self._assign_node_name(node, seen_node_names, node_counter): + modified = True + else: + if self._fix_duplicate_node_name(node, seen_node_names): + modified = True + + # Fix input value names (only if not already processed) + for input_value in node.inputs: + if input_value is not None: + if self._process_value(input_value, seen_value_names, value_counter, processed_values): + modified = True + + # Fix output value names (only if not already processed) + for output_value in node.outputs: + if self._process_value(output_value, seen_value_names, value_counter, processed_values): + modified = True + + return modified + + def _process_value( + self, + value: ir.Value, + seen_value_names: set[str], + value_counter: list[int], + processed_values: set[ir.Value] + ) -> bool: + """Process a value only if it hasn't been processed before.""" + if value in processed_values: + return False + + processed_values.add(value) + + if value.name is None or value.name == "": + return self._assign_value_name(value, seen_value_names, value_counter) + else: + return self._fix_duplicate_value_name(value, seen_value_names) + + def _assign_value_name( + self, value: ir.Value, seen_names: set[str], counter: list[int] + ) -> bool: + """Assign a name to an unnamed value. Returns True if modified.""" + while True: + new_name = f"val_{counter[0]}" + counter[0] += 1 + if new_name not in seen_names: + value.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed value", new_name) + return True + + def _assign_node_name( + self, node: ir.Node, seen_names: set[str], counter: list[int] + ) -> bool: + """Assign a name to an unnamed node. Returns True if modified.""" + while True: + new_name = f"node_{counter[0]}" + counter[0] += 1 + if new_name not in seen_names: + node.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed node", new_name) + return True + + def _fix_duplicate_value_name( + self, value: ir.Value, seen_names: set[str] + ) -> bool: + """Fix a value's name if it conflicts with existing names. Returns True if modified.""" + original_name = value.name + + if original_name is None or original_name == "": + return False # Should not happen if called correctly + + # If name is already seen, make it unique + if original_name in seen_names: + base_name = original_name + suffix = 1 + while True: + new_name = f"{base_name}_{suffix}" + if new_name not in seen_names: + value.name = new_name + seen_names.add(new_name) + logger.debug("Renamed value from %s to %s for uniqueness", original_name, new_name) + return True + suffix += 1 + else: + # Name is unique, just record it + seen_names.add(original_name) + return False + + def _fix_duplicate_node_name( + self, node: ir.Node, seen_names: set[str] + ) -> bool: + """Fix a node's name if it conflicts with existing names. Returns True if modified.""" + original_name = node.name + + if original_name is None or original_name == "": + return False # Should not happen if called correctly + + # If name is already seen, make it unique + if original_name in seen_names: + base_name = original_name + suffix = 1 + while True: + new_name = f"{base_name}_{suffix}" + if new_name not in seen_names: + node.name = new_name + seen_names.add(new_name) + logger.debug("Renamed node from %s to %s for uniqueness", original_name, new_name) + return True + suffix += 1 + else: + # Name is unique, just record it + seen_names.add(original_name) + return False \ No newline at end of file diff --git a/src/onnx_ir/passes/common/naming_test.py b/src/onnx_ir/passes/common/naming_test.py new file mode 100644 index 00000000..5eb7863f --- /dev/null +++ b/src/onnx_ir/passes/common/naming_test.py @@ -0,0 +1,410 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the name fix pass.""" + +from __future__ import annotations + +import unittest + +import onnx_ir as ir +from onnx_ir.passes.common import naming + + +class TestNameFixPass(unittest.TestCase): + """Test cases for NameFixPass.""" + + def test_assign_names_to_unnamed_values(self): + """Test ensuring all values have names even if IR auto-assigned them.""" + # Create a simple model with auto-assigned names + input_value = ir.Input( + None, shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) # Will get auto-assigned name when added to graph + + # Create Add node + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify IR has auto-assigned names + self.assertIsNotNone(input_value.name) + self.assertIsNotNone(add_node.outputs[0].name) + + # Store original names + original_input_name = input_value.name + original_output_name = add_node.outputs[0].name + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass didn't modify anything (names were already assigned and unique) + self.assertFalse(result.modified) + + # Verify names remain the same + self.assertEqual(input_value.name, original_input_name) + self.assertEqual(add_node.outputs[0].name, original_output_name) + + def test_assign_names_to_unnamed_nodes(self): + """Test ensuring all nodes have names even if IR auto-assigned them.""" + # Create a simple model + input_value = ir.Input( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create Add node - IR will auto-assign name when added to graph + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "output" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify IR has auto-assigned node name + self.assertIsNotNone(add_node.name) + original_node_name = add_node.name + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass didn't modify anything (node already had unique name) + self.assertFalse(result.modified) + + # Verify node name remains the same + self.assertEqual(add_node.name, original_node_name) + + def test_assigns_names_when_truly_unnamed(self): + """Test that the pass assigns names when values/nodes are created without names and manually cleared.""" + # Create a model and manually clear names to test assignment + input_value = ir.Input( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "output" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Manually clear some names to test assignment + add_node.name = None + add_node.outputs[0].name = "" + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify names were assigned + self.assertIsNotNone(add_node.name) + self.assertIsNotNone(add_node.outputs[0].name) + self.assertNotEqual(add_node.outputs[0].name, "") + + def test_handles_global_uniqueness_across_subgraphs(self): + """Test that names are unique globally, including across subgraphs.""" + # Create main graph input + main_input = ir.Input( + "main_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a simple subgraph for an If node + # Subgraph input and output (with potential name conflicts) + sub_input = ir.Input( + "main_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) # Same name as main input - should cause conflict + + sub_add_node = ir.Node("", "Add", inputs=[sub_input, sub_input]) + sub_add_node.outputs[0].name = "main_input" # Another conflict + sub_add_node.outputs[0].shape = sub_input.shape + sub_add_node.outputs[0].type = sub_input.type + + subgraph = ir.Graph( + inputs=[sub_input], + outputs=[sub_add_node.outputs[0]], + nodes=[sub_add_node], + name="subgraph", + ) + + # Create condition input for If node + condition_input = ir.Input( + "condition", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.BOOL) + ) + + # Create If node with subgraph + if_node = ir.Node( + "", "If", + inputs=[condition_input], + attributes={"then_branch": ir.Attr("then_branch", ir.AttributeType.GRAPH, subgraph)} + ) + if_node.outputs[0].name = "if_output" + if_node.outputs[0].shape = main_input.shape + if_node.outputs[0].type = main_input.type + + # Create main graph + main_graph = ir.Graph( + inputs=[main_input, condition_input], + outputs=[if_node.outputs[0]], + nodes=[if_node], + name="main_graph", + ) + + model = ir.Model(main_graph, ir_version=10) + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass was applied (should fix duplicates) + self.assertTrue(result.modified) + + # Collect all value names to verify uniqueness + all_value_names = set() + + # Main graph values + for input_val in main_graph.inputs: + self.assertIsNotNone(input_val.name) + self.assertNotIn(input_val.name, all_value_names, f"Duplicate value name: {input_val.name}") + all_value_names.add(input_val.name) + + for output_val in main_graph.outputs: + self.assertIsNotNone(output_val.name) + if output_val.name not in all_value_names: # Could be same as input + all_value_names.add(output_val.name) + + # Node values in main graph + for node in main_graph: + for input_val in node.inputs: + if input_val is not None: + if input_val.name not in all_value_names: + all_value_names.add(input_val.name) + for output_val in node.outputs: + if output_val.name not in all_value_names: + all_value_names.add(output_val.name) + + # Subgraph values + for input_val in subgraph.inputs: + self.assertIsNotNone(input_val.name) + self.assertNotIn(input_val.name, all_value_names, f"Duplicate value name in subgraph: {input_val.name}") + all_value_names.add(input_val.name) + + for output_val in subgraph.outputs: + if output_val.name not in all_value_names: # Could be same as input + all_value_names.add(output_val.name) + + # Node values in subgraph + for node in subgraph: + for input_val in node.inputs: + if input_val is not None: + if input_val.name not in all_value_names: + all_value_names.add(input_val.name) + for output_val in node.outputs: + if output_val.name not in all_value_names: + all_value_names.add(output_val.name) + + # Verify main_input keeps its name (has precedence as graph input) + self.assertEqual(main_input.name, "main_input") + + def test_handle_duplicate_value_names(self): + """Test handling duplicate value names by making them unique.""" + # Create values with duplicate names + input1 = ir.Input( + "duplicate_name", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + input2 = ir.Input( + "duplicate_name", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + add_node = ir.Node("", "Add", inputs=[input1, input2]) + add_node.outputs[0].name = "output" + add_node.outputs[0].shape = input1.shape + add_node.outputs[0].type = input1.type + + graph = ir.Graph( + inputs=[input1, input2], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify both inputs have the same name initially + self.assertEqual(input1.name, "duplicate_name") + self.assertEqual(input2.name, "duplicate_name") + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify names are now unique + self.assertNotEqual(input1.name, input2.name) + # One should keep the original name, the other should have a suffix + names = {input1.name, input2.name} + self.assertIn("duplicate_name", names) + self.assertTrue( + "duplicate_name_1" in names, + f"Expected 'duplicate_name_1' in {names}" + ) + + def test_handle_duplicate_node_names(self): + """Test handling duplicate node names by making them unique.""" + input_value = ir.Input( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create nodes with duplicate names + add_node1 = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node1.name = "duplicate_node" + add_node1.outputs[0].name = "output1" + add_node1.outputs[0].shape = input_value.shape + add_node1.outputs[0].type = input_value.type + + add_node2 = ir.Node("", "Add", inputs=[input_value, add_node1.outputs[0]]) + add_node2.name = "duplicate_node" # Same name as first node + add_node2.outputs[0].name = "output2" + add_node2.outputs[0].shape = input_value.shape + add_node2.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node2.outputs[0]], + nodes=[add_node1, add_node2], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify both nodes have the same name initially + self.assertEqual(add_node1.name, "duplicate_node") + self.assertEqual(add_node2.name, "duplicate_node") + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify names are now unique + self.assertNotEqual(add_node1.name, add_node2.name) + # One should keep the original name, the other should have a suffix + names = {add_node1.name, add_node2.name} + self.assertIn("duplicate_node", names) + self.assertTrue( + "duplicate_node_1" in names, + f"Expected 'duplicate_node_1' in {names}" + ) + + def test_no_modification_when_all_names_unique(self): + """Test that the pass doesn't modify anything when all names are already unique.""" + input_value = ir.Input( + "unique_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.name = "unique_node" + add_node.outputs[0].name = "unique_output" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Store original names + original_input_name = input_value.name + original_node_name = add_node.name + original_output_name = add_node.outputs[0].name + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass didn't modify anything + self.assertFalse(result.modified) + + # Verify names remain unchanged + self.assertEqual(input_value.name, original_input_name) + self.assertEqual(add_node.name, original_node_name) + self.assertEqual(add_node.outputs[0].name, original_output_name) + + def test_graph_inputs_outputs_have_precedence(self): + """Test that graph inputs and outputs keep their names when there are conflicts.""" + # Create an input with a specific name + input_value = ir.Input( + "important_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a node that produces an intermediate value with the same name + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "important_input" # Conflicts with input name + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + # Create another node that uses the intermediate value and produces the final output + mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], input_value]) + mul_node.outputs[0].name = "important_output" + mul_node.outputs[0].shape = input_value.shape + mul_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[mul_node.outputs[0]], + nodes=[add_node, mul_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify input keeps its original name (has precedence) + self.assertEqual(input_value.name, "important_input") + + # Verify output keeps its original name (has precedence) + self.assertEqual(mul_node.outputs[0].name, "important_output") + + # Verify intermediate value got renamed to avoid conflict + self.assertNotEqual(add_node.outputs[0].name, "important_input") + self.assertTrue(add_node.outputs[0].name.startswith("important_input_")) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From b979a3d444b274c3d078b82d655e7b52350ad5dd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 11 Jul 2025 22:22:57 +0000 Subject: [PATCH 03/25] Replace processed values set with value-to-name mapping for clearer tracking Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- src/onnx_ir/passes/common/naming.py | 50 +++++++++++++++-------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 09ccfe98..c2f86d50 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -36,20 +36,23 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: seen_value_names: set[str] = set() seen_node_names: set[str] = set() + # Dictionary to track which values have been assigned names + value_to_name: dict[ir.Value, str] = {} + # Counters for generating unique names (using list to pass by reference) value_counter = [0] node_counter = [0] # Process the main graph if self._fix_graph_names( - model.graph, seen_value_names, seen_node_names, value_counter, node_counter + model.graph, seen_value_names, seen_node_names, value_to_name, value_counter, node_counter ): modified = True # Process functions for function in model.functions.values(): if self._fix_function_names( - function, seen_value_names, seen_node_names, value_counter, node_counter + function, seen_value_names, seen_node_names, value_to_name, value_counter, node_counter ): modified = True @@ -64,28 +67,26 @@ def _fix_graph_names( graph: ir.Graph, seen_value_names: set[str], seen_node_names: set[str], + value_to_name: dict[ir.Value, str], value_counter: list[int], node_counter: list[int], ) -> bool: """Fix names in a graph and return whether modifications were made.""" modified = False - - # Keep track of values we've already processed to avoid double-processing - processed_values: set[ir.Value] = set() # Step 1: Fix graph input names first (they have precedence) for input_value in graph.inputs: - if self._process_value(input_value, seen_value_names, value_counter, processed_values): + if self._process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True # Step 2: Fix graph output names (they have precedence) for output_value in graph.outputs: - if self._process_value(output_value, seen_value_names, value_counter, processed_values): + if self._process_value(output_value, seen_value_names, value_to_name, value_counter): modified = True # Step 3: Fix initializer names for initializer in graph.initializers.values(): - if self._process_value(initializer, seen_value_names, value_counter, processed_values): + if self._process_value(initializer, seen_value_names, value_to_name, value_counter): modified = True # Step 4: Process all nodes and their values @@ -101,12 +102,12 @@ def _fix_graph_names( # Fix input value names (only if not already processed) for input_value in node.inputs: if input_value is not None: - if self._process_value(input_value, seen_value_names, value_counter, processed_values): + if self._process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True # Fix output value names (only if not already processed) for output_value in node.outputs: - if self._process_value(output_value, seen_value_names, value_counter, processed_values): + if self._process_value(output_value, seen_value_names, value_to_name, value_counter): modified = True return modified @@ -116,23 +117,21 @@ def _fix_function_names( function: ir.Function, seen_value_names: set[str], seen_node_names: set[str], + value_to_name: dict[ir.Value, str], value_counter: list[int], node_counter: list[int], ) -> bool: """Fix names in a function and return whether modifications were made.""" modified = False - - # Keep track of values we've already processed to avoid double-processing - processed_values: set[ir.Value] = set() # Process function inputs first (they have precedence) for input_value in function.inputs: - if self._process_value(input_value, seen_value_names, value_counter, processed_values): + if self._process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True # Process function outputs (they have precedence) for output_value in function.outputs: - if self._process_value(output_value, seen_value_names, value_counter, processed_values): + if self._process_value(output_value, seen_value_names, value_to_name, value_counter): modified = True # Process all nodes and their values @@ -148,12 +147,12 @@ def _fix_function_names( # Fix input value names (only if not already processed) for input_value in node.inputs: if input_value is not None: - if self._process_value(input_value, seen_value_names, value_counter, processed_values): + if self._process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True # Fix output value names (only if not already processed) for output_value in node.outputs: - if self._process_value(output_value, seen_value_names, value_counter, processed_values): + if self._process_value(output_value, seen_value_names, value_to_name, value_counter): modified = True return modified @@ -162,19 +161,22 @@ def _process_value( self, value: ir.Value, seen_value_names: set[str], - value_counter: list[int], - processed_values: set[ir.Value] + value_to_name: dict[ir.Value, str], + value_counter: list[int] ) -> bool: """Process a value only if it hasn't been processed before.""" - if value in processed_values: + if value in value_to_name: return False - processed_values.add(value) - + modified = False if value.name is None or value.name == "": - return self._assign_value_name(value, seen_value_names, value_counter) + modified = self._assign_value_name(value, seen_value_names, value_counter) else: - return self._fix_duplicate_value_name(value, seen_value_names) + modified = self._fix_duplicate_value_name(value, seen_value_names) + + # Record the final name for this value + value_to_name[value] = value.name + return modified def _assign_value_name( self, value: ir.Value, seen_names: set[str], counter: list[int] From 3cbe736d4b0aa268ef8cfb54066b1a5d0ce8c47b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Jul 2025 09:05:17 -0700 Subject: [PATCH 04/25] wip Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 355 ++++++++++++++-------------- 1 file changed, 181 insertions(+), 174 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index c2f86d50..4e0a3033 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -9,7 +9,6 @@ ] import logging -from collections.abc import Set as AbstractSet import onnx_ir as ir @@ -29,30 +28,39 @@ class NameFixPass(ir.passes.InPlacePass): """ def call(self, model: ir.Model) -> ir.passes.PassResult: - """Main entry point for the name fix pass.""" modified = False # Use sets to track seen names globally seen_value_names: set[str] = set() seen_node_names: set[str] = set() - + # Dictionary to track which values have been assigned names value_to_name: dict[ir.Value, str] = {} - + # Counters for generating unique names (using list to pass by reference) value_counter = [0] node_counter = [0] # Process the main graph - if self._fix_graph_names( - model.graph, seen_value_names, seen_node_names, value_to_name, value_counter, node_counter + if _fix_graph_names( + model.graph, + seen_value_names, + seen_node_names, + value_to_name, + value_counter, + node_counter, ): modified = True # Process functions for function in model.functions.values(): - if self._fix_function_names( - function, seen_value_names, seen_node_names, value_to_name, value_counter, node_counter + if _fix_function_names( + function, + seen_value_names, + seen_node_names, + value_to_name, + value_counter, + node_counter, ): modified = True @@ -62,196 +70,195 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: return ir.passes.PassResult(model, modified=modified) - def _fix_graph_names( - self, - graph: ir.Graph, - seen_value_names: set[str], - seen_node_names: set[str], - value_to_name: dict[ir.Value, str], - value_counter: list[int], - node_counter: list[int], - ) -> bool: - """Fix names in a graph and return whether modifications were made.""" - modified = False +def _fix_graph_names( + graph: ir.Graph, + seen_value_names: set[str], + seen_node_names: set[str], + value_to_name: dict[ir.Value, str], + value_counter: list[int], + node_counter: list[int], +) -> bool: + """Fix names in a graph and return whether modifications were made.""" + modified = False + + # Step 1: Fix graph input names first (they have precedence) + for input_value in graph.inputs: + if _process_value(input_value, seen_value_names, value_to_name, value_counter): + modified = True - # Step 1: Fix graph input names first (they have precedence) - for input_value in graph.inputs: - if self._process_value(input_value, seen_value_names, value_to_name, value_counter): - modified = True + # Step 2: Fix graph output names (they have precedence) + for output_value in graph.outputs: + if _process_value(output_value, seen_value_names, value_to_name, value_counter): + modified = True - # Step 2: Fix graph output names (they have precedence) - for output_value in graph.outputs: - if self._process_value(output_value, seen_value_names, value_to_name, value_counter): - modified = True + # Step 3: Fix initializer names + for initializer in graph.initializers.values(): + if _process_value(initializer, seen_value_names, value_to_name, value_counter): + modified = True - # Step 3: Fix initializer names - for initializer in graph.initializers.values(): - if self._process_value(initializer, seen_value_names, value_to_name, value_counter): + # Step 4: Process all nodes and their values + for node in ir.traversal.RecursiveGraphIterator(graph): + # Fix node name + if node.name is None or node.name == "": + if _assign_node_name(node, seen_node_names, node_counter): + modified = True + else: + if _fix_duplicate_node_name(node, seen_node_names): modified = True - # Step 4: Process all nodes and their values - for node in ir.traversal.RecursiveGraphIterator(graph): - # Fix node name - if node.name is None or node.name == "": - if self._assign_node_name(node, seen_node_names, node_counter): - modified = True - else: - if self._fix_duplicate_node_name(node, seen_node_names): + # Fix input value names (only if not already processed) + for input_value in node.inputs: + if input_value is not None: + if _process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True - # Fix input value names (only if not already processed) - for input_value in node.inputs: - if input_value is not None: - if self._process_value(input_value, seen_value_names, value_to_name, value_counter): - modified = True + # Fix output value names (only if not already processed) + for output_value in node.outputs: + if _process_value(output_value, seen_value_names, value_to_name, value_counter): + modified = True - # Fix output value names (only if not already processed) - for output_value in node.outputs: - if self._process_value(output_value, seen_value_names, value_to_name, value_counter): - modified = True + return modified - return modified - - def _fix_function_names( - self, - function: ir.Function, - seen_value_names: set[str], - seen_node_names: set[str], - value_to_name: dict[ir.Value, str], - value_counter: list[int], - node_counter: list[int], - ) -> bool: - """Fix names in a function and return whether modifications were made.""" - modified = False - # Process function inputs first (they have precedence) - for input_value in function.inputs: - if self._process_value(input_value, seen_value_names, value_to_name, value_counter): - modified = True +def _fix_function_names( + function: ir.Function, + seen_value_names: set[str], + seen_node_names: set[str], + value_to_name: dict[ir.Value, str], + value_counter: list[int], + node_counter: list[int], +) -> bool: + """Fix names in a function and return whether modifications were made.""" + modified = False - # Process function outputs (they have precedence) - for output_value in function.outputs: - if self._process_value(output_value, seen_value_names, value_to_name, value_counter): - modified = True + # Process function inputs first (they have precedence) + for input_value in function.inputs: + if _process_value(input_value, seen_value_names, value_to_name, value_counter): + modified = True - # Process all nodes and their values - for node in ir.traversal.RecursiveGraphIterator(function): - # Fix node name - if node.name is None or node.name == "": - if self._assign_node_name(node, seen_node_names, node_counter): - modified = True - else: - if self._fix_duplicate_node_name(node, seen_node_names): - modified = True + # Process function outputs (they have precedence) + for output_value in function.outputs: + if _process_value(output_value, seen_value_names, value_to_name, value_counter): + modified = True - # Fix input value names (only if not already processed) - for input_value in node.inputs: - if input_value is not None: - if self._process_value(input_value, seen_value_names, value_to_name, value_counter): - modified = True + # Process all nodes and their values + for node in ir.traversal.RecursiveGraphIterator(function): + # Fix node name + if node.name is None or node.name == "": + if _assign_node_name(node, seen_node_names, node_counter): + modified = True + else: + if _fix_duplicate_node_name(node, seen_node_names): + modified = True - # Fix output value names (only if not already processed) - for output_value in node.outputs: - if self._process_value(output_value, seen_value_names, value_to_name, value_counter): + # Fix input value names (only if not already processed) + for input_value in node.inputs: + if input_value is not None: + if _process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True - return modified - - def _process_value( - self, - value: ir.Value, - seen_value_names: set[str], - value_to_name: dict[ir.Value, str], - value_counter: list[int] - ) -> bool: - """Process a value only if it hasn't been processed before.""" - if value in value_to_name: - return False - - modified = False - if value.name is None or value.name == "": - modified = self._assign_value_name(value, seen_value_names, value_counter) - else: - modified = self._fix_duplicate_value_name(value, seen_value_names) - - # Record the final name for this value - value_to_name[value] = value.name - return modified - - def _assign_value_name( - self, value: ir.Value, seen_names: set[str], counter: list[int] - ) -> bool: - """Assign a name to an unnamed value. Returns True if modified.""" + # Fix output value names (only if not already processed) + for output_value in node.outputs: + if _process_value(output_value, seen_value_names, value_to_name, value_counter): + modified = True + + return modified + + +def _process_value( + value: ir.Value, + seen_value_names: set[str], + value_to_name: dict[ir.Value, str], + value_counter: list[int], +) -> bool: + """Process a value only if it hasn't been processed before.""" + if value in value_to_name: + return False + + modified = False + if value.name is None or value.name == "": + modified = _assign_value_name(value, seen_value_names, value_counter) + else: + modified = _fix_duplicate_value_name(value, seen_value_names) + + # Record the final name for this value + value_to_name[value] = value.name + return modified + + +def _assign_value_name(value: ir.Value, seen_names: set[str], counter: list[int]) -> bool: + """Assign a name to an unnamed value. Returns True if modified.""" + while True: + new_name = f"val_{counter[0]}" + counter[0] += 1 + if new_name not in seen_names: + value.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed value", new_name) + return True + + +def _assign_node_name(node: ir.Node, seen_names: set[str], counter: list[int]) -> bool: + """Assign a name to an unnamed node. Returns True if modified.""" + while True: + new_name = f"node_{counter[0]}" + counter[0] += 1 + if new_name not in seen_names: + node.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed node", new_name) + return True + + +def _fix_duplicate_value_name(value: ir.Value, seen_names: set[str]) -> bool: + """Fix a value's name if it conflicts with existing names. Returns True if modified.""" + original_name = value.name + + if original_name is None or original_name == "": + return False # Should not happen if called correctly + + # If name is already seen, make it unique + if original_name in seen_names: + base_name = original_name + suffix = 1 while True: - new_name = f"val_{counter[0]}" - counter[0] += 1 + new_name = f"{base_name}_{suffix}" if new_name not in seen_names: value.name = new_name seen_names.add(new_name) - logger.debug("Assigned name %s to unnamed value", new_name) + logger.debug( + "Renamed value from %s to %s for uniqueness", original_name, new_name + ) return True + suffix += 1 + else: + # Name is unique, just record it + seen_names.add(original_name) + return False + + +def _fix_duplicate_node_name(node: ir.Node, seen_names: set[str]) -> bool: + """Fix a node's name if it conflicts with existing names. Returns True if modified.""" + original_name = node.name - def _assign_node_name( - self, node: ir.Node, seen_names: set[str], counter: list[int] - ) -> bool: - """Assign a name to an unnamed node. Returns True if modified.""" + if original_name is None or original_name == "": + return False # Should not happen if called correctly + + # If name is already seen, make it unique + if original_name in seen_names: + base_name = original_name + suffix = 1 while True: - new_name = f"node_{counter[0]}" - counter[0] += 1 + new_name = f"{base_name}_{suffix}" if new_name not in seen_names: node.name = new_name seen_names.add(new_name) - logger.debug("Assigned name %s to unnamed node", new_name) + logger.debug( + "Renamed node from %s to %s for uniqueness", original_name, new_name + ) return True - - def _fix_duplicate_value_name( - self, value: ir.Value, seen_names: set[str] - ) -> bool: - """Fix a value's name if it conflicts with existing names. Returns True if modified.""" - original_name = value.name - - if original_name is None or original_name == "": - return False # Should not happen if called correctly - - # If name is already seen, make it unique - if original_name in seen_names: - base_name = original_name - suffix = 1 - while True: - new_name = f"{base_name}_{suffix}" - if new_name not in seen_names: - value.name = new_name - seen_names.add(new_name) - logger.debug("Renamed value from %s to %s for uniqueness", original_name, new_name) - return True - suffix += 1 - else: - # Name is unique, just record it - seen_names.add(original_name) - return False - - def _fix_duplicate_node_name( - self, node: ir.Node, seen_names: set[str] - ) -> bool: - """Fix a node's name if it conflicts with existing names. Returns True if modified.""" - original_name = node.name - - if original_name is None or original_name == "": - return False # Should not happen if called correctly - - # If name is already seen, make it unique - if original_name in seen_names: - base_name = original_name - suffix = 1 - while True: - new_name = f"{base_name}_{suffix}" - if new_name not in seen_names: - node.name = new_name - seen_names.add(new_name) - logger.debug("Renamed node from %s to %s for uniqueness", original_name, new_name) - return True - suffix += 1 - else: - # Name is unique, just record it - seen_names.add(original_name) - return False \ No newline at end of file + suffix += 1 + else: + # Name is unique, just record it + seen_names.add(original_name) + return False From 42efd784222ddf8a1bd21030aed36e8b5454c744 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Jul 2025 16:46:49 -0700 Subject: [PATCH 05/25] Simplify implementation Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 172 +++++++++++----------------- 1 file changed, 66 insertions(+), 106 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 4e0a3033..1dc2424d 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -54,7 +54,13 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: # Process functions for function in model.functions.values(): - if _fix_function_names( + # Reset seen names and counters for each function + seen_value_names: set[str] = set() + seen_node_names: set[str] = set() + value_to_name: dict[ir.Value, str] = {} + value_counter = [0] + node_counter = [0] + if _fix_graph_names( function, seen_value_names, seen_node_names, @@ -71,7 +77,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: def _fix_graph_names( - graph: ir.Graph, + graph_like: ir.Graph | ir.Function, seen_value_names: set[str], seen_node_names: set[str], value_to_name: dict[ir.Value, str], @@ -82,69 +88,19 @@ def _fix_graph_names( modified = False # Step 1: Fix graph input names first (they have precedence) - for input_value in graph.inputs: + for input_value in graph_like.inputs: if _process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True # Step 2: Fix graph output names (they have precedence) - for output_value in graph.outputs: + for output_value in graph_like.outputs: if _process_value(output_value, seen_value_names, value_to_name, value_counter): modified = True - # Step 3: Fix initializer names - for initializer in graph.initializers.values(): - if _process_value(initializer, seen_value_names, value_to_name, value_counter): - modified = True - - # Step 4: Process all nodes and their values - for node in ir.traversal.RecursiveGraphIterator(graph): - # Fix node name - if node.name is None or node.name == "": - if _assign_node_name(node, seen_node_names, node_counter): - modified = True - else: - if _fix_duplicate_node_name(node, seen_node_names): - modified = True - - # Fix input value names (only if not already processed) - for input_value in node.inputs: - if input_value is not None: - if _process_value(input_value, seen_value_names, value_to_name, value_counter): - modified = True - - # Fix output value names (only if not already processed) - for output_value in node.outputs: - if _process_value(output_value, seen_value_names, value_to_name, value_counter): - modified = True - - return modified - - -def _fix_function_names( - function: ir.Function, - seen_value_names: set[str], - seen_node_names: set[str], - value_to_name: dict[ir.Value, str], - value_counter: list[int], - node_counter: list[int], -) -> bool: - """Fix names in a function and return whether modifications were made.""" - modified = False - - # Process function inputs first (they have precedence) - for input_value in function.inputs: - if _process_value(input_value, seen_value_names, value_to_name, value_counter): - modified = True - - # Process function outputs (they have precedence) - for output_value in function.outputs: - if _process_value(output_value, seen_value_names, value_to_name, value_counter): - modified = True - - # Process all nodes and their values - for node in ir.traversal.RecursiveGraphIterator(function): + # Step 3: Process all nodes and their values. Initializers are processed as node inputs. + for node in ir.traversal.RecursiveGraphIterator(graph_like): # Fix node name - if node.name is None or node.name == "": + if not node.name: if _assign_node_name(node, seen_node_names, node_counter): modified = True else: @@ -176,89 +132,93 @@ def _process_value( return False modified = False - if value.name is None or value.name == "": + if not value.name: modified = _assign_value_name(value, seen_value_names, value_counter) else: modified = _fix_duplicate_value_name(value, seen_value_names) # Record the final name for this value + assert value.name is not None value_to_name[value] = value.name return modified def _assign_value_name(value: ir.Value, seen_names: set[str], counter: list[int]) -> bool: """Assign a name to an unnamed value. Returns True if modified.""" - while True: - new_name = f"val_{counter[0]}" + assert not value.name, ( + "value should not have a name already if function is called correctly" + ) + + new_name = f"val_{counter[0]}" + while new_name in seen_names: counter[0] += 1 - if new_name not in seen_names: - value.name = new_name - seen_names.add(new_name) - logger.debug("Assigned name %s to unnamed value", new_name) - return True + new_name = f"val_{counter[0]}" + + value.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed value", new_name) + return True def _assign_node_name(node: ir.Node, seen_names: set[str], counter: list[int]) -> bool: """Assign a name to an unnamed node. Returns True if modified.""" - while True: - new_name = f"node_{counter[0]}" + assert not node.name, "node should not have a name already if function is called correctly" + + new_name = f"node_{counter[0]}" + + while new_name in seen_names: counter[0] += 1 - if new_name not in seen_names: - node.name = new_name - seen_names.add(new_name) - logger.debug("Assigned name %s to unnamed node", new_name) - return True + new_name = f"node_{counter[0]}" + + node.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed node", new_name) + return True def _fix_duplicate_value_name(value: ir.Value, seen_names: set[str]) -> bool: """Fix a value's name if it conflicts with existing names. Returns True if modified.""" original_name = value.name - if original_name is None or original_name == "": - return False # Should not happen if called correctly + assert original_name, "value should have a name already if function is called correctly" - # If name is already seen, make it unique - if original_name in seen_names: - base_name = original_name - suffix = 1 - while True: - new_name = f"{base_name}_{suffix}" - if new_name not in seen_names: - value.name = new_name - seen_names.add(new_name) - logger.debug( - "Renamed value from %s to %s for uniqueness", original_name, new_name - ) - return True - suffix += 1 - else: + if original_name not in seen_names: # Name is unique, just record it seen_names.add(original_name) return False + # If name is already seen, make it unique + base_name = original_name + suffix = 1 + new_name = base_name + while new_name in seen_names: + new_name = f"{base_name}_{suffix}" + suffix += 1 + value.name = new_name + seen_names.add(new_name) + logger.debug("Renamed value from %s to %s for uniqueness", original_name, new_name) + return True + def _fix_duplicate_node_name(node: ir.Node, seen_names: set[str]) -> bool: """Fix a node's name if it conflicts with existing names. Returns True if modified.""" original_name = node.name - if original_name is None or original_name == "": - return False # Should not happen if called correctly + assert original_name, "node should have a name already if function is called correctly" - # If name is already seen, make it unique - if original_name in seen_names: - base_name = original_name - suffix = 1 - while True: - new_name = f"{base_name}_{suffix}" - if new_name not in seen_names: - node.name = new_name - seen_names.add(new_name) - logger.debug( - "Renamed node from %s to %s for uniqueness", original_name, new_name - ) - return True - suffix += 1 - else: + if original_name not in seen_names: # Name is unique, just record it seen_names.add(original_name) return False + + # If name is already seen, make it unique + base_name = original_name + suffix = 1 + new_name = base_name + while new_name in seen_names: + new_name = f"{base_name}_{suffix}" + suffix += 1 + node.name = new_name + seen_names.add(new_name) + logger.debug("Renamed node from %s to %s for uniqueness", original_name, new_name) + return True From 6977a1fea187630f0611efd18b9614f29201edbc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Jul 2025 16:48:33 -0700 Subject: [PATCH 06/25] format test Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming_test.py | 45 +++++++++++++----------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/src/onnx_ir/passes/common/naming_test.py b/src/onnx_ir/passes/common/naming_test.py index 5eb7863f..30f5aae7 100644 --- a/src/onnx_ir/passes/common/naming_test.py +++ b/src/onnx_ir/passes/common/naming_test.py @@ -35,7 +35,7 @@ def test_assign_names_to_unnamed_values(self): # Verify IR has auto-assigned names self.assertIsNotNone(input_value.name) self.assertIsNotNone(add_node.outputs[0].name) - + # Store original names original_input_name = input_value.name original_output_name = add_node.outputs[0].name @@ -53,7 +53,7 @@ def test_assign_names_to_unnamed_values(self): def test_assign_names_to_unnamed_nodes(self): """Test ensuring all nodes have names even if IR auto-assigned them.""" - # Create a simple model + # Create a simple model input_value = ir.Input( "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) ) @@ -136,7 +136,7 @@ def test_handles_global_uniqueness_across_subgraphs(self): sub_input = ir.Input( "main_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) ) # Same name as main input - should cause conflict - + sub_add_node = ir.Node("", "Add", inputs=[sub_input, sub_input]) sub_add_node.outputs[0].name = "main_input" # Another conflict sub_add_node.outputs[0].shape = sub_input.shape @@ -156,9 +156,12 @@ def test_handles_global_uniqueness_across_subgraphs(self): # Create If node with subgraph if_node = ir.Node( - "", "If", + "", + "If", inputs=[condition_input], - attributes={"then_branch": ir.Attr("then_branch", ir.AttributeType.GRAPH, subgraph)} + attributes={ + "then_branch": ir.Attr("then_branch", ir.AttributeType.GRAPH, subgraph) + }, ) if_node.outputs[0].name = "if_output" if_node.outputs[0].shape = main_input.shape @@ -183,13 +186,15 @@ def test_handles_global_uniqueness_across_subgraphs(self): # Collect all value names to verify uniqueness all_value_names = set() - + # Main graph values for input_val in main_graph.inputs: self.assertIsNotNone(input_val.name) - self.assertNotIn(input_val.name, all_value_names, f"Duplicate value name: {input_val.name}") + self.assertNotIn( + input_val.name, all_value_names, f"Duplicate value name: {input_val.name}" + ) all_value_names.add(input_val.name) - + for output_val in main_graph.outputs: self.assertIsNotNone(output_val.name) if output_val.name not in all_value_names: # Could be same as input @@ -208,9 +213,13 @@ def test_handles_global_uniqueness_across_subgraphs(self): # Subgraph values for input_val in subgraph.inputs: self.assertIsNotNone(input_val.name) - self.assertNotIn(input_val.name, all_value_names, f"Duplicate value name in subgraph: {input_val.name}") + self.assertNotIn( + input_val.name, + all_value_names, + f"Duplicate value name in subgraph: {input_val.name}", + ) all_value_names.add(input_val.name) - + for output_val in subgraph.outputs: if output_val.name not in all_value_names: # Could be same as input all_value_names.add(output_val.name) @@ -268,10 +277,7 @@ def test_handle_duplicate_value_names(self): # One should keep the original name, the other should have a suffix names = {input1.name, input2.name} self.assertIn("duplicate_name", names) - self.assertTrue( - "duplicate_name_1" in names, - f"Expected 'duplicate_name_1' in {names}" - ) + self.assertTrue("duplicate_name_1" in names, f"Expected 'duplicate_name_1' in {names}") def test_handle_duplicate_node_names(self): """Test handling duplicate node names by making them unique.""" @@ -317,10 +323,7 @@ def test_handle_duplicate_node_names(self): # One should keep the original name, the other should have a suffix names = {add_node1.name, add_node2.name} self.assertIn("duplicate_node", names) - self.assertTrue( - "duplicate_node_1" in names, - f"Expected 'duplicate_node_1' in {names}" - ) + self.assertTrue("duplicate_node_1" in names, f"Expected 'duplicate_node_1' in {names}") def test_no_modification_when_all_names_unique(self): """Test that the pass doesn't modify anything when all names are already unique.""" @@ -397,14 +400,14 @@ def test_graph_inputs_outputs_have_precedence(self): # Verify input keeps its original name (has precedence) self.assertEqual(input_value.name, "important_input") - + # Verify output keeps its original name (has precedence) self.assertEqual(mul_node.outputs[0].name, "important_output") - + # Verify intermediate value got renamed to avoid conflict self.assertNotEqual(add_node.outputs[0].name, "important_input") self.assertTrue(add_node.outputs[0].name.startswith("important_input_")) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 212c1a855a252b4891d4dc2fe31a0c2d03d92c61 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Jul 2025 16:55:37 -0700 Subject: [PATCH 07/25] Handle initializers Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 1dc2424d..60bc8325 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -132,10 +132,18 @@ def _process_value( return False modified = False + if not value.name: modified = _assign_value_name(value, seen_value_names, value_counter) else: + old_name = value.name modified = _fix_duplicate_value_name(value, seen_value_names) + if modified: + assert value.graph is not None + if value.is_initializer(): + value.graph.initializers.pop(old_name) + # Add the initializer back with the new name + value.graph.initializers.add(value) # Record the final name for this value assert value.name is not None From 463d5002995bbab7f8c5d055660a5692914cd15e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Jul 2025 16:57:18 -0700 Subject: [PATCH 08/25] fix Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming_test.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/onnx_ir/passes/common/naming_test.py b/src/onnx_ir/passes/common/naming_test.py index 30f5aae7..2c44c810 100644 --- a/src/onnx_ir/passes/common/naming_test.py +++ b/src/onnx_ir/passes/common/naming_test.py @@ -41,8 +41,7 @@ def test_assign_names_to_unnamed_values(self): original_output_name = add_node.outputs[0].name # Run the pass - pass_instance = naming.NameFixPass() - result = pass_instance(model) + result = naming.NameFixPass()(model) # Verify the pass didn't modify anything (names were already assigned and unique) self.assertFalse(result.modified) @@ -78,8 +77,7 @@ def test_assign_names_to_unnamed_nodes(self): original_node_name = add_node.name # Run the pass - pass_instance = naming.NameFixPass() - result = pass_instance(model) + result = naming.NameFixPass()(model) # Verify the pass didn't modify anything (node already had unique name) self.assertFalse(result.modified) @@ -113,8 +111,7 @@ def test_assigns_names_when_truly_unnamed(self): add_node.outputs[0].name = "" # Run the pass - pass_instance = naming.NameFixPass() - result = pass_instance(model) + result = naming.NameFixPass()(model) # Verify the pass was applied self.assertTrue(result.modified) @@ -178,8 +175,7 @@ def test_handles_global_uniqueness_across_subgraphs(self): model = ir.Model(main_graph, ir_version=10) # Run the pass - pass_instance = naming.NameFixPass() - result = pass_instance(model) + result = naming.NameFixPass()(model) # Verify the pass was applied (should fix duplicates) self.assertTrue(result.modified) @@ -266,8 +262,7 @@ def test_handle_duplicate_value_names(self): self.assertEqual(input2.name, "duplicate_name") # Run the pass - pass_instance = naming.NameFixPass() - result = pass_instance(model) + result = naming.NameFixPass()(model) # Verify the pass was applied self.assertTrue(result.modified) @@ -312,8 +307,7 @@ def test_handle_duplicate_node_names(self): self.assertEqual(add_node2.name, "duplicate_node") # Run the pass - pass_instance = naming.NameFixPass() - result = pass_instance(model) + result = naming.NameFixPass()(model) # Verify the pass was applied self.assertTrue(result.modified) @@ -352,8 +346,7 @@ def test_no_modification_when_all_names_unique(self): original_output_name = add_node.outputs[0].name # Run the pass - pass_instance = naming.NameFixPass() - result = pass_instance(model) + result = naming.NameFixPass()(model) # Verify the pass didn't modify anything self.assertFalse(result.modified) @@ -392,8 +385,7 @@ def test_graph_inputs_outputs_have_precedence(self): model = ir.Model(graph, ir_version=10) # Run the pass - pass_instance = naming.NameFixPass() - result = pass_instance(model) + result = naming.NameFixPass()(model) # Verify the pass was applied self.assertTrue(result.modified) From b6be30adfe3afa48b8d8e0767b86e54764b2ca0f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 10:47:21 -0700 Subject: [PATCH 09/25] Create callback Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 89 +++++++++++++---------------- src/onnx_ir/traversal.py | 20 +++++++ 2 files changed, 61 insertions(+), 48 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 60bc8325..a6ef2dfc 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -30,44 +30,13 @@ class NameFixPass(ir.passes.InPlacePass): def call(self, model: ir.Model) -> ir.passes.PassResult: modified = False - # Use sets to track seen names globally - seen_value_names: set[str] = set() - seen_node_names: set[str] = set() - - # Dictionary to track which values have been assigned names - value_to_name: dict[ir.Value, str] = {} - - # Counters for generating unique names (using list to pass by reference) - value_counter = [0] - node_counter = [0] - # Process the main graph - if _fix_graph_names( - model.graph, - seen_value_names, - seen_node_names, - value_to_name, - value_counter, - node_counter, - ): + if _fix_graph_names(model.graph): modified = True # Process functions for function in model.functions.values(): - # Reset seen names and counters for each function - seen_value_names: set[str] = set() - seen_node_names: set[str] = set() - value_to_name: dict[ir.Value, str] = {} - value_counter = [0] - node_counter = [0] - if _fix_graph_names( - function, - seen_value_names, - seen_node_names, - value_to_name, - value_counter, - node_counter, - ): + if _fix_graph_names(function): modified = True if modified: @@ -76,46 +45,70 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: return ir.passes.PassResult(model, modified=modified) -def _fix_graph_names( - graph_like: ir.Graph | ir.Function, - seen_value_names: set[str], - seen_node_names: set[str], - value_to_name: dict[ir.Value, str], - value_counter: list[int], - node_counter: list[int], -) -> bool: +def _fix_graph_names(graph_like: ir.Graph | ir.Function) -> bool: """Fix names in a graph and return whether modifications were made.""" modified = False + # Dictionaries to track which values have been assigned names + value_to_name: dict[ir.Value, str] = {} + scoped_seen_value_names: list[set[str]] = [set()] + scoped_seen_node_names: list[set[str]] = [set()] + + # Counters for generating unique names (using list to pass by reference) + value_counter = [0] + node_counter = [0] + + def enter_graph(graph: ir.Graph, node: ir.Node) -> None: + """Callback for entering a subgraph.""" + # Initialize new scopes with all names from the parent scope + scoped_seen_value_names.append(set(scoped_seen_value_names[-1])) + scoped_seen_node_names.append(set()) + + def exit_graph(graph: ir.Graph, node: ir.Node) -> None: + """Callback for exiting a subgraph.""" + # Pop the current scope + scoped_seen_value_names.pop() + scoped_seen_node_names.pop() + # Step 1: Fix graph input names first (they have precedence) for input_value in graph_like.inputs: - if _process_value(input_value, seen_value_names, value_to_name, value_counter): + if _process_value( + input_value, scoped_seen_value_names[0], value_to_name, value_counter + ): modified = True # Step 2: Fix graph output names (they have precedence) for output_value in graph_like.outputs: - if _process_value(output_value, seen_value_names, value_to_name, value_counter): + if _process_value( + output_value, scoped_seen_value_names[0], value_to_name, value_counter + ): modified = True # Step 3: Process all nodes and their values. Initializers are processed as node inputs. - for node in ir.traversal.RecursiveGraphIterator(graph_like): + for node in ir.traversal.RecursiveGraphIterator( + graph_like, enter_graph=enter_graph, exit_graph=exit_graph + ): # Fix node name if not node.name: - if _assign_node_name(node, seen_node_names, node_counter): + if _assign_node_name(node, scoped_seen_node_names[-1], node_counter): modified = True else: - if _fix_duplicate_node_name(node, seen_node_names): + if _fix_duplicate_node_name(node, scoped_seen_node_names[-1]): modified = True # Fix input value names (only if not already processed) for input_value in node.inputs: if input_value is not None: - if _process_value(input_value, seen_value_names, value_to_name, value_counter): + if _process_value( + input_value, scoped_seen_value_names[-1], value_to_name, value_counter + ): modified = True # Fix output value names (only if not already processed) for output_value in node.outputs: - if _process_value(output_value, seen_value_names, value_to_name, value_counter): + if _process_value( + output_value, scoped_seen_value_names[-1], value_to_name, value_counter + ): modified = True return modified diff --git a/src/onnx_ir/traversal.py b/src/onnx_ir/traversal.py index 26c4e008..4e5df4df 100644 --- a/src/onnx_ir/traversal.py +++ b/src/onnx_ir/traversal.py @@ -25,6 +25,8 @@ def __init__( *, recursive: Callable[[_core.Node], bool] | None = None, reverse: bool = False, + enter_graph: Callable[[_core.Graph, _core.Node], None] | None = None, + exit_graph: Callable[[_core.Graph, _core.Node], None] | None = None, ): """Iterate over the nodes in the graph, recursively visiting subgraphs. @@ -33,11 +35,15 @@ def __init__( recursive: A callback that determines whether to recursively visit the subgraphs contained in a node. If not provided, all nodes in subgraphs are visited. reverse: Whether to iterate in reverse order. + enter_graph: An optional callback that is called when entering a subgraph. + exit_graph: An optional callback that is called when exiting a subgraph. """ self._graph = graph_like self._recursive = recursive self._reverse = reverse self._iterator = self._recursive_node_iter(graph_like) + self._enter_graph = enter_graph + self._exit_graph = exit_graph def __iter__(self) -> Self: self._iterator = self._recursive_node_iter(self._graph) @@ -61,23 +67,37 @@ def _iterate_subgraphs(self, node: _core.Node): if not isinstance(attr, _core.Attr): continue if attr.type == _enums.AttributeType.GRAPH: + if self._enter_graph is not None: + self._enter_graph(attr.value, node) yield from RecursiveGraphIterator( attr.value, recursive=self._recursive, reverse=self._reverse, + enter_graph=self._enter_graph, + exit_graph=self._exit_graph, ) + if self._exit_graph is not None: + self._exit_graph(attr.value, node) elif attr.type == _enums.AttributeType.GRAPHS: graphs = reversed(attr.value) if self._reverse else attr.value for graph in graphs: + if self._enter_graph is not None: + self._enter_graph(graph, node) yield from RecursiveGraphIterator( graph, recursive=self._recursive, reverse=self._reverse, + enter_graph=self._enter_graph, + exit_graph=self._exit_graph, ) + if self._exit_graph is not None: + self._exit_graph(graph, node) def __reversed__(self) -> Iterator[_core.Node]: return RecursiveGraphIterator( self._graph, recursive=self._recursive, reverse=not self._reverse, + enter_graph=self._enter_graph, + exit_graph=self._exit_graph, ) From f01fea91d26219a4a9de558e64b543a568549426 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:02:49 -0700 Subject: [PATCH 10/25] update Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 48 +++++++++++++++++------------ src/onnx_ir/traversal.py | 19 ++++++++---- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index a6ef2dfc..e89e6aca 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -51,40 +51,50 @@ def _fix_graph_names(graph_like: ir.Graph | ir.Function) -> bool: # Dictionaries to track which values have been assigned names value_to_name: dict[ir.Value, str] = {} - scoped_seen_value_names: list[set[str]] = [set()] - scoped_seen_node_names: list[set[str]] = [set()] + scoped_seen_value_names: list[set[str]] = [] + scoped_seen_node_names: list[set[str]] = [] # Counters for generating unique names (using list to pass by reference) value_counter = [0] node_counter = [0] - def enter_graph(graph: ir.Graph, node: ir.Node) -> None: + def enter_graph(graph_like) -> None: """Callback for entering a subgraph.""" # Initialize new scopes with all names from the parent scope scoped_seen_value_names.append(set(scoped_seen_value_names[-1])) scoped_seen_node_names.append(set()) - def exit_graph(graph: ir.Graph, node: ir.Node) -> None: + nonlocal modified + + # Step 1: Fix graph input names first (they have precedence) + for input_value in graph_like.inputs: + if _process_value( + input_value, scoped_seen_value_names[-1], value_to_name, value_counter + ): + modified = True + + # Step 2: Fix graph output names (they have precedence) + for output_value in graph_like.outputs: + if _process_value( + output_value, scoped_seen_value_names[-1], value_to_name, value_counter + ): + modified = True + + if isinstance(graph_like, ir.Graph): + # For graphs, also fix initializers + for initializer in graph_like.initializers.values(): + if _process_value( + initializer, scoped_seen_value_names[-1], value_to_name, value_counter + ): + modified = True + + def exit_graph(_) -> None: """Callback for exiting a subgraph.""" # Pop the current scope scoped_seen_value_names.pop() scoped_seen_node_names.pop() - # Step 1: Fix graph input names first (they have precedence) - for input_value in graph_like.inputs: - if _process_value( - input_value, scoped_seen_value_names[0], value_to_name, value_counter - ): - modified = True - - # Step 2: Fix graph output names (they have precedence) - for output_value in graph_like.outputs: - if _process_value( - output_value, scoped_seen_value_names[0], value_to_name, value_counter - ): - modified = True - - # Step 3: Process all nodes and their values. Initializers are processed as node inputs. + # Step 3: Process all nodes and their values for node in ir.traversal.RecursiveGraphIterator( graph_like, enter_graph=enter_graph, exit_graph=exit_graph ): diff --git a/src/onnx_ir/traversal.py b/src/onnx_ir/traversal.py index 4e5df4df..b4b41802 100644 --- a/src/onnx_ir/traversal.py +++ b/src/onnx_ir/traversal.py @@ -25,8 +25,8 @@ def __init__( *, recursive: Callable[[_core.Node], bool] | None = None, reverse: bool = False, - enter_graph: Callable[[_core.Graph, _core.Node], None] | None = None, - exit_graph: Callable[[_core.Graph, _core.Node], None] | None = None, + enter_graph: Callable[[GraphLike], None] | None = None, + exit_graph: Callable[[GraphLike], None] | None = None, ): """Iterate over the nodes in the graph, recursively visiting subgraphs. @@ -56,19 +56,26 @@ def _recursive_node_iter( self, graph: _core.Graph | _core.Function | _core.GraphView ) -> Iterator[_core.Node]: iterable = reversed(graph) if self._reverse else graph + + if self._enter_graph is not None: + self._enter_graph(graph) + for node in iterable: # type: ignore[union-attr] yield node if self._recursive is not None and not self._recursive(node): continue yield from self._iterate_subgraphs(node) + if self._exit_graph is not None: + self._exit_graph(graph) + def _iterate_subgraphs(self, node: _core.Node): for attr in node.attributes.values(): if not isinstance(attr, _core.Attr): continue if attr.type == _enums.AttributeType.GRAPH: if self._enter_graph is not None: - self._enter_graph(attr.value, node) + self._enter_graph(attr.value) yield from RecursiveGraphIterator( attr.value, recursive=self._recursive, @@ -77,12 +84,12 @@ def _iterate_subgraphs(self, node: _core.Node): exit_graph=self._exit_graph, ) if self._exit_graph is not None: - self._exit_graph(attr.value, node) + self._exit_graph(attr.value) elif attr.type == _enums.AttributeType.GRAPHS: graphs = reversed(attr.value) if self._reverse else attr.value for graph in graphs: if self._enter_graph is not None: - self._enter_graph(graph, node) + self._enter_graph(graph) yield from RecursiveGraphIterator( graph, recursive=self._recursive, @@ -91,7 +98,7 @@ def _iterate_subgraphs(self, node: _core.Node): exit_graph=self._exit_graph, ) if self._exit_graph is not None: - self._exit_graph(graph, node) + self._exit_graph(graph) def __reversed__(self) -> Iterator[_core.Node]: return RecursiveGraphIterator( From 3b7fa60e47d6731bf01ee8847b4e00bbf77dfec9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:04:44 -0700 Subject: [PATCH 11/25] scope Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index e89e6aca..06542612 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -51,8 +51,11 @@ def _fix_graph_names(graph_like: ir.Graph | ir.Function) -> bool: # Dictionaries to track which values have been assigned names value_to_name: dict[ir.Value, str] = {} - scoped_seen_value_names: list[set[str]] = [] - scoped_seen_node_names: list[set[str]] = [] + + # The first set is a dummy placeholder so that there is always a [-1] scope for access + # (even though we don't write to it) + scoped_seen_value_names: list[set[str]] = [set()] + scoped_seen_node_names: list[set[str]] = [set()] # Counters for generating unique names (using list to pass by reference) value_counter = [0] From 7b49ab81b46221e3d047870a9aff216068941b9e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:24:09 -0700 Subject: [PATCH 12/25] Support _generate_node_name Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 361 +++++++++++++++------------- 1 file changed, 188 insertions(+), 173 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 06542612..f8026463 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -3,6 +3,7 @@ """Name fix pass for ensuring unique names for all values and nodes.""" from __future__ import annotations +from typing import Callable __all__ = [ "NameFixPass", @@ -27,16 +28,31 @@ class NameFixPass(ir.passes.InPlacePass): The pass maintains global uniqueness across the entire model. """ + def __init__( + self, + generate_node_name: Callable[[ir.Node], str] = lambda n: n.name or "node", + generate_value_name: Callable[[ir.Value], str] = lambda v: v.name or "v", + ) -> None: + """Initialize the NameFixPass with custom name generation functions. + + Args: + generate_node_name: Function to generate a unique name for a node. + generate_value_name: Function to generate a unique name for a value. + """ + super().__init__() + self._generate_node_name = generate_node_name + self._generate_value_name = generate_value_name + def call(self, model: ir.Model) -> ir.passes.PassResult: modified = False # Process the main graph - if _fix_graph_names(model.graph): + if self._fix_graph_names(model.graph): modified = True # Process functions for function in model.functions.values(): - if _fix_graph_names(function): + if self._fix_graph_names(function): modified = True if modified: @@ -44,195 +60,194 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: return ir.passes.PassResult(model, modified=modified) + def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool: + """Fix names in a graph and return whether modifications were made.""" + modified = False -def _fix_graph_names(graph_like: ir.Graph | ir.Function) -> bool: - """Fix names in a graph and return whether modifications were made.""" - modified = False - - # Dictionaries to track which values have been assigned names - value_to_name: dict[ir.Value, str] = {} - - # The first set is a dummy placeholder so that there is always a [-1] scope for access - # (even though we don't write to it) - scoped_seen_value_names: list[set[str]] = [set()] - scoped_seen_node_names: list[set[str]] = [set()] - - # Counters for generating unique names (using list to pass by reference) - value_counter = [0] - node_counter = [0] + # Dictionaries to track which values have been assigned names + value_to_name: dict[ir.Value, str] = {} - def enter_graph(graph_like) -> None: - """Callback for entering a subgraph.""" - # Initialize new scopes with all names from the parent scope - scoped_seen_value_names.append(set(scoped_seen_value_names[-1])) - scoped_seen_node_names.append(set()) + # The first set is a dummy placeholder so that there is always a [-1] scope for access + # (even though we don't write to it) + scoped_seen_value_names: list[set[str]] = [set()] + scoped_seen_node_names: list[set[str]] = [set()] - nonlocal modified + # Counters for generating unique names (using list to pass by reference) + value_counter = [0] + node_counter = [0] - # Step 1: Fix graph input names first (they have precedence) - for input_value in graph_like.inputs: - if _process_value( - input_value, scoped_seen_value_names[-1], value_to_name, value_counter - ): - modified = True + def enter_graph(graph_like) -> None: + """Callback for entering a subgraph.""" + # Initialize new scopes with all names from the parent scope + scoped_seen_value_names.append(set(scoped_seen_value_names[-1])) + scoped_seen_node_names.append(set()) - # Step 2: Fix graph output names (they have precedence) - for output_value in graph_like.outputs: - if _process_value( - output_value, scoped_seen_value_names[-1], value_to_name, value_counter - ): - modified = True + nonlocal modified - if isinstance(graph_like, ir.Graph): - # For graphs, also fix initializers - for initializer in graph_like.initializers.values(): - if _process_value( - initializer, scoped_seen_value_names[-1], value_to_name, value_counter + # Step 1: Fix graph input names first (they have precedence) + for input_value in graph_like.inputs: + if self._process_value( + input_value, scoped_seen_value_names[-1], value_to_name, value_counter ): modified = True - def exit_graph(_) -> None: - """Callback for exiting a subgraph.""" - # Pop the current scope - scoped_seen_value_names.pop() - scoped_seen_node_names.pop() - - # Step 3: Process all nodes and their values - for node in ir.traversal.RecursiveGraphIterator( - graph_like, enter_graph=enter_graph, exit_graph=exit_graph - ): - # Fix node name - if not node.name: - if _assign_node_name(node, scoped_seen_node_names[-1], node_counter): - modified = True - else: - if _fix_duplicate_node_name(node, scoped_seen_node_names[-1]): - modified = True - - # Fix input value names (only if not already processed) - for input_value in node.inputs: - if input_value is not None: - if _process_value( - input_value, scoped_seen_value_names[-1], value_to_name, value_counter + # Step 2: Fix graph output names (they have precedence) + for output_value in graph_like.outputs: + if self._process_value( + output_value, scoped_seen_value_names[-1], value_to_name, value_counter ): modified = True - # Fix output value names (only if not already processed) - for output_value in node.outputs: - if _process_value( - output_value, scoped_seen_value_names[-1], value_to_name, value_counter - ): - modified = True - - return modified - - -def _process_value( - value: ir.Value, - seen_value_names: set[str], - value_to_name: dict[ir.Value, str], - value_counter: list[int], -) -> bool: - """Process a value only if it hasn't been processed before.""" - if value in value_to_name: - return False - - modified = False - - if not value.name: - modified = _assign_value_name(value, seen_value_names, value_counter) - else: - old_name = value.name - modified = _fix_duplicate_value_name(value, seen_value_names) - if modified: - assert value.graph is not None - if value.is_initializer(): - value.graph.initializers.pop(old_name) - # Add the initializer back with the new name - value.graph.initializers.add(value) - - # Record the final name for this value - assert value.name is not None - value_to_name[value] = value.name - return modified - + if isinstance(graph_like, ir.Graph): + # For graphs, also fix initializers + for initializer in graph_like.initializers.values(): + if self._process_value( + initializer, scoped_seen_value_names[-1], value_to_name, value_counter + ): + modified = True + + def exit_graph(_) -> None: + """Callback for exiting a subgraph.""" + # Pop the current scope + scoped_seen_value_names.pop() + scoped_seen_node_names.pop() + + # Step 3: Process all nodes and their values + for node in ir.traversal.RecursiveGraphIterator( + graph_like, enter_graph=enter_graph, exit_graph=exit_graph + ): + # Fix node name + if not node.name: + if self._assign_node_name(node, scoped_seen_node_names[-1], node_counter): + modified = True + else: + if self._fix_duplicate_node_name(node, scoped_seen_node_names[-1]): + modified = True -def _assign_value_name(value: ir.Value, seen_names: set[str], counter: list[int]) -> bool: - """Assign a name to an unnamed value. Returns True if modified.""" - assert not value.name, ( - "value should not have a name already if function is called correctly" - ) + # Fix input value names (only if not already processed) + for input_value in node.inputs: + if input_value is not None: + if self._process_value( + input_value, scoped_seen_value_names[-1], value_to_name, value_counter + ): + modified = True + + # Fix output value names (only if not already processed) + for output_value in node.outputs: + if self._process_value( + output_value, scoped_seen_value_names[-1], value_to_name, value_counter + ): + modified = True - new_name = f"val_{counter[0]}" - while new_name in seen_names: - counter[0] += 1 - new_name = f"val_{counter[0]}" + return modified - value.name = new_name - seen_names.add(new_name) - logger.debug("Assigned name %s to unnamed value", new_name) - return True + def _process_value( + self, + value: ir.Value, + seen_value_names: set[str], + value_to_name: dict[ir.Value, str], + value_counter: list[int], + ) -> bool: + """Process a value only if it hasn't been processed before.""" + if value in value_to_name: + return False -def _assign_node_name(node: ir.Node, seen_names: set[str], counter: list[int]) -> bool: - """Assign a name to an unnamed node. Returns True if modified.""" - assert not node.name, "node should not have a name already if function is called correctly" + modified = False - new_name = f"node_{counter[0]}" + if not value.name: + modified = self._assign_value_name(value, seen_value_names, value_counter) + else: + old_name = value.name + modified = self._fix_duplicate_value_name(value, seen_value_names) + if modified: + assert value.graph is not None + if value.is_initializer(): + value.graph.initializers.pop(old_name) + # Add the initializer back with the new name + value.graph.initializers.add(value) + + # Record the final name for this value + assert value.name is not None + value_to_name[value] = value.name + return modified + + + def _assign_value_name(self, value: ir.Value, seen_names: set[str], counter: list[int]) -> bool: + """Assign a name to an unnamed value. Returns True if modified.""" + assert not value.name, ( + "value should not have a name already if function is called correctly" + ) + + new_name = f"v_{counter[0]}" + while new_name in seen_names: + counter[0] += 1 + new_name = f"v_{counter[0]}" + + value.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed value", new_name) + return True + + def _assign_node_name(self, node: ir.Node, seen_names: set[str], counter: list[int]) -> bool: + """Assign a name to an unnamed node. Returns True if modified.""" + assert not node.name, "node should not have a name already if function is called correctly" - while new_name in seen_names: - counter[0] += 1 new_name = f"node_{counter[0]}" - node.name = new_name - seen_names.add(new_name) - logger.debug("Assigned name %s to unnamed node", new_name) - return True - - -def _fix_duplicate_value_name(value: ir.Value, seen_names: set[str]) -> bool: - """Fix a value's name if it conflicts with existing names. Returns True if modified.""" - original_name = value.name - - assert original_name, "value should have a name already if function is called correctly" - - if original_name not in seen_names: - # Name is unique, just record it - seen_names.add(original_name) - return False - - # If name is already seen, make it unique - base_name = original_name - suffix = 1 - new_name = base_name - while new_name in seen_names: - new_name = f"{base_name}_{suffix}" - suffix += 1 - value.name = new_name - seen_names.add(new_name) - logger.debug("Renamed value from %s to %s for uniqueness", original_name, new_name) - return True - - -def _fix_duplicate_node_name(node: ir.Node, seen_names: set[str]) -> bool: - """Fix a node's name if it conflicts with existing names. Returns True if modified.""" - original_name = node.name - - assert original_name, "node should have a name already if function is called correctly" - - if original_name not in seen_names: - # Name is unique, just record it - seen_names.add(original_name) - return False - - # If name is already seen, make it unique - base_name = original_name - suffix = 1 - new_name = base_name - while new_name in seen_names: - new_name = f"{base_name}_{suffix}" - suffix += 1 - node.name = new_name - seen_names.add(new_name) - logger.debug("Renamed node from %s to %s for uniqueness", original_name, new_name) - return True + while new_name in seen_names: + counter[0] += 1 + new_name = f"node_{counter[0]}" + + node.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed node", new_name) + return True + + + def _fix_duplicate_value_name(self, value: ir.Value, seen_names: set[str]) -> bool: + """Fix a value's name if it conflicts with existing names. Returns True if modified.""" + original_name = value.name + + assert original_name, "value should have a name already if function is called correctly" + + if original_name not in seen_names: + # Name is unique, just record it + seen_names.add(original_name) + return False + + # If name is already seen, make it unique + base_name = self._generate_value_name(value) + suffix = 1 + new_name = base_name + while new_name in seen_names: + new_name = f"{base_name}_{suffix}" + suffix += 1 + value.name = new_name + seen_names.add(new_name) + logger.debug("Renamed value from %s to %s for uniqueness", original_name, new_name) + return True + + + def _fix_duplicate_node_name(self, node: ir.Node, seen_names: set[str]) -> bool: + """Fix a node's name if it conflicts with existing names. Returns True if modified.""" + original_name = node.name + + assert original_name, "node should have a name already if function is called correctly" + + if original_name not in seen_names: + # Name is unique, just record it + seen_names.add(original_name) + return False + + # If name is already seen, make it unique + base_name = self._generate_node_name(node) + suffix = 1 + new_name = base_name + while new_name in seen_names: + new_name = f"{base_name}_{suffix}" + suffix += 1 + node.name = new_name + seen_names.add(new_name) + logger.debug("Renamed node from %s to %s for uniqueness", original_name, new_name) + return True From 291b21c50b3ee0619f5c56ed4180b13139ebecd6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:24:58 -0700 Subject: [PATCH 13/25] lint Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index f8026463..7b7ec8c1 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -3,6 +3,7 @@ """Name fix pass for ensuring unique names for all values and nodes.""" from __future__ import annotations + from typing import Callable __all__ = [ @@ -36,8 +37,8 @@ def __init__( """Initialize the NameFixPass with custom name generation functions. Args: - generate_node_name: Function to generate a unique name for a node. - generate_value_name: Function to generate a unique name for a value. + generate_node_name: Function to generate a preferred name for a node. + generate_value_name: Function to generate a preferred name for a value. """ super().__init__() self._generate_node_name = generate_node_name From 4c35f2be2e036602a0f841367f158b3b8d693fff Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:28:52 -0700 Subject: [PATCH 14/25] versionadded and docs Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 7b7ec8c1..9fe12565 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -23,10 +23,27 @@ class NameFixPass(ir.passes.InPlacePass): This pass ensures that: 1. Graph inputs and outputs have unique names (take precedence) 2. All intermediate values have unique names (assign names to unnamed values) - 3. All values in subgraphs have unique names - 4. All nodes have unique names (assign names to unnamed nodes) + 3. All values in subgraphs have unique names within their graph and parent graphs + 4. All nodes have unique names within their graph The pass maintains global uniqueness across the entire model. + + You can customize the name generation functions for nodes and values by passing + `generate_node_name` and `generate_value_name` parameters to the constructor. + + For example, you can use a custom naming scheme like this:: + + def custom_node_name(node: ir.Node) -> str: + return f"custom_node_{node.op_type}" + def custom_value_name(value: ir.Value) -> str: + return f"custom_value_{value.type}" + + name_fix_pass = NameFixPass( + generate_node_name=custom_node_name, + generate_value_name=custom_value_name + ) + + .. versionadded:: 0.1.5 """ def __init__( From da3ea58643a6fe5df084cdc6d3a04d3e49f2b9af Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:32:27 -0700 Subject: [PATCH 15/25] _generate_value_name Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 9fe12565..37eeacef 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -197,7 +197,8 @@ def _assign_value_name(self, value: ir.Value, seen_names: set[str], counter: lis "value should not have a name already if function is called correctly" ) - new_name = f"v_{counter[0]}" + new_name = self._generate_value_name(value) + while new_name in seen_names: counter[0] += 1 new_name = f"v_{counter[0]}" @@ -211,7 +212,7 @@ def _assign_node_name(self, node: ir.Node, seen_names: set[str], counter: list[i """Assign a name to an unnamed node. Returns True if modified.""" assert not node.name, "node should not have a name already if function is called correctly" - new_name = f"node_{counter[0]}" + new_name = self._generate_node_name(node) while new_name in seen_names: counter[0] += 1 From 5b6c7c934838837db1887487b17dd841febe1713 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:42:22 -0700 Subject: [PATCH 16/25] refactor unique name finding Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 78 +++++++++++++---------------- 1 file changed, 36 insertions(+), 42 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 37eeacef..0b4b9e2e 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -159,7 +159,6 @@ def exit_graph(_) -> None: return modified - def _process_value( self, value: ir.Value, @@ -190,45 +189,39 @@ def _process_value( value_to_name[value] = value.name return modified - - def _assign_value_name(self, value: ir.Value, seen_names: set[str], counter: list[int]) -> bool: + def _assign_value_name( + self, value: ir.Value, seen_names: set[str], counter: list[int] + ) -> bool: """Assign a name to an unnamed value. Returns True if modified.""" assert not value.name, ( "value should not have a name already if function is called correctly" ) - new_name = self._generate_value_name(value) - - while new_name in seen_names: - counter[0] += 1 - new_name = f"v_{counter[0]}" - - value.name = new_name - seen_names.add(new_name) - logger.debug("Assigned name %s to unnamed value", new_name) + preferred_name = self._generate_value_name(value) + value.name = _find_and_record_next_unique_name(preferred_name, seen_names, counter) + logger.debug("Assigned name %s to unnamed value", value.name) return True - def _assign_node_name(self, node: ir.Node, seen_names: set[str], counter: list[int]) -> bool: + def _assign_node_name( + self, node: ir.Node, seen_names: set[str], counter: list[int] + ) -> bool: """Assign a name to an unnamed node. Returns True if modified.""" - assert not node.name, "node should not have a name already if function is called correctly" - - new_name = self._generate_node_name(node) - - while new_name in seen_names: - counter[0] += 1 - new_name = f"node_{counter[0]}" + assert not node.name, ( + "node should not have a name already if function is called correctly" + ) - node.name = new_name - seen_names.add(new_name) - logger.debug("Assigned name %s to unnamed node", new_name) + preferred_name = self._generate_node_name(node) + node.name = _find_and_record_next_unique_name(preferred_name, seen_names, counter) + logger.debug("Assigned name %s to unnamed node", node.name) return True - def _fix_duplicate_value_name(self, value: ir.Value, seen_names: set[str]) -> bool: """Fix a value's name if it conflicts with existing names. Returns True if modified.""" original_name = value.name - assert original_name, "value should have a name already if function is called correctly" + assert original_name, ( + "value should have a name already if function is called correctly" + ) if original_name not in seen_names: # Name is unique, just record it @@ -237,17 +230,10 @@ def _fix_duplicate_value_name(self, value: ir.Value, seen_names: set[str]) -> bo # If name is already seen, make it unique base_name = self._generate_value_name(value) - suffix = 1 - new_name = base_name - while new_name in seen_names: - new_name = f"{base_name}_{suffix}" - suffix += 1 - value.name = new_name - seen_names.add(new_name) - logger.debug("Renamed value from %s to %s for uniqueness", original_name, new_name) + value.name = _find_and_record_next_unique_name(base_name, seen_names) + logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name) return True - def _fix_duplicate_node_name(self, node: ir.Node, seen_names: set[str]) -> bool: """Fix a node's name if it conflicts with existing names. Returns True if modified.""" original_name = node.name @@ -261,12 +247,20 @@ def _fix_duplicate_node_name(self, node: ir.Node, seen_names: set[str]) -> bool: # If name is already seen, make it unique base_name = self._generate_node_name(node) - suffix = 1 - new_name = base_name - while new_name in seen_names: - new_name = f"{base_name}_{suffix}" - suffix += 1 - node.name = new_name - seen_names.add(new_name) - logger.debug("Renamed node from %s to %s for uniqueness", original_name, new_name) + node.name = _find_and_record_next_unique_name(base_name, seen_names) + logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name) return True + + +def _find_and_record_next_unique_name( + preferred_name: str, seen_names: set[str], counter: list[int] | None = None +) -> str: + """Generate a unique name based on the preferred name and current counter.""" + new_name = preferred_name + if counter is None: + counter = [0] + while new_name in seen_names: + counter[0] += 1 + new_name = f"{preferred_name}_{counter[0]}" + seen_names.add(new_name) + return new_name From 2efce8820db2e1d58f3d8188d064131379ae504f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:45:50 -0700 Subject: [PATCH 17/25] Use a set Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 0b4b9e2e..2c1c95f8 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -82,8 +82,8 @@ def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool: """Fix names in a graph and return whether modifications were made.""" modified = False - # Dictionaries to track which values have been assigned names - value_to_name: dict[ir.Value, str] = {} + # Set to track which values have been assigned names + seen_values: set[ir.Value] = set() # The first set is a dummy placeholder so that there is always a [-1] scope for access # (even though we don't write to it) @@ -105,14 +105,14 @@ def enter_graph(graph_like) -> None: # Step 1: Fix graph input names first (they have precedence) for input_value in graph_like.inputs: if self._process_value( - input_value, scoped_seen_value_names[-1], value_to_name, value_counter + input_value, scoped_seen_value_names[-1], seen_values, value_counter ): modified = True # Step 2: Fix graph output names (they have precedence) for output_value in graph_like.outputs: if self._process_value( - output_value, scoped_seen_value_names[-1], value_to_name, value_counter + output_value, scoped_seen_value_names[-1], seen_values, value_counter ): modified = True @@ -120,7 +120,7 @@ def enter_graph(graph_like) -> None: # For graphs, also fix initializers for initializer in graph_like.initializers.values(): if self._process_value( - initializer, scoped_seen_value_names[-1], value_to_name, value_counter + initializer, scoped_seen_value_names[-1], seen_values, value_counter ): modified = True @@ -146,14 +146,14 @@ def exit_graph(_) -> None: for input_value in node.inputs: if input_value is not None: if self._process_value( - input_value, scoped_seen_value_names[-1], value_to_name, value_counter + input_value, scoped_seen_value_names[-1], seen_values, value_counter ): modified = True # Fix output value names (only if not already processed) for output_value in node.outputs: if self._process_value( - output_value, scoped_seen_value_names[-1], value_to_name, value_counter + output_value, scoped_seen_value_names[-1], seen_values, value_counter ): modified = True @@ -163,11 +163,11 @@ def _process_value( self, value: ir.Value, seen_value_names: set[str], - value_to_name: dict[ir.Value, str], + seen_values: set[ir.Value], value_counter: list[int], ) -> bool: """Process a value only if it hasn't been processed before.""" - if value in value_to_name: + if value in seen_values: return False modified = False @@ -186,7 +186,7 @@ def _process_value( # Record the final name for this value assert value.name is not None - value_to_name[value] = value.name + seen_values.add(value) return modified def _assign_value_name( From 797a368ac7bf6711e5fd3b95932ddde64fa618ba Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:47:20 -0700 Subject: [PATCH 18/25] seen -> used Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 66 ++++++++++++++--------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 2c1c95f8..a79467ec 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -87,8 +87,8 @@ def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool: # The first set is a dummy placeholder so that there is always a [-1] scope for access # (even though we don't write to it) - scoped_seen_value_names: list[set[str]] = [set()] - scoped_seen_node_names: list[set[str]] = [set()] + scoped_used_value_names: list[set[str]] = [set()] + scoped_used_node_names: list[set[str]] = [set()] # Counters for generating unique names (using list to pass by reference) value_counter = [0] @@ -97,22 +97,22 @@ def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool: def enter_graph(graph_like) -> None: """Callback for entering a subgraph.""" # Initialize new scopes with all names from the parent scope - scoped_seen_value_names.append(set(scoped_seen_value_names[-1])) - scoped_seen_node_names.append(set()) + scoped_used_value_names.append(set(scoped_used_value_names[-1])) + scoped_used_node_names.append(set()) nonlocal modified # Step 1: Fix graph input names first (they have precedence) for input_value in graph_like.inputs: if self._process_value( - input_value, scoped_seen_value_names[-1], seen_values, value_counter + input_value, scoped_used_value_names[-1], seen_values, value_counter ): modified = True # Step 2: Fix graph output names (they have precedence) for output_value in graph_like.outputs: if self._process_value( - output_value, scoped_seen_value_names[-1], seen_values, value_counter + output_value, scoped_used_value_names[-1], seen_values, value_counter ): modified = True @@ -120,15 +120,15 @@ def enter_graph(graph_like) -> None: # For graphs, also fix initializers for initializer in graph_like.initializers.values(): if self._process_value( - initializer, scoped_seen_value_names[-1], seen_values, value_counter + initializer, scoped_used_value_names[-1], seen_values, value_counter ): modified = True def exit_graph(_) -> None: """Callback for exiting a subgraph.""" # Pop the current scope - scoped_seen_value_names.pop() - scoped_seen_node_names.pop() + scoped_used_value_names.pop() + scoped_used_node_names.pop() # Step 3: Process all nodes and their values for node in ir.traversal.RecursiveGraphIterator( @@ -136,24 +136,24 @@ def exit_graph(_) -> None: ): # Fix node name if not node.name: - if self._assign_node_name(node, scoped_seen_node_names[-1], node_counter): + if self._assign_node_name(node, scoped_used_node_names[-1], node_counter): modified = True else: - if self._fix_duplicate_node_name(node, scoped_seen_node_names[-1]): + if self._fix_duplicate_node_name(node, scoped_used_node_names[-1]): modified = True # Fix input value names (only if not already processed) for input_value in node.inputs: if input_value is not None: if self._process_value( - input_value, scoped_seen_value_names[-1], seen_values, value_counter + input_value, scoped_used_value_names[-1], seen_values, value_counter ): modified = True # Fix output value names (only if not already processed) for output_value in node.outputs: if self._process_value( - output_value, scoped_seen_value_names[-1], seen_values, value_counter + output_value, scoped_used_value_names[-1], seen_values, value_counter ): modified = True @@ -162,7 +162,7 @@ def exit_graph(_) -> None: def _process_value( self, value: ir.Value, - seen_value_names: set[str], + used_value_names: set[str], seen_values: set[ir.Value], value_counter: list[int], ) -> bool: @@ -173,10 +173,10 @@ def _process_value( modified = False if not value.name: - modified = self._assign_value_name(value, seen_value_names, value_counter) + modified = self._assign_value_name(value, used_value_names, value_counter) else: old_name = value.name - modified = self._fix_duplicate_value_name(value, seen_value_names) + modified = self._fix_duplicate_value_name(value, used_value_names) if modified: assert value.graph is not None if value.is_initializer(): @@ -190,7 +190,7 @@ def _process_value( return modified def _assign_value_name( - self, value: ir.Value, seen_names: set[str], counter: list[int] + self, value: ir.Value, used_names: set[str], counter: list[int] ) -> bool: """Assign a name to an unnamed value. Returns True if modified.""" assert not value.name, ( @@ -198,12 +198,12 @@ def _assign_value_name( ) preferred_name = self._generate_value_name(value) - value.name = _find_and_record_next_unique_name(preferred_name, seen_names, counter) + value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter) logger.debug("Assigned name %s to unnamed value", value.name) return True def _assign_node_name( - self, node: ir.Node, seen_names: set[str], counter: list[int] + self, node: ir.Node, used_names: set[str], counter: list[int] ) -> bool: """Assign a name to an unnamed node. Returns True if modified.""" assert not node.name, ( @@ -211,11 +211,11 @@ def _assign_node_name( ) preferred_name = self._generate_node_name(node) - node.name = _find_and_record_next_unique_name(preferred_name, seen_names, counter) + node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter) logger.debug("Assigned name %s to unnamed node", node.name) return True - def _fix_duplicate_value_name(self, value: ir.Value, seen_names: set[str]) -> bool: + def _fix_duplicate_value_name(self, value: ir.Value, used_names: set[str]) -> bool: """Fix a value's name if it conflicts with existing names. Returns True if modified.""" original_name = value.name @@ -223,44 +223,44 @@ def _fix_duplicate_value_name(self, value: ir.Value, seen_names: set[str]) -> bo "value should have a name already if function is called correctly" ) - if original_name not in seen_names: + if original_name not in used_names: # Name is unique, just record it - seen_names.add(original_name) + used_names.add(original_name) return False - # If name is already seen, make it unique + # If name is already used, make it unique base_name = self._generate_value_name(value) - value.name = _find_and_record_next_unique_name(base_name, seen_names) + value.name = _find_and_record_next_unique_name(base_name, used_names) logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name) return True - def _fix_duplicate_node_name(self, node: ir.Node, seen_names: set[str]) -> bool: + def _fix_duplicate_node_name(self, node: ir.Node, used_names: set[str]) -> bool: """Fix a node's name if it conflicts with existing names. Returns True if modified.""" original_name = node.name assert original_name, "node should have a name already if function is called correctly" - if original_name not in seen_names: + if original_name not in used_names: # Name is unique, just record it - seen_names.add(original_name) + used_names.add(original_name) return False - # If name is already seen, make it unique + # If name is already used, make it unique base_name = self._generate_node_name(node) - node.name = _find_and_record_next_unique_name(base_name, seen_names) + node.name = _find_and_record_next_unique_name(base_name, used_names) logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name) return True def _find_and_record_next_unique_name( - preferred_name: str, seen_names: set[str], counter: list[int] | None = None + preferred_name: str, used_names: set[str], counter: list[int] | None = None ) -> str: """Generate a unique name based on the preferred name and current counter.""" new_name = preferred_name if counter is None: counter = [0] - while new_name in seen_names: + while new_name in used_names: counter[0] += 1 new_name = f"{preferred_name}_{counter[0]}" - seen_names.add(new_name) + used_names.add(new_name) return new_name From 4e07db2984b8554737af688c0e29604880adb0a3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:47:49 -0700 Subject: [PATCH 19/25] remove logging Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index a79467ec..e79b0049 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -73,9 +73,6 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: if self._fix_graph_names(function): modified = True - if modified: - logger.info("Name fix pass modified the model") - return ir.passes.PassResult(model, modified=modified) def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool: From 784552c85d5813a5070755f84831d770a22bf7c6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 11:48:23 -0700 Subject: [PATCH 20/25] docs Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index e79b0049..72fb3337 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -55,7 +55,9 @@ def __init__( Args: generate_node_name: Function to generate a preferred name for a node. + By default, it uses the node's existing name or "node". generate_value_name: Function to generate a preferred name for a value. + By default, it uses the value's existing name or "v". """ super().__init__() self._generate_node_name = generate_node_name From b1e0956d1152f8ef6b936175d6441ab7de268ed0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 31 Jul 2025 11:45:27 -0700 Subject: [PATCH 21/25] Create NameGenerator Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 62 +++++++++++++++-------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 72fb3337..233cc3c5 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -4,10 +4,9 @@ from __future__ import annotations -from typing import Callable - __all__ = [ "NameFixPass", + "NameGenerator", ] import logging @@ -17,6 +16,18 @@ logger = logging.getLogger(__name__) +class NameGenerator: + """Base class for name generation functions.""" + + def generate_node_name(self, node: ir.Node) -> str: + """Generate a preferred name for a node.""" + return node.name or "node" + + def generate_value_name(self, value: ir.Value) -> str: + """Generate a preferred name for a value.""" + return value.name or "v" + + class NameFixPass(ir.passes.InPlacePass): """Pass for fixing names to ensure all values and nodes have unique names. @@ -29,51 +40,44 @@ class NameFixPass(ir.passes.InPlacePass): The pass maintains global uniqueness across the entire model. You can customize the name generation functions for nodes and values by passing - `generate_node_name` and `generate_value_name` parameters to the constructor. + a subclass of :class:`NameGenerator`. For example, you can use a custom naming scheme like this:: - def custom_node_name(node: ir.Node) -> str: - return f"custom_node_{node.op_type}" - def custom_value_name(value: ir.Value) -> str: - return f"custom_value_{value.type}" + class CustomNameGenerator: + def custom_node_name(node: ir.Node) -> str: + return f"custom_node_{node.op_type}" - name_fix_pass = NameFixPass( - generate_node_name=custom_node_name, - generate_value_name=custom_value_name - ) + def custom_value_name(value: ir.Value) -> str: + return f"custom_value_{value.type}" + + name_fix_pass = NameFixPass(nameGenerator=CustomNameGenerator()) .. versionadded:: 0.1.5 """ def __init__( self, - generate_node_name: Callable[[ir.Node], str] = lambda n: n.name or "node", - generate_value_name: Callable[[ir.Value], str] = lambda v: v.name or "v", + name_generator: NameGenerator | None = None, ) -> None: """Initialize the NameFixPass with custom name generation functions. Args: - generate_node_name: Function to generate a preferred name for a node. - By default, it uses the node's existing name or "node". - generate_value_name: Function to generate a preferred name for a value. - By default, it uses the value's existing name or "v". + name_generator (NameGenerator, optional): An instance of a subclass of + :class:`NameGenerator` to customize name generation for nodes and values. + If not provided, defaults to a basic implementation that uses + the node's or value's existing name or a generic name like "node" or "v". """ super().__init__() - self._generate_node_name = generate_node_name - self._generate_value_name = generate_value_name + self._name_generator = name_generator or NameGenerator() def call(self, model: ir.Model) -> ir.passes.PassResult: - modified = False - # Process the main graph - if self._fix_graph_names(model.graph): - modified = True + modified = self._fix_graph_names(model.graph) # Process functions for function in model.functions.values(): - if self._fix_graph_names(function): - modified = True + modified = self._fix_graph_names(function) or modified return ir.passes.PassResult(model, modified=modified) @@ -196,7 +200,7 @@ def _assign_value_name( "value should not have a name already if function is called correctly" ) - preferred_name = self._generate_value_name(value) + preferred_name = self._name_generator.generate_value_name(value) value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter) logger.debug("Assigned name %s to unnamed value", value.name) return True @@ -209,7 +213,7 @@ def _assign_node_name( "node should not have a name already if function is called correctly" ) - preferred_name = self._generate_node_name(node) + preferred_name = self._name_generator.generate_node_name(node) node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter) logger.debug("Assigned name %s to unnamed node", node.name) return True @@ -228,7 +232,7 @@ def _fix_duplicate_value_name(self, value: ir.Value, used_names: set[str]) -> bo return False # If name is already used, make it unique - base_name = self._generate_value_name(value) + base_name = self._name_generator.generate_value_name(value) value.name = _find_and_record_next_unique_name(base_name, used_names) logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name) return True @@ -245,7 +249,7 @@ def _fix_duplicate_node_name(self, node: ir.Node, used_names: set[str]) -> bool: return False # If name is already used, make it unique - base_name = self._generate_node_name(node) + base_name = self._name_generator.generate_node_name(node) node.name = _find_and_record_next_unique_name(base_name, used_names) logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name) return True From bc69b7123308764d522269cff62eb41fe07e9f41 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 8 Aug 2025 09:55:58 -0700 Subject: [PATCH 22/25] Add NameGenerator Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 233cc3c5..90ca0b79 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -7,16 +7,28 @@ __all__ = [ "NameFixPass", "NameGenerator", + "SimpleNameGenerator", ] import logging +from typing import Protocol import onnx_ir as ir logger = logging.getLogger(__name__) -class NameGenerator: +class NameGenerator(Protocol): + def generate_node_name(self, node: ir.Node) -> str: + """Generate a preferred name for a node.""" + ... + + def generate_value_name(self, value: ir.Value) -> str: + """Generate a preferred name for a value.""" + ... + + +class SimpleNameGenerator(NameGenerator): """Base class for name generation functions.""" def generate_node_name(self, node: ir.Node) -> str: @@ -53,7 +65,7 @@ def custom_value_name(value: ir.Value) -> str: name_fix_pass = NameFixPass(nameGenerator=CustomNameGenerator()) - .. versionadded:: 0.1.5 + .. versionadded:: 0.1.6 """ def __init__( @@ -69,7 +81,7 @@ def __init__( the node's or value's existing name or a generic name like "node" or "v". """ super().__init__() - self._name_generator = name_generator or NameGenerator() + self._name_generator = name_generator or SimpleNameGenerator() def call(self, model: ir.Model) -> ir.passes.PassResult: # Process the main graph From ff712bc5fabc65ae0cb767829fcf25a5719bd403 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 8 Aug 2025 10:11:10 -0700 Subject: [PATCH 23/25] docs Signed-off-by: Justin Chu --- src/onnx_ir/traversal.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/onnx_ir/traversal.py b/src/onnx_ir/traversal.py index b4b41802..15f620ca 100644 --- a/src/onnx_ir/traversal.py +++ b/src/onnx_ir/traversal.py @@ -30,6 +30,14 @@ def __init__( ): """Iterate over the nodes in the graph, recursively visiting subgraphs. + This iterator allows for traversing the nodes of a graph, including subgraphs, + in a depth-first manner. It supports optional callbacks for entering and exiting + subgraphs, as well as a callback `recursive` to determine whether to visit subgraphs + contained within nodes. + + .. versionadded:: 0.1.6 + Added the `enter_graph` and `exit_graph` callbacks. + Args: graph_like: The graph to traverse. recursive: A callback that determines whether to recursively visit the subgraphs From 425c469d74642e26eb316401c988ac3cf7a9e49c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 8 Aug 2025 10:11:51 -0700 Subject: [PATCH 24/25] subgraphs Signed-off-by: Justin Chu --- src/onnx_ir/traversal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx_ir/traversal.py b/src/onnx_ir/traversal.py index 15f620ca..efc9c39f 100644 --- a/src/onnx_ir/traversal.py +++ b/src/onnx_ir/traversal.py @@ -30,7 +30,7 @@ def __init__( ): """Iterate over the nodes in the graph, recursively visiting subgraphs. - This iterator allows for traversing the nodes of a graph, including subgraphs, + This iterator allows for traversing the nodes of a graph and its subgraphs in a depth-first manner. It supports optional callbacks for entering and exiting subgraphs, as well as a callback `recursive` to determine whether to visit subgraphs contained within nodes. From 50bce439203d5a08bd80fc90d3de68a29846dc35 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 8 Aug 2025 14:57:18 -0700 Subject: [PATCH 25/25] Use a counter for all name stems Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 37 ++++++++++++++++------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 90ca0b79..be5469e5 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -10,6 +10,7 @@ "SimpleNameGenerator", ] +import collections import logging from typing import Protocol @@ -106,8 +107,8 @@ def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool: scoped_used_node_names: list[set[str]] = [set()] # Counters for generating unique names (using list to pass by reference) - value_counter = [0] - node_counter = [0] + value_counter = collections.Counter() + node_counter = collections.Counter() def enter_graph(graph_like) -> None: """Callback for entering a subgraph.""" @@ -154,7 +155,9 @@ def exit_graph(_) -> None: if self._assign_node_name(node, scoped_used_node_names[-1], node_counter): modified = True else: - if self._fix_duplicate_node_name(node, scoped_used_node_names[-1]): + if self._fix_duplicate_node_name( + node, scoped_used_node_names[-1], node_counter + ): modified = True # Fix input value names (only if not already processed) @@ -179,7 +182,7 @@ def _process_value( value: ir.Value, used_value_names: set[str], seen_values: set[ir.Value], - value_counter: list[int], + value_counter: collections.Counter, ) -> bool: """Process a value only if it hasn't been processed before.""" if value in seen_values: @@ -191,7 +194,7 @@ def _process_value( modified = self._assign_value_name(value, used_value_names, value_counter) else: old_name = value.name - modified = self._fix_duplicate_value_name(value, used_value_names) + modified = self._fix_duplicate_value_name(value, used_value_names, value_counter) if modified: assert value.graph is not None if value.is_initializer(): @@ -205,7 +208,7 @@ def _process_value( return modified def _assign_value_name( - self, value: ir.Value, used_names: set[str], counter: list[int] + self, value: ir.Value, used_names: set[str], counter: collections.Counter ) -> bool: """Assign a name to an unnamed value. Returns True if modified.""" assert not value.name, ( @@ -218,7 +221,7 @@ def _assign_value_name( return True def _assign_node_name( - self, node: ir.Node, used_names: set[str], counter: list[int] + self, node: ir.Node, used_names: set[str], counter: collections.Counter ) -> bool: """Assign a name to an unnamed node. Returns True if modified.""" assert not node.name, ( @@ -230,7 +233,9 @@ def _assign_node_name( logger.debug("Assigned name %s to unnamed node", node.name) return True - def _fix_duplicate_value_name(self, value: ir.Value, used_names: set[str]) -> bool: + def _fix_duplicate_value_name( + self, value: ir.Value, used_names: set[str], counter: collections.Counter + ) -> bool: """Fix a value's name if it conflicts with existing names. Returns True if modified.""" original_name = value.name @@ -245,11 +250,13 @@ def _fix_duplicate_value_name(self, value: ir.Value, used_names: set[str]) -> bo # If name is already used, make it unique base_name = self._name_generator.generate_value_name(value) - value.name = _find_and_record_next_unique_name(base_name, used_names) + value.name = _find_and_record_next_unique_name(base_name, used_names, counter) logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name) return True - def _fix_duplicate_node_name(self, node: ir.Node, used_names: set[str]) -> bool: + def _fix_duplicate_node_name( + self, node: ir.Node, used_names: set[str], counter: collections.Counter + ) -> bool: """Fix a node's name if it conflicts with existing names. Returns True if modified.""" original_name = node.name @@ -262,20 +269,18 @@ def _fix_duplicate_node_name(self, node: ir.Node, used_names: set[str]) -> bool: # If name is already used, make it unique base_name = self._name_generator.generate_node_name(node) - node.name = _find_and_record_next_unique_name(base_name, used_names) + node.name = _find_and_record_next_unique_name(base_name, used_names, counter) logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name) return True def _find_and_record_next_unique_name( - preferred_name: str, used_names: set[str], counter: list[int] | None = None + preferred_name: str, used_names: set[str], counter: collections.Counter ) -> str: """Generate a unique name based on the preferred name and current counter.""" new_name = preferred_name - if counter is None: - counter = [0] while new_name in used_names: - counter[0] += 1 - new_name = f"{preferred_name}_{counter[0]}" + counter[preferred_name] += 1 + new_name = f"{preferred_name}_{counter[preferred_name]}" used_names.add(new_name) return new_name