# Object-Oriented Tree Walk Generator

The problem is that of walking a tree (let's say depth-first) as efficiently as possible given that
- the tree is too big to fully load into memory
- we cannot know each node's children without accessing the node itself
- each node is very slow to access
- we want to be able to filter through every node to keep or skip
- (we want to paginate this walk, which is a kind of filter)

The solution is to chain `Generator`s (a producer then a series of filters), with a way to propagate an `Exception` up the `Generator` chain to signal that the current node and all its (yet unaccessed) descendants may be skipped.

In [74]:
from __future__ import annotations

from typing import (
    TypeAlias,
    Tuple,
    List,
    Any,
    Generator,
    Type,
    Callable,
    Optional,
)

from dataclasses import dataclass

from collections import (
    deque
)

Let's create a sample tree to walk through.

To keep this as simple as possible each node will be a tuple composed of the node name and its child nodes list. Neither the root nor the leaves will be special in any way.

_Imagine however that each node is a remote ressource, which must be slowly fetched to be read, and that itself contains its children list unavailable anywhere else._

In [75]:
Tree: TypeAlias = Tuple[str, List["Tree"]]

tree: Tree = ("a", [
    ("a.a", [
        ("a.a.a", []),
        ("a.a.b", [
            ("a.a.b.a", []),
            ("a.a.b.b", [
                ("a.a.b.b.a", []),
            ]),
        ])
    ]),
    ("a.b", [
        ("a.b.a", []),
        ("a.b.b", [])
    ]),
    ("a.c", [])
])

We then create two generator classes.

A recursive `TreeWalk` generator which handles walking through the tree and yields each node; and a `TreeWalkFilter` generator which takes a `TreeWalk` or `TreeWalkFilter` generator as input, applies a custom filtering logic to each node it receives then yields it again (or not and eventually sends a _skip_ signal up the chain). This signal (an `Exception`) is propagated up from downstream filter generators or user code to the source walk generator scope using the generators `.throw()` method.

Our special _signal_, other exceptions will be delegated to the default exception handling behaviour.

In [76]:
class SkipBranchWalk(Exception):
    pass

The recursive `TreeWalk` generator.

In [77]:
class TreeWalk(Generator):

    root: Tree

    _remaining_nodes: deque[Tree]
    _current_yielding_node: TreeWalk | None = None
    _current_yielded_node: Tree | None = None

    def __init__(self, root: Tree):

        self.root = root
        self._remaining_nodes = deque(root[1])

    def __next__(self) -> str:
        if self._current_yielding_node:
            try:
                self._current_yielded_node = next(self._current_yielding_node)
                return self._current_yielded_node
            except StopIteration:
                self._current_yielding_node = None

        if not self._remaining_nodes:
            self._current_yielded_node = None
            raise StopIteration

        node = self._remaining_nodes.popleft()

        self._current_yielded_node = node[0]
        self._current_yielding_node = TreeWalk(node)

        return self._current_yielded_node

    def __iter__(self):
        return self

    def send(self, value: Any = None):
        return super().send(value)

    def throw(self, type: Type[Exception], value: Exception = None, traceback: Any = None):
        if type is SkipBranchWalk or isinstance(type, SkipBranchWalk):
            print(f"SkipBranchWalk thrown on {self._current_yielded_node}")

            self._current_yielding_node = None
        else:
            return super().throw(type, value, traceback)


The `TreeWalkFilter` generator. 

The `filter` is intended to be a pure function which receives a node name and decides to :
- return `True` or `False` to let through the node or instead skip it (but not its descendants)
- return `None` (same as `True`) to let through the node
- return `Any` other value, to transform the node name, then let through to the next filter (or the end consumer)
- return or raise `SkipBranchWalk` skip the node **and all its descendants**.

The filter being a generator, eventually in the middle of a chain of other filters, it must also handle the case of being `.throw()`n a `SkipBranchWalk` exception and pass it through to the root Walk Generator.

In [78]:


class TreeWalkFilter(Generator):

    _walk: TreeWalk
    _filter: Callable[[str], Optional[bool | str]]
    _current_yielded_node: str | None = None

    def __init__(self, walk: TreeWalk | TreeWalkFilter, filter: Callable[[str], Optional[bool | str]]):
        self._walk = walk
        self._filter = filter

    def __next__(self) -> str:
        current_node = next(self._walk)

        try:
            filtered_current_node = self._filter(current_node)

            if filtered_current_node in [False, None]:
                self._current_yielded_node = None
            elif filtered_current_node == True:
                self._current_yielded_node = current_node
            else:
                self._current_yielded_node = filtered_current_node
        except SkipBranchWalk as error:
            print(f"SkipBranchWalk raised on {current_node} in {str(self._filter)}")
            self._current_yielded_node = None

            assert self._walk.throw(error) is None

            return next(self)
        else:
            if self._current_yielded_node is not None:
                return self._current_yielded_node
            else:
                return next(self)

    def __iter__(self):
        return self

    def send(self, value: Any = None):
        return super().send(value)

    def throw(self, type: Type[Exception], value: Exception = None, traceback: Any = None):
        if type is SkipBranchWalk or isinstance(type, SkipBranchWalk):
            print(f"SkipBranchWalk thrown on {self._current_yielded_node} in {str(self._filter)}")

            assert self._walk.throw(type, value, traceback) is None
        else:
            return super().throw(type, value, traceback)



Finally let's see the result of this (slightly unconventional) setup

In [79]:

print(" --- Full tree --- ")

for node in (walk := TreeWalk(tree)):
    print(node)

print(" --- Without branch 'a' --- ")

for node in (walk := TreeWalk(tree)):
    if node.startswith("a.a"):
        walk.throw(SkipBranchWalk)
    else:
        print(node)

print(" --- Without branch 'a' using a filter --- ")


def filter(node: str):
    if node.startswith("a.a"):
        raise SkipBranchWalk
    else:
        return True

walk = TreeWalk(tree)
filtered_walk = TreeWalkFilter(walk, filter)

for node in filtered_walk:
    print(node)

print(" --- Without branch 'a' and 'b' using chained filters --- ")

def filter_a(node: str):
    if node.startswith("a.a"):
        raise SkipBranchWalk
    else:
        return True

def filter_b(node: str):
    if node.startswith("a.b"):
        raise SkipBranchWalk
    else:
        return True

walk = TreeWalk(tree)
filtered_walk = TreeWalkFilter(TreeWalkFilter(walk, filter_a), filter_b)

for node in filtered_walk:
    print(node)

 --- Full tree --- 
a.a
a.a.a
a.a.b
a.a.b.a
a.a.b.b
a.a.b.b.a
a.b
a.b.a
a.b.b
a.c
 --- Without branch 'a' --- 
SkipBranchWalk thrown on a.a
a.b
a.b.a
a.b.b
a.c
 --- Without branch 'a' using a filter --- 
SkipBranchWalk raised on a.a in <function filter at 0x7b9a2cce6700>
SkipBranchWalk thrown on a.a
a.b
a.b.a
a.b.b
a.c
 --- Without branch 'a' and 'b' using chained filters --- 
SkipBranchWalk raised on a.a in <function filter_a at 0x7b9a2cce5440>
SkipBranchWalk thrown on a.a
SkipBranchWalk raised on a.b in <function filter_b at 0x7b9a2cce6160>
SkipBranchWalk thrown on a.b in <function filter_a at 0x7b9a2cce5440>
SkipBranchWalk thrown on a.b
a.c
