In [None]:
__author__ = "Nagarajan"

## LRU Cache

Requirements:
 - cache should have size n
 - should keep the most recently accessed n items in the cache
 - when new item is accessed, the last item in cache is removed and new item is inserted.
 - cache can be increased or decreased at runtime (optional)
 


In [None]:
%%prun

class LRUCache(object):
    """
    Implementation of an LRU cache using a doubly linked list
    and dict.
    It has O(1) lookups and is blazingly fast (for python).
    """
    def __init__(self, size):
        """
        Setup the size and datastore.
        Instantiate the doubly linked list and nodeMap
        """
        self.maxSize = size
        self.dlist = DoublyLinkedList()
        self.keyToNodeMap = {}
        
    def get(self, key):
        """
        Return the data if its in the cache.
        Else, get it from the datastore, store it in the cache and 
        return the data.
        """
        if key in self.keyToNodeMap:
            # hit
            node = self.keyToNodeMap[key]
            self.dlist.moveToFront(node)
        else:
            raise KeyError("Key %s not in cache" % key)
            
    def add(self, key, value):
        if self.getCurrSize() >= self.maxSize:
            oldnode = self.dlist.removeLast()
            del self.keyToNodeMap[oldnode.key]
        node = Node(key, value)
        self.dlist.insertAtFront(node)
        self.keyToNodeMap[key] = node
        
    def remove(self, key):
        if key in self.keyToNodeMap:
            node = self.keyToNodeMap[key]
            self.dlist.remove(node)
            del self.keyToNodeMap[key]
        else:
            raise KeyError("Key %s not in cache" % key)
        
    
    def set(self, key, value):
        self.db.set(key, value)
        node = self.keyToNodeMap.get(key, None)
        if node is not None:
            node.value = value
            
    def getCurrSize(self):
        return len(self.keyToNodeMap)
    
    def updateSize(self, size):
        currSize = self.getCurrSize()
        if size < currSize:
            diff = currSize - size
            for x in xrange(diff):
                node = self.dlist.removeLast()
                del self.keyToNodeMap[node.key]
        self.maxSize = size
        
    def getKeys(self):
        return self.keyToNodeMap.keys()


class Node(object):
    def __init__(self, key, value):
        self.key = key
        self.value = value
        self.left = self.right = None
        

class DoublyLinkedList(object):
    def __init__(self):
        self.head = self.tail = None
        
    def removeLast(self):
        if self.tail is None:
            raise Exception("Cannot removeLast from empty list")
        if self.tail == self.head:
            last = self.tail
            self.tail = self.head = None
        else:
            last = self.tail
            self.tail = last.left
            self.tail.right = None
        return last
    
    def insertAtFront(self, node):
        if self.head is None:
            self.head = self.tail = node
        else:
            node.right = self.head
            node.left = None
            self.head.left = node
            self.head = node
        return node
    
    def moveToFront(self, node):
        if node is self.head:
            return node
        if node is self.tail:
            last = node.left
            last.right = None
            self.tail = last
        else:
            left, right = node.left, node.right
            left.right, right.left = right, left
            
        return self.insertAtFront(node)


import unittest    
import mock
import sys

class DBMock(object):
    get = mock.Mock(return_value="get mock")
    set = mock.Mock(return_value="set mock")

class TestLRUCache(unittest.TestCase):
    def setUp(self):
        self.db = DBMock()
        self.cache = LRUCache(4, self.db)

        
    def test_cacheMissesHitDisk(self):
        result = self.cache.get('key1')
        self.assertEqual(result, "get mock")
        
        callArgs = self.db.get.call_args
        self.assertEqual(callArgs, mock.call('key1'))
        
    def test_dataSetHitsDisk(self):
        self.cache.set('key2', 'value2')
        self.assertEqual(self.cache.getCurrSize(), 0)
        callArgs = self.db.set.call_args
        self.assertEqual(callArgs, mock.call('key2', 'value2'))
        
    def test_cache4Items(self):
        keys = set()
        for x in xrange(4):
            key = 'key%s' % x
            keys.add(key)
            self.cache.get(key)
            
        self.assertEqual(self.cache.getCurrSize(), 4)
        self.assertEqual(set(self.cache.getKeys()), keys)
        
    def test_cacheModifications(self):
        keys = set()
        for x in xrange(20):
            key = 'key%s' % x
            keys.add(key)
            self.cache.get(key)
            
        self.assertEqual(self.cache.getCurrSize(), 4)
        self.assertEqual(set(self.cache.getKeys()), set(['key16', 'key17', 'key18', 'key19']))
        
        self.cache.get('keya')
        self.assertEqual(set(self.cache.getKeys()), set(['keya', 'key17', 'key18', 'key19']))

        self.cache.get('keyb')
        self.assertEqual(set(self.cache.getKeys()), set(['keya', 'keyb', 'key18', 'key19']))

        self.cache.get('keyc')
        self.assertEqual(set(self.cache.getKeys()), set(['keya', 'keyb', 'keyc', 'key19']))

        self.cache.get('keya')
        self.assertEqual(set(self.cache.getKeys()), set(['keya', 'keyb', 'keyc', 'key19']))

        self.cache.get('keyd')
        self.assertEqual(set(self.cache.getKeys()), set(['keya', 'keyb', 'keyc', 'keyd']))

        # keyb gets kicked out
        self.cache.get('keye')
        self.assertEqual(set(self.cache.getKeys()), set(['keya', 'keye', 'keyc', 'keyd']))

        # keyc gets kicked out
        self.cache.get('keyf')
        self.assertEqual(set(self.cache.getKeys()), set(['keya', 'keye', 'keyf', 'keyd']))
        
        # keya gets kicked out
        self.cache.get('keyg')
        self.assertEqual(set(self.cache.getKeys()), set(['keyg', 'keye', 'keyf', 'keyd']))
        
        
        
suite = unittest.TestLoader().loadTestsFromTestCase( TestLRUCache )
result = unittest.TextTestRunner(verbosity=1,stream=sys.stderr).run( suite )
            
        

In [2]:
%%prun

def fib(n):
    if n <= 2:
        return 1
    return fib(n-1) + fib(n-2)

result = [fib(x) for x in xrange(28)]
print result


[1, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765, 10946, 17711, 28657, 46368, 75025, 121393, 196418]
 