## Requirements

In [2]:
from typing import Any, Callable, Self

# Binary tree implementation

You can start by creating a class to represent the nodes in a binary tree.  These nodes can have a value of any type, and optionally, a left and a right child.  If a node has no left or right child, that is represented as `None`.

In [3]:
class Node:
    
    value: Any
    _left: Self | None
    _right: Self | None
    
    def __init__(self: Self, value: Any) -> None:
        self.value = value
        self._left, self._right = None, None
    
    @property
    def left(self: Self) -> Self:
        return self._left
    
    def add_left(self: Self, child: Self) -> Self:
        self._left = child
        return self
    
    @property
    def right(self: Self) -> Self:
        return self._right
    
    def add_right(self: Self, child: Self) -> Self:
        self._right = child
        return self
    
    def to_string(self: Self, prefix: str='') -> str:
        repr = f'{prefix}{self.value}\n'
        prefix += '  '
        no_child_str = f'{prefix}None\n'
        repr += self.left.to_string(prefix) if self.left is not None else no_child_str
        repr += self.right.to_string(prefix) if self.right is not None else no_child_str
        return repr

The `add_left` and `add_right` methods allow for a builder pattern type of tree creation as you will see below.

Next, you can implement a class that represents trees.

In [4]:
class Tree:
    
    root: Node | None
    
    def __init__(self: Self, root: Node | None) -> None:
        self.root = root
        
    def __repr__(self: Self) -> str:
        return self.root.to_string() if self.root is not None else 'None'

Below you see the initialization of a tree, making clear why the `add_left` and `add_right` methods are convenient.

In [5]:
tree = Tree(
    Node(1)
      .add_left(Node(2))
      .add_right(
          Node(3)
              .add_left(
                  Node(4)
                      .add_right(Node(5))
                )
              .add_right(Node(6))
      )
)

The `__repr__` method calls `Node` instances `to_string` method to visualize the tree structure.

In [6]:
print(tree)

1
  2
    None
    None
  3
    4
      None
      5
        None
        None
    6
      None
      None



Now you can write a function to create a perfectly balanced binary tree with a specified depth for testing purposes.  If the depth of the tree is $d$, the tree will have $2^d - 1$ nodes.

In [6]:
def create_balanced_tree(max_depth: int) -> Tree:
    tree = Tree(None)
    node_value = 0
    if max_depth > 0:
        node_value += 1
        tree.root = Node(node_value)
        nodes = [tree.root]
        for depth in range(1, max_depth):
            next_generation = []
            for node in nodes:
                node_value += 1
                node.add_left(Node(node_value))
                next_generation.append(node.left)
                node_value += 1
                node.add_right(Node(node_value))
                next_generation.append(node.right)
            nodes = next_generation
    return tree

In [10]:
create_balanced_tree(2)

1
  2
    None
    None
  3
    None
    None

# Tree traversal

The most intuitive way to traverse a tree is by using recursion.  However, you can also implement this using iteration only.

## Recursive traversal

You can write a function that applies a function to each node of a tree.

In [7]:
def recursive_traversal(tree: Tree, func: Callable) -> Any | None:
    if tree.root is not None:
        def recurse(node: Node | None) -> Callable | None:
            if node is not None:
                func(node)
                recurse(node.left)
                recurse(node.right)
            return func.result() if hasattr(func, 'result') else None
    return recurse(tree.root)

### Node value sum

Using a higher-order function you can use this to compute the sum of the value of all the nodes.

In [8]:
class Sum:
    
    def __init__(self: Self):
        self._result = 0
        
    def __call__(self: Self, node: Node) -> None:
        self._result += node.value
        
    def result(self: Self) -> int:
        return self._result

In [29]:
recursive_traversal(create_balanced_tree(2), Sum())

6

In [9]:
recursive_traversal(create_balanced_tree(10), Sum())

523776

The sum of the node values is $\sum_{i=1{^{2^d - 1} i$, so for $d = 10$:

In [39]:
sum(range(1, 2**10))

523776

The result is indeed correct.

### Node value transformation

You can use the implementation of `recursive_traaversal` as well to do node transformations, i.e., multiplying the value of each node by 2.

In [41]:
tree = create_balanced_tree(3)
tree

1
  2
    4
      None
      None
    5
      None
      None
  3
    6
      None
      None
    7
      None
      None

In [10]:
def times_2(node):
    node.value *= 2

In [43]:
recursive_traversal(tree, times_2)
tree

2
  4
    8
      None
      None
    10
      None
      None
  6
    12
      None
      None
    14
      None
      None

## Iterative implementation

In [11]:
def iterative_traversal(tree: Tree, func: Callable) -> Any | None:
    if tree.root is not None:
        stack = [tree.root]
        while stack:
            node = stack.pop()
            func(node)
            if node.right is not None:
                stack.append(node.right)
            if node.left is not None:
                stack.append(node.left)
    return func.result() if hasattr(func, 'result') else None

### Node value sum

Using the `Sum` class defined above, you can again compute the sum of all node values.

In [36]:
iterative_traversal(create_balanced_tree(2), Sum())

6

In [37]:
iterative_traversal(create_balanced_tree(10), Sum())

523776

### Node value transoformations

Again, just like for the recursive implemtation, the function `times_2` defined above can be used to multiply the value of each node by 2.

In [44]:
tree = create_balanced_tree(3)
tree

1
  2
    4
      None
      None
    5
      None
      None
  3
    6
      None
      None
    7
      None
      None

In [45]:
iterative_traversal(tree, times_2)
tree

2
  4
    8
      None
      None
    10
      None
      None
  6
    12
      None
      None
    14
      None
      None

## Performance

To compare the performance of both impelementations, you can create a large tree.

In [17]:
large_tree = create_balanced_tree(20)

In [18]:
%timeit recursive_traversal(large_tree, Sum())

414 ms ± 9.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
%timeit iterative_traversal(large_tree, Sum())

249 ms ± 14.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


It is clear that the iterative implementation outperforms the recursive implementation.

In [20]:
large_tree = create_balanced_tree(20)

In [21]:
%timeit recursive_traversal(large_tree, times_2)

271 ms ± 18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
large_tree = create_balanced_tree(20)

In [23]:
%timeit iterative_traversal(large_tree, times_2)

197 ms ± 7.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Also for this example, the iterative implementation is more efficient.

## Note

Although it might seem inefficient to use a list and `append` and `pop` to manage a stack, it is in fact quite efficient, and faster than pre-allocating a list and keeping track of the size of the stack separately.