# 11.0 Advanced Data Structures

## 11.1 Binary Indexed Tree aka Fenwick Tree

### Problem Statement
A [binary indexed tree](https://en.wikipedia.org/wiki/Fenwick_tree) aka Fenwick Tree is used to efficiently compute prefix sums over ranges of elements equivalent to bins in a histogram.  The structure supports the `add` and `cumsum` operations using $O(\log n)$ time and no additional space.

Implement a binary indexed tree data structure with the following interface:
* Constructor which takes the total number of bins.
* `add(t, x)` adds the value $x$ to bin at $t$.
* `cumsum(t0, t1)` returns the cumulative sum of the values in bins $[t0, t1]$ where $t0 \leq t1$.

The operations should satisfy the time and space requirements described above.

In [1]:
import unittest


class BinaryIndexedTree(object):
    """A binary indexed tree for computing prefix sums."""
    
    def __init__(self, nbins):
        self.bins = [0]*(nbins+1)  # Elements indexed starting from 1.

    def add(self, t, x):
        """Add the value $x$ to bin at $t$."""
        assert t <= len(self.bins), 'invalid: t > len(self.bins)'
        
        # Increment the value in all bins from t until end of tree
        # by increasing t to the next power of 2 at each iteration.
        #
        # For example, assume t = 3d.
        # Increment bin at index 3 by x.
        #
        # Increase t by the next power of 2 from 3d to 4d.
        # Increment bin at index 4 by x.
        #
        # Increase t by the next power of 2 from 4d to 8d.
        # Increment bin at index 8 by x.
        #
        # Stop when index is larger than largest bin.
        ind = t
        while ind < len(self.bins):
            self.bins[ind] += x
            ind += ind & (-ind)  # Set to next higher power of 2.

    def cumsum(self, t0, t1):
        """Return the cumulative sum of values in bins $[t0, t1]$."""
        assert t0 <= t1, 'invalid: t0 > t1'
        assert t1 <= len(self.bins), 'invalid: t1 > len(self.bins)'

        # Compose prefix sum starting from a nonzero bin as
        # the difference of 2 prefix sums starting from 0. 
        if t0 != 1:
            return self.cumsum(1, t1) - self.cumsum(1, t0-1)
        
        # Accumulate the sum of values from bins whose indices are
        # computed by subtracting the lowest nonzero power of 2.
        #
        # For example, assume t1 = 9d.
        # Start with value at index 9d (9d = 2^3 + 2^0 = 101).
        #
        # Subtract lowest nonzero power of 2 from 9d yields 8d.
        # Add the value at index 8d (8d = 2^3 = 100).
        #
        # Subtract lowest nonzero power of 2 from 8d yields 0d.
        # Index 0d is not a valid bin and ends the traversal.
        result, ind = 0, t1
        while ind != 0:
            result += self.bins[ind]
            ind = ind & (ind-1)  # Subtract lowest nonzero power of 2.
        return result


class BinaryIndexedTreeTest(unittest.TestCase):
    
    def test_binary_indexed_tree(self):
        nbins = 14
        bit = BinaryIndexedTree(nbins)
        
        # Initialize the tree.
        ops = [(1,1),(2,7),(3,3),(4,0),(5,5),(6,8),(7,3),
               (8,2),(9,6),(10,2),(11,1),(12,1),(13,4),(14,5)]
        for t, x in ops:
            bit.add(t, x)
        
        # Compare expected sum starting from 1 at each value of t. 
        init_sums = {1:1, 2:8, 3:11, 4:11, 5:16, 6:24, 7:27,
                     8:29, 9:35, 10:37, 11:38, 12:39, 13:43, 14:48}
        for t1, expected in init_sums.items():
            rcv = bit.cumsum(1, t1)
            self.assertEqual(rcv, expected)

        # Compare some ranges not starting from 1.
        init_range_sums = [(5,7,16),(3,4,3),(2,14,47),(9,9,6)]
        for t0, t1, expected in init_range_sums:
            rcv = bit.cumsum(t0, t1)
            self.assertEqual(rcv, expected)

        # Add 1 to each bin and verify the sums are updated.
        for t in sorted(init_sums):
            bit.add(t, 1)
            # Update each expected sum at t_i by 1 where i >= t.
            for ti in range(t, nbins+1):
                init_sums[ti] += 1
            # Compare expected sum at each value of t.
            for t1, expected in init_sums.items():
                rcv = bit.cumsum(1, t1)
                self.assertEqual(rcv, expected)


unittest.main(BinaryIndexedTreeTest(), argv=[''], verbosity=2, 
              exit=False)

test_binary_indexed_tree (__main__.BinaryIndexedTreeTest) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.002s

OK


<unittest.main.TestProgram at 0x7f7c712d2710>