In [5]:
"""
https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree/


Given a binary tree, find the lowest common ancestor (LCA) of two given nodes in the tree.

According to the definition of LCA on Wikipedia: 
“The lowest common ancestor is defined between two nodes p and q as the lowest node in T 
that has both p and q as descendants (where we allow a node to be a descendant of itself).”


Constraints:
The number of nodes in the tree is in the range [2, 10^5].
-10^9 <= Node.val <= 10^9
All Node.val are unique.
p != q
p and q will exist in the tree.
"""

class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None


def readTree(vals):
    nodeMap={}
    # we read from the back so that heads of nodes 
    # are created at a later time
    for i in range(len(vals)-1, -1, -1):
        if vals[i] == None:
            continue
        nodeMap[i] = TreeNode(vals[i])
        li, ri = i * 2 + 1, i * 2 + 2
        if li in nodeMap:
            nodeMap[i].left = nodeMap[li]
        if ri in nodeMap:
            nodeMap[i].right = nodeMap[ri]
    return nodeMap[0]


def printTree(root):
    def _print(n, d):
        if n == None:
            return
        _print(n.left, d+1)
        print('\t' * d, n.val)
        _print(n.right, d+1)
    _print(root, 0)


def findpath(n, p):
    if n == None:
        return None
    q = []
    q.append((n,[n.val]))
    while len(q)>0:
        curNode, curPath = q.pop()
        if curPath[-1] == p:
            return curPath
        if curNode.left:
            q.append((curNode.left, curPath+[curNode.left.val]))
        if curNode.right:
            q.append((curNode.right, curPath+[curNode.right.val]))
    return None


def LCA(root, p, q):
    # find the path from root to p
    #     3 -> 5
    # find the path from root to q
    #     3 -> 5 -> 2 -> 4
    #     from back to front (4->3)
    #        check if 4 in p path -> No
    #                 2 -> No
    #                 5 -> YES
    #  p=6, q=8
    #  3 -> 5 -> 6
    #  3 -> 1 -> 8

    ppath = findpath(root, p)
    ppathSet = set(ppath)
    qpath = findpath(root, q)

    for qi in range(len(qpath)-1, -1, -1):
        if qpath[qi] in ppathSet:
            return qpath[qi]

    # this should never happen
    raise Exception('no LCA found')


tests = [
    ([3,5,1,6,2,0,8,None,None,7,4], 5, 1, 3),
    #     3
    #  5    1
    # 6 2  0 8
    #  7 4
    #Explanation: The LCA of nodes 5 and 1 is 3.
    ([3,5,1,6,2,0,8,None,None,7,4], 5, 4, 5),
    ([1,2], 1, 2, 1)
]
for t in tests:
    root = readTree(t[0])
    # printTree(root)
    assert(LCA(root, t[1], t[2]) == t[3])
