In [None]:
# class that represents a tree node
class Node:
    def __init__(self, height, left, right, key):
        self.height = height
        self.left = left
        self.right = right
        self.key = key


# class that represents an AVL tree
class AVL_tree:
    
    def __init__(self, root):
        self.root = root
    

    # method that adds a node into the AVL tree
    def search_node(self, value):
        stack = []
        t = self.root
        stack.append([t,'o'])        
        while True: 
            if t == None: return stack
            if value > t.key: 
                t = t.right
                stack.append([t,'r'])
            elif value < t.key: 
                t = t.left
                stack.append([t,'l'])
            else: return stack 


    # method that computes the height of a node t and also detects unbalanced
    # between the left and right childs
    def compute_height(self, t):
        left_height = right_height = 0  # height = 0 if node has empty subtree
        if t.left:
            left_height = t.left.height
        if t.right:
            right_height = t.right.height
        t.height = max(left_height, right_height) + 1
        if abs(left_height - right_height) > 1:
            return True
        else:
            return False
        

    # method that applies a rotation correction
    def rotation_tree(self,a,z,y,x):
        if z.right == y and y.right == x: # single (left) rotation
            z.right = y.left
            y.left = z
            parent = y
            # print("Case 1")
            # update heights starting from child nodes
            self.compute_height(z)
            self.compute_height(x)
            self.compute_height(y)
        elif z.left == y and y.left == x:   # single (right) rotation
            z.left = y.right
            y.right = z
            parent = y
            # print("Case 2")
            # update heights starting from child nodes
            self.compute_height(z)
            self.compute_height(x)
            self.compute_height(y)
        elif z.right == y and y.left == x:  # double (right-left) rotation
            y.left = x.right
            z.right = x.left
            x.left = z
            x.right = y 
            parent = x
            # print("Case 3")
            # update heights starting from child nodes
            self.compute_height(z)
            self.compute_height(y)
            self.compute_height(x)
        elif z.left == y and y.right == x:  # double (left-right) rotation
            y.right = x.left
            z.left = x.right
            x.left = y
            x.right = z
            parent = x
            # print("Case 4")
            # update heights starting from child nodes
            self.compute_height(z)
            self.compute_height(y)
            self.compute_height(x)

        # print(a.key)
        # print(z.key)
        # print(y.key)
        # print(x.key)

        # Reassign parent node to the respective
        # left or right child of ancestor node a
        if a is None:   # parent node is root so reassign parent as root
            self.root = parent
            self.compute_height(self.root)
        elif a.right == z:
            a.right = parent
            self.compute_height(a)
        elif a.left == z:
            a.left = parent
            self.compute_height(a)


    # method that will look for possible unbalances in the tree after a node has been added
    def backtrack_height_from_add(self, path):
        # Backtrack from inserted node to root while 
        # checking for unbalances and performing the necessary
        # rotations 
        for i in range(len(path) - 1, -1, -1):
            unbalanced = self.compute_height(path[i][0])
            if unbalanced:
                if i == 0:  # If z is root with no parent node
                    a = None  
                else:
                    a = path[i - 1][0]
                z = path[i][0]  # unbalanced node
                y = path[i + 1][0]
                x = path[i + 2][0]

                self.rotation_tree(a, z, y, x)
                break   # stop since only 1 correction needed

    # method that will look for possible unbalances in the tree after a node has been removed
    def backtrack_height_from_remove(self, path):
        # Backtrack from parent of removed node to root while
        # checking for unbalances and performing the necessary
        # rotations
        for i in range(len(path) - 1, -1, -1):
            unbalanced = self.compute_height(path[i][0])
            if unbalanced:
                if i == 0:
                    a = None
                else:
                    a = path[i - 1][0]
                z = path[i][0]
                # if z has only 1 children, only 1 choice of y
                if z.right and z.left is None:
                    y = z.right
                elif z.left and z.right is None:
                    y = z.left
                # if z has 2 children, y is taller child of z
                elif z.right.height > z.left.height:
                    y = z.right
                else:
                    y = z.left
                # if y has only 1 children, only 1 choice of x
                if y.right and y.left is None:
                    x = y.right
                elif y.left and y.right is None:
                    x = y.left
                else:
                    # if y has 2 children, x is taller child of y
                    if y.right.height > y.left.height:
                        x = y.right
                    elif y.left.height > y.right.height:
                        x = y.left
                    else:
                        # if heights of children are equal, pick the
                        # one that allows for a single rotation
                        if z.right == y:
                            x = y.right
                        elif z.left == y:
                            x = y.left
                
                self.rotation_tree(a, z, y, x)
                    
    
    # method that adds a node into the AVL tree
    def add_node(self, value):
        node = Node(1, None, None, value)
        # If tree is empty, insert as root
        if self.root is None:
            self.root = node
        else:
            path = self.search_node(value)
            if path[-1][0]:
                # print("Key already exists!")
                return None
            # Node to be inserted belongs to right
            # child of parent
            if path[-1][1] == "r":
                path[-2][0].right = node
                path[-1][0] = node
            else:
                # Node to be inserted belongs to left
                # child of parent
                path[-2][0].left = node
                path[-1][0] = node

            # Check and apply necessary corrections
            self.backtrack_height_from_add(path)
        return node


    # method that removes a node from the AVL tree
    def remove_node(self, value):
        path = self.search_node(value)
        node = path[-1][0]
        if node is None:
            # print("Key does not exists!")
            return None
   
        if node.left is None and node.right is None:    # node has no children
            if node == self.root:
                self.root = None
            elif path[-1][1] == "r":
                path[-2][0].right = None
                path.pop()
            else:
                path[-2][0].left = None 
                path.pop()  

            # Check and apply necessary corrections
            self.backtrack_height_from_remove(path)

        elif node.left and node.right is None:  # node has left child
            if node == self.root:
                self.root = node.left
            elif path[-1][1] == "r":
                path[-2][0].right = node.left
                path.pop()
            else:
                path[-2][0].left = node.left
                path.pop()

            # Check and apply necessary corrections
            self.backtrack_height_from_remove(path)

        elif node.right and node.left is None:  # node has right child
            if node == self.root:
                self.root = node.right
            if path[-1][1] == "r":
                path[-2][0].right = node.right
                path.pop()
            else:
                path[-2][0].left = node.right
                path.pop()

            # Check and apply necessary corrections
            self.backtrack_height_from_remove(path)
        
        else:   # node has both left and right child
            inorder_path = []  
            # node.right is root of right subtree of node
            inorder_path.append([node.right, 'r'])
            tmp = inorder_path[0][0]
            # Keep track of nodes as we perform the in-order traversal 
            while tmp.left is not None:
                tmp = tmp.left
                inorder_path.append([tmp, 'l'])

            # If node.right does not have a left subtree, we copy its
            # key as the node's new key and replace it with its right subtree
            if len(inorder_path) == 1:
                # If node.right has no children, we copy its key as the
                # node's new key and delete the node
                if node.right.left is None and node.right.right is None:
                    node.key = node.right.key
                    node.right = None
                    inorder_path.pop()
                else:
                    node.key = node.right.key
                    node.right = node.right.right
                    inorder_path.pop()
            else:
                # If node.right has a left subtree, we copy the key of
                # its leftmost child as the node's new key and replace it
                # with its right child (if any)
                node.key = inorder_path[-1][0].key
                inorder_path[-2][0].left = inorder_path[-1][0].right
                inorder_path.pop()

            # Check and apply necessary corrections 
            self.backtrack_height_from_remove(path + inorder_path)

        return node


In [None]:
def _build_tree_string(root, curr_index, index=False, delimiter='-'):
    if root is None:
        return [], 0, 0, 0

    line1 = []
    line2 = []
    if index:
        node_repr = '{}{}{}'.format(curr_index, delimiter, root.key)
    else:
        try:
            node_repr = str(root.key) + "(" + str(root.height) + ")"
        except AttributeError:
            node_repr = str(root.key)

    new_root_width = gap_size = len(node_repr)

    if root.left == root or root.right == root: 
        print("ouch")
        input("frefref")
    # Get the left and right sub-boxes, their widths, and root repr positions
    l_box, l_box_width, l_root_start, l_root_end = \
        _build_tree_string(root.left, 2 * curr_index + 1, index, delimiter)
    r_box, r_box_width, r_root_start, r_root_end = \
        _build_tree_string(root.right, 2 * curr_index + 2, index, delimiter)

    # Draw the branch connecting the current root node to the left sub-box
    # Pad the line with whitespaces where necessary
    if l_box_width > 0:
        l_root = (l_root_start + l_root_end) // 2 + 1
        line1.append(' ' * (l_root + 1))
        line1.append('_' * (l_box_width - l_root))
        line2.append(' ' * l_root + '/')
        line2.append(' ' * (l_box_width - l_root))
        new_root_start = l_box_width + 1
        gap_size += 1
    else:
        new_root_start = 0

    # Draw the representation of the current root node
    line1.append(node_repr)
    line2.append(' ' * new_root_width)

    # Draw the branch connecting the current root node to the right sub-box
    # Pad the line with whitespaces where necessary
    if r_box_width > 0:
        r_root = (r_root_start + r_root_end) // 2
        line1.append('_' * r_root)
        line1.append(' ' * (r_box_width - r_root + 1))
        line2.append(' ' * r_root + '\\')
        line2.append(' ' * (r_box_width - r_root))
        gap_size += 1
    new_root_end = new_root_start + new_root_width - 1

    # Combine the left and right sub-boxes with the branches drawn above
    gap = ' ' * gap_size
    new_box = [''.join(line1), ''.join(line2)]
    for i in range(max(len(l_box), len(r_box))):
        l_line = l_box[i] if i < len(l_box) else ' ' * l_box_width
        r_line = r_box[i] if i < len(r_box) else ' ' * r_box_width
        new_box.append(l_line + gap + r_line)

    # Return the new box, its width and its root repr positions
    return new_box, len(new_box[0]), new_root_start, new_root_end

def print_tree(tree):
    lines = _build_tree_string(tree.root, 0, False, "-")[0]
    print('\n' + '\n'.join((line.rstrip() for line in lines)))

In [None]:
import random

# 1st step
print("\n\n ******************* 1st STEP: search node *******************")
# code that creates the same tree as in the lecture nodes on AVL trees
N54 = Node(1,None,None,54)
N39 = Node(1,None,None,39)
N24 = Node(1,None,None,24)
N71 = Node(1,None,None,71)
N45 = Node(2,N39,N54,45)
N6  = Node(2,None,N24,6)
N67 = Node(3,N45,N71,67)
N33 = Node(4,N6,N67,33)
my_AVL_tree = AVL_tree(N33)
print_tree(my_AVL_tree)
 
# code that searches node 54 (in order to test the search_node function)
print("searching " + str(54) + " ... " ,end="")
stack= my_AVL_tree.search_node(54)
print(stack)
if stack == []: print("ERROR: empty list returned by search_node")
elif stack == None: print("ERROR: None returned by search_node")
elif stack[-1][0]!= None: 
    print("found this: ",end="")
    print(stack[-1][0].key)
    print("Here is the path to find it: ")
    for i in stack: print("[" +str(i[0].key)+ " , " +i[1]+ "]")
    if stack[0][0].key != 33 or stack[0][1] != 'o': print("ERROR: wrong output of search_node function")
    if stack[1][0].key != 67 or stack[1][1] != 'r': print("ERROR: wrong output of search_node function")
    if stack[2][0].key != 45 or stack[2][1] != 'l': print("ERROR: wrong output of search_node function")
    if stack[3][0].key != 54 or stack[3][1] != 'r': print("ERROR: wrong output of search_node function")
else: print("ERROR: found nothing :-/")
print("")

# code that searches node 55 (in order to test the search_node function)
print("\nsearching " + str(55) + " ... " ,end="")
stack= my_AVL_tree.search_node(55)
print(stack)


# 2nd step
print("\n\n ******************* 2nd STEP: add node *******************")
# code that adds node 62 (in order to test the add_node function)
print("adding " + str(62))
t = my_AVL_tree.add_node(62)
if t == None: print("ERROR: no node is returned from add_node function")
print_tree(my_AVL_tree)
print("")


# 3rd step
print("\n\n ******************* 3rd STEP: remove node *******************")
# code that removes node 24 (in order to test the remove_node function)
print("removing " + str(24))
t = my_AVL_tree.remove_node(24)
if t == None: print("ERROR: no node is returned from remove_node function")
print_tree(my_AVL_tree)
print("")


# 4th step
print("\n\n ******************* 4th STEP: general test *******************")
# code that randomly adds and removes nodes in the AVL tree (to test that all is working fine)
my_AVL_tree = AVL_tree(None)
L = []
for j in range(3):
    for i in range (10): 
        v = random.randint(0,99)
        print("adding " + str(v),end="")
        if my_AVL_tree.add_node(v) != None:
            L.append(v)
        print_tree(my_AVL_tree)
    
    for i in range(2):
        if len(L)>1:
            v = random.randint(0,len(L)-1)
        else: 
            print("Empty tree, can't remove a node !")
            break
        print("removing " + str(L[v]),end="")
        if my_AVL_tree.remove_node(L[v]) != None:
            L.remove(L[v])   
        print_tree(my_AVL_tree)



 ******************* 1st STEP: search node *******************

   ________33(4)____________________
  /                                 \
6(2)__                     ________67(3)__
      \                   /               \
     24(1)           __45(2)__           71(1)
                    /         \
                 39(1)       54(1)

searching 54 ... [[<__main__.Node object at 0x7f48029a4c90>, 'o'], [<__main__.Node object at 0x7f48029a4c50>, 'r'], [<__main__.Node object at 0x7f48029a4b90>, 'l'], [<__main__.Node object at 0x7f48029a4910>, 'r']]
found this: 54
Here is the path to find it: 
[33 , o]
[67 , r]
[45 , l]
[54 , r]


searching 55 ... [[<__main__.Node object at 0x7f48029a4c90>, 'o'], [<__main__.Node object at 0x7f48029a4c50>, 'r'], [<__main__.Node object at 0x7f48029a4b90>, 'l'], [<__main__.Node object at 0x7f48029a4910>, 'r'], [None, 'r']]


 ******************* 2nd STEP: add node *******************
adding 62

   ________33(4)______________
  /                          