In [2]:
# AVL Tree Node
class Node(object):
    
    def __init__(self, data):
        self.data = data
        self.height = 0
        self.left_child = None
        self.right_child = None
        

In [36]:
# AVL Tree
class AVL(object):
    
    def __init__(self):
        self.root = None;
        
    def insert(self, data):
        print('Insert: %s' % data)
        self.root = self.insert_node(data, self.root)
        
        
    def insert_node(self, data, node):
        if not node:
            return Node(data)
        
        if data < node.data:
            node.left_child = self.insert_node(data, node.left_child)
        else:
            node.right_child = self.insert_node(data, node.right_child)
        
        node.height = max(self.calc_height(node.left_child), self.calc_height(node.right_child)) + 1
        
        return self.settle_violations(data, node)
        
        
    def remove(self, data):
        print('Remove: %s' % data)
        if self.root:
            self.root = self.remove_node(data, self.root)
            
            
    # O(logN)
    def remove_node(self, data, node):
        if not node:
            return node
        
        if data < node.data:
            node.left_child = self.remove_node(data, node.left_child)
        elif data > node.data:
            node.right_child = self.remove_node(data, node.right_child)
        else:
            # leaf node
            if not node.left_child and not node.right_child:
                del node
                return None
            
            # single left child
            if not node.left_child:
                temp_node = node.right_child
                del node
                return temp_node  
            
            # single right child
            elif not node.right_child:
                temp_node = node.left_child
                del node
                return temp_node
                
            # both left and right children - remove node and replace w/ predecessor in subtree
            temp_node = self.get_predecessor(node.left_child)
            node.data = temp_node.data
            node.left_child = self.remove_node(temp_node.data, node.left_child)
        
        if not node:
            return node; # if the tree had just a single node

        node.height = max(self.calc_height(node.left_child), self.calc_height(node.right_child)) + 1;

        balance = self.calc_balance(node);

        if balance > 1 and self.calc_balance(node.left_child) >= 0:
            return self.rotate_right(node);

        if balance > 1 and self.calc_balance(node.left_child) < 0:
            node.leftChild = self.rotate_left(node.left_child);
            return self.rotate_right(node);

        if balance < -1 and self.calc_balance(node.right_child) <= 0:
            return self.rotate_left(node);

        if balance < -1 and self.calc_balance(node.rightChild) > 0:
            node.rightChild = self.rotate_right(node.rightChild);
            return self.rotate_left(node);

        return node;

            
    def get_predecessor(self, node):
        if node.right_child:
            return self.get_predecessor(node.right_child)
        
        
    def settle_violations(self, data, node):
        balance = self.calc_balance(node)
        
        # doubly-left heavy situation
        if balance > 1 and data < node.left_child.data:
            return self.rotate_right(node)
        
        # doubly-right heavy situation
        if balance < -1 and data > node.right_child.data:
            return self.rotate_left(node)
        
        # left-right heavy situation
        if balance > 1 and data > node.left_child.data:
            node.left_child = self.rotate_left(node.left_child)
            return self.rotate_right(node)
            
        # right-left heavy situation
        if balance < -1 and data < node.right_child.data:
            node.right_child = self.rotate_right(node.right_child)
            return self.rotate_left(node)
        
        # no rotations needed
        return node
            
        
    def calc_height(self, node):
        if not node:
            return -1
        
        return node.height
    
    
    # if > 1 => left heavy => right rotation
    # if < -1 => right heavy => left rotation
    def calc_balance(self, node):
        if not node:
            return 0;
        
        return self.calc_height(node.left_child) - self.calc_height(node.right_child)
    
    
    # O(1)
    def rotate_right(self, node):
        print('Right rotation')
        
        temp_left_child = node.left_child
        t = temp_left_child.right_child
        
        temp_left_child.right_child = node
        node.left_child = t
        
        node.height = max(self.calc_height(node.left_child), self.calc_height(node.right_child)) + 1
        temp_left_child.height = max(self.calc_height(temp_left_child.left_child), self.calc_height(temp_left_child.right_child)) + 1
        
        return temp_left_child
        
    # O(1)    
    def rotate_left(self, node):
        print('Left rotation')
        
        temp_right_child = node.right_child
        t = temp_right_child.left_child
        
        temp_right_child.left_child = node
        node.right_child = t
        
        node.height = max(self.calc_height(node.left_child), self.calc_height(node.right_child)) + 1
        temp_right_child.height = max(self.calc_height(temp_right_child.left_child), self.calc_height(temp_right_child.right_child)) + 1
        
        return temp_right_child
    
    
    def traverse(self):
        print('Traverse:')
        if self.root:
            self.traverse_in_order(self.root)
            

    def traverse_in_order(self, node):
        if node.left_child:
            self.traverse_in_order(node.left_child)
            
        print('  %s ' % node.data)
        
        if node.right_child:
            self.traverse_in_order(node.right_child)
        
    

In [37]:
# doubly right heavy
avl = AVL()
avl.insert(10)
avl.insert(20)
avl.insert(30)
avl.insert(40)
avl.insert(50)
avl.traverse()

Insert: 10
Insert: 20
Insert: 30
Left rotation
Insert: 40
Insert: 50
Left rotation
Traverse:
  10 
  20 
  30 
  40 
  50 


In [38]:
# doubly left heavy
avl = AVL()
avl.insert(500)
avl.insert(400)
avl.insert(300)
avl.insert(200)
avl.insert(100)
avl.traverse()

Insert: 500
Insert: 400
Insert: 300
Right rotation
Insert: 200
Insert: 100
Right rotation
Traverse:
  100 
  200 
  300 
  400 
  500 


In [39]:
# right-left heavy
avl = AVL()
avl.insert(5)
avl.insert(7)
avl.insert(6)
avl.traverse()

Insert: 5
Insert: 7
Insert: 6
Right rotation
Left rotation
Traverse:
  5 
  6 
  7 


In [40]:
# left-right heavy
avl = AVL()
avl.insert(5)
avl.insert(3)
avl.insert(4)
avl.traverse()

Insert: 5
Insert: 3
Insert: 4
Left rotation
Right rotation
Traverse:
  3 
  4 
  5 


In [41]:
# remove - triggers rebalance
avl = AVL()
avl.insert(10)
avl.insert(9)
avl.insert(15)
avl.insert(20)
avl.remove(9)
avl.traverse()

Insert: 10
Insert: 9
Insert: 15
Insert: 20
Remove: 9
Left rotation
Traverse:
  10 
  15 
  20 


In [42]:
# remove - triggers rebalance
avl = AVL()
avl.insert(10)
avl.insert(9)
avl.insert(8)
avl.insert(7)
avl.insert(15)
avl.remove(10)
avl.traverse()

Insert: 10
Insert: 9
Insert: 8
Right rotation
Insert: 7
Insert: 15
Remove: 10
Traverse:
  7 
  8 
  9 
  15 
