-
Notifications
You must be signed in to change notification settings - Fork 13
Create a name fix pass to ensure unique names for all values and nodes #124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
7aa2bb6
Initial plan
Copilot 1f75858
Implement NameFixPass for ensuring unique names
Copilot b979a3d
Replace processed values set with value-to-name mapping for clearer t…
Copilot 3cbe736
wip
justinchuby b9cf0b6
Merge branch 'main' into copilot/fix-123
justinchuby 42efd78
Simplify implementation
justinchuby 6977a1f
format test
justinchuby 212c1a8
Handle initializers
justinchuby 463d500
fix
justinchuby b6be30a
Create callback
justinchuby f01fea9
update
justinchuby 3b7fa60
scope
justinchuby 7b49ab8
Support _generate_node_name
justinchuby 291b21c
lint
justinchuby 4c35f2b
versionadded and docs
justinchuby acd6d8e
Merge branch 'main' into copilot/fix-123
justinchuby da3ea58
_generate_value_name
justinchuby 5b6c7c9
refactor unique name finding
justinchuby 2efce88
Use a set
justinchuby 797a368
seen -> used
justinchuby 4e07db2
remove logging
justinchuby 784552c
docs
justinchuby b1e0956
Create NameGenerator
justinchuby bc69b71
Add NameGenerator
justinchuby 5a498a8
Merge branch 'main' into copilot/fix-123
justinchuby ff712bc
docs
justinchuby 425c469
subgraphs
justinchuby 50bce43
Use a counter for all name stems
justinchuby 8f213bf
Merge branch 'main' into copilot/fix-123
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
# Copyright (c) ONNX Project Contributors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Name fix pass for ensuring unique names for all values and nodes.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"NameFixPass", | ||
"NameGenerator", | ||
"SimpleNameGenerator", | ||
] | ||
|
||
import collections | ||
import logging | ||
from typing import Protocol | ||
|
||
import onnx_ir as ir | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class NameGenerator(Protocol): | ||
def generate_node_name(self, node: ir.Node) -> str: | ||
"""Generate a preferred name for a node.""" | ||
... | ||
|
||
def generate_value_name(self, value: ir.Value) -> str: | ||
"""Generate a preferred name for a value.""" | ||
... | ||
|
||
|
||
class SimpleNameGenerator(NameGenerator): | ||
"""Base class for name generation functions.""" | ||
|
||
def generate_node_name(self, node: ir.Node) -> str: | ||
"""Generate a preferred name for a node.""" | ||
return node.name or "node" | ||
|
||
def generate_value_name(self, value: ir.Value) -> str: | ||
"""Generate a preferred name for a value.""" | ||
return value.name or "v" | ||
|
||
|
||
class NameFixPass(ir.passes.InPlacePass): | ||
"""Pass for fixing names to ensure all values and nodes have unique names. | ||
|
||
This pass ensures that: | ||
1. Graph inputs and outputs have unique names (take precedence) | ||
2. All intermediate values have unique names (assign names to unnamed values) | ||
3. All values in subgraphs have unique names within their graph and parent graphs | ||
4. All nodes have unique names within their graph | ||
|
||
The pass maintains global uniqueness across the entire model. | ||
|
||
You can customize the name generation functions for nodes and values by passing | ||
a subclass of :class:`NameGenerator`. | ||
|
||
For example, you can use a custom naming scheme like this:: | ||
|
||
class CustomNameGenerator: | ||
def custom_node_name(node: ir.Node) -> str: | ||
return f"custom_node_{node.op_type}" | ||
|
||
def custom_value_name(value: ir.Value) -> str: | ||
return f"custom_value_{value.type}" | ||
|
||
name_fix_pass = NameFixPass(nameGenerator=CustomNameGenerator()) | ||
|
||
.. versionadded:: 0.1.6 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
name_generator: NameGenerator | None = None, | ||
) -> None: | ||
"""Initialize the NameFixPass with custom name generation functions. | ||
|
||
Args: | ||
name_generator (NameGenerator, optional): An instance of a subclass of | ||
:class:`NameGenerator` to customize name generation for nodes and values. | ||
If not provided, defaults to a basic implementation that uses | ||
the node's or value's existing name or a generic name like "node" or "v". | ||
""" | ||
super().__init__() | ||
self._name_generator = name_generator or SimpleNameGenerator() | ||
|
||
def call(self, model: ir.Model) -> ir.passes.PassResult: | ||
# Process the main graph | ||
modified = self._fix_graph_names(model.graph) | ||
|
||
# Process functions | ||
for function in model.functions.values(): | ||
modified = self._fix_graph_names(function) or modified | ||
|
||
return ir.passes.PassResult(model, modified=modified) | ||
|
||
def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool: | ||
"""Fix names in a graph and return whether modifications were made.""" | ||
modified = False | ||
|
||
# Set to track which values have been assigned names | ||
seen_values: set[ir.Value] = set() | ||
|
||
# The first set is a dummy placeholder so that there is always a [-1] scope for access | ||
# (even though we don't write to it) | ||
scoped_used_value_names: list[set[str]] = [set()] | ||
scoped_used_node_names: list[set[str]] = [set()] | ||
|
||
# Counters for generating unique names (using list to pass by reference) | ||
value_counter = collections.Counter() | ||
node_counter = collections.Counter() | ||
|
||
def enter_graph(graph_like) -> None: | ||
"""Callback for entering a subgraph.""" | ||
# Initialize new scopes with all names from the parent scope | ||
scoped_used_value_names.append(set(scoped_used_value_names[-1])) | ||
scoped_used_node_names.append(set()) | ||
|
||
nonlocal modified | ||
|
||
# Step 1: Fix graph input names first (they have precedence) | ||
for input_value in graph_like.inputs: | ||
if self._process_value( | ||
input_value, scoped_used_value_names[-1], seen_values, value_counter | ||
): | ||
modified = True | ||
|
||
# Step 2: Fix graph output names (they have precedence) | ||
for output_value in graph_like.outputs: | ||
if self._process_value( | ||
output_value, scoped_used_value_names[-1], seen_values, value_counter | ||
): | ||
modified = True | ||
|
||
if isinstance(graph_like, ir.Graph): | ||
# For graphs, also fix initializers | ||
for initializer in graph_like.initializers.values(): | ||
if self._process_value( | ||
initializer, scoped_used_value_names[-1], seen_values, value_counter | ||
): | ||
modified = True | ||
|
||
def exit_graph(_) -> None: | ||
"""Callback for exiting a subgraph.""" | ||
# Pop the current scope | ||
scoped_used_value_names.pop() | ||
scoped_used_node_names.pop() | ||
|
||
# Step 3: Process all nodes and their values | ||
for node in ir.traversal.RecursiveGraphIterator( | ||
graph_like, enter_graph=enter_graph, exit_graph=exit_graph | ||
): | ||
# Fix node name | ||
if not node.name: | ||
if self._assign_node_name(node, scoped_used_node_names[-1], node_counter): | ||
modified = True | ||
else: | ||
if self._fix_duplicate_node_name( | ||
node, scoped_used_node_names[-1], node_counter | ||
): | ||
modified = True | ||
|
||
# Fix input value names (only if not already processed) | ||
for input_value in node.inputs: | ||
if input_value is not None: | ||
if self._process_value( | ||
input_value, scoped_used_value_names[-1], seen_values, value_counter | ||
): | ||
modified = True | ||
|
||
# Fix output value names (only if not already processed) | ||
for output_value in node.outputs: | ||
if self._process_value( | ||
output_value, scoped_used_value_names[-1], seen_values, value_counter | ||
): | ||
modified = True | ||
|
||
return modified | ||
|
||
def _process_value( | ||
self, | ||
value: ir.Value, | ||
used_value_names: set[str], | ||
seen_values: set[ir.Value], | ||
value_counter: collections.Counter, | ||
) -> bool: | ||
"""Process a value only if it hasn't been processed before.""" | ||
if value in seen_values: | ||
return False | ||
|
||
modified = False | ||
|
||
if not value.name: | ||
modified = self._assign_value_name(value, used_value_names, value_counter) | ||
else: | ||
old_name = value.name | ||
modified = self._fix_duplicate_value_name(value, used_value_names, value_counter) | ||
if modified: | ||
assert value.graph is not None | ||
if value.is_initializer(): | ||
value.graph.initializers.pop(old_name) | ||
# Add the initializer back with the new name | ||
value.graph.initializers.add(value) | ||
|
||
# Record the final name for this value | ||
assert value.name is not None | ||
seen_values.add(value) | ||
return modified | ||
|
||
def _assign_value_name( | ||
self, value: ir.Value, used_names: set[str], counter: collections.Counter | ||
) -> bool: | ||
"""Assign a name to an unnamed value. Returns True if modified.""" | ||
assert not value.name, ( | ||
"value should not have a name already if function is called correctly" | ||
) | ||
|
||
preferred_name = self._name_generator.generate_value_name(value) | ||
value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter) | ||
logger.debug("Assigned name %s to unnamed value", value.name) | ||
return True | ||
|
||
def _assign_node_name( | ||
self, node: ir.Node, used_names: set[str], counter: collections.Counter | ||
) -> bool: | ||
"""Assign a name to an unnamed node. Returns True if modified.""" | ||
assert not node.name, ( | ||
"node should not have a name already if function is called correctly" | ||
) | ||
|
||
preferred_name = self._name_generator.generate_node_name(node) | ||
node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter) | ||
logger.debug("Assigned name %s to unnamed node", node.name) | ||
return True | ||
|
||
def _fix_duplicate_value_name( | ||
self, value: ir.Value, used_names: set[str], counter: collections.Counter | ||
) -> bool: | ||
"""Fix a value's name if it conflicts with existing names. Returns True if modified.""" | ||
original_name = value.name | ||
|
||
assert original_name, ( | ||
"value should have a name already if function is called correctly" | ||
) | ||
|
||
if original_name not in used_names: | ||
# Name is unique, just record it | ||
used_names.add(original_name) | ||
return False | ||
|
||
# If name is already used, make it unique | ||
base_name = self._name_generator.generate_value_name(value) | ||
value.name = _find_and_record_next_unique_name(base_name, used_names, counter) | ||
logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name) | ||
return True | ||
|
||
def _fix_duplicate_node_name( | ||
self, node: ir.Node, used_names: set[str], counter: collections.Counter | ||
) -> bool: | ||
"""Fix a node's name if it conflicts with existing names. Returns True if modified.""" | ||
original_name = node.name | ||
|
||
assert original_name, "node should have a name already if function is called correctly" | ||
|
||
if original_name not in used_names: | ||
# Name is unique, just record it | ||
used_names.add(original_name) | ||
return False | ||
|
||
# If name is already used, make it unique | ||
base_name = self._name_generator.generate_node_name(node) | ||
node.name = _find_and_record_next_unique_name(base_name, used_names, counter) | ||
logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name) | ||
return True | ||
|
||
|
||
def _find_and_record_next_unique_name( | ||
preferred_name: str, used_names: set[str], counter: collections.Counter | ||
) -> str: | ||
"""Generate a unique name based on the preferred name and current counter.""" | ||
new_name = preferred_name | ||
while new_name in used_names: | ||
counter[preferred_name] += 1 | ||
new_name = f"{preferred_name}_{counter[preferred_name]}" | ||
used_names.add(new_name) | ||
return new_name |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.