In [None]:
class Node:
    def __init__(self, val):
        self.l_child = None
        self.r_child = None
        self.parent = None
        self.data = val

class BinarySearchTree:
    def __init__(self):
        self.root = None

    def insert(self, node):
        """
        Insert a node/ value of the node into the tree

        Parameters
        ----------
        node: Node/ int
            the node or the value of the node

        Returns
        ----------
        None
        """

        if type(node) is int: 
            node = Node(node)

        if self.root is None:
            self.root = node
            return
        else:
            self._insert_node(node, self.root)

    def _insert_node(self, node, root):
        """
        Insert a node into an existing subtree

        Parameters
        ----------
        root: Node
            the root of the subtree
        node: Node
            the node to be inserted

        Returns
        ----------
        None
        """
        if root.data > node.data:
            if root.l_child is None:
                root.l_child = node
                node.parent = root
            else:
                self._insert_node(node, root.l_child)
        else:
            if root.r_child is None:
                root.r_child = node
                node.parent = root
            else:
                self._insert_node(node, root.r_child)
print("‚úîÔ∏è Node and BinarySearchTree classes loaded")

In [None]:
class AVLNode(Node):
    def __init__(self, val):
        self.l_child = None
        self.r_child = None
        self.parent = None
        self.data = val
        self.lr_balance = 0
        self.height = 0

class AVLTree:
    def __init__(self):
        self.root = None
    
    def get_violating_node(self, node):
        if not self.root or node == self.root:
            return None
        if node == node.parent.r_child:
            if not node.parent.l_child:
                node.parent.height = max(-1, node.height) + 1
                node.parent.lr_balance = -1-node.height
            else:
                node.parent.height = max(node.parent.l_child.height, node.height) + 1
                node.parent.lr_balance = node.parent.l_child.height-node.height
        else:
            if not node.parent.r_child:
                node.parent.height = max(-1, node.height) + 1
                node.parent.lr_balance = node.height + 1
            else:
                node.parent.height = max(node.parent.r_child.height, node.height) + 1
                node.parent.lr_balance = node.height-node.parent.r_child.height
        if abs(node.parent.lr_balance) > 1:
            return node.parent
        else:
            return self.get_violating_node(node.parent)
                

    def insert(self, node):
        """inserts a node into a AVL Tree"""
        if not self.root:
            self.root = node
            return

        violating_node = None
        curr_node = self.root
        root = self.root
        while curr_node:
            if curr_node.data > node.data:
                if curr_node.l_child is None:
                    curr_node.l_child = node
                    node.parent = curr_node
                    break
                else:
                    curr_node = curr_node.l_child 
            else:   
                if curr_node.r_child is None:
                    curr_node.r_child = node
                    node.parent = curr_node
                    break
                else:
                    curr_node = curr_node.r_child
        # fix up the tree
        violating_node = self.get_violating_node(node)
        
        while violating_node:
            x = violating_node
            if x.lr_balance < 0: # right heavy
                y = x.r_child
                if y.lr_balance < 0:
                    root = left_rotate(x, root)
                    x.lr_balance = 0
                    x.height -= 2
                    y.lr_balance = 0
                elif y.lr_balance == 0:
                    root = left_rotate(x, root)
                    x.lr_balance = -1
                    x.height -= 1
                    y.lr_balance = 1
                    y.height += 1 
                else:
                    z = x.r_child
                    y = z.l_child
                    root = right_rotate(z, root)
                    root = left_rotate(x, root)
                    x.height -= 2
                    y.lr_balance = 0
                    y.height += 1
                    z.height -= 1 
                    if y.lr_balance == -1: 
                        x.lr_balance = 1
                        z.lr_balance = 0
                    elif y.lr_balance == 0: 
                        x.lr_balance = 0
                        z.lr_balance = 0
                    else: 
                        x.lr_balance = 0
                        z.lr_balance = -1 
       
            else:
                y = x.l_child
                if y.lr_balance > 0:
                    root = right_rotate(x, root)
                    x.lr_balance = 0
                    x.height -= 2
                    y.lr_balance = 0
                elif y.lr_balance == 0:
                    root = right_rotate(x, root)
                    x.lr_balance = 1
                    x.height -= 1
                    y.lr_balance = -1
                    y.height += 1
                else:
                    z = x.l_child 
                    y = z.r_child
                    root = left_rotate(z, root)
                    root = right_rotate(x, root)
                    x.height -= 2
                    y.lr_balance = 0
                    y.height += 1
                    z.height -= 1
                    if y.lr_balance == -1: # B: k-2, C: k-1
                        x.lr_balance = 0
                        z.lr_balance = 1
                    elif y.lr_balance == 0: # B=C=k-1
                        x.lr_balance = 0
                        z.lr_balance = 0
                    else: # B: k-1, C: k-2
                        x.lr_balance = -1
                        z.lr_balance = 0
            if root == y: 
                violating_node = None
                continue
            elif y == y.parent.l_child:
                y.parent.lr_balance = y.height - y.parent.r_child.height
                y.parent.height = max(y.height, y.parent.r_child.height) + 1
            elif y == y.parent.r_child:
                y.parent.lr_balance = y.parent.l_child.height - y.height
                y.parent.height = max(y.parent.l_child.height, y.height) + 1
            if abs(y.parent.lr_balance) > 1:
                violating_node = y.parent
            else: violating_node = None 

        self.root = root
        return root
    
def left_rotate(x, root):
    """Performs left-rotation on x, returns the root.
    This procedure does NOT update any augmented data (if any)
    of the nodes (e.g., height, left-right balance, etc.), simply
    changing the pointers and the parent-child relationship,
    and setting the new root (if any). The updating task belongs to 
    the procedure that calls this function.
    
    Input:
    - x: a node, to be performed the rotation on
    - root: the root node of the tree.
    
    Output:
    - root: the (new) root of the tree
    """
    y = x.r_child
    x.r_child = y.l_child
    if not y.l_child is None:
        y.l_child.parent = x
    y.parent = x.parent
    if not x.parent:
        root = y
    elif x == x.parent.l_child:
        x.parent.l_child = y
    else:
        x.parent.r_child = y
    y.l_child = x
    x.parent = y
    return root

def right_rotate(x, root):
    """Performs right-rotation on x, returns the root.
    This procedure does NOT update any augmented data (if any)
    of the nodes (e.g., height, left-right balance, etc.), simply
    changing the pointers and the parent-child relationship,
    and setting the new root (if any). The updating task belongs to 
    the procedure that calls this function.
    
    Input:
    - x: a node, to be performed the rotation on
    - root: the root node of the tree.
    
    Output:
    - root: the (new) root of the tree
    """
    y = x.l_child 
    
    x.l_child = y.r_child
    
    if not y.r_child is None:
        y.r_child.parent = x
        
    y.parent = x.parent
    
    if not x.parent:
        root = y
    elif x == x.parent.r_child:
        x.parent.r_child = y
    else:
        x.parent.l_child = y
    y.r_child = x
    x.parent = y
    return root
    
print("‚úîÔ∏è Classes AVLNode and AVLTree loaded!")

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()


def height(node):
    """Finds the height of a BST rooted at a node.
    
    Input:
    - node: a node, the root of the BST
    
    Output:
    - h: int, the height of the BST"""
    if node is None:
        return -1
    
    left_height = height(node.l_child)
    right_height = height(node.r_child)
    
    return 1 + max(left_height, right_height)

In [None]:
import random 

def get_expected_height_stats(iterations):
    """Generate the data for plotting the expected heights of BST and AVL.
    
    Input:
    - iterations: int, the number of times to insert into the tree for each 
    value of the number of nodes to insert. For each iteration, a height is 
    computed. After all the iterations, all the computed heights are averaged
    to get an estimate of the expected height. 
    
    Output:
    - bst_expected_heights, avl_expected_heights: list of float, containing
    the expected heights for the two types of trees. Each element in each list
    corresponds to one value of N, the number of nodes in the tree. The values 
    of N are taken from range(1, 500, 10) (i.e., 1, 11, 21, 31, etc.)"""
    bst_expected_heights = []
    avl_expected_heights = []
    
    for n in range(1, 500, 10):
        bst_heights = []
        avl_heights = []
        for trial in range(iterations):
            bst = BinarySearchTree() 
            avl = AVLTree()
            
            vals = list(range(int(n)))
            random.seed(trial) #for reproducibility
            random.shuffle(vals)
            
            BSTnodes = [Node(val) for val in vals]
            AVLnodes = [AVLNode(val) for val in vals]
            # insert nodes below
            ## your code here
            for i in range(len(BSTnodes)):
                bst.insert(BSTnodes[i])
                avl.insert(AVLnodes[i])

            # compute the resulting tree heights here
            ## your code here
            bst_heights.append(height(bst.root))
            avl_heights.append(height(avl.root))
        # compute the average heights
        bst_expected_heights.append(np.sum(bst_heights)/iterations)
        avl_expected_heights.append(np.sum(avl_heights)/iterations)
    return bst_expected_heights, avl_expected_heights


bst_expected_heights, avl_expected_heights = get_expected_height_stats(10)


try:
    assert(bst_expected_heights == [0.0, 5.3, 6.8, 7.8,
                                    9.9, 10.2, 9.9, 11.2, 
                                    12.0, 12.2, 12.3, 
                                    12.2, 12.1, 13.4, 13.2,
                                    13.0, 14.3, 14.7, 13.8, 
                                    14.1, 14.5, 15.5, 14.9, 
                                    16.0, 16.1, 16.2, 15.4, 
                                    17.5, 16.7, 16.1, 16.4,
                                    17.4, 16.1, 17.3, 16.8, 
                                    16.2, 16.7, 17.5, 17.1, 
                                    18.4, 17.5, 17.0, 17.8, 
                                    18.4, 18.3, 17.8, 17.7, 
                                    17.3, 18.0, 19.5])
    assert(avl_expected_heights == [0.0, 3.0, 4.1, 5.0, 
                                    5.4, 6.0, 6.0, 6.1, 
                                    6.8, 6.9, 7.0, 7.0,
                                    7.3, 7.3, 7.5, 8.0, 
                                    8.0, 8.0, 8.0, 8.0,
                                    8.0, 8.2, 8.2, 8.0, 
                                    8.4, 8.5, 8.7, 8.8,
                                    8.9, 8.8, 8.9, 9.0,
                                    9.0, 9.0, 9.0, 9.0, 
                                    9.0, 9.0, 9.3, 9.1,
                                    9.1, 9.3, 9.3, 9.7,
                                    9.5, 9.6, 9.8, 9.8, 
                                    9.9, 9.8])
    print("üéâ Woohoo, your code passed this test!")
except:
    print("üêû Your code did not return the expected results, so please try again...")