## Search in a BST

You are given the root of a binary search tree (BST) and an integer val. Your task is to find the node in the BST whose value equals val and return the subtree rooted with that node. 

If such a node does not exist, return null.

A binary search tree (BST) is a binary tree in which for every node, all elements in the left subtree are smaller, and all elements in the right subtree are larger than the node's value.

**Input Parameters:**

`root (TreeNode):` The root node of the binary search tree.

`val (int):` The value to search for in the tree.

**Output:**

The node whose value matches val, or None if the node does not exist in the tree.

**Example:**

```Input:
        4
       / \
      2   7
     / \
    1   3
val = 2
 
Output:
      2
     / \
    1   3
 
 
Input: 
        4
       / \
      2   7
     / \
    1   3
val = 5
 
Output: None
```


In [1]:
def search_bst(root, data):
    """
    Function to search for a node in a Binary Search Tree (BST) whose value equals data.
    :param root: TreeNode -> root of the binary search tree
    :param data: int -> the value to search for
    :return: TreeNode -> the node whose value equals data, or None if it doesn't exist
    """

    if root is None:
        return None
    if root.data == data:
        return root
    elif root.data < data:
        return search_bst(root.right, data)
    else:
        return search_bst(root.left, data)


In [3]:
from bst_class import BinarySearchTree, BSTNode, build_tree_from_level_order, print_bst, print_binary_search_tree
bst = BinarySearchTree()
for value in [20, 10, 30, 5, 15, 35]:
    bst.insert(value)

print("\nSearching for 30 in BST:", bst.search(30))
print("Searching for 90 in BST:", bst.search(90))


Searching for 30 in BST: True
Searching for 90 in BST: False


## Successor and Predecessor in a BST

You are given a binary search tree (BST) with N nodes and an integer KEY representing the data of a node in this BST. Your task is to find and return the predecessor and successor of the node with the given KEY.

**Predecessor:** The node that would be visited immediately before the node with KEY in an inorder traversal of the BST. If the given node is the first node in the inorder traversal, the predecessor should be NULL.

**Successor:** The node that would be visited immediately after the node with KEY in an inorder traversal of the BST. If the given node is the last node in the inorder traversal, the successor should be NULL.

**Input Parameters:**

`root (TreeNode)`: The root of the binary search tree.

`KEY (int)`: The data value of the node for which to find the predecessor and successor.

**Output:**

A tuple `(predecessor, successor)`, where both predecessor and successor are integers. If the predecessor or successor does not exist, return None for that value.

**Example:**

```Input:
      20
     /  \
    10   30
   / \    \
  5  15   35
KEY = 35
Output: (30, None)
Explanation: In the inorder traversal [5, 10, 15, 20, 30, 35], the predecessor of 35 is 30 and there is no successor.
 
 
Input:
      20
     /  \
    10   30
   / \    \
  5  15   35
KEY = 10
Output: (5, 15)
Explanation: In the inorder traversal [5, 10, 15, 20, 30, 35], the predecessor of 10 is 5 and the successor is 15.
```

In [4]:
def find_predecessor_successor(root, key):
    """
    Function to find the predecessor and successor of a node with the given key in a BST.
    
    :param root: TreeNode -> The root of the binary search tree
    :param key: int -> The value of the node for which to find the predecessor and successor
    :return: Tuple[Optional[int], Optional[int]] -> A tuple containing the predecessor and successor
    """
    current = root
    predecessor = None
    successor = None
    
    # Find the node with the given key
    while current:
        if current.data == key:
            # Found the node, now find predecessor and successor
            if current.left:
                # Find the maximum in the left subtree for predecessor
                pred = current.left
                while pred.right:
                    pred = pred.right
                predecessor = pred.data
            
            if current.right:
                # Find the minimum in the right subtree for successor
                succ = current.right
                while succ.left:
                    succ = succ.left
                successor = succ.data
            break
        
        elif key < current.data:
            successor = current.data
            current = current.left
        else:
            predecessor = current.data
            current = current.right
            
    return (predecessor if predecessor else None, successor if successor else None)


In [5]:
bst = BinarySearchTree()
for value in [20, 10, 30, 5, 15, 35]:
    bst.insert(value)


In [6]:
find_predecessor_successor(bst.root, 35)

(30, None)

In [7]:
find_predecessor_successor(bst.root, 10)

(5, 15)

## Recover a BST

You are given the root of a binary search tree (BST), where the values of exactly two nodes of the tree were swapped by mistake. 

Your task is to recover the BST by swapping the values of these two nodes back to their correct positions. The structure of the tree should remain unchanged.


**Input Parameters:** root (TreeNode): The root of the binary search tree.

**Output:** The root of the corrected binary search tree.

```Input:
      3
     / \
    1   4
       /
      2
```

```Output:
      2
     / \
    1   4
       /
      3
```

**Explanation:** The original tree has 2 and 3 swapped. The corrected tree is a valid BST.

In [31]:
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
 
def recover_tree(root):
    """
    Function to recover a BST where two nodes were swapped by mistake.
    
    :param root: TreeNode -> The root of the binary search tree
    :return: TreeNode -> The root of the corrected binary search tree
    """
    def inorder_traversal(node):
        if node is None:
            return []
        return inorder_traversal(node.left) + [node] + inorder_traversal(node.right)
    
    # In-order traversal to get nodes in sorted order
    nodes = inorder_traversal(root)
    
    # Find the two nodes that are swapped
    first = second = None
    prev = TreeNode(float('-inf')) # Initialize with a value less than any node's value
    
    for node in nodes:
        if node.val < prev.val:
            if first is None:
                first = prev
            second = node
        prev = node
    
    # Swap the values of the two nodes
    if first and second:
        first.val, second.val = second.val, first.val
 
    return root
 
# Helper function for debugging (can be removed for production)
def display_recovered_tree(root):
    result = inorder_traversal(root)
    print(result)
 
def inorder_traversal(node):
    """Helper function to get in-order traversal of the tree."""
    if node is None:
        return []
    return inorder_traversal(node.left) + [node.val] + inorder_traversal(node.right)
 


In [32]:
# Example usage (can be removed)
tree = TreeNode(1)
tree.left = TreeNode(3)
tree.right = TreeNode(2)
recovered_tree = recover_tree(tree)
print(inorder_traversal(recovered_tree))  # Output: [1, 2, 3]

[1, 3, 2]


In [33]:
# Example usage (can be removed)
tree = TreeNode(3)
tree.left = TreeNode(1)
tree.right = TreeNode(4)
tree.right.left = TreeNode(2)
recovered_tree = recover_tree(tree)
print(inorder_traversal(recovered_tree))  # Output: [1, 2, 3, 4]

[1, 2, 3, 4]


## Kth smallest element in BST

You are given the root of a binary search tree (BST) and an integer k. Your task is to return the k-th smallest value of all the values of the nodes in the tree.

**Input:**

`root`: The root of the binary search tree (BST).

`k`: An integer representing the rank (1-indexed) of the smallest element to find.

**Output:**

Return the k-th smallest value from the BST.

**Example:**

    Input:
    root = [3,1,4,null,2], k = 1
    Output: 1
**Explanation:** The in-order traversal of the tree is [1, 2, 3, 4], and the 1st smallest element is 1.

    Input:
    root = [5,3,6,2,4,null,null,1], k = 3
    Output: 3
**Explanation:** The in-order traversal of the tree is [1, 2, 3, 4, 5, 6], and the 3rd smallest element is 3.

In [None]:
def kth_smallest(root, k):
    stack = [] # Initialize an empty stack
    while root or stack: # Continue until all nodes are processed
        while root: # Traverse to the leftmost node
            stack.append(root)
            root = root.left
        root = stack.pop() # Process the node
        k -= 1 # Decrement k for each node processed
        if k == 0: # If k reaches 0, we found the kth smallest element
            return root.data
        root = root.right


In [40]:
bst = BinarySearchTree()
for value in [20, 10, 30, 5, 15, 35]:
    bst.insert(value)

# Find the 3rd smallest element
k = 3
print(f"The {k}rd smallest element is: {kth_smallest(bst.root, k)}")

The 3rd smallest element is: 15


## BST Queries

You are given an arbitrary binary search tree (BST) with N nodes numbered from 1 to N. Each node has a value, and you are given Q queries. 

Each query is of the form [L, R], where L and R are integers representing the range. 

Your task is to find the number of nodes in the BST whose values lie within the range [L, R] for each query.


**Input:**

`root`: The root node of the binary search tree (BST).

`queries`: A list of Q queries where each query is a list [L, R] representing the range.

**Output:**

A list of integers where each integer represents the count of nodes within the given range for each query.

**Example:**

    Input: root = [10,5,15,1,7,null,20]
    queries = [[1, 5], [6, 10], [10, 20]]
    Output: [2, 1, 2]

**Explanation:**

- For query [1, 5], nodes within the range are [1, 5] (2 nodes).
 
- For query [6, 10], nodes within the range are [7, 10] (1 node).
 
- For query [10, 20], nodes within the range are [10, 15, 20] (2 nodes).

In [45]:
def count_nodes_in_range(root, queries):
    sorted_values = []
    
    def inorder_traversal(node):
        if node is None:
            return
        inorder_traversal(node.left)
        sorted_values.append(node.data)
        inorder_traversal(node.right)

    inorder_traversal(root)
    result = []
    for L, R in queries:
        count = sum(1 for x in sorted_values if L <= x <= R)
        result.append(count)
    return result


In [46]:
bst = BinarySearchTree()
for value in [20, 10, 30, 5, 15, 35]:
    bst.insert(value)

queries = [(10, 20), (15, 30), (5, 35)]
result = count_nodes_in_range(bst.root, queries)
print("Counts of nodes in range for each query:", result)

Counts of nodes in range for each query: [3, 3, 6]


In [47]:
bst = BinarySearchTree()
for value in [10, 5, 15, 1, 7, 20]:
    bst.insert(value)

queries = [(1, 5), (6, 10), (10, 20)]
result = count_nodes_in_range(bst.root, queries)
print("Counts of nodes in range for each query:", result)

Counts of nodes in range for each query: [2, 2, 3]
