Skip to content

Commit

Permalink
feat(common): expose node.__children__ property to access the flatt…
Browse files Browse the repository at this point in the history
…ened list of children of a node
  • Loading branch information
kszucs authored and cpcloud committed Dec 14, 2023
1 parent 15acf7d commit 2e91476
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
19 changes: 12 additions & 7 deletions ibis/common/graph.py
Expand Up @@ -128,6 +128,11 @@ def __args__(self) -> tuple[Any, ...]:
def __argnames__(self) -> tuple[str, ...]:
"""Sequence of argument names."""

@property
def __children__(self) -> tuple[Node, ...]:
"""Sequence of children nodes."""
return tuple(_flatten_collections(self.__args__))

def __rich_repr__(self):
"""Support for rich reprerentation of the node."""
return zip(self.__argnames__, self.__args__)
Expand Down Expand Up @@ -228,7 +233,7 @@ def find_topmost(self, pat: type, context: Optional[dict] = None) -> list[Node]:
if pat.match(node, ctx) is not NoMatch:
result.append(node)
else:
queue.extend(_flatten_collections(node.__args__))
queue.extend(node.__children__)
seen.add(node)
else:
# fast path for locating a specific type
Expand All @@ -237,7 +242,7 @@ def find_topmost(self, pat: type, context: Optional[dict] = None) -> list[Node]:
if isinstance(node, pat):
result.append(node)
else:
queue.extend(_flatten_collections(node.__args__))
queue.extend(node.__children__)
seen.add(node)

return result
Expand Down Expand Up @@ -454,7 +459,7 @@ def traverse(

if control is not halt:
if control is proceed:
children = tuple(_flatten_collections(node.__args__))
children = node.__children__
elif isinstance(control, Iterable):
children = control
else:
Expand Down Expand Up @@ -488,7 +493,7 @@ def bfs(root: Node) -> Graph:

while queue:
if (node := queue.popleft()) not in graph:
children = tuple(_flatten_collections(node.__args__))
children = node.__children__
graph[node] = children
queue.extend(children)

Expand Down Expand Up @@ -524,7 +529,7 @@ def bfs_while(root: Node, filter: Optional[Any] = None) -> Graph:
if (node := queue.popleft()) not in graph:
children = tuple(
child
for child in _flatten_collections(node.__args__)
for child in node.__children__
if filter.match(child, {}) is not NoMatch
)
graph[node] = children
Expand Down Expand Up @@ -555,7 +560,7 @@ def dfs(root: Node) -> Graph:

while stack:
if (node := stack.pop()) not in graph:
children = tuple(_flatten_collections(node.__args__))
children = node.__children__
graph[node] = children
stack.extend(children)

Expand Down Expand Up @@ -591,7 +596,7 @@ def dfs_while(root: Node, filter: Optional[Any] = None) -> Graph:
if (node := stack.pop()) not in graph:
children = tuple(
child
for child in _flatten_collections(node.__args__)
for child in node.__children__
if filter.match(child, {}) is not NoMatch
)
graph[node] = children
Expand Down
8 changes: 7 additions & 1 deletion ibis/common/tests/test_graph.py
Expand Up @@ -117,7 +117,7 @@ def test_nested_children():
b = MyNode(name="b", children=[a])
c = MyNode(name="c", children=[])
d = MyNode(name="d", children=[])
e = MyNode(name="e", children=[[b, c], d])
e = MyNode(name="e", children=[[b, c], {"d": d}])
assert bfs(e) == {
e: (b, c, d),
b: (a,),
Expand All @@ -126,6 +126,12 @@ def test_nested_children():
a: (),
}

assert a.__children__ == ()
assert b.__children__ == (a,)
assert c.__children__ == ()
assert d.__children__ == ()
assert e.__children__ == (b, c, d)


@pytest.mark.parametrize("func", [bfs_while, dfs_while, Graph.from_bfs, Graph.from_dfs])
def test_traversals_with_filter(func):
Expand Down

0 comments on commit 2e91476

Please sign in to comment.