# Challenge: Implement AVL

In [31]:
# first, define a node
class BSTNode:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.parent = None

In [105]:
# next, define a bst
class BST:
    def __init__(self, head, node_class = BSTNode):
        self.node_class = node_class
        self.head = head
    
    def insert(self, value):
        '''
        takes in a value and return the inserted node
        '''
        print('inserting', value)
        if self.head is None:
            return None
        node = self.node_class(value)
        curr = self.head
        while curr:
            if node.value <= curr.value:
                if curr.left is None:
                    curr.left, node.parent = node, curr
                    return node
                else:
                    curr = curr.left
            else:
                if curr.right is None:
                    curr.right, node.parent = node, curr
                    return node
                else:
                    curr = curr.right
    
    def search(self, root, value):
        if not root:
            return None
        if root.value is value:
            return root
        elif root.value < value:
            return self.search(root.right, value)
        else:
            return self.search(root.left, value)
    
    def delete(self, node):
        # case 1: when node has no children -> simply delete it
        # case 2: when node has 1 children -> delete it and link up the child
        if node.left is None or node.right is None:
            if node.parent.left is node:
                node.parent.left = node.left or node.right
                if node.parent.left is not None:
                    node.parent.left.parent = node.parent
            else:
                node.parent.right = node.left or node.right
                if node.parent.right is not None:
                    node.parent.right.parent = node.parent
            node.parent = None
            return node
        else:
        # case 3: when there are two children -> swap the node's value with the next successor then delete the successor
        # this works because after the value swap, the successor can only have a right child (no left child)
            next_successor = self.next_successor(node)
            node.value, next_successor.value = next_successor.value, node.value
            return self.delete(next_successor)
    
    def next_successor(self, node):
        if node.right:
            return self.find_min(node.right)
        else:
            return None
    
    def find_min(self, node):
        if node.left is None:
            return node
        else:
            return self.find_min(node.left)
                    
    def __str__(self):
        ret = ''
        level = [self.head]
        while level:
            for node in level:
                ret += str(node.value)
            ret += '\n'
            level = [child for node in level for child in (node.left, node.right) if child]
        return ret

In [110]:
# do the same for AVL
class AVLNode(BSTNode):
    def __init__(self, value):
        BSTNode.__init__(self, value)
        self.height = 0
    
    def update_height(self):
        if self.left:
            left_child_height = self.left.height
        else:
            left_child_height = -1
        if self.right:
            right_child_height = self.right.height
        else:
            right_child_height = -1
        height = 1+max(left_child_height, right_child_height)
        self.height = height

class AVL(BST):
    def __init__(self, head, node_class = AVLNode):
        BST.__init__(self, head, node_class = node_class)
    
    def insert(self, value):
        # 1. simple BST insert
        inserted_node = BST.insert(self, value)
        # 2. fix AVL property from changed node up
        self._rebalance(inserted_node)
        return inserted_node
    
    def delete(self, node):
        # standard BST delete
        deleted_node = BST.delete(self, node)
        # rebalance
        self._rebalance(deleted_node)
        return deleted_node
    
    def _rebalance(self, node):
        print('rebalancing...')
        # find the lowest node violating AVL RI
        # then do the correct rotation
        while node:
            node.update_height()
            if self._get_height(node.right) >= 2 + self._get_height(node.left):
                if self._get_height(node.right.left) > self._get_height(node.right.right):
                    self._right_rotate(node.right)
                self._left_rotate(node)
            elif self._get_height(node.left) >= 2 + self._get_height(node.right):
                if self._get_height(node.left.right) > self._get_height(node.left.left):
                    self._left_rotate(node.left)
                self._right_rotate(node)
            node = node.parent
    
    @staticmethod
    def _get_height(node):
        if not node:
            return -1
        return node.height
    
    def _left_rotate(self, node):
        print('left rotating...', node.value)
        new_root = node.right
        new_root.parent, node.parent = node.parent, new_root
        if new_root.parent is None:
            self.head = new_root
        else:
            if new_root.parent.left is node:
                new_root.parent.left = new_root
            else:
                new_root.parent.right = new_root
        node.right = new_root.left
        if node.right is not None:
            node.right.parent = node
        new_root.left = node
        node.update_height()
        print('node height', node.height)
        new_root.update_height()
        print('new root height', new_root.height)
    
    def _right_rotate(self, node):
        print('right rotating...', node.value)
        # mirrors left rotate
        new_root = node.left
        new_root.parent, node.parent = node.parent, new_root
        if new_root.parent is None:
            self.head = new_root
        else:
            if new_root.parent.left is node:
                new_root.parent.left = new_root
            else:
                new_root.parent.right = new_root
        node.left = new_root.right
        if node.left is not None:
            node.left.parent = node
        new_root.right = node
        node.update_height()
        new_root.update_height()

In [107]:
# test
head = BSTNode(3)
bst = BST(head)
bst.insert(1)
bst.insert(4)
bst.insert(2)
bst.delete(head)
print(bst)

inserting 1
inserting 4
inserting 2
4
1
2



In [111]:
# test
head = AVLNode(3)
avl = AVL(head)
avl.insert(1)
avl.insert(5)
avl.insert(4)
avl.insert(6)
avl.insert(3.5)
avl.insert(4.5)
avl.delete(head)
print(avl)

inserting 1
rebalancing...
inserting 5
rebalancing...
inserting 4
rebalancing...
inserting 6
rebalancing...
inserting 3.5
rebalancing...
right rotating... 5
left rotating... 3
node height 1
new root height 2
inserting 4.5
rebalancing...
rebalancing...
rebalancing...
4
3.55
14.56



In [70]:
# test
avl = AVL(AVLNode(3))
avl.insert(1)
avl.insert(5)
avl.insert(4)
avl.insert(6)
avl.insert(5.5)
avl.insert(6.5)
print(avl)

inserting 1
rebalancing...
inserting 5
rebalancing...
inserting 4
rebalancing...
inserting 6
rebalancing...
inserting 5.5
rebalancing...
left rotating... 3
node height 1
new root height 2
inserting 6.5
rebalancing...
5
36
145.56.5

