# Binary Search Tree

In [2]:
class BSTNode(object):
    def __init__(self, value, left = None, right = None, parent = None):
        self.value = value
        if isinstance(left, BSTNode) or left is None:
            self.left = left
        if isinstance(right, BSTNode) or right is None:
            self.right = right
        if isinstance(parent, BSTNode) or parent is None:
            self.parent = parent
    
    def __repr__(self):
        return f'Value: {self.value}'

class BST(object):
    def __init__(self, items = None):
        self.head = None
        if items != None:
            self.insert_many(items)
    
    def insert(self, value):
        new_node = BSTNode(value)
        if self.head == None:
            self.head = new_node
            return
        current = self.head
        parent = None
        while current != None:
            parent = current
            if value < current.value:
                current = current.left
            else:
                current = current.right
        if value < parent.value:
            parent.left = new_node
            new_node.parent = parent
        else:
            parent.right = new_node
            new_node.parent = new_node

    def insert_many(self, items):
        for item in items:
            self.insert(item)
    
    @staticmethod
    def pretty_print(node, level = 0):
        if node is None:
            return
        BST.pretty_print(node.left, level + 1)
        print(' ' * 4 * level + '->', node.value)
        BST.pretty_print(node.right, level + 1)

In [3]:
import random
def find_value(bst, value):
    current = bst.head
    while current is not None and current.value != value:
        if value < current.value:
            current = current.left
        else:
            current = current.right
    return current


# testing
import random
tree_nodes = [el for el in range(1, 6)]
random.shuffle(tree_nodes)
bst = BST(tree_nodes)
bst.pretty_print(bst.head)
find_value(bst, 5)

    -> 1
        -> 2
-> 3
    -> 4
        -> 5


Value: 5

## Delete a node in the Binary Search Tree

In [4]:
def transplant(bst, node_u, node_v):
    if node_u.parent == None:
        bst.head = node_v
    elif node_u.parent.left == node_u:
        node_u.parent.left = node_v
    else:
        node_u.parent.right = node_v
    if node_v:
        node_v.parent = node_u.parent

def minimum(node):
    while node.left:
        node = node.left
    return node

def delete_node(bst, value):
    node = find_value(bst, value)
    if node == None:
        return

    ## if node has no children, simply transplant node with None
    elif node.left == None and node.right == None:
        transplant(bst, node, None)
    
    ## if node has either left or right child, transplant node with present child
    elif (node.left and not node.right) or (node.right and not node.left):
        if node.left:
            transplant(bst, node, node.left)
        else:
            transplant(bst, node, node.right)
    
    ## if node has both child
    else:
        successor = minimum(node.right)
        ## if successor is not right child
        if successor != node.right:
            transplant(bst, successor, successor.right)
            successor.right = node.right
            node.right.parent = successor

        ## if successor has become the right child
        transplant(bst, node, successor)
        successor.left = node.left
        node.left.parent = successor

# testing
tree_nodes = [el for el in range(1, 6)]
random.shuffle(tree_nodes)
bst = BST(tree_nodes)
bst.pretty_print(bst.head)
delete_node(bst, 3)
print('-----------------')
bst.pretty_print(bst.head)

        -> 1
    -> 2
-> 3
        -> 4
    -> 5
-----------------
        -> 1
    -> 2
-> 4
    -> 5


## Find min and max in a BST

In [5]:
def minimum(node):
    while node.left:
        node = node.left
    return node

def maximum(node):
    while node.right:
        node = node.right
    return node

# testing
tree_nodes = [el for el in range(1, 6)]
random.shuffle(tree_nodes)
bst = BST(tree_nodes)
bst.pretty_print(bst.head)
print(minimum(bst.head))
print(maximum(bst.head))

            -> 1
                -> 2
        -> 3
    -> 4
-> 5
Value: 1
Value: 5


## Inorder successor and inorder predecessor in BST

In [6]:
def inorder_suc_pre(root, value):
    if root == None:
        return
    if root.value == value:
        current = root.right
        if current:
            while current.left:
                current = current.left
            inorder_suc_pre.successor = current
        current = root.left
        if current:
            while current.right:
                current = current.right
            inorder_suc_pre.predecessor = current
    else:
        if value < root.value:
            inorder_suc_pre.successor = root
            inorder_suc_pre(root.left, value)
        else:
            inorder_suc_pre.predecessor = root
            inorder_suc_pre(root.right, value)
        
# testing
bst.pretty_print(bst.head)
inorder_suc_pre(bst.head, 3)
print(inorder_suc_pre.successor)
print(inorder_suc_pre.predecessor)

            -> 1
                -> 2
        -> 3
    -> 4
-> 5
Value: 4
Value: 2


## Check for BST

In [7]:
from collections import deque

def check_bst(root):
    prev = float('-inf')
    stack = deque()
    current = root

    while current or len(stack):
        if current:
            stack.append(current)
            current = current.left
        else:
            current = stack.pop()
            if prev <= current.value:
                prev = current.value
            else:
                return False
            current = current.right
    return True

# testing
check_bst(bst.head)

True

## Populate inorder successor for all nodes

In [46]:
class BSTSNode(BSTNode):
    def __init__(self, value, left = None, right = None, parent = None):
        super().__init__(value, left, right, parent)
        self.next = None
    
class BSTS(BST):
    def __init__(self, items = None):
        super().__init__(items)
        self.in_head = None

def change_pointers(bst: BSTS, node: BSTS):
    if node == None:
        return
    change_pointers(bst, node.left)
    if change_pointers.prev == None:
        change_pointers.prev = node
        bst.in_head = node
    else:
        change_pointers.prev.next = node
        change_pointers.prev = node
    change_pointers(bst, node.right)

def print_inorder(bst: BSTS):
    current = bst.in_head
    while current:
        print(current.value, end = ' ')
        current = current.next

# testing
bst = BSTS()
bst.head = BSTSNode(5)
bst.head.left = BSTSNode(3)
bst.head.right = BSTSNode(8)
change_pointers.prev = None
change_pointers(bst, bst.head)
print_inorder(bst)

3 5 8 

In [50]:
def lowest_common_ancestor(bst, value1, value2):
    root = bst.head
    while root:
        if root.value < value1 and root.value < value2:
            root = root.right
        elif root.value > value1 and root.value > value2:
            root = root.left
        else:
            return root
# testing
tree_nodes = [el for el in range(1, 6)]
random.shuffle(tree_nodes)
bst = BST(tree_nodes)
bst.pretty_print(bst.head)
lowest_common_ancestor(bst, 5, 3)

    -> 1
        -> 2
            -> 3
-> 4
    -> 5


Value: 4

## Construct BST from given preorder traversal

In [97]:
def construct_preorder(pre, key, mini, maxi, size):
    if construct_preorder.pre_index >= size:
        return None
    root = None
    if key > mini and key < maxi:
        root = BSTNode(key)
        construct_preorder.pre_index += 1

        if construct_preorder.pre_index < size:
            next_key = pre[construct_preorder.pre_index]
            root.left = construct_preorder(pre, next_key, mini, key, size)
        if construct_preorder.pre_index < size:
            next_key = pre[construct_preorder.pre_index]
            root.right = construct_preorder(pre, next_key, key, maxi, size)

    return root

# testing
bst = BST()
construct_preorder.pre_index = 0
mini = float('-inf')
maxi = float('inf')
bst.head = construct_preorder([10, 5, 1, 7, 40, 50], 10, mini, maxi, 6)
bst.pretty_print(bst.head)

        -> 1
    -> 5
        -> 7
-> 10
    -> 40
        -> 50


## Convert Binary Tree to Binary Search Tree

In [102]:
from collections import deque
def convert_to_bst(bst):
    array = []
    current = bst.head
    stack = deque()
    while current or len(stack):
        if current:
            stack.append(current)
            current = current.left
        else:
            current = stack.pop()
            array.append(current.value)
            current = current.right

    array.sort()
    index = 0
    current = bst.head
    stack = deque()
    while current or len(stack):
        if current:
            stack.append(current)
            current = current.left
        else:
            current = stack.pop()
            current.value = array[index]
            index += 1
            current = current.right

# testing
root = BSTNode(10)
root.left = BSTNode(30)
root.right = BSTNode(15)
root.left.left = BSTNode(20)
root.right.right = BSTNode(5)
bst = BST()
bst.head = root
BST.pretty_print(bst.head)
convert_to_bst(bst)
print('-------------')
BST.pretty_print(bst.head)

        -> 20
    -> 30
-> 10
    -> 15
        -> 5
-------------
        -> 5
    -> 10
-> 15
    -> 20
        -> 30


## BST to Balanced BST

In [104]:
def bst_balanced(bst):
    bst.pretty_print(bst.head)
    array = []
    stack = deque()
    current = bst.head
    while current or len(stack):
        if current:
            stack.append(current)
            current = current.left
        else:
            current = stack.pop()
            array.append(current.value)
            current = current.right
    left = 0
    right = len(array) - 1
    bst = BST()
    bst.head = bst_balanced_helper(array, left, right)
    bst.pretty_print(bst.head)

def bst_balanced_helper(array, left, right):
    if left > right:
        return None
    mid = (left + right) // 2
    root = BSTNode(array[mid])
    root.left = bst_balanced_helper(array, left, mid - 1)
    root.right = bst_balanced_helper(array, mid + 1, right)
    return root

# testing
bst = BST([1, 2, 3, 4, 5])
bst_balanced(bst)

-> 1
    -> 2
        -> 3
            -> 4
                -> 5
    -> 1
        -> 2
-> 3
    -> 4
        -> 5


## Find Kth largest element in a BST

In [107]:
def k_largest(bst, k):
    stack = deque()
    current = bst.head

    i = 1
    while current or len(stack):
        if current:
            stack.append(current)
            current = current.right
        else:
            current = stack.pop()
            if i == k:
                return current
            else:
                i += 1
                current = current.left

# testing
bst = BST([1, 2, 3, 4, 5, 6])
k_largest(bst, 2)

Value: 5

## Find Kth smallest element in a BST

In [112]:
def k_smallest(bst, k):
    stack = deque()
    current = bst.head
    i = 1
    while current or len(stack):
        if current:
            stack.append(current)
            current = current.left
        else:
            current = stack.pop()
            if i == k:
                return current
            else:
                i += 1
                current = current.right
    return None

# testing
k_smallest(bst, 3)

Value: 3