# Binary Search Tree

![image](./images/bst_01.png)

#### Define Node class

In [1]:
# this code makes the tree that we'll traverse

class Node(object):
        
    def __init__(self,value = None):
        self.value = value
        self.left = None
        self.right = None
        
    def set_value(self,value):
        self.value = value
        
    def get_value(self):
        return self.value
        
    def set_left_child(self,left):
        self.left = left
        
    def set_right_child(self, right):
        self.right = right
        
    def get_left_child(self):
        return self.left
    
    def get_right_child(self):
        return self.right

    def has_left_child(self):
        return self.left != None
    
    def has_right_child(self):
        return self.right != None
    
    # define __repr_ to decide what a print statement displays for a Node object
    def __repr__(self):
        return f"Node({self.get_value()})"
    
    def __str__(self):
        return f"Node({self.get_value()})"


In [3]:
from collections import deque
class Queue():
    def __init__(self):
        self.q = deque()
        
    def enq(self,value):
        self.q.appendleft(value)
        
    def deq(self):
        if len(self.q) > 0:
            return self.q.pop()
        else:
            return None
    
    def __len__(self):
        return len(self.q)
    
    def __repr__(self):
        if len(self.q) > 0:
            s = "<enqueue here>\n_________________\n" 
            s += "\n_________________\n".join([str(item) for item in self.q])
            s += "\n_________________\n<dequeue here>"
            return s
        else:
            return "<queue is empty>"

#### Define insert

Let's assume that duplicates are overriden by the new node that is to be inserted.  Other options are to keep a counter of duplicate nodes, or to keep a list of duplicates nodes with the same value.

In [2]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    """
    define insert here
    can use a for loop (try one or both ways)
    """
    def insert_with_loop(self,new_value):
        node = self.get_root()
        if not node:
            self.set_root(new_value)
            return
        
        while True:
            if node.get_value() == new_value:
                node.set_value(new_value)
                break
            elif node.get_value() <= new_value:
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(Node(new_value))
                    break
            elif node.get_value() >= new_value:
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(Node(new_value))
                    break

    """
    define insert here (can use recursion)
    try one or both ways
    """  
    def insert_with_recursion(self, value):
        
        node = self.get_root()
        if  node is None:
            self.set_root(value)
            return
        
        def insert_recursively(node, new_node):
            c = self.compare(node, new_node)
            if c == 0:
                node.set_value(value)
            elif c == -1:
                if node.has_left_child():
                    insert_recursively(node.get_left_child(), new_node)
                else:
                    node.set_left_child(new_node)
            elif c == 1:
                if node.has_right_child():
                    insert_recursively(node.get_right_child(), new_node)
                else:
                    node.set_right_child(new_node)
            
        insert_recursively(node, Node(value))
                    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


In [15]:
tree = Tree()
tree.insert_with_loop(5)
tree.insert_with_loop(6)
tree.insert_with_loop(4)
tree.insert_with_loop(2)
tree.insert_with_loop(5) # insert duplicate
print(tree)

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


In [19]:
tree = Tree()
tree.insert_with_recursion(5)
tree.insert_with_recursion(6)
tree.insert_with_recursion(4)
tree.insert_with_recursion(2)
tree.insert_with_recursion(5) # insert duplicate
print(tree)

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


## Search

Define a search function that takes a value, and returns true if a node containing that value exists in the tree, otherwise false.

In [4]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    def insert(self,new_value):
        new_node = Node(new_value)
        node = self.get_root()
        if node == None:
            self.root = new_node
            return
        
        while(True):
            comparison = self.compare(node, new_node)
            if comparison == 0:
                # override with new node
                node = new_node
                break # override node, and stop looping
            elif comparison == -1:
                # go left
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break #inserted node, so stop looping
            else: #comparison == 1
                # go right
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break # inserted node, so stop looping
                    
    """
    implement search
    """
    def search(self,value):
        node = self.get_root()
        while True:
            c = self.compare(node, Node(value))
            if c == 0:
                return True
            elif c == -1:
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    return False
            elif c == 1:
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    return False
                    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


In [21]:
tree = Tree()
tree.insert(5)
tree.insert(6)
tree.insert(4)
tree.insert(2)

print(f"""
search for 8: {tree.search(8)}
search for 2: {tree.search(2)}
""")
print(tree)


search for 8: False
search for 2: True

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


## Bonus: deletion

Try implementing deletion yourself.  You can also check out this explanation [here](https://www.geeksforgeeks.org/binary-search-tree-set-2-delete/)

In [44]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self, value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    def delete(self, value):
        # 1: node has no child node (is leaf) -> delete node
        # 2: node has one child node -> copy value to node and delete child node
        # 3: node has two child nodes -> find in-order successor in right branch (but go always left), copy value to node and delete
        # https://www.geeksforgeeks.org/binary-search-tree-set-2-delete/
        
        node = self.root
        self.delete_node(node, value)
        
    def delete_node(self, node, value):
            
        # Base Case
        if node is None:
            print(f'Node with value {value} not found.')
            return node

        # traverse left
        if node.value > value:
            node.left = self.delete_node(node.left, value)

        # traverse right
        if node.value < value:
            node.right = self.delete_node(node.right, value)

        # if node with value found
        if node.value == value:
            print('Delete:', node)
            # and has no left child, only right child (might also be None)
            # delete node and return right child to parent
            if node.left is None:
                move_node = node.right
                node = None
                return move_node
            # and has no right child, only left child (might also be None)
            # delete node and return left child to parent
            if node.right is None:
                move_node = node.left
                node = None
                return move_node

            # node has two children (left and right)
            # it is essential that the node with the smallest value is moved
            # to the position of this node. Otherwise it is not guaranteed
            # that it can be found in the tree (see diagram on geeksforgeeks
            min_node = self.min_value_node(node.right)
            node.value = min_node.value
            node.right = self.delete_node(node.right, min_node.value)

        return node    
    
    def min_value_node(self, node):
        # node with the smallest value is always to the left
        current = node
        while current.left is not None:
            current = current.left
        return current

    def insert(self, new_value):
        new_node = Node(new_value)
        node = self.get_root()
        if node == None:
            self.root = new_node
            return
        
        while(True):
            comparison = self.compare(node, new_node)
            if comparison == 0:
                # override with new node
                node = new_node
                break # override node, and stop looping
            elif comparison == -1:
                # go left
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break #inserted node, so stop looping
            else: #comparison == 1
                # go right
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break # inserted node, so stop looping
                    
    """
    implement search
    """
    def search(self,value):
        node = self.get_root()
        while True:
            c = self.compare(node, Node(value))
            if c == 0:
                return True
            elif c == -1:
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    return False
            elif c == 1:
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    return False
                    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


In [47]:
tree = Tree()
tree.insert(5)
tree.insert(6)
tree.insert(4)
tree.insert(3)
tree.insert(1)
tree.insert(2)

print(tree, '\n')
tree.delete(3)
tree.delete(1)
tree.delete(8)
print('\n', tree)

Tree

Node(5)
Node(4) | Node(6)
Node(3) | <empty> | <empty> | <empty>
Node(1) | <empty>
<empty> | Node(2)
<empty> | <empty> 

Delete: Node(3)
Delete: Node(1)
Node with value 8 not found.

 Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


## Solution notebook
The solution for insertion and search is [here](04 binary_search_tree_solution.ipynb)