# Binary Search Tree

![image](./images/bst_01.png)

#### Define Node class

In [1]:
# this code makes the tree that we'll traverse

class Node(object):
        
    def __init__(self,value = None):
        self.value = value
        self.left = None
        self.right = None
        
    def set_value(self,value):
        self.value = value
        
    def get_value(self):
        return self.value
        
    def set_left_child(self,left):
        self.left = left
        
    def set_right_child(self, right):
        self.right = right
        
    def get_left_child(self):
        return self.left
    
    def get_right_child(self):
        return self.right

    def has_left_child(self):
        return self.left != None
    
    def has_right_child(self):
        return self.right != None
    
    # define __repr_ to decide what a print statement displays for a Node object
    def __repr__(self):
        return f"Node({self.get_value()})"
    
    def __str__(self):
        return f"Node({self.get_value()})"


In [2]:
from collections import deque
class Queue():
    def __init__(self):
        self.q = deque()
        
    def enq(self,value):
        self.q.appendleft(value)
        
    def deq(self):
        if len(self.q) > 0:
            return self.q.pop()
        else:
            return None
    
    def __len__(self):
        return len(self.q)
    
    def __repr__(self):
        if len(self.q) > 0:
            s = "<enqueue here>\n_________________\n" 
            s += "\n_________________\n".join([str(item) for item in self.q])
            s += "\n_________________\n<dequeue here>"
            return s
        else:
            return "<queue is empty>"

#### Define insert

Let's assume that duplicates are overriden by the new node that is to be inserted.  Other options are to keep a counter of duplicate nodes, or to keep a list of duplicates nodes with the same value.

In [22]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    """
    define insert here
    can use a for loop (try one or both ways)
    """
    def insert_with_loop(self,new_value):
        # Set initial values
        new_node = Node(new_value)
        node = self.get_root()

        # If the root is empty, set the root as the new node
        if node is None:
            self.root = new_node
            return

        # Repeat while loop until it terminates
        while (True):
            comparison = self.compare(node, new_node)

        # If new node = current node, set the value of the current node to be the new value
            if comparison == 0:
                node.set_value(new_value)
                break

        # If new node < current node, traverse until it reaches the end, and then add the new node
            elif comparison == -1:
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                break

        # If new node > current node, traverse until it reaches the end, and then add the new node
            else:
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                break

    """
    define insert here (can use recursion)
    try one or both ways
    """  
    def insert_with_recursion(self,value):
        # Set initial values
        new_node = Node(value)
        node = self.get_root()

        # If the root is empty, set the root as the new node
        if node is None:
            self.root = new_node
            return

        else:
            self.insert_recur(self.get_root(),new_node)

    def insert_recur(self,node,new_node):
        comparison = self.compare(node, new_node)

        # If new node = current node, set the value of the current node to be the new value
        if comparison == 0:
            node.set_value(new_node.get_value())

    # If new node < current node, traverse until it reaches the end, and then add the new node
        elif comparison == -1:
            if node.has_left_child():
                node = self.insert_recur(node.get_left_child(),new_node)
            else:
                node.set_left_child(new_node)

    # If new node > current node, traverse until it reaches the end, and then add the new node
        else:
            if node.has_right_child():
                node = self.insert_recur(node.get_right_child(),new_node)
            else:
                node.set_right_child(new_node)
                
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s



In [13]:
tree = Tree()
tree.insert_with_loop(5)
tree.insert_with_loop(6)
tree.insert_with_loop(4)
tree.insert_with_loop(2)
tree.insert_with_loop(5) # insert duplicate
print(tree)

Tree

Node(5)
Node(4) | Node(6)
<empty> | <empty> | <empty> | <empty>


In [23]:
tree = Tree()
tree.insert_with_recursion(5)
tree.insert_with_recursion(6)
tree.insert_with_recursion(4)
tree.insert_with_recursion(2)
tree.insert_with_recursion(5) # insert duplicate
print(tree)

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


## Search

Define a search function that takes a value, and returns true if a node containing that value exists in the tree, otherwise false.

In [25]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    def insert(self,new_value):
        new_node = Node(new_value)
        node = self.get_root()
        if node == None:
            self.root = new_node
            return
        
        while(True):
            comparison = self.compare(node, new_node)
            if comparison == 0:
                # override with new node
                node = new_node
                break # override node, and stop looping
            elif comparison == -1:
                # go left
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break #inserted node, so stop looping
            else: #comparison == 1
                # go right
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break # inserted node, so stop looping
                    
    """
    implement search
    """
    def search(self,value):
        """True means the value is in the tree.  False means the value is not in the tree."""
        # Create a new node called target
        target = Node(value)

        # Set the node as root
        node = self.get_root()

        # Case where root is empty
        if node is None:
            return False

        while(True):
            comparison = self.compare(node, target)

            if comparison == 0:
                return True
            
            elif comparison == -1:
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    return False

            else:
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    return False
                    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


In [4]:
tree = Tree()
tree.insert(5)
tree.insert(6)
tree.insert(4)
tree.insert(2)

print(f"""
search for 8: {tree.search(8)}
search for 2: {tree.search(2)}
""")
print(tree)


search for 8: False
search for 2: True

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


## Bonus: deletion

Try implementing deletion yourself.  You can also check out this explanation [here](https://www.geeksforgeeks.org/binary-search-tree-set-2-delete/)

In [113]:
# Create a dummy tree
bst = Tree()
for i in [15, 8, 5, 24, 19, 21, 30, 29, 25]:
    bst.insert(i)
print(bst)

Tree

Node(15)
Node(8) | Node(24)
Node(5) | <empty> | Node(19) | Node(30)
<empty> | <empty> | <empty> | Node(21) | Node(29) | <empty>
<empty> | <empty> | Node(25) | <empty>
<empty> | <empty>


In [163]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    def insert(self,new_value):
        new_node = Node(new_value)
        node = self.get_root()
        if node == None:
            self.root = new_node
            return
        
        while(True):
            comparison = self.compare(node, new_node)
            if comparison == 0:
                # override with new node
                node = new_node
                break # override node, and stop looping
            elif comparison == -1:
                # go left
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break #inserted node, so stop looping
            else: #comparison == 1
                # go right
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break # inserted node, so stop looping
                    
    def search(self,value):
        """True means the value is in the tree.  False means the value is not in the tree."""
        # Create a new node called target
        target = Node(value)

        # Set the node as root
        node = self.get_root()

        # Case where root is empty
        if node is None:
            return False

        while(True):
            comparison = self.compare(node, target)

            if comparison == 0:
                return True
            
            elif comparison == -1:
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    return False

            else:
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    return False

    def remove(self, value):
        """Remove a node from a binary search tree.
        Arg: node is node to be removed from the BST
        Output: new BST
        """
        # If node is not in the tree / tree is empty, return False
        if self.search(value) == False:
            return False

        # Otherwise:

        parent = None
        node = self.get_root()

        while node.get_value():
            if node.get_value() == value:
                break
            elif node.get_value() > value:
                parent = node
                node = node.get_left_child()
            else:
                parent = node
                node = node.get_right_child()

        ## Node has left child only -> set parent to node's left child
        if node.has_left_child() == True and node.has_right_child() == False:
            parent.set_left_child(node.get_left_child())
        
        ## Node has right child only --> set parent to node's right child
        elif node.has_right_child() == True and node.has_left_child() == False:
            parent.set_right_child(node.get_right_child())

        ## Node is a leaf --> set parent to none
        elif node.has_right_child() == False and node.has_left_child() == False:
            parent.set_left_child(None)
            parent.set_right_child(None)
        
        ## Node has both a left child and a right child
        else:
            ## Find smallest value in right tree of Node
            delNodeParent = node        
            delNode = node.get_right_child()
            while delNode.has_left_child():
                delNodeParent = delNode
                delNode = delNode.get_left_child()      
                
            ## Change node into delnode
            node.set_value(delNode.get_value())

            ## Delete delnode
            delNodeParent.set_left_child(None)
                  
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


In [166]:
# Create a dummy tree
bst = Tree()
for i in [15, 8, 5, 24, 19, 21, 30, 29, 25]:
    bst.insert(i)
print(bst)
bst.remove(24)
print(bst)

Tree

Node(15)
Node(8) | Node(24)
Node(5) | <empty> | Node(19) | Node(30)
<empty> | <empty> | <empty> | Node(21) | Node(29) | <empty>
<empty> | <empty> | Node(25) | <empty>
<empty> | <empty>
Tree

Node(15)
Node(8) | Node(25)
Node(5) | <empty> | Node(19) | Node(30)
<empty> | <empty> | <empty> | Node(21) | Node(29) | <empty>
<empty> | <empty> | <empty> | <empty>


## Solution notebook
The solution for insertion and search is [here](04 binary_search_tree_solution.ipynb)