In [4]:
# tree is conceptually very similar to the Linked List

In [5]:
class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

In [6]:
# to visualize the tree, first install the graphiz 'sodu apt install graphviz'
# then install the python package for communicating with graphviz

!pip install graphviz



In [7]:
# # the following function essentially just traverses the whole tree and add node and edges to the graph. Then
# # we use graphviz to visualize it. Don't need to understand it

from graphviz import Digraph

def visualize_tree(tree):
    if tree is None:
        print("Nothing in the tree")
        return

    def add_nodes_edges(tree, dot=None):
        if dot is None:
            dot = Digraph()
            dot.attr('node', shape='circle')
            dot.node(name=str(tree), label=str(tree.val))

        for child, label in [(tree.left, 'L'), (tree.right, 'R')]:
            if child is not None:
                if child == tree.left: dot.attr('node', shape='circle', style='filled', fillcolor='lightblue')
                if child == tree.right: dot.attr('node', shape='doublecircle', style='filled', fillcolor='pink')
                dot.node(name=str(child), label=str(child.val))
                dot.edge(str(tree), str(child), label=label)
                dot = add_nodes_edges(child, dot=dot)
            # else:
            #     null_id = str(id(tree)) + label
            #     dot.attr('node', shape='circle', style='filled', fillcolor='lightgrey')
            #     dot.node(name=null_id, label='None')
            #     dot.edge(str(id(tree)), null_id, label=label)

        return dot

    dot = add_nodes_edges(tree)
    dot.render('tree', view=True)  # Saves and opens the file
    # display(dot)


In [8]:
# creating a tree from an individual node
t1 = TreeNode(1)

t1.left = TreeNode(3)
t1.right = TreeNode(2)

visualize_tree(t1)           # this function needs a bit of extra work

In [9]:
t1.left.left = TreeNode(5)
t1.left.right = TreeNode(15)

t1.right.right = TreeNode(55)
t1.right.right.left = TreeNode(87)
t1.right.right.right = TreeNode(101)
t1.right.right.right.right = TreeNode(1019)
t1.right.right.right.left = TreeNode(200)

visualize_tree(t1)

In [10]:
# Tree traversal - Depth first search

In [11]:
# when we say depth first search, it usually means pre-order traversal: root, left, right

def dfs(self):
    print(self.val)

    if self.left:
        self.left.dfs()
    if self.right:
        self.right.dfs()

TreeNode.dfs = dfs

In [12]:
t1.dfs()

1
3
5
15
2
55
87
101
200
1019


In [13]:
# in-order traversal
def dfs_inorder(self):
    if self.left:
        self.left.dfs_inorder()

    print(self.val)

    if self.right:
        self.right.dfs_inorder()

TreeNode.dfs_inorder = dfs_inorder

In [14]:
t1.dfs_inorder()

5
3
15
1
2
87
55
200
101
1019


In [15]:
# post-order traversal
def dfs_postorder(self):
    if self.left:
        self.left.dfs_postorder()
    if self.right:
        self.right.dfs_postorder()
    print(self.val)

TreeNode.dfs_postorder = dfs_postorder

In [16]:
t1.dfs_postorder()

5
15
3
87
200
1019
101
55
2
1


In [17]:
visualize_tree(t1)

In [19]:
# tree traversal - breadth first search

def bfs(self):
    to_visit = [self]

    while to_visit:
        current = to_visit.pop(0)          # get the first one out
        print(current.val)                 # perform any operation here on all nodes
        if current.left:
            to_visit.append(current.left)
        if current.right:
            to_visit.append(current.right)

TreeNode.bfs = bfs

In [20]:
t1.bfs()

1
3
2
5
15
55
87
101
200
1019


In [21]:
# perform arbitrary functions on all nodes
def dfs_apply(self, fn):
    fn(self)

    if self.left:
        self.left.dfs_apply(fn)

    if self.right:
        self.right.dfs_apply(fn)

TreeNode.dfs_apply = dfs_apply

In [25]:
# arbitrary function (hook)
class PerformSum:
    def __init__(self):
        self.sum = 0

    def process(self, node):
        self.sum += node.val

    def get_sum(self):
        return self.sum

    def reset_sum(self):
        self.sum = 0

p = PerformSum()

In [26]:
t1.dfs_apply(p.process)
print(p.get_sum())

1488


In [28]:
# if we run the above code second time, it will again sum all the nodes and sum it with the previous sum
# that's why, first reset the sum 
p.reset_sum()
t1.dfs_apply(p.process)
print(p.get_sum())

1488
