Validate BST: Implement a function to check if a binary tree is a binary search tree.

In [1]:
from binarytree import Node

root_bst = Node(3)
root_bst.left = Node(2)
root_bst.right = Node(5)
root_bst.left.left = Node(0)
root_bst.right.left = Node(4)
root_bst.right.right = Node(8)
print(root_bst)

root_bt = Node(1)
root_bt.left = Node(2)
root_bt.right = Node(3)
print(root_bt)


    3__
   /   \
  2     5
 /     / \
0     4   8


  1
 / \
2   3



In [11]:
from collections import deque


def find_min(root: Node) -> tuple[int, int]:
    """Finds the minimum value in the tree under root (including root)."""
    q = deque([root])
    min_ = float("inf")
    while q:
        node = q.popleft()
        min_ = min(node.value, min_)
        if node.left:
            q.append(node.left)
        if node.right:
            q.append(node.right)
    return min_

def find_max(root: Node) -> tuple[int, int]:
    """Finds the maximum value in the tree under root (including root)."""
    q = deque([root])
    max_ = float("-inf")
    while q:
        node = q.popleft()
        max_ = max(node.value, max_)
        if node.left:
            q.append(node.left)
        if node.right:
            q.append(node.right)
    return max_

def validate_BST(root: Node) -> tuple[bool, int, int]:
    """Validates if the tree under root is a binary search tree (BST).
    Args:
        root: a Node.
    Returns:
        A 3-tuple of (is_valid, min, max), where is_valid is True if the
        tree under root is a valid BST, and min/max are the minimum and
        maximum value observed in the whole tree.
    """
    nodes_to_validate = deque([root])
    while nodes_to_validate:
        node = nodes_to_validate.popleft()

        if node.left:
            max_l = find_max(node.left)
            valid_l = max_l <= node.value
        else:
            valid_l = True

        if node.right:
            min_r = find_min(node.right)
            valid_r = min_r > node.value
        else:
            valid_r = True

        if not valid_l or not valid_r:
            # If *any* subtree is not a BST, then the whole structure is not a BST,
            # so we can exit early.
            return False
    
        if node.left:
            nodes_to_validate.append(node.left)
        if node.right:
            nodes_to_validate.append(node.right)

    return True

validate_BST(root_bt)

False

# Gayle's solution O(N) -- MUCH better

- Go down the tree starting from the root
- Start with the range `(-inf, +inf]`. The root is within it. OK.
- Go left. When branching left, update the right bound of the range with the node's value: `(-inf, 20]`.
  - Is the left node within the valid range? OK.
- Go right. When branching right, update the left bound: `(20, +inf]`.
  - Is the right node within the valid range? OK
- Keep branching in every non-null direction until the leaves.
- If *any* node fails to fall within the valid range, the whole tree fails and you can exit early.
- To know that the tree is a valid BST, you MUST hit all the nodes at least once. Any node can cause a failure.



In [31]:
class InvalidBstException(Exception):
    pass

def _validate_BST(node: Node, range_: tuple = (float("-inf"), float("+inf"))) -> bool:
    if node.value < range_[0] or node.value >= range_[1]:
        raise InvalidBstException  # Break out of the recursion

    valid_l = _validate_BST(node.left, (range_[0], node.value)) if node.left else True
    valid_r = _validate_BST(node.right, (node.value, range_[1])) if node.right else True

    return valid_l and valid_r

def validate_BST(node: Node) -> bool:
    try:
        return _validate_BST(node)
    except InvalidBstException:
        return False

print(root_bst)
validate_BST(root_bst)


    3__
   /   \
  2     5
 /     / \
0     4   8



True