In [11]:
import time
import collections
import numpy as np

The goal of this problem is to implement the "Median Maintenance" algorithm (covered in the Week 3 lecture on heap applications). The text file contains a list of the integers from 1 to 10000 in unsorted order; you should treat this as a stream of numbers, arriving one by one. Letting xi denote the ith number of the file, the kth median mk is defined as the median of the numbers x1,…,xk. (So, if k is odd, then mk is ((k+1)/2)th smallest number among x1,…,xk; if k is even, then mk is the (k/2)th smallest number among x1,…,xk.)

In the box below you should type the sum of these 10000 medians, modulo 10000 (i.e., only the last 4 digits). That is, you should compute (m1+m2+m3+⋯+m10000)mod10000.

OPTIONAL EXERCISE: Compare the performance achieved by heap-based and search-tree-based implementations of the algorithm.

So let's move on to a less obvious application of heaps, which is a problem I'm going to call median maintenance. The way this is gonna work is that you and I are gonna play a little game. So on my side, what I'm going to do is I'm going to pass you index cards, one at a time, where there's a number written on each index card. Your responsibility is to tell me at each time step the median of the number that I've passed you so far. So, after I've given you the first eleven numbers you should tell me as quickly as possible the sixth smallest after I've given you thirteen numbers you should tell me the seventh smallest and so on. Moreover, we know how to compute the median in linear time but the last thing I want is for you to be doing a linear time computation every single time step. [inaudible] I only give you one new number? Do you really have to do linear time just to re-compute the median? If I just gave you one new number. So to make sure that you don't run a linear time selection algorithm every time I give you one new number, I'm going to put a budget on the amount of time that you can use on each time step to tell me the median. And it's going to be logarithmic in the number of numbers I've passed you so far. So I encourage you to pause the video at this point and spend some time thinking about how you would solve this problem. Alright, so hopefully you've thoug ht about this problem a little bit. So let me give you a hint. What if you use two heaps, do you see a good way to solve this problem then. Alright, so let me show you a solution to this problem that makes use of two heaps. The first heap we'll call H low. This equal supports extract max. Remember we discussed that a heap, you could pick whether it supports extract min or extract max. You don't get both, but you can get either one, it doesn't matter. And then we'll have another heap H high which supports extract min. And the key idea is to maintain the invariant that the smallest half of the numbers that you've seen so far are all in the low heap. And the largest half of the numbers that you've seen so far are all in the high heap. So, for example, after you've seen the first ten elements, the smallest five of them should reside in H low, and the biggest five of them should reside in H high. After you've seen twenty elements then the bottom ten, the smallest ten, should, should reside in H low, and the largest ten should reside in H high. If you've seen an odd number, either one can be bigger, it doesn't matter. So if you have 21 you have the smallest ten in the one and the biggest eleven in the other, or vice-versa. It's not, not important. Now given this key idea of splitting the elements in half, according to the two heaps. You need two realizations, which I'll leave for you to check. So first of all, you have to prove you can actually maintain this invariant with only O of log I work in step I. Second of all, you have to realize this invariant allows you to solve the desired problem. So let me just quickly talk through both of these points, and then you can think about it in more detail, on your own time. So let's start with the first one. How can we maintain this invariant, using only log I work and time step I, and this is a little tricky. So let's suppose we've already processed the first twenty numbers, and the smallest ten of them we've all worked hard to, to put only in H low. And the biggest ten of th ''em we've worked hard to put only in H high. Now, here's a preliminary observation. What's true, so what do we know about the maximum element in h low? Well these are the tenth smallest overall and the maximum then is the biggest of the tenth smallest. So that would be a tenth order statistic, so the tenth order overall. Now what about in the, the hi key? What s its minimum value? Well those are the biggest ten values. So the minimum of, of the ten biggest values would be the eleventh order statistic. Okay, so the maximum of H low is the tenth order statistic. The minimum of H high Is the [inaudible] statistic, they're right next to each other; these are in fact the two medians Right now So When this new element comes in, the twenty-first element comes in, we need to know which heap to insert it into and well it just, if it's smaller than the tenth order statistic then it's still gonna be in the bottom, then it's in the bottom half of the elements and needs to go in the low heap. If it's bigger than the eleventh order statistic, if it's bigger than the minimum value of the high heap then that's where it belongs, in the high heap. If it's wedged in between the tenth and eleventh order of statistics, it doesn't matter. We can put it in either one. This is the new median anyways. Now, we're not done yet with this first point, because there's a problem with potential imbalance. So imagine that the twenty-first element comes up and it's less than the maximum of the low heap, so we stick it in the low heap and now that has a population of eleven. And now imagine the twenty-second number comes up and that again is less than the maximum element in the low heap, so again we have to insert it in the low heap. Now we have twelve elements in the low heap, but we only have ten in the right heap. So we don't have a 50. 50, 50 split of the numbers but we could easily re-balance we just extract the max from the low heap and we insert it into the high heap. And boom. Now they both have eleven, and the low heap has the smallest el even, and the high heap has the biggest eleven. So that's how you maintain, the invariant that you have this 50/50 split in terms of the small and the high, and between the two heaps. You check Where it lies with respect to the max of the low heap and the mid of the high heap. You put it in the appropriate place. And whenever you need to do some re-balancing, you do some re-balancing. Now, this uses only a constant number of heap operations when a new number shows up. So that's log I work. So now given this discussion, it's easy to see the second point given that this invariant is true at each time step. How do we compute the median? Well, it's going to be either the maximum of the low heap and/or the minimum of the high heap depending on whether I is even or odd. If it's even, both of those are medians. If I is odd, then it's just whichever heap has one more element than the other one. 

In [2]:
class heap_min():
    #__marker = object()
    
    def __init__(self):
        self.heap = []
    
    @staticmethod
    def _parent(i):
        return (i-1)//2
    
    @staticmethod
    def _left_child(i):
        return i*2+1
    
    @staticmethod 
    def _right_child(i):
        return 2*(i+1)
    
    def clear(self):
        del self.heap[:]
        
    def __len__(self):
        return len(self.heap)
    
    def _swap(self,i,j):
        self.heap[i], self.heap[j] = self.heap[j],self.heap[i]
    
    def insert(self,value):
        """insert key"""
        self.heap.append(value) #Stick the key at the end of the last level
        self._bubble_up(len(self.heap)-1) #Bubble up until the heap property is restored
        
    def _bubble_up(self, i): 
        """decrease key"""
        while i: #loop until the root
            parent = self._parent(i)
            if self.heap[parent] < self.heap[i]: #if the value of the parent is already less than the child 
                break   #bubbling up process must stop, heap property is restored
            self._swap(i,parent)
            i = parent #if not continue to the next 
            
    def pop(self,value):
        index = self.heap.index(value)
        while index: #loop until the the thing that we wanted to delete is in the root position
            parent_position = self._parent(index)
            #parent = self.heap[parent_position]
            self._swap(index,parent_position) 
            index = parent_position
        self.extract_min()
    
    def __iter__(self):
        return iter(self.heap)
        
    def _min_heapify(self,i):
        """bubble down after extract min so that the heap property is restored"""
        l = self._left_child(i)
        r = self._right_child(i)
        n = len(self.heap)
        if l < n and self.heap[l] < self.heap[i]:
            low = l
        else:
            low = i #heap property is restore if the value of the parent is smaller than the value of its two children
        if r < n and self.heap[r] < self.heap[low]:
            low = r
        
        if low != i: #if it is not, keep bubbling down
            self._swap(i,low)
            self._min_heapify(low)
            
    def extract_min(self):
        """D.extract_min() -> (k, v), remove and return the (key, value) pair with lowest\nvalue; but raise KeyError if D is empty."""
        root = self.heap[0]
        if len(self.heap) == 1:
            self.heap.pop()
        else:
            self.heap[0] = self.heap.pop(-1) #Delete root and move last leaf to be new root
            self._min_heapify(0)  #Iteratively bubble-down until heap property has been restored
        return root
    
    def peek_min(self):
        """D.peek_min() -> (k, v), return the (key, value) pair with lowest value;\n but raise KeyError if D is empty."""
        return self.heap[0]
        

In [3]:
class Node:
    def __init__(self,val):
        self.val = val
        self.left_child = None
        self.right_child = None
        self.parent = None
    
    def get_parent(self):
        if self.parent == None:
            parent = None
        else:
            parent = self.parent.val
        return parent
    
    def get_children(self):
        if self.left_child == None:
            left_child = None
        else:
            left_child = self.left_child.val
        if self.right_child == None:
            right_child = None
        else:
            right_child = self.right_child.val   
        return [left_child,right_child]
    
class BST:
    def __init__(self):
        self.root = None
        self.counter = 0
        
    def __len__(self):
        return self.counter
    
    def insert(self, val):
        self.counter +=1
        if(self.root is None):
            self.root = Node(val)
        else:
            self.insert_node(self.root, val)

    def insert_node(self, current_node, val):
        if(val <= current_node.val):
            if(current_node.left_child):
                self.insert_node(current_node.left_child, val)
            else: #if None pointers
                current_node.left_child = Node(val)
                current_node.left_child.parent = current_node
        elif(val > current_node.val):
            if(current_node.right_child):
                self.insert_node(current_node.right_child, val)
            else:
                current_node.right_child = Node(val)
                current_node.right_child.parent = current_node

    def find(self, val):
        return self.find_node(self.root, val)

    def find_node(self, current_node, val):
        if(current_node is None):
            return False
        elif(val == current_node.val):
            return True
        elif(val < current_node.val):
            return self.find_node(current_node.left_child, val)
        else:
            return self.find_node(current_node.right_child, val)
        
    def find_min(self):
        current_node = self.root
        while current_node != None:
            previous_node = current_node
            current_node = current_node.left_child
        return previous_node.val
    
    def find_max(self):
        current_node = self.root
        while current_node != None:
            previous_node = current_node
            current_node = current_node.right_child
        return previous_node.val
                
    def delete_node(self,node):
        self.counter -= 1
        if node.left_child == None and node.right_child == None:
            if node.parent != None:
                if node.val <= node.parent.val:
                    node.parent.left_child = None
                else:
                    node.parent.right_child = None
            if node == self.root:
                self.root = None
            del node
            return
        elif node.left_child != None and node.right_child != None:
            current_node = node.lelft_child
            while current_node != None:
                previous_node = current_node
                current_node = current_node.right_child
            #previous_node (predecessor) needs to be swap with node
            
            node.left_child.parent = previous_node
            node.right_child.parent = previous_node
            previous_node.left_child = node.left_child
            previous_node.right_child = node.right_child
            if node == self.root:
                self.root = previous_node
            del node
            return
        elif node.left_child != None:
            if node.parent != None:
                if node.val <= node.parent.val:
                    node.parent.left_child = node.left_child
                    node.left_child.parent = node.parent
                else:
                    node.parent.right_child = node.left_child
                    node.left_child.parent = node.parent
            if node == self.root:
                self.root = node.left_child
            
            del node
            return
        elif node.right_child != None:
            if node.parent != None:
                if node.val <= node.parent.val:
                    node.parent.left_child = node.right_child
                    node.right_child.parent = node.parent
                else:
                    node.parent.right_child = node.right_child
                    node.right_child.parent = node.parent
            if node == self.root:
                self.root = node.right_child
            del node
            return
    
    def in_order(self):
        self.in_order_traversal(self.root)
    
    def in_order_traversal(self,node):
        if node.left_child != None:
            self.in_order_traversal(node.left_child)
        #print(node.val)
        if node.right_child != None:
            self.in_order_traversal(node.right_child)
    
    
    def extract_min(self):
        current_node = self.root
        while current_node != None:
            previous_node = current_node
            current_node = current_node.left_child
        min_val = previous_node.val
        self.delete_node(previous_node)
        return min_val

    def extract_max(self):
        current_node = self.root
        while current_node != None:
            previous_node = current_node
            current_node = current_node.right_child
        max_val = previous_node.val
        self.delete_node(previous_node)
        return max_val

In [4]:
B = BST()
for i in ((range(10))):
    B.insert(i)
#B.insert(-1)
print('min:',B.extract_min())
B.in_order()
print('length:',B.root.val)
print('---')
print('max:',B.extract_max())
B.in_order()
print('length:',len(B))

min: 0
1
2
3
4
5
6
7
8
9
length: 1
---
max: 9
1
2
3
4
5
6
7
8
length: 8


In [24]:
def median_maintenance_heap(filename,debug=0):
    medians = []  
    numbers = np.loadtxt('week7_file/'+filename)
    h_low = heap_min() #support extract max
    h_high = heap_min() #support extract min
    
    h_low.insert(-numbers[0])
    medians = [0]*len(numbers)
    medians[0] = numbers[0]
    counter = 0
    
    for i in numbers[1:]:
        counter +=1
        max_val_h_low = -h_low.peek_min()
        
        if i > max_val_h_low:
            h_high.insert(i)
        else:
            h_low.insert(-i)

        
        if len(h_high) - len(h_low) > 1:
            min_val_h_high = h_high.extract_min()
            h_low.insert(-min_val_h_high)

        elif len(h_low) - len(h_high) > 1:
            max_val_h_low = -h_low.extract_min()
            h_high.insert(max_val_h_low)
                
        if len(h_low) == len(h_high):
            max_val_h_low = -h_low.peek_min()
            medians[counter] = max_val_h_low
        elif len(h_low) - len(h_high) == 1:
            max_val_h_low = -h_low.peek_min()
            medians[counter] = max_val_h_low
        elif len(h_high) - len(h_low) == 1:
            min_val_h_high = h_high.peek_min()
            medians[counter] = min_val_h_high
            
        if debug:
            print('Turn:',counter)
            print('Numbers:',i)
            print('H_low:')
            for v in h_low:
                print(-v)
            print('H_high:')
            for v in h_high:
                print(v)
            print('Medians:',medians)
            print()
    return sum(medians)%10000

In [21]:
def median_maintenance_BST(filename,debug=0):
    medians = []   
    numbers = np.loadtxt('week7_file/'+filename)
    h_low = BST() #support extract max
    h_high = BST() #support extract min
    
    h_low.insert(numbers[0])
    medians = [0]*len(numbers)
    medians[0] = numbers[0]
    counter = 0
    
    for i in numbers[1:]:
        counter +=1
        max_val_h_low = h_low.find_max()
        
        if i > max_val_h_low:
            h_high.insert(i)
        else:
            h_low.insert(i)


       #extract min must be implemented 
        if len(h_high) - len(h_low) > 1:
            min_val_h_high = h_high.extract_min()
            h_low.insert(min_val_h_high)
        elif len(h_low) - len(h_high) > 1:
            max_val_h_low = h_low.extract_max()
            h_high.insert(max_val_h_low)

        if len(h_low) == len(h_high):
            max_val_h_low = h_low.find_max()
            medians[counter] = max_val_h_low
        elif len(h_low) - len(h_high) == 1:
            max_val_h_low = h_low.find_max()
            medians[counter] = max_val_h_low
        elif len(h_high) - len(h_low) == 1:
            min_val_h_high = h_high.find_min()
            medians[counter] = min_val_h_high
            
        if debug:
            print('root h - low',h_low.root.val)
            print('root h - high',h_high.root.val)
            print('Turn:',counter)
            print('Numbers:',i)
            print('H_low:')
            h_low.in_order()
            print('len h_low',len(h_low))
            print('H_high:')
            h_high.in_order()
            print('len h_high',len(h_high))
            print('Medians:',medians)
            print()
    return sum(medians)%10000

In [22]:
start_time = time.time()
print(median_maintenance_heap('week7_test1.txt'))
print("--- %s seconds ---" % (time.time() - start_time))
start_time = time.time()
print(median_maintenance_BST('week7_test1.txt'))
print("--- %s seconds ---" % (time.time() - start_time))

142.0
--- 0.009569168090820312 seconds ---
142.0
--- 0.0010700225830078125 seconds ---


In [23]:
start_time = time.time()
print(median_maintenance_heap('week7_test2.txt'))
print("--- %s seconds ---" % (time.time() - start_time))
start_time = time.time()
print(median_maintenance_BST('week7_test2.txt'))
print("--- %s seconds ---" % (time.time() - start_time))

9335.0
--- 0.0027070045471191406 seconds ---
9335.0
--- 0.0011048316955566406 seconds ---


In [19]:
start_time = time.time()
print(median_maintenance_heap('week7_test3.txt'))
print("--- %s seconds ---" % (time.time() - start_time))
start_time = time.time()
print(median_maintenance_BST('week7_test3.txt'))
print("--- %s seconds ---" % (time.time() - start_time))

5174.0
--- 0.011176824569702148 seconds ---
5174.0
--- 0.002808094024658203 seconds ---


In [20]:
start_time = time.time()
print(median_maintenance_heap('week7.txt'))
print("--- %s seconds ---" % (time.time() - start_time))
start_time = time.time()
print(median_maintenance_BST('week7.txt'))
print("--- %s seconds ---" % (time.time() - start_time))

1213.0
--- 0.19062519073486328 seconds ---
1213.0
--- 0.7716951370239258 seconds ---
