Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7aa2bb6
Initial plan
Copilot Jul 11, 2025
1f75858
Implement NameFixPass for ensuring unique names
Copilot Jul 11, 2025
b979a3d
Replace processed values set with value-to-name mapping for clearer t…
Copilot Jul 11, 2025
3cbe736
wip
justinchuby Jul 14, 2025
b9cf0b6
Merge branch 'main' into copilot/fix-123
justinchuby Jul 28, 2025
42efd78
Simplify implementation
justinchuby Jul 28, 2025
6977a1f
format test
justinchuby Jul 28, 2025
212c1a8
Handle initializers
justinchuby Jul 28, 2025
463d500
fix
justinchuby Jul 28, 2025
b6be30a
Create callback
justinchuby Jul 30, 2025
f01fea9
update
justinchuby Jul 30, 2025
3b7fa60
scope
justinchuby Jul 30, 2025
7b49ab8
Support _generate_node_name
justinchuby Jul 30, 2025
291b21c
lint
justinchuby Jul 30, 2025
4c35f2b
versionadded and docs
justinchuby Jul 30, 2025
acd6d8e
Merge branch 'main' into copilot/fix-123
justinchuby Jul 30, 2025
da3ea58
_generate_value_name
justinchuby Jul 30, 2025
5b6c7c9
refactor unique name finding
justinchuby Jul 30, 2025
2efce88
Use a set
justinchuby Jul 30, 2025
797a368
seen -> used
justinchuby Jul 30, 2025
4e07db2
remove logging
justinchuby Jul 30, 2025
784552c
docs
justinchuby Jul 30, 2025
b1e0956
Create NameGenerator
justinchuby Jul 31, 2025
bc69b71
Add NameGenerator
justinchuby Aug 8, 2025
5a498a8
Merge branch 'main' into copilot/fix-123
justinchuby Aug 8, 2025
ff712bc
docs
justinchuby Aug 8, 2025
425c469
subgraphs
justinchuby Aug 8, 2025
50bce43
Use a counter for all name stems
justinchuby Aug 8, 2025
8f213bf
Merge branch 'main' into copilot/fix-123
justinchuby Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/onnx_ir/passes/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"InlinePass",
"LiftConstantsToInitializersPass",
"LiftSubgraphInitializersToMainGraphPass",
"NameFixPass",
"RemoveInitializersFromInputsPass",
"RemoveUnusedFunctionsPass",
"RemoveUnusedNodesPass",
Expand Down Expand Up @@ -38,6 +39,7 @@
DeduplicateInitializersPass,
)
from onnx_ir.passes.common.inliner import InlinePass
from onnx_ir.passes.common.naming import NameFixPass
from onnx_ir.passes.common.onnx_checker import CheckerPass
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
from onnx_ir.passes.common.topological_sort import TopologicalSortPass
Expand Down
286 changes: 286 additions & 0 deletions src/onnx_ir/passes/common/naming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Name fix pass for ensuring unique names for all values and nodes."""

from __future__ import annotations

__all__ = [
"NameFixPass",
"NameGenerator",
"SimpleNameGenerator",
]

import collections
import logging
from typing import Protocol

import onnx_ir as ir

logger = logging.getLogger(__name__)


class NameGenerator(Protocol):
def generate_node_name(self, node: ir.Node) -> str:
"""Generate a preferred name for a node."""
...

def generate_value_name(self, value: ir.Value) -> str:
"""Generate a preferred name for a value."""
...


class SimpleNameGenerator(NameGenerator):
"""Base class for name generation functions."""

def generate_node_name(self, node: ir.Node) -> str:
"""Generate a preferred name for a node."""
return node.name or "node"

def generate_value_name(self, value: ir.Value) -> str:
"""Generate a preferred name for a value."""
return value.name or "v"


class NameFixPass(ir.passes.InPlacePass):
"""Pass for fixing names to ensure all values and nodes have unique names.

This pass ensures that:
1. Graph inputs and outputs have unique names (take precedence)
2. All intermediate values have unique names (assign names to unnamed values)
3. All values in subgraphs have unique names within their graph and parent graphs
4. All nodes have unique names within their graph

The pass maintains global uniqueness across the entire model.

You can customize the name generation functions for nodes and values by passing
a subclass of :class:`NameGenerator`.

For example, you can use a custom naming scheme like this::

class CustomNameGenerator:
def custom_node_name(node: ir.Node) -> str:
return f"custom_node_{node.op_type}"

def custom_value_name(value: ir.Value) -> str:
return f"custom_value_{value.type}"

name_fix_pass = NameFixPass(nameGenerator=CustomNameGenerator())

.. versionadded:: 0.1.6
"""

def __init__(
self,
name_generator: NameGenerator | None = None,
) -> None:
"""Initialize the NameFixPass with custom name generation functions.

Args:
name_generator (NameGenerator, optional): An instance of a subclass of
:class:`NameGenerator` to customize name generation for nodes and values.
If not provided, defaults to a basic implementation that uses
the node's or value's existing name or a generic name like "node" or "v".
"""
super().__init__()
self._name_generator = name_generator or SimpleNameGenerator()

def call(self, model: ir.Model) -> ir.passes.PassResult:
# Process the main graph
modified = self._fix_graph_names(model.graph)

# Process functions
for function in model.functions.values():
modified = self._fix_graph_names(function) or modified

return ir.passes.PassResult(model, modified=modified)

def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool:
"""Fix names in a graph and return whether modifications were made."""
modified = False

# Set to track which values have been assigned names
seen_values: set[ir.Value] = set()

# The first set is a dummy placeholder so that there is always a [-1] scope for access
# (even though we don't write to it)
scoped_used_value_names: list[set[str]] = [set()]
scoped_used_node_names: list[set[str]] = [set()]

# Counters for generating unique names (using list to pass by reference)
value_counter = collections.Counter()
node_counter = collections.Counter()

def enter_graph(graph_like) -> None:
"""Callback for entering a subgraph."""
# Initialize new scopes with all names from the parent scope
scoped_used_value_names.append(set(scoped_used_value_names[-1]))
scoped_used_node_names.append(set())

nonlocal modified

# Step 1: Fix graph input names first (they have precedence)
for input_value in graph_like.inputs:
if self._process_value(
input_value, scoped_used_value_names[-1], seen_values, value_counter
):
modified = True

# Step 2: Fix graph output names (they have precedence)
for output_value in graph_like.outputs:
if self._process_value(
output_value, scoped_used_value_names[-1], seen_values, value_counter
):
modified = True

if isinstance(graph_like, ir.Graph):
# For graphs, also fix initializers
for initializer in graph_like.initializers.values():
if self._process_value(
initializer, scoped_used_value_names[-1], seen_values, value_counter
):
modified = True

def exit_graph(_) -> None:
"""Callback for exiting a subgraph."""
# Pop the current scope
scoped_used_value_names.pop()
scoped_used_node_names.pop()

# Step 3: Process all nodes and their values
for node in ir.traversal.RecursiveGraphIterator(
graph_like, enter_graph=enter_graph, exit_graph=exit_graph
):
# Fix node name
if not node.name:
if self._assign_node_name(node, scoped_used_node_names[-1], node_counter):
modified = True
else:
if self._fix_duplicate_node_name(
node, scoped_used_node_names[-1], node_counter
):
modified = True

# Fix input value names (only if not already processed)
for input_value in node.inputs:
if input_value is not None:
if self._process_value(
input_value, scoped_used_value_names[-1], seen_values, value_counter
):
modified = True

# Fix output value names (only if not already processed)
for output_value in node.outputs:
if self._process_value(
output_value, scoped_used_value_names[-1], seen_values, value_counter
):
modified = True

return modified

def _process_value(
self,
value: ir.Value,
used_value_names: set[str],
seen_values: set[ir.Value],
value_counter: collections.Counter,
) -> bool:
"""Process a value only if it hasn't been processed before."""
if value in seen_values:
return False

modified = False

if not value.name:
modified = self._assign_value_name(value, used_value_names, value_counter)
else:
old_name = value.name
modified = self._fix_duplicate_value_name(value, used_value_names, value_counter)
if modified:
assert value.graph is not None
if value.is_initializer():
value.graph.initializers.pop(old_name)
# Add the initializer back with the new name
value.graph.initializers.add(value)

# Record the final name for this value
assert value.name is not None
seen_values.add(value)
return modified

def _assign_value_name(
self, value: ir.Value, used_names: set[str], counter: collections.Counter
) -> bool:
"""Assign a name to an unnamed value. Returns True if modified."""
assert not value.name, (
"value should not have a name already if function is called correctly"
)

preferred_name = self._name_generator.generate_value_name(value)
value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
logger.debug("Assigned name %s to unnamed value", value.name)
return True

def _assign_node_name(
self, node: ir.Node, used_names: set[str], counter: collections.Counter
) -> bool:
"""Assign a name to an unnamed node. Returns True if modified."""
assert not node.name, (
"node should not have a name already if function is called correctly"
)

preferred_name = self._name_generator.generate_node_name(node)
node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
logger.debug("Assigned name %s to unnamed node", node.name)
return True

def _fix_duplicate_value_name(
self, value: ir.Value, used_names: set[str], counter: collections.Counter
) -> bool:
"""Fix a value's name if it conflicts with existing names. Returns True if modified."""
original_name = value.name

assert original_name, (
"value should have a name already if function is called correctly"
)

if original_name not in used_names:
# Name is unique, just record it
used_names.add(original_name)
return False

# If name is already used, make it unique
base_name = self._name_generator.generate_value_name(value)
value.name = _find_and_record_next_unique_name(base_name, used_names, counter)
logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name)
return True

def _fix_duplicate_node_name(
self, node: ir.Node, used_names: set[str], counter: collections.Counter
) -> bool:
"""Fix a node's name if it conflicts with existing names. Returns True if modified."""
original_name = node.name

assert original_name, "node should have a name already if function is called correctly"

if original_name not in used_names:
# Name is unique, just record it
used_names.add(original_name)
return False

# If name is already used, make it unique
base_name = self._name_generator.generate_node_name(node)
node.name = _find_and_record_next_unique_name(base_name, used_names, counter)
logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name)
return True


def _find_and_record_next_unique_name(
preferred_name: str, used_names: set[str], counter: collections.Counter
) -> str:
"""Generate a unique name based on the preferred name and current counter."""
new_name = preferred_name
while new_name in used_names:
counter[preferred_name] += 1
new_name = f"{preferred_name}_{counter[preferred_name]}"
used_names.add(new_name)
return new_name
Loading
Loading