### LRU Cache

We saw in the lecture what an LRU cache is.

Although Python provides us a decorator to apply an LRU caching mechanism to a function (called **memoization**), we are going to try doing it ourselves first as another excellent example of decorators.

We are not going to worry about cache size - for simplicity we'll allow our cache to always grow unbounded.

Another simplification we are going to make is that we are not going to handle caching functions that use keyword-only arguments.

Python's LRU cache mechanism does not have these two simplifications.

In [1]:
def cache(func):
    def inner(*args):
        result = func(*args)
        return result
    return inner

This is our standard pattern for creating a decorator (albeit only considering positional arguments).

We want to create a cache where we can store/recall the result of calling `func(*args)`.

First thing we'll want to do is create a dictionary - but we don't want to create the dictionary inside `inner`, because that means every time we call `inner` (the decorated function), we would start with an empty cache dictionary.

Instead, we're going to define the cache in the outer `cache` function, and access it as a free variable in `inner`.

So `cache` gets created every time `cache` is called, but the returned `inner` function can use that same `cache` over and over again. 

So next, we need to calculate a key to represent all the arguments that were passed to `inner` (that eventually calls `func`) - and that's simply the tuple `args`. 

Now, it does mean that the tuple `args` must be hashable - just like Python's own implementation of LRU cache.

In [2]:
def cache(func):
    print('initialize cache')
    cache = {}
    def inner(*args):
        key = args
        if key in cache:
            print('Cache hit')
            return cache[args]
        else:
            result = func(*args)
            cache[args] = result
            return result
    return inner

And now we can start using it:

In [3]:
@cache
def my_func(a, b):
    print(f'evaluating my_func({a}, {b})...')
    return a + b

initialize cache


In [4]:
my_func(1, 2)

evaluating my_func(1, 2)...


3

As you can see, the original `my_func` was called.

But if we use the same parameters a second time:

In [5]:
my_func(1, 2)

Cache hit


3

As you can see, we get the result back without the function actually executing - the result was obtained from cache.

Now we can decorate another function with the same decorator, and it will have it's own cache:

In [6]:
@cache
def add(a, b):
    return a + b

initialize cache


In [7]:
add(1, 2)

3

In [8]:
add(1, 2)

Cache hit


3

That's the basic idea behind the `lru_cache` decorator.

Let's use that one instead:

In [9]:
from functools import lru_cache

In [10]:
@lru_cache(maxsize=2)
def add(a, b):
    print(f'Calling add({a}, {b})...')
    return a + b

In [11]:
add(1, 1)

Calling add(1, 1)...


2

In [12]:
add(1, 1)

2

In [13]:
add(2, 2)

Calling add(2, 2)...


4

In [14]:
add(1,1)

2

In [15]:
add(2,2)

4

But we made the cache size `2`, which means if we now call with a new set of args:

In [16]:
add(3, 3)

Calling add(3, 3)...


6

Not only was the function evaluated (these args were not in the cache), but it also cleared out the oldest entry in the cache - (1, 1). (2, 2) is still there though.

In [17]:
add(2, 2)

4

In [18]:
add(1, 1)

Calling add(1, 1)...


2

When you have a function that takes a long time to run, and you often call it with the same arguments, don't forget an LRU cache - it can greatly speed up your code.

Let's take a look at a very simple example to calculate the Fibonacci numbers:

```
0, 1, 1, 2, 3, 5, 8, 13, 21, ...
```

This sequence starts with `0` and `1`, and every element thereafter is the sum of the previous two - so there is a recursive relationship, and we could define a mamethematical function to produce the `n`th number (assuming we are indexing starting at `0`), this way:

```
Fib(0) = 0
Fib(1) = 1
Fib(n) = Fib(n-1) + Fib(n-2), n > 1
```

We can express this same recursive definition using a Python function:

In [19]:
def fib(n):
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)

This function may seem odd as it is calling itself - this is known as a recursive function.

The way it works is that if we call `fib(0)` or `fib(1)` it just returns `0` and `1` respectively.

If we call `fib(2)` it will call `fib(0)` and `fib(1)` which return  `0` and `1`.

If we call `fib(3)` it will call `fib(1)` and `fib(2)` - `fib(1)` will just return `1`, but `fib(2)` calls `fib(0)` and `fib(1)`.

If we call `fib(4)` it will call `fib(2)` and `fib(3)`, and the processing path for those two calls will follow what we just saw.

So recursion is a very simple approach to implement certain algorithms (like Fibonacci, factorials, etc).

But they can often become computationally intensive.

Let's put a print statement to indicate when the `fib` function gets called, and see what happens:

In [20]:
def fib(n):
    print(f'fib({n}) called...')
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)

In [21]:
fib(0)

fib(0) called...


0

In [22]:
fib(1)

fib(1) called...


1

In [23]:
fib(2)

fib(2) called...
fib(1) called...
fib(0) called...


1

In [24]:
fib(3)

fib(3) called...
fib(2) called...
fib(1) called...
fib(0) called...
fib(1) called...


2

In [25]:
fib(4)

fib(4) called...
fib(3) called...
fib(2) called...
fib(1) called...
fib(0) called...
fib(1) called...
fib(2) called...
fib(1) called...
fib(0) called...


3

In [26]:
fib(5)

fib(5) called...
fib(4) called...
fib(3) called...
fib(2) called...
fib(1) called...
fib(0) called...
fib(1) called...
fib(2) called...
fib(1) called...
fib(0) called...
fib(3) called...
fib(2) called...
fib(1) called...
fib(0) called...
fib(1) called...


5

And the higher we go the more function calls are made - in fact the number of calls grows so fast that the timing to calulate the `n`th Fibonacci number even for relatively small `n` is prohibitive:

In [27]:
from time import perf_counter

Let's remove that `print` statement first:

In [28]:
def fib(n):
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)

In [29]:
for n in range(30, 38):
    start = perf_counter()
    result = fib(n)
    end = perf_counter()
    print(f'fib({n})={result}, elapsed: {end - start}')

fib(30)=832040, elapsed: 0.27679114000000005
fib(31)=1346269, elapsed: 0.41020096800000005
fib(32)=2178309, elapsed: 0.6584909400000001
fib(33)=3524578, elapsed: 1.1819654540000002
fib(34)=5702887, elapsed: 1.7600036860000001
fib(35)=9227465, elapsed: 2.8420717659999992
fib(36)=14930352, elapsed: 4.518468769999999
fib(37)=24157817, elapsed: 7.279664487999998


The problem of course is that when we call `fib(37)` it calls `fib(35)` and `fib(36)`.
In turn `fib(35)` calls `fib(34) and fib(33)` and `fib(36)` calls `fib(34)` and `fib(35)`, etc - so we end up calling the same `fib(n)` over and over again.

What if we could cache the results of calling `fib(n)` - then we could use the cache for a previous Fibonacci number without recalculating it.

And that's precisely what the LRU cache can do for us:

Let's put that `print` statement back and see the call stack when we call `fib(6)`:

In [30]:
def fib(n):
    print(f'fib({n}) called...')
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)

In [31]:
fib(6)

fib(6) called...
fib(5) called...
fib(4) called...
fib(3) called...
fib(2) called...
fib(1) called...
fib(0) called...
fib(1) called...
fib(2) called...
fib(1) called...
fib(0) called...
fib(3) called...
fib(2) called...
fib(1) called...
fib(0) called...
fib(1) called...
fib(4) called...
fib(3) called...
fib(2) called...
fib(1) called...
fib(0) called...
fib(1) called...
fib(2) called...
fib(1) called...
fib(0) called...


8

Now let's apply that LRU cache:

In [32]:
@lru_cache
def fib(n):
    print(f'fib({n}) called...')
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)

In [33]:
fib(6)

fib(6) called...
fib(5) called...
fib(4) called...
fib(3) called...
fib(2) called...
fib(1) called...
fib(0) called...


8

As you can see the number of function calls greatly decreased, and we can redo our timings:

In [34]:
@lru_cache
def fib(n):
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)

for n in range(30, 38):
    start = perf_counter()
    result = fib(n)
    end = perf_counter()
    print(f'fib({n})={result}, elapsed: {end - start}')

fib(30)=832040, elapsed: 1.913599999880944e-05
fib(31)=1346269, elapsed: 1.1970000031169548e-06
fib(32)=2178309, elapsed: 7.240000030606097e-07
fib(33)=3524578, elapsed: 6.570000010697186e-07
fib(34)=5702887, elapsed: 6.900000002474371e-07
fib(35)=9227465, elapsed: 6.040000002371926e-07
fib(36)=14930352, elapsed: 6.429999999113534e-07
fib(37)=24157817, elapsed: 5.9699999965801e-07


Much faster, and even for larger `n`s:

In [35]:
for n in range(100, 110):
    start = perf_counter()
    result = fib(n)
    end = perf_counter()
    print(f'fib({n})={result}, elapsed: {end - start}')

fib(100)=354224848179261915075, elapsed: 3.7771999998881256e-05
fib(101)=573147844013817084101, elapsed: 9.4600000011269e-07
fib(102)=927372692193078999176, elapsed: 7.40000000831742e-07
fib(103)=1500520536206896083277, elapsed: 9.709999986284856e-07
fib(104)=2427893228399975082453, elapsed: 7.569999986856146e-07
fib(105)=3928413764606871165730, elapsed: 6.95000000661139e-07
fib(106)=6356306993006846248183, elapsed: 6.570000010697186e-07
fib(107)=10284720757613717413913, elapsed: 6.480000003250552e-07
fib(108)=16641027750620563662096, elapsed: 6.639999980961875e-07
fib(109)=26925748508234281076009, elapsed: 6.95000000661139e-07


In fact, we really only need to cache the results of the last three Fibonacci numbers to gain efficiencies:

In [36]:
@lru_cache(maxsize=3)
def fib(n):
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)

for n in range(30, 38):
    start = perf_counter()
    result = fib(n)
    end = perf_counter()
    print(f'fib({n})={result}, elapsed: {end - start}')

fib(30)=832040, elapsed: 1.1567000001377892e-05
fib(31)=1346269, elapsed: 1.89200000022538e-06
fib(32)=2178309, elapsed: 1.069999999714355e-06
fib(33)=3524578, elapsed: 7.12999998597752e-07
fib(34)=5702887, elapsed: 7.310000000870787e-07
fib(35)=9227465, elapsed: 7.350000004180401e-07
fib(36)=14930352, elapsed: 1.018000002517283e-06
fib(37)=24157817, elapsed: 9.580000011055745e-07


The efficiencies are not as great as an unbounded cache, but the efficiency gain is noetheless perfectly acceptable in view of the fact that we are not growing our cache unbounded as we calculate larger and larger Fibonacci numbers.

Let's use `timeit` to see this:

In [37]:
from timeit import timeit

In [38]:
@lru_cache(maxsize=3)
def fib_3(n):
    if n <= 1:
        return n
    return fib_3(n-1) + fib_3(n-2)

In [39]:
@lru_cache()
def fib_unbounded(n):
    if n <= 1:
        return n
    return fib_unbounded(n-1) + fib_unbounded(n-2)

In [40]:
timeit(
    '[fib_3(n) for n in range(100, 200)]', 
    globals=globals(), 
    number=10_000
)

0.678542482000001

In [41]:
timeit(
    '[fib_unbounded(n) for n in range(100, 200)]', 
    globals=globals(), 
    number=10_000
)

0.07769970500000056

So as is often the case, we need to balance performance against memory usage.