# Binary Search Tree

In [2]:
# Binary Tree
class Tree():
    def __init__(self, key, parent=None, left=None, right=None):
        self.key = key
        self.parent = parent
        self.left = left
        self.right = right
    
    def __repr__(self):
        return str(self.key)

In [128]:
class BinarySearchTree():
    def __init__(self, key=None):
        self.top = None if (key is None) else Tree(key)
        
    # Get the height of the tree
    def get_height(self, tree):
        if tree is None: return 0
        return 1 + max(get_height(tree.left), get_height(tree.right))
    
    # Get the size of the tree
    def get_size(self, tree):
        if tree is None: return 0
        return 1 + get_size(tree.left) + get_size(tree.right)
    
    # Assign the parent of trees according to the parrent tree
    def assign_parent(self, parent):
        def assign(tree):
            if tree is not None: tree.parent = parent
        assign(parent.left)
        assign(parent.right)
        
    # Find a node containing the given key
    # If we don't find a node containing the key
    # we can simply return the best position to insert the key
    def find(self, key):
        def recur(tree):
            if tree.key == key: return tree
            elif key > tree.key:
                if tree.right: return recur(tree.right)
                return tree
            elif key < tree.key:
                if tree.left: return recur(tree.left)
                return tree
        return recur(self.top)
        
    # Insert another node
    def insert(self, key):
        # Find the candidate place
        tree = self.find(key)
        if key == tree.key:
            print(f'Key {key} already exists')
            return 
        if key > tree.key:
            tree.right = Tree(key)
            tree.right.parent = tree
        else:
            tree.left = Tree(key)
            tree.left.parent = tree
        
    # Find the next node which has greater key
    def find_next(self, tree, report=True):
        val = tree.key
        def left_descendant(tree):
            if tree.left: return left_descendant(tree.left)
            else: return tree
        def right_ancestor(tree):
            if tree.key >= val: return tree
            elif tree.parent: return right_ancestor(tree.parent)
            else: return tree
        if tree is None:
            return "Empty BST"
        elif tree.right:
            tk = left_descendant(tree.right)
            if tk.key > val: return tk
            if report: print(f'Cannot find key greater than {val}')
            return None
        elif tree.parent:
            tk = right_ancestor(tree.parent)
            if tk.key > val: return tk
            if report: print(f'Cannot find key greater than {val}')
            return None
        
    def range_search(self, k1, k2):
        if k2 <= k1:
            print(f'k1 ({k1}) should be less than k2 ({k2})')
            return
        
        arr = []
        
        tree = self.find(k1)
        if tree.key > k2: return arr
        
        while True:
            if (k1 <= tree.key) and (tree.key <= k2):
                arr.append(tree.key)
            tree = self.find_next(tree, False)
            if tree is None: break
            elif tree.key > k2: break
        return arr

In [129]:
bst = BinarySearchTree(3)
bst.insert(5)
bst.insert(8)
bst.insert(6)
bst.insert(4)
bst.insert(9)
bst.insert(1)
bst.insert(2)
bst.insert(-1)
bst.insert(0)
bst.insert(3)
bst.insert(10)
bst.insert(5)

Key 3 already exists
Key 5 already exists


In [130]:
print(bst.find_next(bst.find(-1)))
print(bst.find_next(bst.find(10)))
print(bst.find_next(bst.find(9)))

0
Cannot find key greater than 10
None
10


In [133]:
bst.range_search(1, 8)

[1, 2, 3, 4, 5, 6, 8]