# Self Balancing BST

A self-balancing binary tree remains balanced after every insertion or deletion. Several decades of research has gone into creating self-balancing binary trees, and many approaches have been devised e.g. B-trees, Red Black Trees and AVL (Adelson-Velsky Landis) trees.


## Adelson-Velsky Landis (AVL) Tree

An AVL tree is a self balancing tree that re-adjusts itself during every insert to make sure that it is balanced.

It uses a `balance factor` to determine if a node is balanced or not. 

The `balance factor` can have values of 0, 1 or -1. Any other value means that the node is unbalanced.

There are 4 cases of imbalance possible during insertion -

1. **Left of Left (LL) Imbalance**


2. **Left of Right (LR) Imbalance**


3. **Right of Right (RR) Imbalance**


4. **Right of Left (RL) Imbalance**


![imbalances](../static/imbalance.png)


AVL only works on three nodes at a time, regardless of the size of the tree, whenever there is an imbalance.

Imbalance is resolved from bottom up.


## Problem Statement

> Create a self balancing binary search tree using AVL method. 

In [209]:
class BSTNode():
    def __init__(self, key, value=None):
        self.key = key
        self.value = value
        self.parent = None
        self.left = None
        self.right = None
        
    def __str__(self):
        return f"BSTNode(key = {str(self.key)}, value = {(str(self.value))})"
        
    def __eq__(self, tree: BSTNode):
        if self is None and tree is None:
            return True
        elif self is None and tree is not None:
            return False
        elif self is not None and tree is None:
            return False
        return self.key == tree.key and BSTNode.__eq__(self.left, tree.left) and BSTNode.__eq__(self.right, tree.right)
    
    @staticmethod    
    def parse_tuple(data):
        if data is None:
            node = None
        elif isinstance(data, tuple) and len(data) == 3:
            node = BSTNode(data[1])
            node.left = BSTNode.parse_tuple(data[0])
            node.right = BSTNode.parse_tuple(data[2])
        else:
            node = BSTNode(data)
        return node
    
    def height(self):
        if self is None:
            return 0
        return 1 + max(BSTNode.height(self.left), BSTNode.height(self.right))
    
    
    def calculate_balance_factor(self):
        left_height = 0 if self.left is None else self.left.height()
        right_height = 0 if self.right is None else self.right.height()
#         print("\n")
#         print("CALCULATING HEIGHT")
#         print("=="*15)
        
#         print(f"left node {self.left}")
#         print(f"left height {left_height}")
#         print(f"right node {self.right}")      
#         print(f"right height {right_height}")
        
        return left_height - right_height
    
    def display_keys(self, space='\t', level=0):
        # If the node is empty
        if self is None:
            print(space*level + '∅')
            return   

        # If the node is a leaf 
        if self.left is None and self.right is None:
            print(space*level + str(self.key))
            return

        # If the node has children
        BSTNode.display_keys(self.right, space, level+1)
        print(space*level + str(self.key))
        BSTNode.display_keys(self.left,space, level+1) 

## Express the problem in plain english

We need to create a binary search tree that will balance itself on inserting a new node in any order.

## Come up with test cases and cover exceptions as much as possible

1. Insert to left of tree and trigger LL imbalance
2. Insert to left of tree and trigger LR imbalance
3. Insert to right of tree and trigger RR imbalance
4. Insert to right of tree and trigger RL imbalance
5. Multiple imbalance
6. Imbalanced tree of height > 4


In [218]:
tests = [
    ## LL
    {
        "input": {
            "root": BSTNode.parse_tuple(((None, 10, None), 20, None)),
            "node": BSTNode(5)
        },
        "output": BSTNode.parse_tuple(((None, 5, None), 10, (None, 20, None)))
    },
    ## LR
    {
        "input": {
            "root": BSTNode.parse_tuple(((None, 10, None), 5, None)),
            "node": BSTNode(15)
        },
        "output": BSTNode.parse_tuple(((None, 10, None), 15, (None, 20, None)))
    },
    ## RR
    {
        "input": {
            "root": BSTNode.parse_tuple((None, 20, (None, 30, None))),
            "node": BSTNode(40)
        },
        "output": BSTNode.parse_tuple(((None, 20, None), 30, (None, 40, None)))
    },
    ## RL
    {
        "input": {
            "root": BSTNode.parse_tuple((None, 20, (None, 30, None))),
            "node": BSTNode(25)
        },
        "output": BSTNode.parse_tuple(((None, 20, None), 25, (None, 30, None)))
    },
    ## Multiple
    {
        "input": {
            "root": BSTNode.parse_tuple(((((None, 4, None), 5, (None, 8, None)), 10, (None, 11, None)), 13, (None, 15, (None, 16, None)))),
            "node": BSTNode(3)
        },
        "output": BSTNode.parse_tuple(((((None, 3, None), 4, None), 5, ((None, 8, None), 10, (None, 11, None))), 13, (None, 15, (None, 16, None))))
    },
]

In [212]:
for i in range(len(tests)):
    tests[i]["input"]["root"].display_keys()
    print("\n")
    tests[i]["output"].display_keys()
    print("\n"*2)

	∅
20
	10


	20
10
	5



	∅
5
	10


	20
15
	10



	30
20
	∅


	40
30
	20



	30
20
	∅


	30
25
	20



		16
	15
		∅
13
		11
	10
			8
		5
			4


		16
	15
		∅
13
			11
		10
			8
	5
			∅
		4
			3





In [213]:
def insert(root, node: BSTNode):
    pass

In [214]:
i = 0
for test in tests:
    print(f"test case no {i}")
    print(insert(**test["input"]) == test["output"])
    i += 1

test case no 0
False
test case no 1
False
test case no 2
False
test case no 3
False
test case no 4
False


## State the solution in plain english, implement it and evaluate against test cases

We should use normal binary search algorithm to determine where the input should be inserted, then we should calculate the balance factor of each node from bottom up.

if balance factor of current node > 1 then we have a case of left imbalance
- if balance factor of left node is -1 then we have LR imbalance
- else we have LL imbalance

if balance factor of current node < -1 then we have a case of left imbalance
- if balance factor of right node is 1 then we have RL imbalance
- else we have RR imbalance

if we have LL imbalance in current node,
```
A = node
B = node.left
C = node.left.left

temp_a = A
temp_b_right = B.right

A = B
B.right = temp_a
temp_a.left = temp_b_right

node = B
```

if we have LR imbalance in the current node,
```
A = current node
B = node.left
C = node.left.right

temp_a = A
c_left = C.left
c_right = C.right

C = A
C.left = B
C.right = temp_a
temp_a.left = c_right
B.right = c_left
```

if we have RR imbalance in the current node,
```
D = current node
E = D.right
F = E.right

e_left = E.left
temp_d = D

D = E
E.left = temp_d
temp_d.right = e_left
```

if we have RL imbalance in the current node,
```
D = current node
E = D.right
F = E.left

f_left = F.left
f_right = F.right
temp_d = D

D = F
F.left = temp_d
temp_d.right = f_left
F.right = E
E.left = f_right
```

In [215]:
def insert(root: BSTNode, node: BSTNode):
    if root is None:
        return node
    
    if node.key > root.key:
        root.right = insert(root.right, node)
        root.right.parent = root

    elif node.key < root.key:
        root.left = insert(root.left, node)
        root.left.parent = root

    # print(f"CHECK BALANCING FOR {root}")
    balanced = balance(root)
    return balanced

def balance(node: BSTNode):
    imbalance = get_imbalance_type(node)
    
    if imbalance:
        # print(f"found imbalance: {imbalance}")
        # print("BALANCING")
        node.display_keys()
        if imbalance == "LR":
            node = handle_lr_imbalance(node)
        elif imbalance == "LL":
            node = handle_ll_imbalance(node)
        elif imbalance == "RL":
            node = handle_rl_imbalance(node)
        else:
            node = handle_rr_imbalance(node)
        # print("BALANCED")
        # node.display_keys()
    return node

def get_imbalance_type(node):
    bf = node.calculate_balance_factor()

    if bf > 1:
        bf_left = node.left.calculate_balance_factor()

        if bf_left == -1:
            return "LR"
        else:
            return "LL"
    elif bf < -1:
        bf_right = node.right.calculate_balance_factor()

        if bf_right == -1:
            return "RR"
        else:
            return "RL"
    return None

def handle_lr_imbalance(node: BSTNode):
    A = node
    B = node.left
    C = node.left.right
    
    temp_a = A
    c_left = C.left
    c_right = C.right

    A = C
    C.left = B
    C.right = temp_a
    temp_a.left = c_right
    B.right = c_left

    node = C
    
    return node

def handle_ll_imbalance(node: BSTNode):

    A = node
    B = node.left
    C = node.left.left
    
    temp_a = A
    temp_b_right = B.right
    
    A = B
    B.right = temp_a
    temp_a.left = temp_b_right
    
    node = B
    
    return node

def handle_rl_imbalance(node: BSTNode):
    D = node
    E = node.right
    F = node.right.left

    temp_d = D
    f_left = F.left
    f_right = F.right

    D = F
    F.left = temp_d
    F.right = E
    temp_d.right = f_left
    E.left = f_right
    
    node = F
    
    return node

def handle_rr_imbalance(node: BSTNode):
    D = node
    E = node.right
    F = node.right.right

    temp_d = D
    e_left = E.left

    D = E
    E.left = temp_d
    temp_d.right = e_left
    
    node = D
    
    return node




In [219]:
for i in range(len(tests)):
    print("\n")
    test = tests[i]
    print(f"TEST CASE #{i}")
    print("="*len(f"TEST CASE #{i}"))
    print("\n")
    print("INPUT")
    test["input"]["root"].display_keys()
    print(f"insert {str(test['input']['node'])}")
    print("\n")
    print("OUTPUT")
    result = insert(**test["input"])
    result.display_keys()
    print("\n")
    print("TEST STATUS")
    print(result == test["output"])



TEST CASE #0


INPUT
	∅
20
	10
insert BSTNode(key = 5, value = None)


OUTPUT
	∅
20
		∅
	10
		5
	20
10
	5


TEST STATUS
True


TEST CASE #1


INPUT
	∅
5
	10
insert BSTNode(key = 15, value = None)


OUTPUT
	15
5
	10


TEST STATUS
False


TEST CASE #2


INPUT
	30
20
	∅
insert BSTNode(key = 40, value = None)


OUTPUT
		40
	30
		∅
20
	∅
	40
30
	20


TEST STATUS
True


TEST CASE #3


INPUT
	30
20
	∅
insert BSTNode(key = 25, value = None)


OUTPUT
		∅
	30
		25
20
	∅
	30
25
	20


TEST STATUS
True


TEST CASE #4


INPUT
		16
	15
		∅
13
		11
	10
			8
		5
			4
insert BSTNode(key = 3, value = None)


OUTPUT
	11
10
		8
	5
			∅
		4
			3
		16
	15
		∅
13
			11
		10
			8
	5
			∅
		4
			3


TEST STATUS
True
