# Learn PyTree!

PyTree is a utility module to handle the tree structure easily in PyTorch. The tree structure build out of container objects in Python.
Container objects can be anything if they are in PyTree registry. `tuple`, `list`, `dict`, `namedtuple`, `OrderedDict`, `defaultdict`
and `deque` are already in PyTree registry.

In PyTree, there are two major rules. That are:

1. Any object whose type is not in PyTree registry is considered a leaf.
2. Any object whose type is in PyTree registry is considered a node. 

Let's deep dive into PyTree.

In [1]:
# noinspection PyProtectedMember
import torch.utils._pytree as pytree

First, let's check a first major rule. `tuple` and `list` are nodes because these are in PyTree registry.

In [2]:
# noinspection PyProtectedMember
print(f'Is list a node? {"Yes" if not pytree._is_leaf(list()) else "No"}.')
# noinspection PyProtectedMember
print(f'Is dict a node? {"Yes" if not pytree._is_leaf(dict()) else "No"}.')

Is list a node? Yes.
Is dict a node? Yes.


Let's check a second major rule.

In [3]:
class Foo:
    pass

# noinspection PyProtectedMember
print(f'Is Foo a node? {"Yes" if not pytree._is_leaf(Foo()) else "No"}.')

Is Foo a node? No.


Easy. Now let's check how can we leverage PyTree in our coding.

In [4]:
numbers = [0,
           1,
           [10, 11, 12, 13],
           (20, 21, 22, 23),
           {'0': 30},
           3]

# I want to add 1 to every number in numbers. It can be hard because we need to access children recursively.
# There is PyTree to help you.

leaves, spec = pytree.tree_flatten(numbers)
print(f'leaves: {leaves}')

for i, leaf in enumerate(leaves):
    leaves[i] = leaf + 1

numbers = pytree.tree_unflatten(leaves, spec)
print(f'numbers: {numbers}')

leaves: [0, 1, 10, 11, 12, 13, 20, 21, 22, 23, 30, 3]
numbers: [1, 2, [11, 12, 13, 14], (21, 22, 23, 24), {'0': 31}, 4]


By the way, what's `spec`? The type of `spec` is `TreeSpec`. It can be used to reconstruct the tree.
Let's print it out.

In [5]:
print(f'spec: {spec}')

spec: TreeSpec(list, None, [*,
  *,
  TreeSpec(list, None, [*,
    *,
    *,
    *]),
  TreeSpec(tuple, None, [*,
    *,
    *,
    *]),
  TreeSpec(dict, ['0'], [*]),
  *])


It's hard to ready. Let's understand step by step.

```python
@dataclasses.dataclass
class TreeSpec:
    type: Any
    context: Context
    children_specs: List["TreeSpec"]
```

Upper is a definition of TreeSpec.

```
TreeSpec(list, None, [*,
                      *,
                      *,
                      *])
```

Try to understand it with a definition of `TreeSpec`. This represents a node. The type of node is `list` and no context.
This node has four children. This spec builds out of `[10, 11, 12, 13]`. Now we can read.

PyTree provides many utility functions. We can more easily add 1 to every number in numbers.

In [6]:
numbers = pytree.tree_map(lambda x: x + 1, numbers)
print(numbers)

[2, 3, [12, 13, 14, 15], (22, 23, 24, 25), {'0': 32}, 5]


Now it's time to jump the next level. Let's define a custom class as a node in PyTree.

In [7]:
class Child:
    def __init__(self, age):
        self.age = age
        
    def __str__(self):
        return f'Child(age: {self.age})'

class Parent:
    def __init__(self, num_children):
        self.children = [Child(0) for index in range(0, num_children)] 
        
    def __str__(self):
        return f'Parent([{", ".join(str(child) for child in self.children)})]'
    
family = Parent(3)
print(family)

Parent([Child(age: 0), Child(age: 0), Child(age: 0))]


It looks like `Parent` has tree structure. `Parent` is root and `Child` is leaf. Let's check it now. 

In [8]:
# noinspection PyProtectedMember
print(f'Is Parent a leaf? {"Yes" if pytree._is_leaf(Parent(3)) else "No"}.')

Is Parent a leaf? Yes.


It's a leaf. Why? Because `Parent` is not in PyTree registry. We can add `Parent` in PyTree registry using
`register_pytree_node`.

In [9]:
from typing import Tuple, List, Iterable

def parent_flatten(parent: pytree.PyTree) -> Tuple[List[Child], pytree.Context]:
    return parent.children, 'P'

def parent_unflatten(children: Iterable[Child], context: pytree.Context) -> pytree.PyTree:
    parent = Parent(0)
    parent.children = children
    return parent

pytree.register_pytree_node(Parent, parent_flatten, parent_unflatten)

# noinspection PyProtectedMember
print(f'Is Parent a leaf? {"Yes" if pytree._is_leaf(Parent(3)) else "No"}.')

Is Parent a leaf? No.


Let's print out flatten results of `Parent`.

In [10]:
leaves, spec = pytree.tree_flatten(family)
print(f'leaves: {leaves}')
print(f'spec: {spec}')

leaves: [<__main__.Child object at 0x14a5ebf10>, <__main__.Child object at 0x14a5eb7c0>, <__main__.Child object at 0x14a5d3250>]
spec: TreeSpec(Parent, P, [*,
  *,
  *])


Let's use a PyTree's utility function.

In [11]:
def happy_birthday(child: Child):
    child.age = child.age + 1
    
pytree.tree_map(happy_birthday, family)
print(family)

Parent([Child(age: 1), Child(age: 1), Child(age: 1))]


Enjoy!