In [7]:
"""
https://leetcode.com/problems/binary-tree-maximum-path-sum/

A path in a binary tree is a sequence of nodes where each pair of adjacent nodes in the sequence has an edge connecting them. 
A node can only appear in the sequence at most once.
Note that the path does not need to pass through the root.

The path sum of a path is the sum of the node's values in the path.

Given the root of a binary tree, return the maximum path sum of any non-empty path.

Constraints:
The number of nodes in the tree is in the range [1, 3 * 104].
-1000 <= Node.val <= 1000
"""

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right


def readTree(vals):
    if len(vals)==0:
        return None
    nodeMap = {}
    for i in range(len(vals)-1, -1, -1):
        if vals[i] == None:
            continue
        li = i * 2 + 1
        ri = i * 2 + 2
        nodeMap[i]=TreeNode(val=vals[i],
                            left=nodeMap[li] if li in nodeMap else None,
                            right=nodeMap[ri] if ri in nodeMap else None)
    return nodeMap[0]


def maxpathsum(root):
    # recursive relation
    # at a node, return:
    #     maxpath: must include curnode and either left or right path
    #           max( 0,
    #                node.val,
    #                leftpath + node.val,
    #                rightpath + node.val)
    #     maxval:
    #     max(
    #           0,
    #           node.val
    #           leftpath,
    #           rightpath,
    #           leftpath + node.val + rightpath,
    #           leftmax
    #           rightmax
    #     )
    #     leftpath includes left node in the path
    #     leftmax is any path from left side that has the max val

    # at root, get maxpath and maxval, then return larger of the two
    if root == None:
        return 0

    def _maxsum(node):
        if node == None:
            return 0, 0

        leftpath, leftmax = _maxsum(node.left)
        rightpath, rightmax = _maxsum(node.right)

        maxpath = max(0,
                      node.val,
                      leftpath + node.val,
                      rightpath + node.val)
        maxval = max(0,
                     node.val,
                     leftpath,
                     rightpath,
                     leftpath + node.val + rightpath,
                     leftmax,
                     rightmax)
                     
        return maxpath, maxval
    
    _, maxval = _maxsum(root)
    return maxval

tests = [
([1,2,3], 6),
([-10,9,20,None,None,15,7], 42),
]
for t in tests:
    tree = readTree(t[0])
    assert(maxpathsum(tree) == t[1])