In [3]:
from typing import List

# How to construct a Binary Search Tree 
- from a list
- Input: root = [3,9,20,null,null,15,7]

- https://stackoverflow.com/questions/43097045/best-way-to-construct-a-binary-tree-from-a-list-in-python
    - gs: how to construct a binary tree python from list

In [1]:
input_list = [3,9,20,None,None,15,7]

In [8]:
class Node:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
    def __repr__(self):
        return f'node:{self.val}, left: {self.left.val}, right: {self.right.val}'
    
def construct_tree(node_list:List)-> Node:
    ## root index will be at 0
    def inner(index=0):
        if index > len(node_list) or node_list[index] is None:
            return None
        node_val = node_list[index]
        left_node_val = node_list[index*2 + 1] 
        right_node_val = node_list[index*2 + 2]
        root = Node(val=node_val, left=left_node_val, right = right_node_val)
        inner(index*2 +1)
        inner(index*2 +2)
        return root
        # do the same for the other nodes
        
    return inner()

construct_tree(input_list)

IndexError: list index out of range

In [8]:
class Node:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
    def __repr__(self):
        return f'node:{self.val}, left: {self.left}, right: {self.right}'
    
    def printTree_helper(self, node, level=0):
        if node != None:
            self.printTree_helper(node.left, level + 1)
            print(f"{' ' * 4 * level} -> {node.value:02}")
            self.printTree_helper(node.right, level + 1)
        
    def printTree(self):
        return self.printTree_helper(self, 0)
    
def construct_tree(node_list:List)-> Node:
    ## root index will be at 0
    def inner(index=0):
        if index > len(node_list) or node_list[index] is None:
            return None
        node_val = node_list[index]
        left_node_val = node_list[index*2 + 1] if (index *2 +1) < len(node_list) else None
        right_node_val = node_list[index*2 + 2] if (index *2 +2) < len(node_list) else None
        root = Node(val=node_val, left=left_node_val, right = right_node_val)
        inner(index*2 +1)
        inner(index*2 +2)
        return root
        # do the same for the other nodes
        
    return inner()

n = construct_tree(input_list)

In [10]:
n.printTree() 

AttributeError: 'int' object has no attribute 'left'

comments:
- The method above will not work because it is iteratively trying to create the Node on the fly.
- There is no way to reference the left and right node, since the input is an integer. Your input needs to also be a Node
- This means that we dont have a simple method to check that the index is out of range

In [15]:
def construct_tree(input_list: List) -> Node:
    def inner(index):
        if index > len(input_list) or input_list[index] is None:
            return None
        root = Node(input_list[index])
        root.left = inner(index*2+1)
        root.right = inner(index*2+2)
        return root
    return inner(0)
construct_tree(input_list)

node:3, left: 9, right: 20

- We can see that this simplistic way of printing the tree is not good.
- A better way is the following

In [53]:
class Node:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right
        
    def __repr__(self):
        return f'node:{self.value}, left: {self.left}, right: {self.right}'
    
    def printTree_helper(self, node, level=0):
        if node != None:
            self.printTree_helper(node.left, level + 1)
            print(f"{' ' * 4 * level} -> {node.value:02}")
            self.printTree_helper(node.right, level + 1)
        
    def printTree(self):
        return self.printTree_helper(self, 0)
    
    def __str__(self) -> str:
        return str(self.value)
    
def construct_tree(input_list: List) -> Node:
    def inner(index):
        if index > len(input_list) or input_list[index] is None:
            return None
        root = Node(input_list[index])
        root.left = inner(index*2+1)
        root.right = inner(index*2+2)
        return root
    return inner(0)

root = construct_tree(input_list)

### Time complexity
- The time complexity of this algorithm is O(N), we make N/2 recursive inner() calls because inner() will only be called when the index is in range of our input list(constructing our root.left and root.right accesses 3 elements per inner function call) 

In [44]:
root.printTree()

     -> 09
 -> 03
         -> 15
     -> 20
         -> 07


# How to determine if a tree is balanced
- https://leetcode.com/problems/balanced-binary-tree/description/

In [67]:
# find depth of a node
def depth(node):
    if node is None:
        return 0
    return max(depth(node.left), depth(node.right)) + 1
    
def isBalanced(node):
    if not node:
        return True
    left_depth = depth(node.left)
    right_depth = depth(node.right)
    return (abs(left_depth - right_depth) <= 1) & isBalanced(node.left)\
            & isBalanced(node.right)

isBalanced(root)

True

# Time complexity
- Time complexity is O(N^2). This is because make N isBalanced() function call since we check if the root, left and right node is balanced. 
- In each isBalanced() function call, we get the depth of its left and right node which is O(N) because traversing both left and right means that we are traversing the entire node.
    - TLDR: each isBalanced() call is O(N)
-  If we call isBalanced() on each of the N elements, time complexity will be O(N^2)

# Better solution
- A better solution will be to make use of the definition of binary search tree
- We can just make use of the DFS method

In [54]:
root

node:3, left: 9, right: 20

In [55]:
root.left

node:9, left: None, right: None

In [56]:
root.right

node:20, left: 15, right: 7

In [58]:
root.right

node:20, left: 15, right: 7

In [63]:
# def dfs(node, level):
#     if node is None:
#         return 0
    
#     left_depth = dfs(node.left, level+1)
#     right_depth = dfs(node.right, level+1)
    
#     print(f'left_depth: {left_depth}')
#     print(f'right_depth: {right_depth}')
    
#     if (left_depth <0) or (right_depth < 0):
#         return -1
    
#     if abs(left_depth - right_depth) > 1:
#         return -1
    
#     return max(left_depth, right_depth) + 1

def dfs(node):
    if (node.left is None) or (node.right is None):
        return 0
    
    left_depth = dfs(node.left)
    right_depth = dfs(node.right)
    
    print(f'left_depth: {left_depth}')
    print(f'right_depth: {right_depth}')
    
    if (left_depth <0) or (right_depth < 0):
        return -1
    
    if abs(left_depth - right_depth) > 1:
        return -1
    
    return max(left_depth, right_depth) + 1

root = construct_tree(input_list)

dfs(root)

left_depth: 0
right_depth: 0
left_depth: 0
right_depth: 1


2

# Time complexity
- In this more efficient version, time complexity is O(N). This is because we are only calling depth() for each node Once