# Binary Search Trees (BSTs)

**What's a Binary Search Tree (BST):** 🕵️‍♂️
A BST is a [binary tree](./trees.ipynb#binary-trees) with **rules**:
- **Left child < Parent**
- **Right child > Parent**

This makes it super easy to search for stuff! Imagine organizing books on a shelf by their titles:

1. Go left if the title comes earlier alphabetically.
2. Go right if it comes later.

## BST Rules

A **Binary Search Tree (BST)** is a binary tree where each node follows these rules:

1. **Left Subtree:**: All values are smaller than the node's value.
2. **Right Subtree:** All values are greater than the node's value.
3. **No Duplicates:** No two nodes have the same value.

A BST is an efficient structure for **search, insertion, and deletion**, making it a favorite topic in interviews.

## BST Basic Operations

1. **Insert a Node**
2. **Search for a Value**
3. **Delete a Node**
4. **Find the Min and Max**
5. **Validate a BST**


# TreeNode
```
        A(5)
       /   \
    B(3)   C(8)
   /   \   /   \
D(1) E(4) F(7) G(9)
```

In [3]:
# Binary Search Trees (BSTs)
# Definition for a binary tree node.
# class TreeNode:
#       def __init__(self, val=0, left=None, right=None):
#           self.val = val
#           self.left = left
#           self.right = right

class TreeNode:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

    def __str__(self):
        return str(self.val)

# Creating a BST
A = TreeNode(5)
B = TreeNode(3)
C = TreeNode(8)
D = TreeNode(1)
E = TreeNode(4)
F = TreeNode(7)
G = TreeNode(9)

A.left, A.right = B, C
B.left, B.right = D, E
C.left, C.right = F, G

In [4]:
# Will print a TreeNode in in-order fashion
# [-1, 3, 1, 5, 7, 8, 9]
def print_bst(node: TreeNode):
    if not node:
        return
    print_bst(node.left)
    print(node.val)
    print_bst(node.right)

print_bst(A)

1
3
4
5
7
8
9


# Problems
- [Insert a Node](#insert-a-node)
- [Search for a Value](#search-for-a-value)
- [Delete a Node](#delete-a-node)
- [Find the Min and Max](#find-the-min-and-max)
- [Validate a BST](#validate-a-bst)

# Insert a Node

**Algorithm:**

- Recursively traverse the tree.
- Insert the node in the correct position following BST rules

Our goal is th take a list like the following:

`nums = [8, 3, 10, 1, 6, 4, 7, 14, 13]`

And have the output insert each value so our tree now looks like this:

```
        8
      /   \
     3    10
    / \      \
   1   6     14
      / \    /
     4   7  13
```

**Notice:**

Our root is 8. 
All values on the left are smaller, and all values on the right are larger.

In [5]:
from typing import List, Optional

# Insert a Node
# The tree structure should resemble the following:
#         8
#       /   \
#      3    10
#     / \      \
#    1   6     14
#       / \    /
#      4   7  13

class Solution:
    def insert(self, node: TreeNode, value: int) -> TreeNode:
        if node is None:
            return TreeNode(value)

        if value < node.val:
            node.left = self.insert(node.left, value)
        elif value > node.val:
            node.right = self.insert(node.right, value)
        
        return node
    
nums = [8,3,10,1,6,4,7,14,13]
sol = Solution()
tree: Optional[TreeNode] = None # init tree as None

for n in nums:
    tree = sol.insert(tree, n) # update the tree reference with each inseration

# Test
# Should output [1,3,4,6,7,8,10,13,14]
print_bst(tree)


1
3
4
6
7
8
10
13
14


In [6]:
# Itterative solution
from typing import Optional

class Solution:
    def insert_iterative(self, root: Optional[TreeNode], val: int) -> Optional[TreeNode]:
        if root is None:
            # If the tree is empty, create a new root node
            return TreeNode(val)
        
        cur = root
        while True:
            if val > cur.val:
                # Move to the right child
                if not cur.right:
                    # Insert new node if right child is None
                    cur.right = TreeNode(val)
                    return root
                cur = cur.right
            else:
                # Move to the left child
                if not cur.left:
                    # Insert new node if left child is None
                    cur.left = TreeNode(val)
                    return root
                cur = cur.left
        return root


# Search for a Value

In [7]:
# Search Time: O(log n), Space: O(log, n)

def search_bst(node, target):
    if node is None:
        return False

    if node.val == target:
        return True
    
    if target < node.val:
        return search_bst(node.left, target)
    else: 
        return search_bst(node.right, target)
# Test
# Expected Result: True 
search_bst(A, 8)

True

# Delete a Node

In [8]:
# Delete a Node
# Time: O(log n) or O(h) if unbalanced
# Space: O(log n) or O(h) if unbalanced

class Solution:
    def delete_node(self, root: Optional[TreeNode], key: int) -> Optional[TreeNode]:
        if root is None:
            return root
        
        if key > root.val:
            root.right = self.delete_node(root.right, key)
        elif key < root.val:
            root.left = self.delete_node(root.left, key)
        else:
            if root.left is None:
                return root.right
            elif root.right is None:
                return root.left

            # Find the min from right subtree
            cur = root.right
            while cur.left:
                cur = cur.left
            root.val = cur.val
            root.right = self.delete_node(root.right, root.val)

        return root

# Test
# Example: Delete node with value 3
sol = Solution()
tree = sol.delete_node(tree, 3)
print_bst(tree)  # Should print the BST without the node containing value 3

1
4
6
7
8
10
13
14



# Find the Min and Max

In a **Binary Search Tree (BST)**:
- **The minimum value is always in the leftmost node**.
- **The maximum value is always in the rightmost node**.

This means to find the min/max, **we just keep going left/right until we can't anymore.**

## How It Works

### Finding the Minimum:
- Start at the root.
- Keep moving **left** until you hit a node with no left child.
- That node is the minimum.

### Finding the Maximum:
- Start at the root.
- Keep moving **right** until you hit a node with no right child.
- That node is the maximum.

This works because:
- The leftmost node is the **smallest** by definition.
- The rightmost node is the **largest** by definition.

## Example BST:

```
         50
        /  \
      30    70
     / \    / \
   20  40  60  80
```

- **Minimum Value?** `20` (leftmost node)
- **Maximum Value?** `80` (rightmost node)

##  Time and Space Complexity

### Time Complexity: **O(h)**
- `h` is the height of the tree.
- In a balanced BST, `h = log(n)`, so in the **best case**, it's **O(log n)**.
- In a **worst-case (unbalanced tree)**, `h = n`, so it's **O(n)**.

### Space Complexity: **O(1)**
- We’re only using a pointer, so space is **constant**.


## Edge Cases:
1. **Empty Tree**: If the tree is empty (`root == None`), return `None`.
2. **Single Node**: If there's only one node, it's both the min and max.
3. **Skewed Trees**: 
   - If all nodes are on the left (`left-skewed`), the max is the root.
   - If all nodes are on the right (`right-skewed`), the min is the root.

## Key Takeaways:
- **Min = leftmost node**; **Max = rightmost node**.
- **Time Complexity: O(h), Space Complexity: O(1)**.
- **Iterative vs Recursive Approach**: Iterative is simpler and avoids extra stack memory.

In [9]:
# Find the Min and Max

class TreeNode:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None

def find_min(node):
    """Finds the minimum value in a BST."""
    if not node:
        return None  # Edge case: empty tree
    while node.left:
        node = node.left  # Keep moving left
    return node.value

def find_max(node):
    """Finds the maximum value in a BST."""
    if not node:
        return None  # Edge case: empty tree
    while node.right:
        node = node.right  # Keep moving right
    return node.value

# Build example tree
root = TreeNode(50)
root.left = TreeNode(30)
root.right = TreeNode(70)
root.left.left = TreeNode(20)
root.left.right = TreeNode(40)
root.right.left = TreeNode(60)
root.right.right = TreeNode(80)

# Run min/max functions
print("Minimum value in BST:", find_min(root))  # Output: 20
print("Maximum value in BST:", find_max(root))  # Output: 80

Minimum value in BST: 20
Maximum value in BST: 80


In [10]:
# Recursive solution

def find_min_recursive(node):
    if not node:
        return None
    if not node.left:
        return node.value
    return find_min_recursive(node.left)

def find_max_recursive(node):
    if not node:
        return None
    if not node.right:
        return node.value
    return find_max_recursive(node.right)

# Validate a BST

Given teh root of a binary tree, determine if it is a valid binary tree (BST).

A valid BST is defined as follows:
- The left subtree of a node contains only nodes with keys less than the node's key.
- the right subtree of a node contains only nodes with keys greater than the node's key.
- Both the left and right subtrees must also be binary search trees.

In [11]:
# 5. Validate a BST

# Example
#       2
#      / \
#     1   3
#
# Input: root = [2, 1, 3]
# Output: true

#           5
#          / \
#         3   7
#            / \
#           4   8

# O(n) time
# O(k) space ??
# Not a Valid BST because 5 > 4

# - inf < 3 < 5 ( go left update right )
#  5 < 7 < inf  ( go right update left )

#  5 < 4 < 7


from typing import Optional

class Solution:
    def validate_bst(self, root: Optional[TreeNode]) -> bool:

        def valid(node, left, right):
            if not node:
                return True # empty is valid bst
            
            if not ( node.val < right and node.val > left ):
                return False
            
            return (valid(node.left, left, node.val) and valid(node.right, node.val, right))
    
        return valid(root, float("-inf"), float("inf"))






# Kth Smallest Element in BST

Given the root of a binary search tree, and an integer k, return the kth smallest value ( 1-indexed ) of all the values of the nodes in the tree.

In [15]:
# Definition for a binary tree node.
# class TreeNode:
#   def __init__(self, val=0, left=None, right=None):
#       self.val = val
#       self.left = left
#       self.right = right
#           3
#         /   \
#       1       4
#        \
#         2
# Input: root = [3,1,4,null,2], k=1 (1st smallest element)
# Output: 1

#                   5
#                  / \
#                 3   6
#                / \
#               2   4
#              /
#             1
# 
# Input: root = [5,3,6,2,4,null,null,1], k=3 (3rd smallest element)
# Output: 3 


from typing import Optional

# Find the smallest value by going all the way to the left
# We can basically do an In order traversal and decrement K
# Is k == 1 No, so k -= 1, etc

# Can generate in Big(O)n time and space O(h)
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        self.k = k
        self.result = None
        
        def dfs(node):
            if not node or self.result is not None:
                return
            
            dfs(node.left)
            
            self.k -= 1
            if self.k == 0:
                self.result = node.value
                return
                
            dfs(node.right)
            
        dfs(root)
        return self.result
    
##
# Test cases
def test_kth_smallest():
    # Test case 1: Simple tree
    root1 = TreeNode(3)
    root1.left = TreeNode(1)
    root1.right = TreeNode(4)
    root1.left.right = TreeNode(2)
    
    # Test case 2: Larger tree
    root2 = TreeNode(5)
    root2.left = TreeNode(3)
    root2.right = TreeNode(6)
    root2.left.left = TreeNode(2)
    root2.left.right = TreeNode(4)
    root2.left.left.left = TreeNode(1)
    
    sol = Solution()
    
    # Verify results
    assert sol.kthSmallest(root1, 1) == 1, "Test 1 failed"
    assert sol.kthSmallest(root1, 2) == 2, "Test 2 failed"
    assert sol.kthSmallest(root2, 3) == 3, "Test 3 failed"
    
    print("All test cases passed!")

test_kth_smallest()



All test cases passed!
