Skip to content

Commit

Permalink
Create a recursive graph iterator and use it to refactor UnusedFuncti…
Browse files Browse the repository at this point in the history
…onRemover (#1565)

- Create `traversal.py` for graph traversal utilities and implemented
`RecursiveGraphIterator`. Expose `traversal` to the `ir` module. Fixes
#1556
- Remove `NodeTransformer` because `RecursiveGraphIterator` is more
flexible.
- Refactor remove_unused_function.py to use `RecursiveGraphIterator`
  • Loading branch information
justinchuby committed May 24, 2024
1 parent a6843da commit c41ded5
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 115 deletions.
3 changes: 2 additions & 1 deletion onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@
"tensor",
# Pass infrastructure
"passes",
"traversal",
]

from onnxscript.ir import passes, serde
from onnxscript.ir import passes, serde, traversal
from onnxscript.ir._convenience import tensor
from onnxscript.ir._core import (
Attr,
Expand Down
2 changes: 0 additions & 2 deletions onnxscript/ir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"PassBase",
"PassResult",
"PassManager",
"NodeTransformer",
# Errors
"InvariantError",
"PreconditionError",
Expand All @@ -17,7 +16,6 @@

from onnxscript.ir.passes._pass_infra import (
InvariantError,
NodeTransformer,
PassBase,
PassError,
PassManager,
Expand Down
82 changes: 0 additions & 82 deletions onnxscript/ir/passes/_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import Sequence

__all__ = [
"NodeTransformer",
"PassBase",
"PassManager",
"PassResult",
Expand Down Expand Up @@ -100,87 +99,6 @@ def ensures(self, model: ir.Model) -> None:
del model # Unused


class NodeTransformer(PassBase):
"""NodeTransformer for the ONNX IR.
An NodeTransformer is a pass that traverses the IR and performs some
operation on the nodes. The operation can be anything, such as
checking invariants, transforming the IR, or generating code.
By default, the NodeTransformer updates the model in place.
.. warning::
Users should not depend on this class before the warning is removed, because it is not stable.
Attributes:
model: ir.Model: The model being interpreted.
scope (list[ir.Graph]): The current graph the NodeTransformer is running on.
reversed (bool): Whether to traverse the graph in reverse order.
modified (bool): Whether the model was modified.
"""

def __init__(self, reversed: bool = False):
self._model: ir.Model | None = None
self.scope: list[ir.Graph] = []
self.reversed = reversed
self.modified: bool | None = None

@property
def model(self) -> ir.Model:
"""Return the model being interpreted."""
if self._model is None:
raise ValueError("Model is not set. The model is set during the pass execution.")
return self._model

def call(self, model: ir.Model) -> PassResult:
self._model = model
self.enter_pass()
self._call_graph(self._model.graph)
self.exit_pass()
if self.modified is None:
raise PassError("The modified attribute was not set. Please set it in the pass.")
return PassResult(self._model, self.modified)

def _call_graph(self, graph: ir.Graph):
self.enter_graph(graph)
self.scope.append(graph)
iterable = reversed(graph) if self.reversed else graph
for node in iterable:
self.call_node_recursive(node)
self.exit_graph(graph)
self.scope.pop()

def call_node_recursive(self, node: ir.Node):
self.call_node(node)
for attr in node.attributes.values():
if not isinstance(attr, ir.Attr):
continue
if attr.type == ir.AttributeType.GRAPH:
self._call_graph(attr.value)
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.value:
self._call_graph(graph)

def enter_pass(self):
"""Called when entering the pass. Optional to implement."""

def exit_pass(self):
"""Called when exiting the pass. Optional to implement."""

def enter_graph(self, graph: ir.Graph):
"""Called when entering a graph. Optional to implement."""
del graph # Unused

def exit_graph(self, graph: ir.Graph):
"""Called when exiting a graph. Optional to implement."""
del graph # Unused

@abc.abstractmethod
def call_node(self, node: ir.Node):
"""Called when visiting a node."""
...


class PassManager:
"""Pass manager for the IR.
Expand Down
82 changes: 82 additions & 0 deletions onnxscript/ir/traversal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Utilities for traversing the IR graph."""

from __future__ import annotations

__all__ = [
"RecursiveGraphIterator",
]

from typing import Callable, Iterator, Reversible

from typing_extensions import Self

from onnxscript.ir import _core, _enums


class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
def __init__(
self,
graph: _core.Graph | _core.Function | _core.GraphView,
*,
recursive: Callable[[_core.Node], bool] | None = None,
reverse: bool = False,
):
"""Iterate over the nodes in the graph, recursively visiting subgraphs.
Args:
graph: The graph to traverse.
recursive: A callback that determines whether to recursively visit the subgraphs
contained in a node. If not provided, all nodes in subgraphs are visited.
reverse: Whether to iterate in reverse order.
"""
self._graph = graph
self._recursive = recursive
self._reverse = reverse
self._iterator = self._recursive_node_iter(graph)

def __iter__(self) -> Self:
self._iterator = self._recursive_node_iter(self._graph)
return self

def __next__(self) -> _core.Node:
return next(self._iterator)

def _recursive_node_iter(
self, graph: _core.Graph | _core.Function | _core.GraphView
) -> Iterator[_core.Node]:
iterable = reversed(graph) if self._reverse else graph
for node in iterable: # type: ignore[union-attr]
yield node
if self._recursive is not None and not self._recursive(node):
continue
yield from self._iterate_subgraphs(node)

def _iterate_subgraphs(self, node: _core.Node):
for attr in node.attributes.values():
if not isinstance(attr, _core.Attr):
continue
if attr.type == _enums.AttributeType.GRAPH:
yield from RecursiveGraphIterator(
attr.value,
recursive=self._recursive,
reverse=self._reverse,
)
elif attr.type == _enums.AttributeType.GRAPHS:
graphs = reversed(attr.value) if self._reverse else attr.value
for graph in graphs:
yield from RecursiveGraphIterator(
graph,
recursive=self._recursive,
reverse=self._reverse,
)

def __reversed__(self) -> Iterator[_core.Node]:
return RecursiveGraphIterator(
self._graph,
recursive=self._recursive,
reverse=not self._reverse,
)
83 changes: 83 additions & 0 deletions onnxscript/ir/traversal_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations

import unittest

import parameterized

from onnxscript import ir
from onnxscript.ir import traversal


class RecursiveGraphIteratorTest(unittest.TestCase):
def setUp(self):
self.graph = ir.Graph(
[],
[],
nodes=[
ir.Node("", "Node1", []),
ir.Node("", "Node2", []),
ir.Node(
"",
"If",
[],
attributes=[
ir.AttrGraph(
"then_branch",
ir.Graph(
[],
[],
nodes=[ir.Node("", "Node3", []), ir.Node("", "Node4", [])],
name="then_graph",
),
),
ir.AttrGraph(
"else_branch",
ir.Graph(
[],
[],
nodes=[ir.Node("", "Node5", []), ir.Node("", "Node6", [])],
name="else_graph",
),
),
],
),
],
name="main_graph",
)

@parameterized.parameterized.expand(
[
("forward", False, ("Node1", "Node2", "If", "Node3", "Node4", "Node5", "Node6")),
("reversed", True, ("If", "Node4", "Node3", "Node6", "Node5", "Node2", "Node1")),
]
)
def test_recursive_graph_iterator(self, _: str, reverse: bool, expected: tuple[str, ...]):
iterator = traversal.RecursiveGraphIterator(self.graph)
if reverse:
iterator = reversed(iterator)
nodes = list(iterator)
self.assertEqual(tuple(node.op_type for node in nodes), expected)

@parameterized.parameterized.expand(
[
("forward", False, ("Node1", "Node2", "If")),
("reversed", True, ("If", "Node2", "Node1")),
]
)
def test_recursive_graph_iterator_recursive_controls_recursive_behavior(
self, _: str, reverse: bool, expected: list[str]
):
nodes = list(
traversal.RecursiveGraphIterator(
self.graph, recursive=lambda node: node.op_type != "If", reverse=reverse
)
)
self.assertEqual(tuple(node.op_type for node in nodes), expected)


if __name__ == "__main__":
unittest.main()
75 changes: 45 additions & 30 deletions onnxscript/optimizer/remove_unused_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import logging
from typing import TypeVar

import onnx

Expand All @@ -13,47 +14,61 @@
logger = logging.getLogger(__name__)


class UnusedFunctionRemover(ir.passes.NodeTransformer):
TModel = TypeVar("TModel", ir.Model, onnx.ModelProto)


def _clean_up_unused_functions(model: ir.Model, unused: set[ir.OperatorIdentifier]) -> None:
"""Removes unused functions from the model."""
for op_identifier in unused:
del model.functions[op_identifier]

logger.info("Removed %s unused functions", len(unused))
logger.debug("Functions left: %s", list(model.functions))
logger.debug("Functions removed: %s", unused)


class RemoveUnusedFunctionPass(ir.passes.PassBase):
def __init__(self):
super().__init__()
self.used: set[ir.OperatorIdentifier] = set()
self.used: set[ir.OperatorIdentifier] | None = None

def call(self, model: ir.Model) -> ir.passes.PassResult:
self.used = set()
for node in ir.traversal.RecursiveGraphIterator(model.graph):
self._call_node(model, node)

# Update the model to remove unused functions
unused = set(model.functions) - self.used
if not unused:
logger.info("No unused functions to remove")
return ir.passes.PassResult(model, modified=False)

def _call_function(self, function: ir.Function) -> None:
_clean_up_unused_functions(model, unused)
self.used = None
return ir.passes.PassResult(model, modified=True)

def _call_function(self, model: ir.Model, function: ir.Function) -> None:
assert self.used is not None
if function.identifier() in self.used:
# The function and its nodes are already recorded as used
return
self.used.add(function.identifier())
for node in function:
self.call_node_recursive(node)
for node in ir.traversal.RecursiveGraphIterator(function):
self._call_node(model, node)

def call_node(self, node: ir.Node) -> None:
def _call_node(self, model: ir.Model, node: ir.Node) -> None:
op_identifier = node.op_identifier()
if op_identifier in self.model.functions:
self._call_function(self.model.functions[op_identifier])
else:
self.used.add(op_identifier)

def exit_pass(self) -> None:
# Update the model to remove unused functions
unused = set(self.model.functions) - self.used
if not unused:
logger.info("No unused functions to remove")
self.modified = False
if op_identifier not in model.functions:
return
for op_identifier in unused:
if op_identifier not in self.used:
del self.model.functions[op_identifier]
self.modified = True
logger.info("Removed %s unused functions", len(unused))
logger.debug("Functions left: %s", list(self.model.functions))
logger.debug("Functions removed: %s", unused)
self._call_function(model, model.functions[op_identifier])


def remove_unused_functions(model_proto: onnx.ModelProto) -> onnx.ModelProto:
def remove_unused_functions(model: TModel) -> TModel:
"""Removes unused function protos from the model."""
# TODO(justinchuby): Update this to accept an ir.Model
model = ir.serde.deserialize_model(model_proto)
UnusedFunctionRemover()(model)
model_proto = ir.serde.serialize_model(model)

return model_proto
if isinstance(model, ir.Model):
return RemoveUnusedFunctionPass()(model).model # type: ignore[return-value]

model_ = ir.serde.deserialize_model(model)
result = RemoveUnusedFunctionPass()(model_)
return ir.serde.serialize_model(result.model)

0 comments on commit c41ded5

Please sign in to comment.