In [21]:
class BSTnode:
    def __init__(self, key = None, parent = None):
        '''Initialize a BST node'''
        self.key = key
        self.parent = parent
        self.left = None
        self.right = None
        
    def find(self, key):
        if not self or self.key == key:
            return self
        elif self.key < key:
            if self.right:
                return self.right.find(key)
        else:
            if self.left:
                return self.left.find(key)
            
    def successor(self):
        if self.right:
            return self.right.minimum()
        else:
            while self.parent and self is not self.parent.left:
                self = self.parent
            return self.parent
        
    def minimum(self):
        while self.left:
            self = self.left            
        return self
    
    def maximum(self):
        while self.right:
            self = self.right
        return self


    def __str__(self):
        '''Return ASCII drawing of the tree'''
        s = str(self.key)
        if self.left is None and self.right is None:
            return s
        sl, sr = [''], ['']
        if self.left:
            s = '_' + s
            sl = str(self.left).split('\n')
        if self.right:
            s = s + '_'
            sr = str(self.right).split('\n')
        wl, cl = len(sl[0]), len(sl[0].lstrip(' _'))
        wr, cr = len(sr[0]), len(sr[0].rstrip(' _'))
        a = [(' ' * (wl - cl)) + ('_' * cl) + s +
             ('_' * cr) + (' ' * (wr - cr))]
        for i in range(max(len(sl), len(sr))):
            ls = sl[i] if i < len(sl) else ' ' * wl
            rs = sr[i] if i < len(sr) else ' ' * wr
            a.append(ls + ' ' * len(s) + rs) 
        return '\n'.join(a)
        
    
    

In [22]:
class BST:
    def __init__(self):
        self.root = None
                    
        
    def insert(self, node):
        parent = None
        child = self.root
        
        while child:
            parent = child
            if node.key == child.key:
                return
            elif node.key < child.key:
                child = child.left       
            else:
                child = child.right
                
        if not parent:
            self.root = node
        elif not parent.left:
            parent.left = node
        else:
            parent.right = node
        node.parent = parent
        
    def insert_rec(self, key, node):
        if not node.key:
            node.key = key
            return
        if key < node.key:
            if not node.left:
                node.left = node.__class__(None, node)
            self.insert_rec(key, node.left)
        elif key > node.key:
            if not node.right:
                node.right = node.__class__(None, node)
            self.insert_rec(key, node.right)
          
    def transplant(self, x, y):
        if not x.parent:
            self.root = y         
        elif x is x.parent.right:
            x.parent.right = y
        else:
            x.parent.left = y
        if y:    
            y.parent = x.parent
            
    def delete(self, node):
        if not node.left:
            self.transplant(node, node.right)
        elif not node.right:
            self.transplant(node, node.left)
        else:
            y = node.right.minimum()
            if y.parent is not node:
                self.transplant(y, y.right)
                y.right = node.right
                y.right.parent = y
            self.transplant(node, y)
            y.left = node.left
            y.left.parent = y
       
            
        




In [27]:
class AVLnode(BSTnode):
    def __init__(self, key=None, parent=None):
        super().__init__(key, parent)
        self.height = None
        self.skew = None
    
            
    def update(self):
        right_height = self.right.height if self.right else -1
        left_height = self.left.height if self.left else -1
        self.height = 1 + max(right_height, left_height)
        self.skew = right_height - left_height
    
    def __str__(self):
        '''Return ASCII drawing of the tree (visualize skew)'''
        key = self.key
        self.key = str(key) + (
            '=' if self.skew == 0 else
            '>' if self.skew < 0 else 
            '<')
        s = super().__str__()
        self.key = key
        return s



In [31]:
class AVL(BST):
        
    
    def insert(self, node):
        super().insert(node)
        node.maintain()
        
    def insert_rec(self, key, node):
        if not node.key:
            node.key = key
            tree.maintain(node)
            return
        if key < node.key:
            if not node.left:
                node.left = node.__class__(None, node)
            self.insert_rec(key, node.left)
        elif key > node.key:
            if not node.right:
                node.right = node.__class__(None, node)
            self.insert_rec(key, node.right)

    def delete(self, node):
        super().delete(node)
        self.maintain(node)
    
    def maintain(self, node):
        node.update()
        self.balance(node)
        if node.parent:
            self.maintain(node.parent)
    
    
    def balance(self, node):
        if node.skew == 2:
            if node.right.skew == -1:
                self.rotate_right(node.right)
            self.rotate_left(node)
        elif node.skew == -2:
            if node.left.skew == 1:
                self.rotate_left(node.left)
            self.rotate_right(node)
            
        

    def rotate_left(self, x):
        y = x.right
        
        x.right = y.left
        if y.left:
            y.left.parent = x
            
        y.parent = x.parent
        if not x.parent:
            self.root = y
        elif x is x.parent.right:
            x.parent.right = y
        else:
            x.parent.left = y
            
        x.parent = y
        y.left = x
        
        x.update()
        y.update()

        
    def rotate_right(self, x):
        y = x.left
        
        x.left = y.right
        if y.right:
            y.right.parent = x
            
        y.parent = x.parent
        if not x.parent:
            self.root = y
        elif x is x.parent.right:
            x.parent.right = y
        else:
            x.parent.left = y
            
        x.parent = y
        y.right = x
        
        x.update()
        y.update()
        

        
    


   

In [28]:
from random import sample, choice

keys = sample(range(50), 15)

tree = AVL()
tree.root = AVLnode(9)

print(keys)

for key in keys:
    tree.insert_rec(key, tree.root)
    
print(tree.root)

[23, 8, 43, 28, 0, 25, 14, 5, 19, 49, 17, 44, 35, 45, 16]
      ___________________23>________________            
   ___9<____________             ________43=________    
___8>    ________17>____     ____28=____     ____45=____
5=       14<____     19=     25=     35=     44=     49=
             16=                                        


In [30]:
node = tree.root.find(23) 
tree.delete(node)
print(tree.root)

      ___________________25=____________            
   ___9<____________         ________43=________    
___8>    ________17>____     28=____     ____45=____
5=       14<____     19=         35=     44=     49=
             16=                                    
