<a href="https://colab.research.google.com/github/newmantic/avl_tree/blob/main/avl_tree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
class Node:
    def __init__(self, key):
        self.key = key
        self.left = None
        self.right = None
        self.height = 1

class AVLTree:
    def get_height(self, node):
        if not node:
            return 0
        return node.height

    def update_height(self, node):
        node.height = 1 + max(self.get_height(node.left), self.get_height(node.right))

    def get_balance(self, node):
        if not node:
            return 0
        return self.get_height(node.left) - self.get_height(node.right)

    def rotate_right(self, y):
        x = y.left
        T2 = x.right

        # Perform rotation
        x.right = y
        y.left = T2

        # Update heights
        self.update_height(y)
        self.update_height(x)

        # Return the new root
        return x

    def rotate_left(self, x):
        y = x.right
        T2 = y.left

        # Perform rotation
        y.left = x
        x.right = T2

        # Update heights
        self.update_height(x)
        self.update_height(y)

        # Return the new root
        return y

    def insert(self, root, key):
        # Perform normal BST insertion
        if not root:
            return Node(key)
        elif key < root.key:
            root.left = self.insert(root.left, key)
        else:
            root.right = self.insert(root.right, key)

        # Update height of the ancestor node
        self.update_height(root)

        # Get the balance factor
        balance = self.get_balance(root)

        # If the node becomes unbalanced, then there are 4 cases

        # Left Left Case
        if balance > 1 and key < root.left.key:
            return self.rotate_right(root)

        # Right Right Case
        if balance < -1 and key > root.right.key:
            return self.rotate_left(root)

        # Left Right Case
        if balance > 1 and key > root.left.key:
            root.left = self.rotate_left(root.left)
            return self.rotate_right(root)

        # Right Left Case
        if balance < -1 and key < root.right.key:
            root.right = self.rotate_right(root.right)
            return self.rotate_left(root)

        return root

    def min_value_node(self, node):
        if node is None or node.left is None:
            return node
        return self.min_value_node(node.left)

    def delete(self, root, key):
        # Perform standard BST delete
        if not root:
            return root
        elif key < root.key:
            root.left = self.delete(root.left, key)
        elif key > root.key:
            root.right = self.delete(root.right, key)
        else:
            if root.left is None:
                return root.right
            elif root.right is None:
                return root.left
            temp = self.min_value_node(root.right)
            root.key = temp.key
            root.right = self.delete(root.right, temp.key)

        # Update height of the current node
        self.update_height(root)

        # Get the balance factor
        balance = self.get_balance(root)

        # Balance the tree
        # Left Left Case
        if balance > 1 and self.get_balance(root.left) >= 0:
            return self.rotate_right(root)

        # Left Right Case
        if balance > 1 and self.get_balance(root.left) < 0:
            root.left = self.rotate_left(root.left)
            return self.rotate_right(root)

        # Right Right Case
        if balance < -1 and self.get_balance(root.right) <= 0:
            return self.rotate_left(root)

        # Right Left Case
        if balance < -1 and self.get_balance(root.right) > 0:
            root.right = self.rotate_right(root.right)
            return self.rotate_left(root)

        return root

    def inorder_traversal(self, root):
        if root is None:
            return []
        return self.inorder_traversal(root.left) + [root.key] + self.inorder_traversal(root.right)


In [2]:
# Example 1: Insertion
avl_tree = AVLTree()
root = None
keys = [10, 20, 30, 40, 50, 25]
for key in keys:
    root = avl_tree.insert(root, key)

# The in-order traversal should return a sorted sequence of inserted keys
print("Inorder traversal after insertion:", avl_tree.inorder_traversal(root))  # Should be [10, 20, 25, 30, 40, 50]

# Example 2: Deletion
root = avl_tree.delete(root, 10)
print("Inorder traversal after deleting 10:", avl_tree.inorder_traversal(root))  # Should be [20, 25, 30, 40, 50]

root = avl_tree.delete(root, 30)
print("Inorder traversal after deleting 30:", avl_tree.inorder_traversal(root))  # Should be [20, 25, 40, 50]

root = avl_tree.delete(root, 50)
print("Inorder traversal after deleting 50:", avl_tree.inorder_traversal(root))  # Should be [20, 25, 40]

# Example 3: Edge Case - Single Node
single_node_tree = AVLTree()
single_root = single_node_tree.insert(None, 10)
single_root = single_node_tree.delete(single_root, 10)
print("Inorder traversal after deleting the only node:", single_node_tree.inorder_traversal(single_root))  # Should be []

Inorder traversal after insertion: [10, 20, 25, 30, 40, 50]
Inorder traversal after deleting 10: [20, 25, 30, 40, 50]
Inorder traversal after deleting 30: [20, 25, 40, 50]
Inorder traversal after deleting 50: [20, 25, 40]
Inorder traversal after deleting the only node: []
