In [37]:
class TreeNode:
    def __init__(self, val=0, 
                 left=None, 
                 right=None):
        self.val = val
        self.left = left
        self.right = right

In [38]:
def initTree():
    left = TreeNode(2)
    left.left = TreeNode(4)
    left.right = TreeNode(5)
    right = TreeNode(3)
    right.left = TreeNode(6)
    right.right = TreeNode(7)
    root = TreeNode(1, left, right)
    return root

In [39]:
def printTree(root=None):
    if not root: return
    print(root.val, end =" ")
    printTree(root.left)    
    printTree(root.right)

In [40]:
t = initTree()
printTree(t)

1 2 4 5 3 6 7 

## 226. Invert Binary Tree

In [41]:
def invertTree(root):
    if not root: return
    left = root.left
    right = root.right
    root.right = left
    root.left = right
    invertTree(root.left)
    invertTree(root.right)
    return root

In [42]:
root = initTree()
printTree(root)

1 2 4 5 3 6 7 

In [43]:
invertTree(root)
printTree(root)

1 3 7 6 2 5 4 

## 104. Maximum Depth of Binary Tree

In [44]:
def maxDepth(root):
    if not root: return 0
    leftDepth = 1 + maxDepth(root.left)        
    rightDepth = 1 + maxDepth(root.right)
    return leftDepth if leftDepth >= rightDepth else rightDepth

In [45]:
root = initTree()
maxDepth(root)

3

## 110. Balanced Binary Tree

In [46]:
# recursive solution using maxDepth, O(n^2)
def isBalanced(root):
    if not root: return True
    return abs(maxDepth(root.left) - maxDepth(root.right)) <= 1 \
        and isBalanced(root.left) and isBalanced(root.right)

In [47]:
root = initTree()
isBalanced(root)

True

In [48]:
# calculate height using dfs (bottom up)
# if tree is not balanced, return negative int
# otherwise return maximum height
def height(root):
    if not root: return 0
    left = height(root.left)    
    right = height(root.right)
    
    # tree is not balanced
    if (left < 0 or right < 0) or \
        (abs(left - right) > 1): 
        return -1
    
    # tree is balanced
    return 1 + max(left, right)

In [49]:
# recursive dfs height, O(N) complexity
def isBalanced(root):
     return height(root) != -1

In [50]:
isBalanced(root)

True

## 100. Same Tree

In [51]:
def isSameTree(p, q):
    # both nodes null
    if (not p and not q): return True
    # one of nodes null
    if not (p and q): return False

    return p.val == q.val and \
        isSameTree(p.left, q.left) and \
        isSameTree(p.right, q.right)

In [52]:
p, q = initTree(), initTree()
isSameTree(p, q)

True

## 572. Subtree of Another Tree

In [53]:
# solution using the previous isSameTree function
def isSubtree(root, subRoot):
    if not subRoot and not root: return True
    if not (root and subRoot): return False
    res = isSameTree(root, subRoot)
    if (root.left): res = res or isSubtree(root.left, subRoot)
    if (root.right): res = res or isSubtree(root.right, subRoot)
    return res

In [54]:
p, q = initTree(), initTree()
isSubtree(p, q)

True

## 235. Lowest Common Ancestor of a Binary Search Tree

In [55]:
def lowestCommonAncestor(root, p, q):
    res = root
    
    def dfs(root):
        nonlocal res
        if not root: return False
        mid = root == p or root == q
        left = dfs(root.left)            
        right = dfs(root.right)
        if (mid and left) or (mid and right) or (left and right):
            res = root
        return left or right or mid 


    dfs(root)
    return res

In [56]:
p, q, lca = initTree(), initTree(), TreeNode(0, p, q)
lowestCommonAncestor(lca, p, q) == lca

True

## 102. Binary Tree Level Order Traversal

In [57]:
# dfs traversal with caching 
def solution():
    level = 0
    mapping = {}
    
    def levelOrder(root):
        nonlocal level, mapping

        # reset previously cached results
        if (level == 0 and mapping.keys()):
            mapping = {}
        
        if not root: 
            level -= 1
            return
        
        if level not in mapping:
            mapping[level] = []
        
        mapping[level].append(root.val)
        
        level += 1
        levelOrder(root.left)

        level += 1
        levelOrder(root.right)
        
        if not level: 
            return [mapping[k] for k in mapping.keys()] 
        else: 
            level -= 1  
            return      
    
    return levelOrder

In [58]:
levelOrder = solution()
levelOrder(initTree())

[[1], [2, 3], [4, 5, 6, 7]]

In [59]:
# bfs solution using queue
def levelOrder(root):
    res   = []
    queue = [root]
    level = []
    curr = None
    
    while len(queue):
        for i in range(len(queue)):
            curr = queue.pop(0)
            if (curr): 
                level.append(curr.val)
                queue.append(curr.left)
                queue.append(curr.right)
        if len(level): 
            res.append(level)
            level = []
        
    return res

In [60]:
levelOrder(initTree())

[[1], [2, 3], [4, 5, 6, 7]]

## 98. Validate Binary Search Tree

In [61]:
# dfs solution with dynamic boundaries, O(N)
def solution():
    boundaries = [None,  None]
    
    def isValidBST(root):
        nonlocal boundaries
        
        copy = [boundaries[0], boundaries[1]]
        
        if not root: return True
        
        isValid = True
        
        boundaries = [copy[0], root.val]
        isValid = isValid and isValidBST(root.left)
        
        boundaries = [root.val, copy[1]]
        isValid = isValid and isValidBST(root.right)
        
        boundaries = [copy[0], copy[1]]
        
        if boundaries[0] is not None:
            isValid = isValid and (root.val > boundaries[0])
        if boundaries[1] is not None:
            isValid = isValid and (root.val < boundaries[1])
        
        return isValid
    
    return isValidBST

In [62]:
isValidBST = solution()

inValid = initTree()
valid = initTree()
valid.left = TreeNode(valid.val - 1)
valid.right = TreeNode(valid.val + 1)

print(isValidBST(valid), isValidBST(inValid)) 

True False


## 1448. Count Good Nodes in Binary Tree

In [63]:
# dfs with updating the greatest value, O(N)
def solution():
    greatest = None
    total = 0
    
    def goodNodes(root):
        nonlocal greatest, total
    
        if not root: return 0
        
        # root node
        if greatest is None or root.val >= greatest: 
            total += 1
            greatest = root.val
            
        copy = greatest
        goodNodes(root.left)
        greatest = copy
        
        goodNodes(root.right)
        
        return total
    
    return goodNodes

In [64]:
goodNodes = solution()
goodNodes(initTree()) == 7

True

## 230. Kth Smallest Element in a BST

In [70]:
# dfs, add nodes to array, sort and return, O(N)
def solution():
    nodes = []
    
    def kthSmallest(root, k):
        nonlocal nodes
        
        if not root: return
        
        isTopLevelNode = len(nodes) == 0
        
        nodes.append(root)
        
        kthSmallest(root.left, k)        
        kthSmallest(root.right, k)
        
        if (isTopLevelNode):
            res = sorted(nodes, key=lambda node: node.val)[k - 1].val
            nodes.clear()
            return res
        
        return
    
    return kthSmallest

In [72]:
kthSmallest = solution()
kthSmallest(initTree(), 1) == 1

True