In [59]:
import collections

In [60]:
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

In [61]:
def create_example_tree():
    nodes = [TreeNode(i) for i in range(1, 11)]

    nodes[0].left = nodes[1]
    nodes[0].right = nodes[2]
    nodes[1].left = nodes[3]
    nodes[1].right = nodes[4]
    nodes[2].left = nodes[5]
    nodes[2].right = nodes[6]
    nodes[3].left = nodes[7]
    nodes[3].right = nodes[8]
    nodes[4].left = nodes[9]

    return nodes[0]


tree = create_example_tree()

In [62]:
# 104


def max_depth1(root: TreeNode | None) -> int:
    if not root:
        return 0
    return 1 + max(max_depth(root.left), max_depth(root.right))

In [63]:
# 104 iterative


def max_depth(root: TreeNode | None) -> int:
    stk = [(root, 1)]
    res = 0

    while stk:
        node, depth = stk.pop()
        if node:
            res = max(res, depth)
            stk.append((node.left, depth + 1))
            stk.append((node.right, depth + 1))
    return res

In [64]:
max_depth1(tree)

4

In [65]:
# 872


def leaf_similar(root1: TreeNode | None, root2: TreeNode | None) -> bool:
    def collect_leaves(node: TreeNode, leaves: list) -> list:
        if not node:
            return []
        if not (node.left or node.right):
            leaves.append(node.val)

        if node.left:
            collect_leaves(node.left, leaves)
        if node.right:
            collect_leaves(node.right, leaves)
        return leaves

    return collect_leaves(root1, []) == collect_leaves(root2, [])

In [66]:
# 872 iterative


def leaf_similar(root1: TreeNode | None, root2: TreeNode | None) -> bool:
    def collect_leaves(root: TreeNode) -> list:
        stk, leaves = [root], []

        while stk:
            node = stk.pop()
            if not (node.left or node.right):
                leaves.append(node.val)
            if node.left:
                stk.append(node.left)
            if node.right:
                stk.append(node.right)
        return leaves

    return collect_leaves(root1) == collect_leaves(root2)

In [67]:
# 1448


def good_nodes(root: TreeNode) -> int:
    count: int = 0

    def dfs(node: TreeNode, max_val: int):
        nonlocal count
        if node.val >= max_val:
            count += 1
            max_val = max(node.val, max_val)
        if node.left:
            dfs(node.left, max_val)
        if node.right:
            dfs(node.right, max_val)
        return count

    return dfs(root, root.val)

In [68]:
good_nodes(tree)

10

In [69]:
# 437 - pretty not efficient, exponential big o


def path_sum(root: TreeNode | None, target_sum: int) -> int:
    count = 0

    def helper(node: TreeNode, curr_sum: int):
        nonlocal count
        if not node:
            return
        helper(node.left, curr_sum + node.val)
        helper(node.right, curr_sum + node.val)
        if curr_sum + node.val == target_sum:
            count += 1

    def dfs(node: TreeNode):
        if not node:
            return
        helper(node, 0)
        dfs(node.left)
        dfs(node.right)

    dfs(root)
    return count

In [70]:
# O(n), check prefix in the list
def path_sum(root: TreeNode | None, target_sum: int) -> int:
    prefixes = collections.defaultdict(int)
    prefixes[0] = 1
    total = 0

    def dfs(node: TreeNode, curr_sum: int):
        nonlocal total

        if not node:
            return

        curr_sum += node.val
        total += prefixes[curr_sum - target_sum]  # curr_sum - prev_sum = target_sum
        prefixes[curr_sum] += 1

        dfs(node.left, curr_sum)
        dfs(node.right, curr_sum)

        prefixes[curr_sum] -= 1

    dfs(root, 0)
    return total

In [71]:
def longest_zig_zag(root: TreeNode | None) -> int:
    # start, left, right = -1, 0, 1
    longest = 0

    def dfs(node: TreeNode, child: int, curr_sum: int = 0):
        nonlocal longest
        if not node:
            return

        longest = max(longest, curr_sum)
        if node.left:
            dfs(node.left, 0, curr_sum + 1 if child == 1 else 1)
        if node.right:
            dfs(node.right, 1, curr_sum + 1 if child == 1 else 1)

    dfs(root, -1)
    return longest

In [72]:
longest_zig_zag(tree)

2

In [91]:
# 236


def lowest_common_ancestor(root: TreeNode, p: TreeNode, q: TreeNode) -> TreeNode:
    ancestor: TreeNode = TreeNode(-1)

    def dfs(node: TreeNode) -> bool:
        if not node:
            return False

        left = dfs(node.left)
        right = dfs(node.right)
        curr_node = node == p or node == q

        if (left and right) or (curr_node and left) or (curr_node and right):
            nonlocal ancestor
            ancestor = node

        return left or right or curr_node

    dfs(root)
    return ancestor

In [92]:
x, y = TreeNode(2), TreeNode(9)
lowest_common_ancestor(tree, x, y)