# Binary Tree - DFS

## 1) Maximum Depth of Binary Tree

Given the root of a binary tree, return its maximum depth.

A binary tree's maximum depth is the number of nodes along the longest path from the root node down to the farthest leaf node.

<b>Example</b>

Input: root = [3, 9, 20, null, null, 15, 7] <br />
Output: 3

<b>Example</b>

Input: root = [1, null, 2] <br />
Output: 2

In [1]:
from typing import Optional

In [2]:
class TreeNode:
    def __init__(self, val = 0, left = None, right = None):
        self.val = val
        self.left = left
        self.right = right

In [3]:
# Recursive Approach

def maxDepth(root: Optional[TreeNode]) -> int:
    
    if not root:
        return 0
    
    return max(maxDepth(root.left), maxDepth(root.right)) + 1

In [4]:
# Iterative Approach

def maxDepth_iter(root: Optional[TreeNode]) -> int:
    
    stack = []
    if root:
        stack.append((1, root))
    
    depth = 0
    while stack:
        cur_depth, root = stack.pop()
        if root:
            depth = max(cur_depth, depth)
            stack.append((cur_depth + 1, root.left))
            stack.append((cur_depth + 1, root.right))
    
    return depth

In [5]:
root_1 = TreeNode(3)
root_1.left = TreeNode(9)
root_1.right = TreeNode(20)
root_1.right.left = TreeNode(15)
root_1.right.right = TreeNode(7)

maxDepth(root_1)

3

In [6]:
root_2 = TreeNode(1)
root_2.right = TreeNode(2)

maxDepth(root_2)

2

In [7]:
root_1 = TreeNode(3)
root_1.left = TreeNode(9)
root_1.right = TreeNode(20)
root_1.right.left = TreeNode(15)
root_1.right.right = TreeNode(7)

maxDepth_iter(root_1)

3

In [8]:
root_2 = TreeNode(1)
root_2.right = TreeNode(2)

maxDepth_iter(root_2)

2

## 2) Leaf-Similar Trees

Consider all the leaves of a binary tree, from left to right order, the values of those leaves form a leaf value sequence.

Two binary trees are considered leaf-similar if their leaf value sequence is the same.

Return true if and only if the two given trees with head nodes root1 and root2 are leaf-similar.

<b>Example</b>

Input: root1 = [3, 5, 1, 6, 2, 9, 8, null, null, 7, 4], root2 = [3, 5, 1, 6, 7, 4, 2, null, null, null, null, null, null, 9, 8] <br />
Output: true

<b>Example</b>

Input: root1 = [1, 2, 3], root2 = [1, 3, 2] <br />
Output: false

In [9]:
from typing import Optional

In [10]:
class TreeNode:
    def __init__(self, val = 0, left = None, right = None):
        self.val = val
        self.left = left
        self.right = right

In [11]:
def leafSimilar(root1: Optional[TreeNode], root2: Optional[TreeNode]) -> bool:
    
    leaf_seq1 = []
    leaf_seq2 = []
    
    stack1 = []
    stack1.append(root1)
    
    while stack1:
        node = stack1.pop()
        if node:
            if not node.left and not node.right:
                leaf_seq1.append(node.val)
            stack1.append(node.left)
            stack1.append(node.right)
    
    stack2 = []
    stack2.append(root2)
    
    while stack2:
        node = stack2.pop()
        if node:
            if not node.left and not node.right:
                leaf_seq2.append(node.val)
            stack2.append(node.left)
            stack2.append(node.right)
    
    return leaf_seq1[::-1] == leaf_seq2[::-1]

In [12]:
def leafSimilar_rec(root1: Optional[TreeNode], root2: Optional[TreeNode]) -> bool:
    def dfs(node):
        if node:
            if not node.left and not node.right:
                yield node.val
            yield from dfs(node.left)
            yield from dfs(node.right)

    return list(dfs(root1)) == list(dfs(root2))

In [13]:
def leafSimilar_rec(root1: Optional[TreeNode], root2: Optional[TreeNode]) -> bool:
    def dfs(node):
        if not node:
            return []

        if not node.left and not node.right:
            return [node.val]

        return dfs(node.left) + dfs(node.right)

    return dfs(root1) == dfs(root2)

In [14]:
root_1 = TreeNode(3)
root_1.left = TreeNode(5)
root_1.right = TreeNode(1)
root_1.left.left = TreeNode(6)
root_1.left.right = TreeNode(2)
root_1.left.right.left = TreeNode(7)
root_1.left.right.right = TreeNode(4)
root_1.right.left = TreeNode(9)
root_1.right.right = TreeNode(8)

root_2 = TreeNode(3)
root_2.left = TreeNode(5)
root_2.right = TreeNode(1)
root_2.left.left = TreeNode(6)
root_2.left.right = TreeNode(7)
root_2.right.left = TreeNode(4)
root_2.right.right = TreeNode(2)
root_2.right.right.left = TreeNode(9)
root_2.right.right.right = TreeNode(8)

leafSimilar(root_1, root_2)

True

In [15]:
root_1 = TreeNode(1)
root_1.left = TreeNode(2)
root_1.right = TreeNode(3)

root_2 = TreeNode(1)
root_2.left = TreeNode(3)
root_2.right = TreeNode(2)

leafSimilar(root_1, root_2)

False

In [16]:
root_1 = TreeNode(3)
root_1.left = TreeNode(5)
root_1.right = TreeNode(1)
root_1.left.left = TreeNode(6)
root_1.left.right = TreeNode(2)
root_1.left.right.left = TreeNode(7)
root_1.left.right.right = TreeNode(4)
root_1.right.left = TreeNode(9)
root_1.right.right = TreeNode(8)

root_2 = TreeNode(3)
root_2.left = TreeNode(5)
root_2.right = TreeNode(1)
root_2.left.left = TreeNode(6)
root_2.left.right = TreeNode(7)
root_2.right.left = TreeNode(4)
root_2.right.right = TreeNode(2)
root_2.right.right.left = TreeNode(9)
root_2.right.right.right = TreeNode(8)

leafSimilar_rec(root_1, root_2)

True

In [17]:
root_1 = TreeNode(1)
root_1.left = TreeNode(2)
root_1.right = TreeNode(3)

root_2 = TreeNode(1)
root_2.left = TreeNode(3)
root_2.right = TreeNode(2)

leafSimilar_rec(root_1, root_2)

False

## 3) Count Good Nodes in Binary Tree

Given a binary tree root, a node X in the tree is named good if in the path from root to X there are no nodes with a value greater than X.

Return the number of good nodes in the binary tree.

<b>Example</b>

Input: root = [3, 1, 4, 3, null, 1, 5] <br />
Output: 4

Explanation: Nodes in blue are good. <br />
Root Node (3) is always a good node. <br />
Node 4 -> (3,4) is the maximum value in the path starting from the root. <br />
Node 5 -> (3,4,5) is the maximum value in the path <br />
Node 3 -> (3,1,3) is the maximum value in the path.

<b>Example</b>

Input: root = [3, 3, null, 4, 2] <br />
Output: 3

Explanation: 
Node 2 -> (3, 3, 2) is not good, because "3" is higher than it.

<b>Example</b>

Input: root = [1] <br />
Output: 1

Explanation: Root is considered as good.

In [18]:
from typing import Optional

In [19]:
class TreeNode:
    def __init__(self, val = 0, left = None, right = None):
        self.val = val
        self.left = left
        self.right = right

In [20]:
def goodNodes_iter(root: TreeNode) -> int:
    
    stack = [(float('-inf'), root)]
    
    good_nodes_num = 0
    while stack:
        max_so_far, node = stack.pop()
        if max_so_far <= node.val:
            good_nodes_num += 1
        if node.left:
            stack.append((max(max_so_far, node.val), node.left))
        if node.right:
            stack.append((max(max_so_far, node.val), node.right))
    
    return good_nodes_num

In [21]:
root_1 = TreeNode(3)
root_1.left = TreeNode(1)
root_1.right = TreeNode(4)
root_1.left.left = TreeNode(3)
root_1.right.left = TreeNode(1)
root_1.right.right = TreeNode(5)

goodNodes_iter(root_1)

4

In [22]:
root_1 = TreeNode(3)
root_1.left = TreeNode(3)
root_1.left.left = TreeNode(4)
root_1.left.right = TreeNode(2)

goodNodes_iter(root_1)

3

In [23]:
root_1 = TreeNode(1)

goodNodes_iter(root_1)

1

In [24]:
def goodNodes(root: TreeNode) -> int:
    def dfs(node, max_so_far):
        nonlocal num_good_nodes
        if max_so_far <= node.val:
            num_good_nodes += 1
        if node.right:
            dfs(node.right, max(node.val, max_so_far))
        if node.left:
            dfs(node.left, max(node.val, max_so_far))
        
    num_good_nodes = 0
    dfs(root, float("-inf"))
    
    return num_good_nodes

In [25]:
root_1 = TreeNode(3)
root_1.left = TreeNode(1)
root_1.right = TreeNode(4)
root_1.left.left = TreeNode(3)
root_1.right.left = TreeNode(1)
root_1.right.right = TreeNode(5)

goodNodes(root_1)

4

In [26]:
root_1 = TreeNode(3)
root_1.left = TreeNode(3)
root_1.left.left = TreeNode(4)
root_1.left.right = TreeNode(2)

goodNodes_iter(root_1)

3

In [27]:
root_1 = TreeNode(1)

goodNodes_iter(root_1)

1

## 4) Path Sum III

Given the root of a binary tree and an integer targetSum, return the number of paths where the sum of the values along the path equals targetSum.

The path does not need to start or end at the root or a leaf, but it must go downwards (i.e., traveling only from parent nodes to child nodes).

<b>Example</b>

Input: root = [10, 5, -3, 3, 2, null, 11, 3, -2, null, 1], targetSum = 8 <br />
Output: 3 <br />

Explanation: The paths that sum to 8 are shown.

<b>Example</b>

Input: root = [5, 4, 8, 11, null, 13, 4, 7, 2, null, null, 5, 1], targetSum = 22 <br />
Output: 3

In [28]:
from typing import Optional
from collections import defaultdict

In [29]:
class TreeNode:
    def __init__(self, val = 0, left = None, right = None):
        self.val = val
        self.left = left
        self.right = right

In [30]:
def pathSum(root: Optional[TreeNode], targetSum: int) -> int:
    
    def preorder(node: TreeNode, curr_sum) -> None:
        nonlocal count
        if not node:
            return 
            
        # The current prefix sum
        curr_sum += node.val
        
        if curr_sum == k:
            count += 1
            
        # The number of times the curr_sum − k has occurred already, 
        # determines the number of times a path with sum k 
        # has occurred up to the current node
        # THE KEY PART OF THE SOLUTION:
        count += h[curr_sum - k]
            
        # Add the current sum into a hashmap
        # to use it during the child nodes' processing
        h[curr_sum] += 1
            
        # Process the left subtree
        preorder(node.left, curr_sum)
        # Process the right subtree
        preorder(node.right, curr_sum)
            
        # Remove the current sum from the hashmap
        # in order not to use it during 
        # the parallel subtree processing
        h[curr_sum] -= 1
            
    count, k = 0, targetSum
    h = defaultdict(int)
    preorder(root, 0)
    
    return count

In [31]:
# Same Solution (probably easier to understand)

def pathSum(root: Optional[TreeNode], targetSum: int) -> int:
    sums = defaultdict(int)
    sums[0] = 1
        
    def dfs(root, total):
        count = 0
        if root:
            total += root.val

            count = sums[total - targetSum]

            # THE KEY PART OF THE SOLUTION: 
            sums[total] += 1
            count += dfs(root.left, total) + dfs(root.right, total)
            sums[total] -= 1
        return count
    
    return dfs(root, 0)

In [32]:
def pathSum_iter(root: Optional[TreeNode], targetSum: int) -> int:
    
    if not root:
        return 0
    
    count = 0
    
    # To track all the prefix sums
    prefix_sums = defaultdict(int)
    
    prefix_nodes = []
    prefix_sum = 0
    
    next_nodes = [(root, 1)]
    
    while next_nodes:
        node, depth = next_nodes.pop()
        
        # Correct the prefix_sum by subtracting the values of nodes
        # that are not a part of the this node's path to the root.
        while len(prefix_nodes) >= depth:
            prefix_sums[prefix_sum] -= 1
            prefix_sum -= prefix_nodes.pop()
        
        prefix_sum += node.val
        
        if prefix_sum == targetSum:
            count += 1
        
        count += prefix_sums[prefix_sum - targetSum]
        
        prefix_nodes.append(node.val)
        prefix_sums[prefix_sum] += 1
        
        if node.right:
            next_nodes.append((node.right, depth + 1))

        if node.left:
            next_nodes.append((node.left, depth + 1))
        
    return count

In [33]:
root_1 = TreeNode(10)
root_1.left = TreeNode(5)
root_1.right = TreeNode(-3)
root_1.left.left = TreeNode(3)
root_1.left.right = TreeNode(2)
root_1.right.right = TreeNode(11)
root_1.left.left.left = TreeNode(3)
root_1.left.left.right = TreeNode(-2)
root_1.left.right.right = TreeNode(1)

pathSum(root_1, 8)

3

In [34]:
root_2 = TreeNode(5)
root_2.left = TreeNode(4)
root_2.right = TreeNode(8)
root_2.left.left = TreeNode(11)
root_2.left.left.left = TreeNode(7)
root_2.left.left.right = TreeNode(2)
root_2.right.left = TreeNode(13)
root_2.right.right = TreeNode(4)
root_2.right.right.left = TreeNode(5)
root_2.right.right.right = TreeNode(1)

pathSum(root_2, 22)

3

In [35]:
root_1 = TreeNode(10)
root_1.left = TreeNode(5)
root_1.right = TreeNode(-3)
root_1.left.left = TreeNode(3)
root_1.left.right = TreeNode(2)
root_1.right.right = TreeNode(11)
root_1.left.left.left = TreeNode(3)
root_1.left.left.right = TreeNode(-2)
root_1.left.right.right = TreeNode(1)

pathSum_iter(root_1, 8)

3

In [36]:
root_2 = TreeNode(5)
root_2.left = TreeNode(4)
root_2.right = TreeNode(8)
root_2.left.left = TreeNode(11)
root_2.left.left.left = TreeNode(7)
root_2.left.left.right = TreeNode(2)
root_2.right.left = TreeNode(13)
root_2.right.right = TreeNode(4)
root_2.right.right.left = TreeNode(5)
root_2.right.right.right = TreeNode(1)

pathSum_iter(root_2, 22)

3

## 5) Longest ZigZag Path in a Binary Tree

You are given the root of a binary tree.

A ZigZag path for a binary tree is defined as follow:

* Choose any node in the binary tree and a direction (right or left).
* If the current direction is right, move to the right child of the current node; otherwise, move to the left child.
* Change the direction from right to left or from left to right.
* Repeat the second and third steps until you can't move in the tree.

Zigzag length is defined as the number of nodes visited - 1. (A single node has a length of 0).

Return the longest ZigZag path contained in that tree.

<b>Example</b>

Input: root = [1, null, 1, 1, 1, null, null, 1, 1, null, 1, null, null, null, 1] <br />
Output: 3

Explanation: Longest ZigZag path in blue nodes (right -> left -> right).

<b>Example</b>

Input: root = [1, 1, 1, null, 1, null, null, 1, 1, null, 1] <br />
Output: 4

Explanation: Longest ZigZag path in blue nodes (left -> right -> left -> right).

<b>Example</b>

Input: root = [1] <br />
Output: 0

In [37]:
class TreeNode:
    def __init__(self, val = 0, left = None, right = None):
        self.val = val
        self.left = left
        self.right = right

In [38]:
def longestZigZag(root: Optional[TreeNode]) -> int:
    
    stack = [(0, '_', root)]
    
    max_zig_zag = 0
    while stack:
        zig_zag_length, turn, node = stack.pop()
        if node:
            max_zig_zag = max(max_zig_zag, zig_zag_length)

            if turn != 'right':
                stack.append((zig_zag_length + 1, 'right', node.left))
                stack.append((1, 'left', node.right))
            if turn != 'left':
                stack.append((1, 'right', node.left))
                stack.append((zig_zag_length + 1, 'left', node.right))
                
    return max_zig_zag

In [39]:
# Recursive Solution

def longestZigZag(root: Optional[TreeNode]) -> int:
    
    max_zig_zag = 0
    
    def dfs(node, goLeft, zig_zag_length):
        if node:
            nonlocal max_zig_zag
            max_zig_zag = max(max_zig_zag, zig_zag_length)
            if goLeft:
                dfs(node.left, False, zig_zag_length + 1)
                dfs(node.right, True, 1)
            else:
                dfs(node.left, False, 1)
                dfs(node.right, True, zig_zag_length + 1)
    
    dfs(root, False, 0)
    dfs(root, True, 0)
    
    return max_zig_zag

In [40]:
def longestZigZag(root: Optional[TreeNode]) -> int:
    maxi = 0

    def dfs(node, left, right):
        nonlocal maxi
        maxi = max(maxi, left, right)

        if node.left:
            dfs(node.left, right + 1, 0)

        if node.right:
            dfs(node.right, 0, left + 1)

    dfs(root, 0, 0)
    return maxi

In [41]:
root = TreeNode(1)
root.right = TreeNode(1)
root.right.left = TreeNode(1)
root.right.right = TreeNode(1)
root.right.right.left = TreeNode(1)
root.right.right.right = TreeNode(1)
root.right.right.left.right = TreeNode(1)
root.right.right.left.right.right = TreeNode(1)

longestZigZag(root)

3

In [42]:
root = TreeNode(1)
root.left = TreeNode(1)
root.right = TreeNode(1)
root.left.right = TreeNode(1)
root.left.right.left = TreeNode(1)
root.left.right.right = TreeNode(1)
root.left.right.left.right = TreeNode(1)

longestZigZag(root)

4

In [43]:
root = TreeNode(1)

longestZigZag(root)

0

## 6) Lowest Common Ancestor of a Binary Tree

Given a binary tree, find the lowest common ancestor (LCA) of two given nodes in the tree.

According to the definition of LCA on Wikipedia: “The lowest common ancestor is defined between two nodes p and q as the lowest node in T that has both p and q as descendants (where we allow a node to be a descendant of itself).”

<b>Example</b>

Input: root = [3, 5, 1, 6, 2, 0, 8, null, null, 7, 4], p = 5, q = 1 <br />
Output: 3

Explanation: The LCA of nodes 5 and 1 is 3.

<b>Example</b>

Input: root = [3, 5, 1, 6, 2, 0, 8, null, null, 7, 4], p = 5, q = 4 <br />
Output: 5

Explanation: The LCA of nodes 5 and 4 is 5, since a node can be a descendant of itself according to the LCA definition.

<b>Example</b>

Input: root = [1, 2], p = 1, q = 2 <br />
Output: 1

In [44]:
class TreeNode:
    def __init__(self, val = 0, left = None, right = None):
        self.val = val
        self.left = left
        self.right = right

In [45]:
def lowestCommonAncestor(root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
    stack = [root]

    parent = {root: None}

    while p not in parent or q not in parent:

        node = stack.pop()
        if node.left:
            parent[node.left] = node
            stack.append(node.left)
        if node.right:
            parent[node.right] = node
            stack.append(node.right)

    ancestors = set()

    while p:
        ancestors.add(p)
        p = parent[p]

    # The first ancestor of q which appears in
    # p's ancestor set() is their lowest common ancestor.
    while q not in ancestors:
        q = parent[q]
        
    return q

In [46]:
root = TreeNode(3)
root.left = TreeNode(5)
root.right = TreeNode(1)
root.left.left = TreeNode(6)
root.left.right = TreeNode(2)
root.left.right.left = TreeNode(7)
root.left.right.right = TreeNode(4)
root.right.left = TreeNode(0)
root.right.right = TreeNode(8)

p = root.left
q = root.right

lowestCommonAncestor(root, p, q).val

3

In [47]:
root = TreeNode(3)
root.left = TreeNode(5)
root.right = TreeNode(1)
root.left.left = TreeNode(6)
root.left.right = TreeNode(2)
root.left.right.left = TreeNode(7)
root.left.right.right = TreeNode(4)
root.right.left = TreeNode(0)
root.right.right = TreeNode(8)

p = root.left
q = root.left.right.right

lowestCommonAncestor(root, p, q).val

5

In [48]:
root = TreeNode(1)
root.left = TreeNode(2)

p = root
q = root.left

lowestCommonAncestor(root, p, q).val

1