In [None]:
class Node:
    """
    Node for BST
    """
    def __init__(self, 
                 key:float, 
                 left=None, 
                 right=None):
        self.key = key
        self.left = left
        self.right = right
        
    def __repr__(self):
        return str(self.__dict__)
        
        
class BST:
    """
    BST Class
    """
    def __init__(self, root, nodelist):
        self.root = root
        self.nodelist = nodelist
        
    @staticmethod
    def build_from_concise_list(l):
        nodelist = [Node(el) for el in l]
        idx1 = 0
        idx2 = 1
        ln = len(l)
        if idx2 < ln:
            end = False
        else:
            nodelist = [Node(l[0])]
            end = True
        while not end:
            if nodelist[idx1].key is None:
                pass
            else:
                if nodelist[idx2].key != None:
                    nodelist[idx1].left = nodelist[idx2]
                if nodelist[idx2+1].key != None:
                    nodelist[idx1].right = nodelist[idx2+1]
                idx2 += 2
            idx1 += 1
            if idx2 >= ln:
                end = True
        nodelist = [nd 
                    for nd in nodelist
                    if nd.key != None]
        return BST(nodelist[0], nodelist)

In [None]:
import functools
import time

def timing(f):
    @functools.wraps(f)
    def wrap(*args, **kwargs):
        t0 = time.perf_counter()
        val = f(*args, **kwargs)
        t1 = time.perf_counter()
        print(f"Time elapsed for {f.__name__}: "
              "{t1-t0:5.2f}")
        return val
    return wrap

In [None]:
#nan = float("nan")
l = [10, 5, 3, None, None, None, 4, 1, 2]
#l = [1, 2, 4]
#l = [1]

In [None]:
bt = BST.build_from_concise_list(l)

In [None]:
bt.nodelist

In [None]:
# Dynamic Programming solution

#memory = np.empty((n, n))
# previously visited = []
# next batch = []
#begin with root, memory[root][root] = (root.val, 0)
#get nearest neighbors
#for nearest_neighbors:
#    memory[root][neigh] = (root.val + neigh.val, 1)
#    memory[neigh][root] = (root.val + neigh.val, 1)
#then next nearest neighbors
#   memory[nneigh][neigh] = ( ..., 1)
#   for previously_visited:
#         memory[nneigh][root] = memory[neigh][root] + memory[nneigh][neigh]
#is this matrix multiplication? or require extra loop to take minimum adjacency
#this is fine adding one at a time to loop through non-null entries and take minimum over second entry
#next level

def main(root):
    n = size(tree)
    memory = [[None for i in range(n)] for j in range(n)]
    previous = [root]
    next_batch, memory = get_next_batch(previous)
    while len(next_batch) > 0:
        for i, nxt in enumerate(next_batch):
            idx = i + len(previous)
            for j, node in enumerate(previous):
                if memory[idx][j] is None:
                    memory[idx][j] = min([
                        (memory[idx][k][0] + memory[k][j][0] - previous[k].val,
                            memory[idx][k][1] + memory[k][j][1])
                         for k in range(n)
                        if ((memory[idx][k] is not None) 
                            and (memory[k][j] is not None))],
                        key=lambda x: x[1])
                    memory[j][idx] = memory[idx][j]
        previous = previous + next_batch
        next_batch = get_next_batch(previous)
    return min([min(memory[i], key=lambda x: x[0]) for i in range(n)],
              key=lambda x: x[0])
    
def get_next_batch(previous, memory):
    ret = []
    for i, node in enumerate(previous):
        lf = node.left
        if (lf is not None) and (lf not in previous) and (lf not in ret):
            ret.append(lf)
            idx = len(previous) + len(ret) - 1
            memory[i][idx] = (node.val + lf.val, 1)
            memory[idx][i] = (node.val + lf.val, 1)
        rt = node.right
        if (rt is not None) and (rt not in previous) and (rt not in ret):
            ret.append(rt)
            idx = len(previous) + len(ret) - 1
            memory[i][idx] = (node.val + rt.val, 1)
            memory[idx][i] = (node.val + rt.val, 1)            
    return ret, memory

In [None]:
def _f1(node, prev_val):
    new_val = prev_val + node.key
    branches = [(None, new_val)]
    if node.left is not None:
        branches.append((node.left, new_val))
    if node.right is not None:
        branches.append((node.right, new_val))
    return branches


def _f2(root):
    lens = []
    tape = []
    tape.extend(_f1(root, 0))
    while len(tape) > 0:
        node, val = tape.pop()
        if node is None:
            lens.append(val)
        else:
            tape.extend(_f1(node, val))
    if len(lens) > 0:
        retval = max(lens)
    else:
        retval = None
    return retval
    
                        
def parse_tree_from_root(root):
    lens = []
    if root.left is not None:
        x = _f2(root.left)
        if x is not None:
            lens.append(x + root.key)
    if root.right is not None:
        x = _f2(bt.root.right)
        if x is not None:
            lens.append(x + root.key)
    if len(lens) == 2:
        lens.append(lens[0] + lens[1] - root.key)
    retval = max(lens) if len(lens) > 0 else None
    return retval
    
    
def find_max_length(bt):
    lens = []
    tape = [bt.root]
    pointer = 0
    while pointer < len(tape):
        root = tape[pointer]
        if root is not None:
            x = parse_tree_from_root(root)
            if x is not None:
                lens.append(x)
            lf = root.left
            if lf is not None:
                tape.append(lf)
            rf = root.right
            if rf is not None:
                tape.append(rf)
        else:
            pass
        pointer += 1
    return max(lens) if len(lens) > 0 else None

In [None]:
find_max_length(bt)