Skip to content

Commit

Permalink
perf(common): improve the performance of replacing nodes by using a s…
Browse files Browse the repository at this point in the history
…pecialized `node.__recreate__()` method
  • Loading branch information
kszucs committed Dec 14, 2023
1 parent a1881eb commit f3da926
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
17 changes: 8 additions & 9 deletions ibis/common/graph.py
Expand Up @@ -169,22 +169,16 @@ def _coerce_replacer(obj: ReplacerLike, context: Optional[dict] = None) -> Repla
-------
A callable replacer function which can be used to replace nodes.
"""

# TODO(kszucs): add a __recreate__() method to the Node interface
# with a default implementation that uses the __class__ constructor
# which is supposed to provide an implementation for quick object
# reconstruction (the __recreate__ implementation in grounds.py
# should be sped up as well by totally avoiding the validation)

if isinstance(obj, Pattern):
ctx = context or {}

def fn(node, _, **kwargs):
# need to first reconstruct the node from the possible rewritten
# children, so we can match on the new node containing the rewritten
# child arguments, this way we can propagate the rewritten nodes
# upward in the hierarchy
recreated = node.__class__(**kwargs)
# upward in the hierarchy, using a specialized __recreate__ method
# improves the performance by 17% compared node.__class__(**kwargs)
recreated = node.__recreate__(kwargs)
if (result := obj.match(recreated, ctx)) is NoMatch:
return recreated
else:
Expand All @@ -208,6 +202,11 @@ def fn(node, _, **kwargs):
class Node(Hashable):
__slots__ = ()

@classmethod
def __recreate__(cls, kwargs: Any) -> Self:
"""Reconstruct the node from the given arguments."""
return cls(**kwargs)

@property
@abstractmethod
def __args__(self) -> tuple[Any, ...]:
Expand Down
10 changes: 9 additions & 1 deletion ibis/common/tests/test_graph_benchmarks.py
Expand Up @@ -6,8 +6,10 @@
from typing_extensions import Self # noqa: TCH002

from ibis.common.collections import frozendict
from ibis.common.deferred import _
from ibis.common.graph import Graph, Node
from ibis.common.grounds import Concrete
from ibis.common.patterns import Between, Object


class MyNode(Concrete, Node):
Expand All @@ -24,7 +26,7 @@ def generate_node(depth):
if depth == 0:
return MyNode(10, "20", c=(30, 40), d=frozendict(e=50, f=60))
return MyNode(
1,
depth,
"2",
c=(3, 4),
d=frozendict(e=5, f=6),
Expand All @@ -48,3 +50,9 @@ def test_bfs(benchmark):
def test_dfs(benchmark):
node = generate_node(500)
benchmark(Graph.from_dfs, node)


def test_replace(benchmark):
node = generate_node(500)
pattern = Object(MyNode, a=Between(lower=100)) >> _.copy(a=_.a + 1)
benchmark(node.replace, pattern)

0 comments on commit f3da926

Please sign in to comment.