In [139]:
#    A  ==>    B
#  B         C    A
# C

#     A  ==>      A
#  B            C
#     C       D
#  D        B


class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None


class BST:
    def __init__(self):
        self.root = None

    def see(self, root_node=None):
        if not root_node:
            root_node = self.root
        if not root_node:
            raise ValueError('No root found!')
        queue = [root_node]
        while queue:
            popped = queue.pop(0)
            print(popped.value)
            if popped.left: queue.append(popped.left)
            if popped.right: queue.append(popped.right)

    def insert_travel(self, value, current):
        if not current:
            return Node(value)
        if value < current.value:
            current.left = self.insert_travel(value, current.left)
        if value > current.value:
            current.right = self.insert_travel(value, current.right)
        return current

    def height(self, root_node):
        if not root_node:
            return -1
        return 1 + max(self.height(root_node.left), self.height(root_node.right))

    def find_min(self, root_node=None):
        if not root_node:
            root_node = self.root
        while root_node.left:
            root_node = root_node.left
        return root_node

    def find_max(self, root_node=None):
        if not root_node:
            root_node = self.root
        while root_node.right:
            root_node = root_node.right
        return root_node

    def check_balance(self, root_node=None):
        if not root_node:
            root_node = self.root
        return self.balance_travel(root_node)

    def balance_travel(self, current):
        def inner(current, parent=None, side=None):
            if not current:
                return True, -1, None, None, None
            left_ok, left_height, node, node_parent, node_side = inner(current.left, current, 'left')
            if not left_ok:
                return False, None, node, node_parent, node_side
            right_ok, right_height, node, node_parent, node_side = inner(current.right, current, 'right')
            if not right_ok:
                return False, None, node, node_parent, node_side
            if not abs(left_height - right_height) <= 1:
                return False, None, current, parent, side
            return True, max(left_height, right_height) + 1, None, None, None
        is_balanced, _, current, node_parent, side = inner(current)
        return is_balanced, current, node_parent, side

    def insert(self, value):
        self.root = self.insert_travel(value, self.root)
        is_balanced, current, parent, side = self.check_balance(self.root)
        while not is_balanced:
            root_node_rotate = self.rotate(current)
            if parent:
                if side == 'left': parent.left = root_node_rotate
                if side == 'right': parent.right = root_node_rotate
            else:
                self.root = root_node_rotate
            is_balanced, current, parent, side = self.check_balance(self.root)

    def rotate(self, root_node):
        left_height = self.height(root_node.left)
        right_height = self.height(root_node.right)
        if left_height - right_height > 0:
            return self.heavy_left(root_node)
        if left_height - right_height < 0:
            return self.heavy_right(root_node)

    def heavy_left(self, root_node):
        count = 0
        while root_node.left.right:
            node_left, node_right = root_node.left, root_node.left.right
            root_node.left = node_right
            node_left.right = None
            node_min = self.find_min(node_right)
            node_min.left = node_left
            count += 1
        count = 1 if not count else count
        for _ in range(count):
            node_left = root_node.left
            root_node.left = None
            node_left.right = root_node
            root_node = node_left
        return root_node

    def heavy_right(self, root_node):
        count = 0
        while root_node.right.left:
            node_right, node_left = root_node.right, root_node.right.left
            root_node.right = node_left
            node_right.left = None
            node_max = self.find_max(node_left)
            node_max.right = node_right
            count += 1
        count = 1 if not count else count
        for _ in range(count):
            node_right = root_node.right
            root_node.right = None
            node_right.left = root_node
            root_node = node_right
        return root_node


bst = BST()

numbers = [30, 20, 10]
numbers = [20, 30, 10, 9, 8]
numbers = [16, 10, 20, 8, 12, 17, 25, 6, 5]
numbers = [15, 10, 20, 8, 12, 17, 25, 27, 29, 31, 32, 55, 77]
numbers = [16, 10, 20, 8, 12, 17, 25, 6, 5, 9, 11, 13, 30, 15, 14]
numbers = [16, 10, 20, 8, 12, 17, 25, 6, 5, 9]
for number in numbers:
    bst.insert(number)

print('---')
bst.see()

print('---')


print('Balanced?', bst.check_balance()[0])


1 --> 10
left
2 --> 8
1 --> 16
left
2 --> 10
1 --> 8
right
2 --> 9
---
16
8
20
6
10
17
25
5
9
12
---
Balanced? True
