# 6.0 Trees

Add some common functions for representing binary trees and converting between binary trees and python lists.

In [1]:
import collections


class TreeNode(object):
    """Node in a binary tree."""
    
    def __init__(self, data, left=None, right=None):
        self.data = data
        self.left = left
        self.right = right


def make_binary_tree(elems):
    """Return a binary tree initialized from a python list."""
    
    # Since this is a binary tree, there is no ordering wrt value.
    # Values are inserted in level-order as they appear in the
    # python list; starting at root and growing downward adding
    # siblings from left-to-right.
    root, queue = None, collections.deque()
    for x in elems:
        if len(queue) < 1:
            root = TreeNode(x)
            queue.append(root)
            continue
        node = queue[0]  # Access, but do not remove.
        if node.left is None:
            node.left = TreeNode(x)
            queue.append(node.left)
        elif node.right is None:
            node.right = TreeNode(x)
            queue.append(node.right)
            queue.popleft()  # Remove node from queue.
    
    return root


def make_list(root):
    """Return a python list initialized from binary tree."""
    
    # Since this is a binary tree, there is no ordering wrt value.
    # Values are inserted into python list in the order they appear
    # during a level order traversal of the binary tree.
    elems, queue = [], collections.deque([root])
    while len(queue) > 0:
        node = queue.popleft()
        if node:
            elems.append(node.data)
            if node.left:
                queue.append(node.left)
            if node.right:
                queue.append(node.right)
    
    return elems

## 6.1 Count unival trees

### Problem Statement
Count the number of universal nodes in a binary tree.  A node is considered universal when any of the following is true:
* Node is a leaf node.
* All nodes below the current node have the same value.

#### Examples
Unival count: `5`
```
               0
       1                0
                    1       0
                1       1
```

Unival count: `4`
```
               a
       a                a
                    a       a
                               b
                            b
```

Unival count: `2`
```
               a
       b                c
                    d
```

In [2]:
import collections
import functools
import unittest


def is_universal(univals, node):
    """Return True when a node is a universal node."""
    is_leaf = node.left is None and node.right is None
    lunival, runival = True, True
    if node.left:
        lunival = node.left in univals and node.left.data == node.data
    if node.right:
        runival = node.right in univals and node.right.data == node.data
    return is_leaf or (lunival and runival)


def count_unival_nodes_visitor(univals, node):
    """Visitor function that counts universal nodes."""
    if node.left:
        count_unival_nodes_visitor(univals, node.left)
    if node.right:
        count_unival_nodes_visitor(univals, node.right)
    if is_universal(univals, node):
        univals.add(node)


def count_unival_nodes(root):
    """Count the number of universal nodes in a binary tree."""
    
    # Accumulate unival nodes during post-order traversal.
    univals = set()
    visitor_func = functools.partial(count_unival_nodes_visitor, 
                                     univals)
    visitor_func(root)
        
    return len(univals)


class CountUnivalNodesTest(unittest.TestCase):
    
    def setUp(self):
        # Create treeA with unival count of 5.
        self.treeA_expected = 5
        self.treeA = TreeNode(0)
        self.treeA.left = TreeNode(1)
        self.treeA.right = TreeNode(0)
        self.treeA.right.left = TreeNode(1)
        self.treeA.right.right = TreeNode(0)
        self.treeA.right.left.left = TreeNode(1)
        self.treeA.right.left.right = TreeNode(1)
        # Create treeB with unival count of 4.
        self.treeB_expected = 4
        self.treeB = TreeNode('a')
        self.treeB.left = TreeNode('a')
        self.treeB.right = TreeNode('a')
        self.treeB.right.left = TreeNode('a')
        self.treeB.right.right = TreeNode('a')
        self.treeB.right.right.right = TreeNode('b')
        self.treeB.right.right.right.left = TreeNode('b')
        # Create treeC with unival count of 2.
        self.treeC_expected = 2
        self.treeC = TreeNode('a')
        self.treeC.left = TreeNode('b')
        self.treeC.right = TreeNode('c')
        self.treeC.right.left = TreeNode('d')

    def test_count_unival_nodes(self):
        case = collections.namedtuple('case', ['input','expected'])
        cases = [
            case(self.treeA, self.treeA_expected),
            case(self.treeB, self.treeB_expected),
            case(self.treeC, self.treeC_expected),
        ]
        for c in cases:
            rcv = count_unival_nodes(c.input)
            self.assertEqual(rcv, c.expected)


unittest.main(CountUnivalNodesTest(), argv=[''], verbosity=2, exit=False)

test_count_unival_nodes (__main__.CountUnivalNodesTest) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.003s

OK


<unittest.main.TestProgram at 0x7f9794498cc0>

## 6.2 Reconstruct tree from traversals

### Problem Statement
Reconstruct the binary tree from a list of nodes returned from the pre-order and in-order traversals.

In [3]:
import collections
import unittest


def make_tree_from_traversals(preorder, inorder):
    """Reconstruct a binary tree from preorder and inorder traversals."""
    if len(preorder) < 1:
        return None
    # Root node is the first element in the preorder traversal.
    # All root's left children are left of root in inorder traversal.
    # All root's right children are right of root in inorder traversal.
    root = TreeNode(preorder[0])
    # Root's left child is right of root in preorder traversal.
    left_preorder_ind = 1
    # Root's right child in preorder traversal is found by adding
    # the number of nodes in the left subtree to the left child index.
    root_inorder_ind = inorder.index(root.data)  # Count in left subtree.
    right_preorder_ind = left_preorder_ind + root_inorder_ind
    root.left = make_tree_from_traversals(
        preorder[left_preorder_ind:right_preorder_ind],
        inorder[:root_inorder_ind])
    root.right = make_tree_from_traversals(
        preorder[right_preorder_ind:],
        inorder[root_inorder_ind+1:])
    return root


class MakeTreeFromTraversalsTest(unittest.TestCase):
    
    def setUp(self):
        # Create treeA.
        self.treeA_preorder = ['a','b','d','e','c','f','g']
        self.treeA_inorder = ['d','b','e','a','f','c','g']
        self.treeA = TreeNode('a')
        self.treeA.left = TreeNode('b')
        self.treeA.left.left = TreeNode('d')
        self.treeA.left.right = TreeNode('e')
        self.treeA.right = TreeNode('c')
        self.treeA.right.left = TreeNode('f')
        self.treeA.right.right = TreeNode('g')
        # Create treeB.
        self.treeB_preorder = [1,2,4,3,5]
        self.treeB_inorder = [4,2,1,3,5]
        self.treeB = TreeNode(1)
        self.treeB.left = TreeNode(2)
        self.treeB.left.left = TreeNode(4)
        self.treeB.right = TreeNode(3)
        self.treeB.right.right = TreeNode(5)
        # Create treeC.
        self.treeC_preorder = [1,2,3,4,5,7,6]
        self.treeC_inorder = [2,3,1,5,7,4,6]
        self.treeC = TreeNode(1)
        self.treeC.left = TreeNode(2)
        self.treeC.left.right = TreeNode(3)
        self.treeC.right = TreeNode(4)
        self.treeC.right.left = TreeNode(5)
        self.treeC.right.left.right = TreeNode(7)
        self.treeC.right.right = TreeNode(6)

    def test_make_tree_from_traversals(self):
        case = collections.namedtuple('case', 
                                      ['preorder','inorder','expected'])
        cases = [
            case(self.treeA_preorder, self.treeA_inorder, self.treeA),
            case(self.treeB_preorder, self.treeB_inorder, self.treeB),
            case(self.treeC_preorder, self.treeC_inorder, self.treeC),
        ]
        for c in cases:
            rcv = make_tree_from_traversals(c.preorder, c.inorder)            
            self.assertEqual(make_list(rcv), make_list(c.expected))


unittest.main(MakeTreeFromTraversalsTest(), argv=[''], verbosity=2, 
              exit=False)

test_make_tree_from_traversals (__main__.MakeTreeFromTraversalsTest) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.004s

OK


<unittest.main.TestProgram at 0x7f9794455278>

## 6.3 Evaluate arithmetic tree

### Problem Statement
Implement a function for evaluating a [binary expression tree](https://en.wikipedia.org/wiki/Binary_expression_tree).

binary expression tree (def.)
> Leaf nodes in a binary expression tree consist of one or more operands and non-leaf nodes consist of operators.

In [4]:
import collections
import unittest


def evaluate_expression_tree(node):
    """Evaluate the binary expression tree."""
    if node.left is None and node.right is None:
        return node.data  # Operand.
    # Recursively evaluate the subtrees via post-order traversal.
    lhs = evaluate_expression_tree(node.left) if node.left else 0
    rhs = evaluate_expression_tree(node.right) if node.right else 0
    if node.data == '+':
        return lhs + rhs
    elif node.data == '-':
        return lhs - rhs
    elif node.data == '*':
        return lhs * rhs
    elif node.data == '/':
        return lhs / rhs
    else:
        raise ValueError('invalid operator: {0}'.format(node.data))


class EvaluateExpressionTreeTest(unittest.TestCase):
    
    def setUp(self):
        # Create treeA.
        self.treeA_expected = 13
        self.treeA = TreeNode('+')
        self.treeA.left = TreeNode('*')
        self.treeA.left.left = TreeNode('+')
        self.treeA.left.left.left = TreeNode(1)
        self.treeA.left.left.right = TreeNode(2)
        self.treeA.left.right = TreeNode(3)
        self.treeA.right = TreeNode(4)
        # Create treeB.
        self.treeB_expected = 45
        self.treeB = TreeNode('*')
        self.treeB.left = TreeNode('+')
        self.treeB.left.left = TreeNode(3)
        self.treeB.left.right = TreeNode(2)
        self.treeB.right = TreeNode('+')
        self.treeB.right.left = TreeNode(4)
        self.treeB.right.right = TreeNode(5)

    def test_evaluate_expression_tree(self):
        case = collections.namedtuple('case', ['input','expected'])
        cases = [
            case(self.treeA, self.treeA_expected),
            case(self.treeB, self.treeB_expected),
        ]
        for c in cases:
            rcv = evaluate_expression_tree(c.input)
            self.assertEqual(rcv, c.expected)


unittest.main(EvaluateExpressionTreeTest(), argv=[''], verbosity=2, 
              exit=False)

test_evaluate_expression_tree (__main__.EvaluateExpressionTreeTest) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.004s

OK


<unittest.main.TestProgram at 0x7f9794534828>

## 6.4 Get tree level with minimum sum

### Problem Statement
Given a binary tree of integers, find the level in the tree having the minimum sum and return both level and sum.

#### Examples
Minimum sum is `1` at level `2`.
```
			3
	-3				7
-10		-2		5		8
```

In [5]:
import collections
import unittest


def level_with_minimum_sum(root):
    """Return the level in the tree having the minimum sum."""
    assert not(root is None), 'invalid: root is None'
    
    # Perform a level order traversal accumulating the sum of
    # values at each level.  Recognize transitions between levels
    # by recording the level at the time a node inserted into queue.
    best, level, minsum = None, 0, 0
    queue = collections.deque([(root, level)])
    while len(queue) > 0:
        (node, nodelevel) = queue.popleft()
        if nodelevel != level:  # Level transition.
            best = (level, minsum) \
                if best is None or minsum < best[1] else best
            level, minsum = nodelevel, node.data
        else:
            minsum += node.data
        if node.left:
            queue.append((node.left, level+1))
        if node.right:
            queue.append((node.right, level+1))

    # Handle transition for the last node in the tree.
    best = (level, minsum) if best is None or minsum < best[1] else best

    return best


class LevelWithMinimumSumTest(unittest.TestCase):

    def test_solution(self):
        case = collections.namedtuple('case', ['input','expected'])
        cases = [
            # Edge case: one-node tree.
            case([-10], (0,-10)),
            # 3-level tree, level 0 is minimum.
            case([1,-4,7,-6,-2,5,8], (0,1)),
            # 3-level tree, level 1 is minimum.
            case([3,-6,7,-7,-2,5,8], (1,1)),
            # 3-level tree, level 2 is minimum.
            case([3,-3,7,-10,-2,5,8], (2,1)),
            # 4-level tree, last level is partial and minimum.
            case([3,-3,7,-10,-2,5,8,-20], (3,-20)),
        ]
        for c in cases:
            tree = make_binary_tree(c.input)
            rcv = level_with_minimum_sum(tree)
            self.assertEqual(rcv, c.expected)


unittest.main(LevelWithMinimumSumTest(), argv=[''], verbosity=2, 
              exit=False)

test_solution (__main__.LevelWithMinimumSumTest) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.005s

OK


<unittest.main.TestProgram at 0x7f979445a0b8>