### Binary Search Tree to Greater Sum Tree

In [1]:
class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None
    
    @staticmethod
    def parse_tuple(tup):
        if tup is None:
            node = None
        elif (isinstance(tup, tuple)) and (len(tup) == 3):
            node = TreeNode(tup[1])
            node.left = TreeNode.parse_tuple(tup[0])
            node.right = TreeNode.parse_tuple(tup[2])
        else:
            node = TreeNode(tup)
        return node
        
    def to_tuple(self):
        if self is None:
            return None
        elif self.left is None and self.right is None:
            return self.val
        else:
            return TreeNode.to_tuple(self.left), self.val, TreeNode.to_tuple(self.right)
        

    def __str__(self):
        return "BinaryTree <{}>".format(self.to_tuple())
    
    def __repr__(self):
        return "BinaryTree <{}>".format(self.to_tuple())
    
    def display(self, space='\t', level=0):
        if self is None:
            print((space*level)+'∅')
            return
        if self.left is None and self.right is None:
            print((space*level)+str(self.val))
            return
        
        TreeNode.display(self.right, space, level+1)
        print((space*level)+str(self.val))
        TreeNode.display(self.left, space, level+1)
        
        
    def height(self):
        return 0 if not self else 1+max(TreeNode.height(self.left), TreeNode.height(self.right))
    
    def size(self):
        return 0 if not self else 1+TreeNode.size(self.left) + TreeNode.size(self.right)
    
    def traverse_in_order(self):
        if self is None:
            return []
        return (TreeNode.traverse_in_order(self.left) + [self.val] + TreeNode.traverse_in_order(self.right))
    
    def traverse_pre_order(self):
        if self is None:
            return []
        return ([self.val] + TreeNode.traverse_in_order(self.left) + TreeNode.traverse_in_order(self.right))
    
    def traverse_post_order(self):
        if self is None:
            return []
        return (TreeNode.traverse_in_order(self.left) + TreeNode.traverse_in_order(self.right) + [self.val])
    


In [102]:
def bstToGst(node):
    def getGST(node, sumSoFar):
        if node is None:
            return
        if node.right:
            sumSoFar = getGST(node.right, sumSoFar)
        node.val = node.val + sumSoFar
        sumSoFar = node.val
        if node.left:
            return getGST(node.left, sumSoFar)
        return node.val
    getGST(node, 0)
    return node

In [103]:
data = (1,3,5)
data = TreeNode.parse_tuple(data)
data.display()
print('\n GST of BST \n')
result = bstToGst(data)
result.display()

	5
3
	1

 GST of BST 

	5
8
	9


In [104]:
data = (((None, 0, None), 1, (None, 2, 3)), 4, (5, 6, (None, 7, 8)))

data = TreeNode.parse_tuple(data)

data.display()

print('\n GST of BST \n')
result = bstToGst(data)
result.display()

			8
		7
			∅
	6
		5
4
			3
		2
			∅
	1
		0

 GST of BST 

			8
		15
			∅
	21
		26
30
			33
		35
			∅
	36
		36


In [None]:
bstToGst(data)

			8
		7
			∅
	6
		5
4
			3
		2
			∅
	1
		0
