Skip to content

Commit

Permalink
feat(common): node.replace() now supports mappings for quick lookup…
Browse files Browse the repository at this point in the history
…-like substitutions
  • Loading branch information
kszucs committed Dec 14, 2023
1 parent 1d314f7 commit bbc93c7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 20 deletions.
45 changes: 27 additions & 18 deletions ibis/common/graph.py
Expand Up @@ -3,7 +3,7 @@

from abc import abstractmethod
from collections import deque
from collections.abc import Iterable, Iterator, KeysView, Sequence
from collections.abc import Iterable, Iterator, KeysView, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar

from ibis.common.bases import Hashable
Expand Down Expand Up @@ -233,24 +233,33 @@ def replace(
-------
The root node of the graph with the replaced nodes.
"""
pat = pattern(pat)
ctx = context or {}

def fn(node, _, **kwargs):
# need to first reconstruct the node from the possible rewritten
# children, so we can match on the new node containing the rewritten
# child arguments, this way we can propagate the rewritten nodes
# upward in the hierarchy
# TODO(kszucs): add a __recreate__() method to the Node interface
# with a default implementation that uses the __class__ constructor
# which is supposed to provide an implementation for quick object
# reconstruction (the __recreate__ implementation in grounds.py
# should be sped up as well by totally avoiding the validation)
recreated = node.__class__(**kwargs)
if (result := pat.match(recreated, ctx)) is NoMatch:
return recreated
else:
return result
if isinstance(pat, Mapping):

def fn(node, _, **kwargs):
try:
return pat[node]
except KeyError:
return node.__class__(**kwargs)
else:
pat = pattern(pat)
ctx = context or {}

def fn(node, _, **kwargs):
# need to first reconstruct the node from the possible rewritten
# children, so we can match on the new node containing the rewritten
# child arguments, this way we can propagate the rewritten nodes
# upward in the hierarchy
# TODO(kszucs): add a __recreate__() method to the Node interface
# with a default implementation that uses the __class__ constructor
# which is supposed to provide an implementation for quick object
# reconstruction (the __recreate__ implementation in grounds.py
# should be sped up as well by totally avoiding the validation)
recreated = node.__class__(**kwargs)
if (result := pat.match(recreated, ctx)) is NoMatch:
return recreated
else:
return result

results = self.map(fn, filter=filter)
return results.get(self, self)
Expand Down
18 changes: 16 additions & 2 deletions ibis/common/tests/test_graph.py
Expand Up @@ -28,11 +28,11 @@ def __init__(self, name, children):

@property
def __args__(self):
return (self.children,)
return (self.name, self.children)

@property
def __argnames__(self):
return ("children",)
return ("name", "children")

def __repr__(self):
return f"{self.__class__.__name__}({self.name})"
Expand Down Expand Up @@ -145,6 +145,20 @@ def test_replace_with_filtering_out_root():
assert result == A


def test_replace_with_mapping():
new_E = MyNode(name="e", children=[])
new_D = MyNode(name="d", children=[])
new_B = MyNode(name="B", children=[new_D, new_E])
new_A = MyNode(name="A", children=[new_B, C])

subs = {
E: new_E,
D: new_D,
}
result = A.replace(subs)
assert result == new_A


def test_example():
class Example(Annotable, Node):
def __hash__(self):
Expand Down

0 comments on commit bbc93c7

Please sign in to comment.