# AVL tree

Solutions from other ppl:
* http://www.cs.toronto.edu/~rdanek/csc148h_09/lectures/8/bst.py
* https://github.com/bfaure/Python3_Data_Structures/blob/master/Binary_Search_Tree/main.py (he also has a [video]( https://www.youtube.com/watch?v=Zaf8EOVa72I) coding all that)

In [123]:
class Tree():
    class Node():
        def __init__(self, key=None, parent=None):
            self.left = None
            self.right = None
            self.key = key
            self.parent = parent
            
        def __str__(self):
            return str(self.key) + ':(' + str(self.left) + ',' + str(self.right) + ')'

    def __init__(self):
        self.root = None
        
    def __str__(self):
        return str(self.root)
    
    def inorder(self):
        return self._inorder(self.root)
    def _inorder(self,node):
        if node is None:
            return []
        return self._inorder(node.left) + [node.key] + self._inorder(node.right)
    
    def height(self,node):
        """Hight of a node"""
        if node is None: return 0
        return 1 + max(self.height(node.left), self.height(node.right))
    
    def draw(self):
        """Attempts to draw a tree."""
        h = self.height(self.root)
        old = [self.root]
        for i in range(h):
            gap = 2**(h-i)-1
            print(' '*(gap//2) + (' '*gap).join([str(node.key) if node is not None else '.' for node in old]))            
            new = []
            for node in old:
                if node is None: 
                    new += [None,None]
                    continue
                if node.left is not None:  
                    new.append(node.left)
                else:
                    new.append(None)
                if node.right is not None: 
                    new.append(node.right)
                else:
                    new.append(None)
            old = new.copy()

    def get(self,key):
        return self._get(self.root,key)
    def _get(self, key, root=None):
        if root is None:    return None
        if key==root.key:   return root
        if key>root.key:    return self._get(key, root.right)
        else:               return self._get(key, root.left)
        
    def balance(self,root):
        if root is None: return
        if abs(self.height(root.left)-self.height(root.right))<=1: # Already balanced
            return root

        def rotate_right(y):
            x = y.left
            t2 = x.right
            x.right = y
            y.left = t2
            return x
        def rotate_left(x):
            y = x.right
            t2 = y.left
            y.left = x
            x.right = t2
            return y            
        
        # For node insertion, if we still remember the key that was added,
        # these tests can be replaced with a comparison of the new key with old keys 
        # (like, where the new node is, relative to the root. Is it LL, LR, ...?
        if self.height(root.left)>self.height(root.right):
            if self.height(node.left.left)<self.height(node.left.right): # LR
                root.left = rotate_left(root.left)
            return rotate_right(root) # For both LL and LR
        else:
            if self.height(root.right.left)>self.height(root.right.right): # RL
                root.right = rotate_right(root.right)
            else:
                print('RR')
            return rotate_left(root) # For both RL and RR
        return

    def add(self, key, root=None):
        if self.root is None: # Special case for the very first node
            self.root = self.Node(key)
            return self.root
        if root is None: 
            self.root = self.add(key,self.root) # Non-recurrent call to recurrent
            return
            
        if key==root.key:   return root # The key exists already
        if key< root.key:
            if root.left is not None:
                root.left = self.add(key, root.left)
                return self.balance(root)
            else:
                root.left = self.Node(key)
                root.left.parent = root
                return root
        else:
            if root.right is not None:
                root.right = self.add(key, root.right)
                return self.balance(root)
            else:
                root.right = self.Node(key)
                root.right.parent = root
                return root
            
    def delete(self, key):
        def find(key, node):
            """Recursive finder."""
            if node is None: return None          # Nothing to delete
            if key==node.key: return node
            if key<node.key: return find(key, node.left)
            else: return find(key, node.right)
        doomed = find(key, self.root)
        self._delete(doomed) 
    def _delete(self, doomed):
        parent = doomed.parent
        def find_min(node): # Find the successor node and its parent
            current = node.right
            while current.left is not None:
                current = current.left
            return current  
        
        if doomed is None: return False
        if doomed.left is None and doomed.right is None:
            child = None
        elif doomed.left is None and doomed.right is not None:
            child = doomed.right
            child.parent = parent
        elif doomed.left is not None and doomed.right is None:
            child = doomed.left
            child.parent = parent
        else: # The doomed node has 2 real branches:
            successor = find_min(doomed) # Successor node and its parent
            child = doomed               # Formerly known as doomed, is now a child
            child.key = successor.key    # ..coz we copy the data
            # No need to set the parent as we never deleted this node, just copied the data
            self._delete(successor) # Recursive deletion
        # Now delete the doomed node by replacing it with a child
        if parent is None:
            root = child
        elif parent.left == doomed:
            parent.left = child
        else:
            parent.right = child

# Test
t = Tree()
for i in [1,2,3,4,5,6,7]:
    t.add(i)
print(t)
t.draw()

RR
RR
RR
RR
4:(2:(1:(None,None),3:(None,None)),6:(5:(None,None),7:(None,None)))
   4
 2   6
1 3 5 7


In [124]:
# Test
t = Tree()
for i in [3,1,5,0,2,4,6,9,8]: # Balanced
    t.add(i)
print(t)
t.draw()
print(t.inorder())
t.delete(1)
t.delete(4)
print('\n',t)
print(t.inorder())
t.draw()

3:(1:(0:(None,None),2:(None,None)),5:(4:(None,None),8:(6:(None,None),9:(None,None))))
       3
   1       5
 0   2   4   8
. . . . . . 6 9
[0, 1, 2, 3, 4, 5, 6, 8, 9]

 3:(2:(0:(None,None),None),5:(None,8:(6:(None,None),9:(None,None))))
[0, 2, 3, 5, 6, 8, 9]
       3
   2       5
 0   .   .   8
. . . . . . 6 9
