diff --git a/src/onnx_ir/passes/common/__init__.py b/src/onnx_ir/passes/common/__init__.py index 2aee4df7..a8d73175 100644 --- a/src/onnx_ir/passes/common/__init__.py +++ b/src/onnx_ir/passes/common/__init__.py @@ -11,6 +11,7 @@ "InlinePass", "LiftConstantsToInitializersPass", "LiftSubgraphInitializersToMainGraphPass", + "NameFixPass", "RemoveInitializersFromInputsPass", "RemoveUnusedFunctionsPass", "RemoveUnusedNodesPass", @@ -38,6 +39,7 @@ DeduplicateInitializersPass, ) from onnx_ir.passes.common.inliner import InlinePass +from onnx_ir.passes.common.naming import NameFixPass from onnx_ir.passes.common.onnx_checker import CheckerPass from onnx_ir.passes.common.shape_inference import ShapeInferencePass from onnx_ir.passes.common.topological_sort import TopologicalSortPass diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py new file mode 100644 index 00000000..be5469e5 --- /dev/null +++ b/src/onnx_ir/passes/common/naming.py @@ -0,0 +1,286 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Name fix pass for ensuring unique names for all values and nodes.""" + +from __future__ import annotations + +__all__ = [ + "NameFixPass", + "NameGenerator", + "SimpleNameGenerator", +] + +import collections +import logging +from typing import Protocol + +import onnx_ir as ir + +logger = logging.getLogger(__name__) + + +class NameGenerator(Protocol): + def generate_node_name(self, node: ir.Node) -> str: + """Generate a preferred name for a node.""" + ... + + def generate_value_name(self, value: ir.Value) -> str: + """Generate a preferred name for a value.""" + ... + + +class SimpleNameGenerator(NameGenerator): + """Base class for name generation functions.""" + + def generate_node_name(self, node: ir.Node) -> str: + """Generate a preferred name for a node.""" + return node.name or "node" + + def generate_value_name(self, value: ir.Value) -> str: + """Generate a preferred name for a value.""" + return value.name or "v" + + +class NameFixPass(ir.passes.InPlacePass): + """Pass for fixing names to ensure all values and nodes have unique names. + + This pass ensures that: + 1. Graph inputs and outputs have unique names (take precedence) + 2. All intermediate values have unique names (assign names to unnamed values) + 3. All values in subgraphs have unique names within their graph and parent graphs + 4. All nodes have unique names within their graph + + The pass maintains global uniqueness across the entire model. + + You can customize the name generation functions for nodes and values by passing + a subclass of :class:`NameGenerator`. + + For example, you can use a custom naming scheme like this:: + + class CustomNameGenerator: + def custom_node_name(node: ir.Node) -> str: + return f"custom_node_{node.op_type}" + + def custom_value_name(value: ir.Value) -> str: + return f"custom_value_{value.type}" + + name_fix_pass = NameFixPass(nameGenerator=CustomNameGenerator()) + + .. versionadded:: 0.1.6 + """ + + def __init__( + self, + name_generator: NameGenerator | None = None, + ) -> None: + """Initialize the NameFixPass with custom name generation functions. + + Args: + name_generator (NameGenerator, optional): An instance of a subclass of + :class:`NameGenerator` to customize name generation for nodes and values. + If not provided, defaults to a basic implementation that uses + the node's or value's existing name or a generic name like "node" or "v". + """ + super().__init__() + self._name_generator = name_generator or SimpleNameGenerator() + + def call(self, model: ir.Model) -> ir.passes.PassResult: + # Process the main graph + modified = self._fix_graph_names(model.graph) + + # Process functions + for function in model.functions.values(): + modified = self._fix_graph_names(function) or modified + + return ir.passes.PassResult(model, modified=modified) + + def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool: + """Fix names in a graph and return whether modifications were made.""" + modified = False + + # Set to track which values have been assigned names + seen_values: set[ir.Value] = set() + + # The first set is a dummy placeholder so that there is always a [-1] scope for access + # (even though we don't write to it) + scoped_used_value_names: list[set[str]] = [set()] + scoped_used_node_names: list[set[str]] = [set()] + + # Counters for generating unique names (using list to pass by reference) + value_counter = collections.Counter() + node_counter = collections.Counter() + + def enter_graph(graph_like) -> None: + """Callback for entering a subgraph.""" + # Initialize new scopes with all names from the parent scope + scoped_used_value_names.append(set(scoped_used_value_names[-1])) + scoped_used_node_names.append(set()) + + nonlocal modified + + # Step 1: Fix graph input names first (they have precedence) + for input_value in graph_like.inputs: + if self._process_value( + input_value, scoped_used_value_names[-1], seen_values, value_counter + ): + modified = True + + # Step 2: Fix graph output names (they have precedence) + for output_value in graph_like.outputs: + if self._process_value( + output_value, scoped_used_value_names[-1], seen_values, value_counter + ): + modified = True + + if isinstance(graph_like, ir.Graph): + # For graphs, also fix initializers + for initializer in graph_like.initializers.values(): + if self._process_value( + initializer, scoped_used_value_names[-1], seen_values, value_counter + ): + modified = True + + def exit_graph(_) -> None: + """Callback for exiting a subgraph.""" + # Pop the current scope + scoped_used_value_names.pop() + scoped_used_node_names.pop() + + # Step 3: Process all nodes and their values + for node in ir.traversal.RecursiveGraphIterator( + graph_like, enter_graph=enter_graph, exit_graph=exit_graph + ): + # Fix node name + if not node.name: + if self._assign_node_name(node, scoped_used_node_names[-1], node_counter): + modified = True + else: + if self._fix_duplicate_node_name( + node, scoped_used_node_names[-1], node_counter + ): + modified = True + + # Fix input value names (only if not already processed) + for input_value in node.inputs: + if input_value is not None: + if self._process_value( + input_value, scoped_used_value_names[-1], seen_values, value_counter + ): + modified = True + + # Fix output value names (only if not already processed) + for output_value in node.outputs: + if self._process_value( + output_value, scoped_used_value_names[-1], seen_values, value_counter + ): + modified = True + + return modified + + def _process_value( + self, + value: ir.Value, + used_value_names: set[str], + seen_values: set[ir.Value], + value_counter: collections.Counter, + ) -> bool: + """Process a value only if it hasn't been processed before.""" + if value in seen_values: + return False + + modified = False + + if not value.name: + modified = self._assign_value_name(value, used_value_names, value_counter) + else: + old_name = value.name + modified = self._fix_duplicate_value_name(value, used_value_names, value_counter) + if modified: + assert value.graph is not None + if value.is_initializer(): + value.graph.initializers.pop(old_name) + # Add the initializer back with the new name + value.graph.initializers.add(value) + + # Record the final name for this value + assert value.name is not None + seen_values.add(value) + return modified + + def _assign_value_name( + self, value: ir.Value, used_names: set[str], counter: collections.Counter + ) -> bool: + """Assign a name to an unnamed value. Returns True if modified.""" + assert not value.name, ( + "value should not have a name already if function is called correctly" + ) + + preferred_name = self._name_generator.generate_value_name(value) + value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter) + logger.debug("Assigned name %s to unnamed value", value.name) + return True + + def _assign_node_name( + self, node: ir.Node, used_names: set[str], counter: collections.Counter + ) -> bool: + """Assign a name to an unnamed node. Returns True if modified.""" + assert not node.name, ( + "node should not have a name already if function is called correctly" + ) + + preferred_name = self._name_generator.generate_node_name(node) + node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter) + logger.debug("Assigned name %s to unnamed node", node.name) + return True + + def _fix_duplicate_value_name( + self, value: ir.Value, used_names: set[str], counter: collections.Counter + ) -> bool: + """Fix a value's name if it conflicts with existing names. Returns True if modified.""" + original_name = value.name + + assert original_name, ( + "value should have a name already if function is called correctly" + ) + + if original_name not in used_names: + # Name is unique, just record it + used_names.add(original_name) + return False + + # If name is already used, make it unique + base_name = self._name_generator.generate_value_name(value) + value.name = _find_and_record_next_unique_name(base_name, used_names, counter) + logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name) + return True + + def _fix_duplicate_node_name( + self, node: ir.Node, used_names: set[str], counter: collections.Counter + ) -> bool: + """Fix a node's name if it conflicts with existing names. Returns True if modified.""" + original_name = node.name + + assert original_name, "node should have a name already if function is called correctly" + + if original_name not in used_names: + # Name is unique, just record it + used_names.add(original_name) + return False + + # If name is already used, make it unique + base_name = self._name_generator.generate_node_name(node) + node.name = _find_and_record_next_unique_name(base_name, used_names, counter) + logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name) + return True + + +def _find_and_record_next_unique_name( + preferred_name: str, used_names: set[str], counter: collections.Counter +) -> str: + """Generate a unique name based on the preferred name and current counter.""" + new_name = preferred_name + while new_name in used_names: + counter[preferred_name] += 1 + new_name = f"{preferred_name}_{counter[preferred_name]}" + used_names.add(new_name) + return new_name diff --git a/src/onnx_ir/passes/common/naming_test.py b/src/onnx_ir/passes/common/naming_test.py new file mode 100644 index 00000000..2c44c810 --- /dev/null +++ b/src/onnx_ir/passes/common/naming_test.py @@ -0,0 +1,405 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the name fix pass.""" + +from __future__ import annotations + +import unittest + +import onnx_ir as ir +from onnx_ir.passes.common import naming + + +class TestNameFixPass(unittest.TestCase): + """Test cases for NameFixPass.""" + + def test_assign_names_to_unnamed_values(self): + """Test ensuring all values have names even if IR auto-assigned them.""" + # Create a simple model with auto-assigned names + input_value = ir.Input( + None, shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) # Will get auto-assigned name when added to graph + + # Create Add node + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify IR has auto-assigned names + self.assertIsNotNone(input_value.name) + self.assertIsNotNone(add_node.outputs[0].name) + + # Store original names + original_input_name = input_value.name + original_output_name = add_node.outputs[0].name + + # Run the pass + result = naming.NameFixPass()(model) + + # Verify the pass didn't modify anything (names were already assigned and unique) + self.assertFalse(result.modified) + + # Verify names remain the same + self.assertEqual(input_value.name, original_input_name) + self.assertEqual(add_node.outputs[0].name, original_output_name) + + def test_assign_names_to_unnamed_nodes(self): + """Test ensuring all nodes have names even if IR auto-assigned them.""" + # Create a simple model + input_value = ir.Input( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create Add node - IR will auto-assign name when added to graph + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "output" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify IR has auto-assigned node name + self.assertIsNotNone(add_node.name) + original_node_name = add_node.name + + # Run the pass + result = naming.NameFixPass()(model) + + # Verify the pass didn't modify anything (node already had unique name) + self.assertFalse(result.modified) + + # Verify node name remains the same + self.assertEqual(add_node.name, original_node_name) + + def test_assigns_names_when_truly_unnamed(self): + """Test that the pass assigns names when values/nodes are created without names and manually cleared.""" + # Create a model and manually clear names to test assignment + input_value = ir.Input( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "output" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Manually clear some names to test assignment + add_node.name = None + add_node.outputs[0].name = "" + + # Run the pass + result = naming.NameFixPass()(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify names were assigned + self.assertIsNotNone(add_node.name) + self.assertIsNotNone(add_node.outputs[0].name) + self.assertNotEqual(add_node.outputs[0].name, "") + + def test_handles_global_uniqueness_across_subgraphs(self): + """Test that names are unique globally, including across subgraphs.""" + # Create main graph input + main_input = ir.Input( + "main_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a simple subgraph for an If node + # Subgraph input and output (with potential name conflicts) + sub_input = ir.Input( + "main_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) # Same name as main input - should cause conflict + + sub_add_node = ir.Node("", "Add", inputs=[sub_input, sub_input]) + sub_add_node.outputs[0].name = "main_input" # Another conflict + sub_add_node.outputs[0].shape = sub_input.shape + sub_add_node.outputs[0].type = sub_input.type + + subgraph = ir.Graph( + inputs=[sub_input], + outputs=[sub_add_node.outputs[0]], + nodes=[sub_add_node], + name="subgraph", + ) + + # Create condition input for If node + condition_input = ir.Input( + "condition", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.BOOL) + ) + + # Create If node with subgraph + if_node = ir.Node( + "", + "If", + inputs=[condition_input], + attributes={ + "then_branch": ir.Attr("then_branch", ir.AttributeType.GRAPH, subgraph) + }, + ) + if_node.outputs[0].name = "if_output" + if_node.outputs[0].shape = main_input.shape + if_node.outputs[0].type = main_input.type + + # Create main graph + main_graph = ir.Graph( + inputs=[main_input, condition_input], + outputs=[if_node.outputs[0]], + nodes=[if_node], + name="main_graph", + ) + + model = ir.Model(main_graph, ir_version=10) + + # Run the pass + result = naming.NameFixPass()(model) + + # Verify the pass was applied (should fix duplicates) + self.assertTrue(result.modified) + + # Collect all value names to verify uniqueness + all_value_names = set() + + # Main graph values + for input_val in main_graph.inputs: + self.assertIsNotNone(input_val.name) + self.assertNotIn( + input_val.name, all_value_names, f"Duplicate value name: {input_val.name}" + ) + all_value_names.add(input_val.name) + + for output_val in main_graph.outputs: + self.assertIsNotNone(output_val.name) + if output_val.name not in all_value_names: # Could be same as input + all_value_names.add(output_val.name) + + # Node values in main graph + for node in main_graph: + for input_val in node.inputs: + if input_val is not None: + if input_val.name not in all_value_names: + all_value_names.add(input_val.name) + for output_val in node.outputs: + if output_val.name not in all_value_names: + all_value_names.add(output_val.name) + + # Subgraph values + for input_val in subgraph.inputs: + self.assertIsNotNone(input_val.name) + self.assertNotIn( + input_val.name, + all_value_names, + f"Duplicate value name in subgraph: {input_val.name}", + ) + all_value_names.add(input_val.name) + + for output_val in subgraph.outputs: + if output_val.name not in all_value_names: # Could be same as input + all_value_names.add(output_val.name) + + # Node values in subgraph + for node in subgraph: + for input_val in node.inputs: + if input_val is not None: + if input_val.name not in all_value_names: + all_value_names.add(input_val.name) + for output_val in node.outputs: + if output_val.name not in all_value_names: + all_value_names.add(output_val.name) + + # Verify main_input keeps its name (has precedence as graph input) + self.assertEqual(main_input.name, "main_input") + + def test_handle_duplicate_value_names(self): + """Test handling duplicate value names by making them unique.""" + # Create values with duplicate names + input1 = ir.Input( + "duplicate_name", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + input2 = ir.Input( + "duplicate_name", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + add_node = ir.Node("", "Add", inputs=[input1, input2]) + add_node.outputs[0].name = "output" + add_node.outputs[0].shape = input1.shape + add_node.outputs[0].type = input1.type + + graph = ir.Graph( + inputs=[input1, input2], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify both inputs have the same name initially + self.assertEqual(input1.name, "duplicate_name") + self.assertEqual(input2.name, "duplicate_name") + + # Run the pass + result = naming.NameFixPass()(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify names are now unique + self.assertNotEqual(input1.name, input2.name) + # One should keep the original name, the other should have a suffix + names = {input1.name, input2.name} + self.assertIn("duplicate_name", names) + self.assertTrue("duplicate_name_1" in names, f"Expected 'duplicate_name_1' in {names}") + + def test_handle_duplicate_node_names(self): + """Test handling duplicate node names by making them unique.""" + input_value = ir.Input( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create nodes with duplicate names + add_node1 = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node1.name = "duplicate_node" + add_node1.outputs[0].name = "output1" + add_node1.outputs[0].shape = input_value.shape + add_node1.outputs[0].type = input_value.type + + add_node2 = ir.Node("", "Add", inputs=[input_value, add_node1.outputs[0]]) + add_node2.name = "duplicate_node" # Same name as first node + add_node2.outputs[0].name = "output2" + add_node2.outputs[0].shape = input_value.shape + add_node2.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node2.outputs[0]], + nodes=[add_node1, add_node2], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify both nodes have the same name initially + self.assertEqual(add_node1.name, "duplicate_node") + self.assertEqual(add_node2.name, "duplicate_node") + + # Run the pass + result = naming.NameFixPass()(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify names are now unique + self.assertNotEqual(add_node1.name, add_node2.name) + # One should keep the original name, the other should have a suffix + names = {add_node1.name, add_node2.name} + self.assertIn("duplicate_node", names) + self.assertTrue("duplicate_node_1" in names, f"Expected 'duplicate_node_1' in {names}") + + def test_no_modification_when_all_names_unique(self): + """Test that the pass doesn't modify anything when all names are already unique.""" + input_value = ir.Input( + "unique_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.name = "unique_node" + add_node.outputs[0].name = "unique_output" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Store original names + original_input_name = input_value.name + original_node_name = add_node.name + original_output_name = add_node.outputs[0].name + + # Run the pass + result = naming.NameFixPass()(model) + + # Verify the pass didn't modify anything + self.assertFalse(result.modified) + + # Verify names remain unchanged + self.assertEqual(input_value.name, original_input_name) + self.assertEqual(add_node.name, original_node_name) + self.assertEqual(add_node.outputs[0].name, original_output_name) + + def test_graph_inputs_outputs_have_precedence(self): + """Test that graph inputs and outputs keep their names when there are conflicts.""" + # Create an input with a specific name + input_value = ir.Input( + "important_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a node that produces an intermediate value with the same name + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "important_input" # Conflicts with input name + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + # Create another node that uses the intermediate value and produces the final output + mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], input_value]) + mul_node.outputs[0].name = "important_output" + mul_node.outputs[0].shape = input_value.shape + mul_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[mul_node.outputs[0]], + nodes=[add_node, mul_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + result = naming.NameFixPass()(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify input keeps its original name (has precedence) + self.assertEqual(input_value.name, "important_input") + + # Verify output keeps its original name (has precedence) + self.assertEqual(mul_node.outputs[0].name, "important_output") + + # Verify intermediate value got renamed to avoid conflict + self.assertNotEqual(add_node.outputs[0].name, "important_input") + self.assertTrue(add_node.outputs[0].name.startswith("important_input_")) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/onnx_ir/traversal.py b/src/onnx_ir/traversal.py index 26c4e008..efc9c39f 100644 --- a/src/onnx_ir/traversal.py +++ b/src/onnx_ir/traversal.py @@ -25,19 +25,33 @@ def __init__( *, recursive: Callable[[_core.Node], bool] | None = None, reverse: bool = False, + enter_graph: Callable[[GraphLike], None] | None = None, + exit_graph: Callable[[GraphLike], None] | None = None, ): """Iterate over the nodes in the graph, recursively visiting subgraphs. + This iterator allows for traversing the nodes of a graph and its subgraphs + in a depth-first manner. It supports optional callbacks for entering and exiting + subgraphs, as well as a callback `recursive` to determine whether to visit subgraphs + contained within nodes. + + .. versionadded:: 0.1.6 + Added the `enter_graph` and `exit_graph` callbacks. + Args: graph_like: The graph to traverse. recursive: A callback that determines whether to recursively visit the subgraphs contained in a node. If not provided, all nodes in subgraphs are visited. reverse: Whether to iterate in reverse order. + enter_graph: An optional callback that is called when entering a subgraph. + exit_graph: An optional callback that is called when exiting a subgraph. """ self._graph = graph_like self._recursive = recursive self._reverse = reverse self._iterator = self._recursive_node_iter(graph_like) + self._enter_graph = enter_graph + self._exit_graph = exit_graph def __iter__(self) -> Self: self._iterator = self._recursive_node_iter(self._graph) @@ -50,34 +64,55 @@ def _recursive_node_iter( self, graph: _core.Graph | _core.Function | _core.GraphView ) -> Iterator[_core.Node]: iterable = reversed(graph) if self._reverse else graph + + if self._enter_graph is not None: + self._enter_graph(graph) + for node in iterable: # type: ignore[union-attr] yield node if self._recursive is not None and not self._recursive(node): continue yield from self._iterate_subgraphs(node) + if self._exit_graph is not None: + self._exit_graph(graph) + def _iterate_subgraphs(self, node: _core.Node): for attr in node.attributes.values(): if not isinstance(attr, _core.Attr): continue if attr.type == _enums.AttributeType.GRAPH: + if self._enter_graph is not None: + self._enter_graph(attr.value) yield from RecursiveGraphIterator( attr.value, recursive=self._recursive, reverse=self._reverse, + enter_graph=self._enter_graph, + exit_graph=self._exit_graph, ) + if self._exit_graph is not None: + self._exit_graph(attr.value) elif attr.type == _enums.AttributeType.GRAPHS: graphs = reversed(attr.value) if self._reverse else attr.value for graph in graphs: + if self._enter_graph is not None: + self._enter_graph(graph) yield from RecursiveGraphIterator( graph, recursive=self._recursive, reverse=self._reverse, + enter_graph=self._enter_graph, + exit_graph=self._exit_graph, ) + if self._exit_graph is not None: + self._exit_graph(graph) def __reversed__(self) -> Iterator[_core.Node]: return RecursiveGraphIterator( self._graph, recursive=self._recursive, reverse=not self._reverse, + enter_graph=self._enter_graph, + exit_graph=self._exit_graph, )