### Balance a Binary Search Tree

In [48]:
class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None
    
    @staticmethod
    def parse_tuple(tup):
        if tup is None:
            node = None
        elif (isinstance(tup, tuple)) and (len(tup) == 3):
            node = TreeNode(tup[1])
            node.left = TreeNode.parse_tuple(tup[0])
            node.right = TreeNode.parse_tuple(tup[2])
        else:
            node = TreeNode(tup)
        return node
        
    def to_tuple(self):
        if self is None:
            return None
        elif self.left is None and self.right is None:
            return self.val
        else:
            return TreeNode.to_tuple(self.left), self.val, TreeNode.to_tuple(self.right)
        

    def __str__(self):
        return "BinaryTree <{}>".format(self.to_tuple())
    
    def __repr__(self):
        return "BinaryTree <{}>".format(self.to_tuple())
    
    def display(self, space='\t', level=0):
        if self is None:
            print((space*level)+'∅')
            return
        if self.left is None and self.right is None:
            print((space*level)+str(self.val))
            return
        
        TreeNode.display(self.right, space, level+1)
        print((space*level)+str(self.val))
        TreeNode.display(self.left, space, level+1)
        
        
    def height(self):
        return 0 if not self else 1+max(TreeNode.height(self.left), TreeNode.height(self.right))
    
    def size(self):
        return 0 if not self else 1+TreeNode.size(self.left) + TreeNode.size(self.right)
    
    def traverse_in_order(self):
        if self is None:
            return []
        return (TreeNode.traverse_in_order(self.left) + [self.val] + TreeNode.traverse_in_order(self.right))
    
    def traverse_pre_order(self):
        if self is None:
            return []
        return ([self.val] + TreeNode.traverse_in_order(self.left) + TreeNode.traverse_in_order(self.right))
    
    def traverse_post_order(self):
        if self is None:
            return []
        return (TreeNode.traverse_in_order(self.left) + TreeNode.traverse_in_order(self.right) + [self.val])
    


#### Theoretically first we need to check that if left and right subtree are balanced or not. If not then we suppose shuffle the positions. In this case of imbalanced tree we need re-position and one of the way of doing that is by getting it in ascending order and make the middle element as the root and then keep adding things

In [104]:
def balancedBST(node):
    
    def isBalanced(root):
        if root is None:
            return True, 0
        isBalancedLeft, heightLeft = isBalanced(root.left)
        isBalancedRight, heightRight = isBalanced(root.right)

        Balanced = isBalancedLeft and isBalancedRight and (abs(heightLeft - heightRight) <= 1)
        height = 1 + max(heightLeft, heightRight)
        return Balanced, height
    
    def inorder_traversal(node):
        if node is None:
            return []
        return inorder_traversal(node.left) + [node.val] + inorder_traversal(node.right)
    
    def make_balance_bst(data, lo = 0, hi=None):
        if hi is None:
            hi = len(data)-1
        if lo>hi:
            return None
        mid = (lo + hi) // 2
        root = TreeNode(data[mid])
        root.left = make_balance_bst(data, lo, mid-1)
        root.right = make_balance_bst(data, mid+1, hi)

        return root
       
        
        
    if node is None:
        return
    if not isBalanced(node)[0]:
        # Action need to be taken to balance the BST
        inordered = inorder_traversal(node)
        return make_balance_bst(inordered)
    return node

In [105]:
data = (None, 1,(None, 3, (None, 5, None)))
data = TreeNode.parse_tuple(data)
data.display()
print('\n Balanced BST \n')
balancedBST(data).display()

		5
	3
		∅
1
	∅

 Balanced BST 

	5
3
	1


In [106]:
data = (((None, 0, None), 1, (None, 2, 3)), 4, (5, 6, (None, 7, 8)))

data = TreeNode.parse_tuple(data)

data.display()

print('\n Balanced BST \n')
result = balancedBST(data)
result.display()

			8
		7
			∅
	6
		5
4
			3
		2
			∅
	1
		0

 Balanced BST 

			8
		7
			∅
	6
		5
4
			3
		2
			∅
	1
		0


In [None]:
[1,null,15,14,17,7,null,null,null,2,12,null,3,9,null,null,null,null,11]

In [12]:
# To check if it is balanced or not?

def isBalanced(root):
    if root is None:
        return True, 0
    isBalancedLeft, heightLeft = isBalanced(root.left)
    isBalancedRight, heightRight = isBalanced(root.right)
    
    Balanced = isBalancedLeft and isBalancedRight and (abs(heightLeft - heightRight) <= 1)
    height = 1 + max(heightLeft, heightRight)
    return Balanced, height

In [13]:
isBalanced(data)

(True, 4)

In [96]:
def make_balance(data, lo = 0, hi=None):
    if hi is None:
        hi = len(data)-1
    if lo>hi:
        return None
    mid = (lo + hi) //2
    root = TreeNode(data[mid])
    root.left = make_balance(data, lo, mid-1)
    root.right = make_balance(data, mid+1, hi)
    
    return root

In [97]:
def make_balanced_bst(data, lo=0, hi=None):
    if hi is None:
        hi = len(data) - 1
    if lo > hi:
        return None
    
    mid = (lo + hi) // 2

    root = TreeNode(data[mid])
    root.left = make_balanced_bst(data, lo, mid-1)
    root.right = make_balanced_bst(data, mid+1, hi)
    
    return root

In [98]:
data = (((None, 0, None), 1, (None, 2, 3)), 4, (5, 6, (None, 7, 8)))

data = TreeNode.parse_tuple(data)

data.display()

			8
		7
			∅
	6
		5
4
			3
		2
			∅
	1
		0


In [99]:
def inorder_traversal(node):
    if node is None:
        return []
    return inorder_traversal(node.left) + [node.val] + inorder_traversal(node.right)


In [100]:
data = inorder_traversal(data)

In [101]:
data

[0, 1, 2, 3, 4, 5, 6, 7, 8]

In [102]:
make_balanced_bst(data)

BinaryTree <((0, 1, (None, 2, 3)), 4, (5, 6, (None, 7, 8)))>

In [103]:
make_balance(data)

BinaryTree <((0, 1, (None, 2, 3)), 4, (5, 6, (None, 7, 8)))>