### Balance a Binary Search Tree

In [1]:
class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None
    
    @staticmethod
    def parse_tuple(tup):
        if tup is None:
            node = None
        elif (isinstance(tup, tuple)) and (len(tup) == 3):
            node = TreeNode(tup[1])
            node.left = TreeNode.parse_tuple(tup[0])
            node.right = TreeNode.parse_tuple(tup[2])
        else:
            node = TreeNode(tup)
        return node
        
    def to_tuple(self):
        if self is None:
            return None
        elif self.left is None and self.right is None:
            return self.val
        else:
            return TreeNode.to_tuple(self.left), self.val, TreeNode.to_tuple(self.right)
        

    def __str__(self):
        return "BinaryTree <{}>".format(self.to_tuple())
    
    def __repr__(self):
        return "BinaryTree <{}>".format(self.to_tuple())
    
    def display(self, space='\t', level=0):
        if self is None:
            print((space*level)+'∅')
            return
        if self.left is None and self.right is None:
            print((space*level)+str(self.val))
            return
        
        TreeNode.display(self.right, space, level+1)
        print((space*level)+str(self.val))
        TreeNode.display(self.left, space, level+1)
        
        
    def height(self):
        return 0 if not self else 1+max(TreeNode.height(self.left), TreeNode.height(self.right))
    
    def size(self):
        return 0 if not self else 1+TreeNode.size(self.left) + TreeNode.size(self.right)
    
    def traverse_in_order(self):
        if self is None:
            return []
        return (TreeNode.traverse_in_order(self.left) + [self.val] + TreeNode.traverse_in_order(self.right))
    
    def traverse_pre_order(self):
        if self is None:
            return []
        return ([self.val] + TreeNode.traverse_in_order(self.left) + TreeNode.traverse_in_order(self.right))
    
    def traverse_post_order(self):
        if self is None:
            return []
        return (TreeNode.traverse_in_order(self.left) + TreeNode.traverse_in_order(self.right) + [self.val])
    


In [93]:
def balancedBST(node):
    
    def height(root):
        if root is None:
            return 0
        return 1+max(height(root.left), height(root.right))
    
    def inorder_traversal(node):
        if node is None:
            return []
        return inorder_traversal(node.left) + [node.val] + inorder_traversal(node.right)
    
    def add_element(root, node_to_add):
        if root is None or root.val == node_to_add:
            return
        elif root.val > node_to_add:
            if root.left:
                return add_element(root.left, node_to_add)
            root.left = TreeNode(node_to_add)
        elif root.val < node_to_add:
            if root.right:
                return add_element(root.right, node_to_add)
            root.right = TreeNode(node_to_add)
       
    if node is None:
        return
    if abs(height(node.left)-height(node.right)) > 1:
        # Action need to be taken to balance the BST
        inordered = inorder_traversal(node)
        print(f'inordered: {inordered}')
        root = TreeNode(inordered[len(inordered)//2])
        for i in inordered:
            add_element(root, i)
        return root
    return node

In [94]:
data = (None, 1,(None, 3, (None, 5, None)))
data = TreeNode.parse_tuple(data)
data.display()
print('\n Balanced BST \n')

		5
	3
		∅
1
	∅

 Balanced BST 



In [95]:
balancedBST(data)

inordered: [1, 3, 5]


BinaryTree <(1, 3, 5)>

In [96]:
data = (((None, 0, None), 1, (None, 2, 3)), 4, (5, 6, (None, 7, 8)))

data = TreeNode.parse_tuple(data)

data.display()

print('\n Balanced BST \n')
result = balancedBST(data)
result.display()

			8
		7
			∅
	6
		5
4
			3
		2
			∅
	1
		0

 Balanced BST 

			8
		7
			∅
	6
		5
4
			3
		2
			∅
	1
		0


In [None]:
[1,null,15,14,17,7,null,null,null,2,12,null,3,9,null,null,null,null,11]