In [16]:
# self balancing tree - AVL
from collections import deque

class TreeNode:
    def __init__(self,data):
        self.data = data
        self.left = None
        self.right = None
        self.height = 1
       
class AVLTree: 
    def add_child(self,root,data):
        if root is None:
            return TreeNode(data)
        elif data < root.data:
            root.left = self.add_child(root.left,data)
        elif data > root.data:
            root.right = self.add_child(root.right,data)
            
        root.height = 1 + max(self.get_height(root.left), self.get_height(root.right))
        balance = self.get_balance(root)
        
        if balance > 1 and data < root.left.data:
            return self.right_rotate(root)
        if balance < -1 and data > root.right.data:
            return self.left_rotate(root)
        if balance > 1 and data > root.left.data:
            root.left = self.left_rotate(root.left)
            return self.right_rotate(root)
        if balance < -1 and data < root.right.data:
            root.right = self.right_rotate(root.right)
            return self.left_rotate(root)
        
        return root
    
    def get_balance(self, root):
        if not root:
            return 0
        return self.get_height(root.left) - self.get_height(root.right)
    
    def left_rotate(self, z):
        y = z.right
        T2 = y.left
        
        #rotation
        y.left = z
        z.right = T2
        
        y.height = 1 + max(self.get_height(y.left),self.get_height(y.right))
        z.height = 1 + max(self.get_height(z.left),self.get_height(z.right))
        
        return y
    
    def right_rotate(self,z):
        y = z.left
        T3 = y.right
        
        #rotation
        y.right = z
        z.left = T3
        
        y.height = 1 + max(self.get_height(y.left),self.get_height(y.right))
        z.height = 1 + max(self.get_height(z.left),self.get_height(z.right))
        
        return y
    
    def search(self,root,data):
        if not root:
            return False
        
        if root.data == data:
            return True
        
        if data < root.data:
            if root.left:
                if self.search(root.left,data):
                    return self.search(root.left,data)    
        elif data > root.data:
            if root.right:
                if self.search(root.right,data):
                    return self.search(root.right,data)
            
        return False
    
    def inorder_traversal(self,root):
        elements = []
        
        if root.left:
            elements += self.inorder_traversal(root.left)
            
        elements.append(root.data)
            
        if root.right:
            elements += self.inorder_traversal(root.right)
            
        return elements
    
    def preorder_traversal(self,root):
        elements = []
        
        elements.append(root.data)
        
        if root.left:
            elements += self.preorder_traversal(root.left)
            
        if root.right:
            elements += self.preorder_traversal(root.right)
            
        return elements
    
    def postorder_traversal(self,root):
        elements = []
        
        if root.left:
            elements += self.postorder_traversal(root.left)
            
        if root.right:
            elements += self.postorder_traversal(root.right)
            
        elements.append(root.data)
            
        return elements
    
    def calculate_sum(self,root):
        left = right = 0
        if root.left:
            left += self.calculate_sum(root.left)
            
        if root.right:
            right += self.calculate_sum(root.right)
            
        return left+right+root.data
            
    def get_min(self,root):
        if not root:
            return None
        
        if root.left:
            return self.get_min(root.left)
        else:
            return root.data
    
    def get_max(self,root):
        if not root:
            return None
        
        if root.right:
            return self.get_max(root.right)
        else:
            return root.data
    
    def get_height(self,root):
        if not root:
            return 0
        
        return root.height
    
    def get_diameter(self,root):
        ldepth = rdepth = 0
        ldiameter = rdiameter = 0
        
        if root.left:
            ldepth = self.get_depth(root.left)
            ldiameter = self.get_diameter(root.left)
            
        if root.right:
            rdepth = self.get_depth(root.right)
            rdiameter = self.get_diameter(root.right)
            
        return max(ldepth+rdepth+1,max(ldiameter,rdiameter))
    
    def get_depth(self,root):
        ldepth = rdepth = 0
        
        if root.left:
            ldepth = self.get_depth(root.left)
            
        if root.right:
            rdepth = self.get_depth(root.right)
            
        return max(ldepth,rdepth) + 1
    
    def get_count_node(self,root):
        count = 0
        
        if root.left:
            count += self.get_count_node(root.left)
            
        if root.right:
            count += self.get_count_node(root.right)
            
        return count + 1
    
    def bfs(self,root,data):
        if not root:
            return False
        
        q = []
        q.append(root)
        while q:
            len_q = len(q)
            for i in range(len_q):
                node = q.pop(0)
                
                if node.data == data:
                    return True
                
                if node.left:
                    q.append(node.left)
                    
                if node.right:
                    q.append(node.right)
                    
        return False
    
    def delete(self,root,data):
        if not root:
            return None
        elif data < root.data:
            root.left = self.delete(root.left, data)
        elif data > root.data:
            root.right = self.delete(root.right, data)
        else:
            if root.left is None:
                tmp = root.right 
                root = None
                return tmp
            elif root.right is None:
                tmp = root.left
                root = None
                return tmp
            
            min_val = self.get_min(root.right)
            root.data = min_val
            root.right = self.delete(root.right,min_val)
            
        if not root:
            return None
        
        root.height = 1 + max(self.get_height(root.left),self.get_height(root.right))
        balance = self.get_balance(root)
        
        if balance > 1 and self.get_balance(root.left) >= 0:
            return self.right_rotate(root)
        if balance < -1 and self.get_balance(root.right) <= 0:
            return self.left_rotate(root)
        if balance > 1 and self.get_balance(root.left) < 0:
            root.left = self.left_rotate(root.left)
            return self.right_rotate(root)
        if balance < -1 and self.get_balance(root.right) > 0:
            root.right = self.right_rotate(root.right)
            return self.left_rotate(root)
        
        return root
    
def build_tree(arr, tree, root):
    if not arr:
        print("No elements to build the tree!")
        return
    
    for i in arr:
        root = tree.add_child(root,i)
    return root
        
        
arr = [17, 4, 1, 20, 9, 23, 18, 34, -15]
root = None    
tree = AVLTree()
root = build_tree(arr,tree,root)
print(tree.inorder_traversal(root))
print(tree.preorder_traversal(root))
print(tree.postorder_traversal(root))
print(tree.search(root,12))
print(tree.search(root,1))
print(tree.calculate_sum(root))
print(tree.get_height(root))
print(tree.get_depth(root))
print(tree.get_diameter(root))
print(tree.get_count_node(root))
print(tree.inorder_traversal(root))
tree.delete(root,17)
print(tree.inorder_traversal(root))
print(tree.bfs(root,18))
print(tree.bfs(root,21))

[-15, 1, 4, 9, 17, 18, 20, 23, 34]
[17, 4, 1, -15, 9, 20, 18, 23, 34]
[-15, 1, 9, 4, 18, 34, 23, 20, 17]
False
True
111
4
4
7
9
[-15, 1, 4, 9, 17, 18, 20, 23, 34]
[-15, 1, 4, 9, 18, 20, 23, 34]
True
False
