### Use Caching to Optimize your Apps

Sometimes, we have functions that are called repeatedly during the lifetime of our application.

If those function calls are expensive (either CPU or memory), get called often with the same parameters, and are invariant, then we might have the option of caching the results of these calls instead of re-computing the entire call.

There are a few things we usually need to ensure before going down that road:

- function is invariant (i.e. given the same inputs, it will always return the same output). So a function that is time sensitive, or uses random numbers for example, will not work. You probably also do not want to cache a function that has side effects, so in fact you shoudl probably insist on a pure function.
- function arguments need to be hashable (this is because of Python's caching implementations, which uses the argument values as keys in a dictionary)
- only useful if the function is repeatedly called with the same values multiple times during the life of our app, and running the function is expensive, in terms of CPU, resource utilization, latency, etc.

#### Example - From First Principles

Let's take a quick look at an example of how caching benefits generating the Fibonacci sequence.

For this example we'll do our own cache mechanism, so you get an understanding of what Python's caching mechanism actually does.

In [1]:
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)

In [2]:
for i in range(10):
    print(fib(i), end=", ")

1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 

Now let's time this function as `n` increases:

In [3]:
from timeit import timeit

In [4]:
for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed:.2f} s")

25 - 0.01 s
26 - 0.02 s
27 - 0.02 s
28 - 0.03 s
29 - 0.06 s
30 - 0.09 s
31 - 0.14 s
32 - 0.23 s
33 - 0.37 s
34 - 0.60 s
35 - 0.97 s
36 - 1.57 s
37 - 2.55 s
38 - 4.10 s
39 - 6.63 s


As you can see, as `n` gets larger, our computation times increase dramatically. 

In fact, the complexity of this algorithm is exponential, O(2^n).

Why is this happening?

The fib function is called repeatedly with the same argument over and over again.

Let's put a print statement and see this:

In [5]:
def fib(n):
    if n <= 1:
        return 1
    print(f"calculating fib({n})")
    return fib(n-1) + fib(n-2)

In [6]:
fib(10)

calculating fib(10)
calculating fib(9)
calculating fib(8)
calculating fib(7)
calculating fib(6)
calculating fib(5)
calculating fib(4)
calculating fib(3)
calculating fib(2)
calculating fib(2)
calculating fib(3)
calculating fib(2)
calculating fib(4)
calculating fib(3)
calculating fib(2)
calculating fib(2)
calculating fib(5)
calculating fib(4)
calculating fib(3)
calculating fib(2)
calculating fib(2)
calculating fib(3)
calculating fib(2)
calculating fib(6)
calculating fib(5)
calculating fib(4)
calculating fib(3)
calculating fib(2)
calculating fib(2)
calculating fib(3)
calculating fib(2)
calculating fib(4)
calculating fib(3)
calculating fib(2)
calculating fib(2)
calculating fib(7)
calculating fib(6)
calculating fib(5)
calculating fib(4)
calculating fib(3)
calculating fib(2)
calculating fib(2)
calculating fib(3)
calculating fib(2)
calculating fib(4)
calculating fib(3)
calculating fib(2)
calculating fib(2)
calculating fib(5)
calculating fib(4)
calculating fib(3)
calculating fib(2)
calculating

89

So we could alter our algorithm to be more efficient (which is what we should do in this particular case). 

But suppose improving our function's algorithm was not an option. What then?

Well, if we observe these repeated calls to the `fib()` function, we see that the arguments often have the same values - so, instead of re-computing them, we can cache the results and save the computations.

Let's start simplistcially by establishing a global dictionary that will hold the value of `n` ad the key, and the result `fib(n)` as the value. Then, when we perform the calculation, we first check to see if `n` is in the dictionary.

If it is, we simply return that value (remember that lookups in Python dictionaries are fast, O(1)). If the value is not in the dictionary, we calculate the result, store it in the dictionary, and then return the result.

In [7]:
cache = {}

def fib(n):
    if n <= 1:
        return 1
    if n in cache:
        return cache[n]
    print(f"calculating fib({n})")
    result = fib(n-1) + fib(n-2)
    cache[n] = result
    return result

In [8]:
fib(10)

calculating fib(10)
calculating fib(9)
calculating fib(8)
calculating fib(7)
calculating fib(6)
calculating fib(5)
calculating fib(4)
calculating fib(3)
calculating fib(2)


89

See how many less calls we have now?

How about our timings?

In [9]:
cache = {}

def fib(n):
    if n <= 1:
        return 1
    if n in cache:
        return cache[n]
    result = fib(n-1) + fib(n-2)
    cache[n] = result
    return result

In [10]:
for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")

25 - 5.875015631318092e-06 s
26 - 5.420297384262085e-07 s
27 - 3.33995558321476e-07 s
28 - 3.750319592654705e-07 s
29 - 2.9098009690642357e-07 s
30 - 3.3300602808594704e-07 s
31 - 2.9098009690642357e-07 s
32 - 2.919696271419525e-07 s
33 - 2.9103830456733704e-07 s
34 - 3.3294782042503357e-07 s
35 - 3.3300602808594704e-07 s
36 - 2.92027834802866e-07 s
37 - 3.33995558321476e-07 s
38 - 2.92027834802866e-07 s
39 - 2.92027834802866e-07 s


Much better!

Now doing it this way is not exactly very good code - it works, but has a number of drawbacks. That cache is a global 
variable (which we tend to avoid whenever possible) and the caching is "baked" into the `fib()` function itself.

To implement this caching mechanism to other functions in our code, we therefore would have to alter the functions themselves - which is far from ideal. And we would have to create an additional global variable for each cached function, cluttering our code needlessly, and also opening ourselves to inadvertent bugs if something outside the caching mechanism tampers with the cache dictionary.

Instead, we can actually use a closure to perform all this work, and by setting it up as a decorator we end up with a solution that is reusable and properly applies the decomposition concept I covered in an earlier video.

For simplicity, I will only handle caching arbitrary functions that use positional arguments only, but the same idea can be extended to functions that also use keyword-only arguments.

In [11]:
def cache(fn):
    data_cache = {}

    def inner(*args):
        key = tuple(args)
        if key in data_cache:
            return data_cache[key]
        result = fn(*args)
        data_cache[key] = result
        return result
        
    return inner

Now, let's try using this decorator for our `fib()` function.

This was our original function whose implementation we ended up modifying in order to add caching.

In [12]:
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)    

To implement the cache now, we only need to decorate it:

In [13]:
@cache
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)     

And now let's time things again:

In [14]:
for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")

25 - 9.374984074383974e-06 s
26 - 7.500057108700275e-07 s
27 - 8.330098353326321e-07 s
28 - 7.080379873514175e-07 s
29 - 5.839974619448185e-07 s
30 - 5.410402081906796e-07 s
31 - 4.5797787606716156e-07 s
32 - 5.409820005297661e-07 s
33 - 5.839974619448185e-07 s
34 - 7.079797796905041e-07 s
35 - 5.830079317092896e-07 s
36 - 5.839974619448185e-07 s
37 - 4.5797787606716156e-07 s
38 - 4.5797787606716156e-07 s
39 - 5.00003807246685e-07 s


So, this is how we can implement caching from first principles. You'll also see why I said earlier that the arguments to the function being cached need to be hashable - they end up (as a tuple) as the **keys** in our cache dictionary.

#### Python's LRU Cache

We don't have to implement this cache mechanism ourselves. In fact, the way we did it has some serious limitations.

For example, we do not handle caching functions which have keyword-only arguments. Also, our cache size is unlimited, which may not be something we want - we may want to only cache a fixed number of "calls" - a common approach here is to limit the cache size, and once the cache goes beyond that limit, start removing the "oldest" (least recently used = LRU) items from the cache.

If you want to know more about LRU caching in general terms, and other cache replacement algorithms, here is the Wikipedia link for it:

[https://en.wikipedia.org/wiki/Cache_replacement_policies](https://en.wikipedia.org/wiki/Cache_replacement_policies)

Python implements this LRU (least recently used) cache replacement algorithm, in the `functools` module.

In [15]:
from functools import lru_cache

And let's use it with our example:

In [16]:
@lru_cache
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)   

This establishes an LRU cache for our `fib()` function, with a default max size of 128.

In [17]:
for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")

25 - 1.2166041415184736e-05 s
26 - 4.169996827840805e-07 s
27 - 2.92027834802866e-07 s
28 - 2.919696271419525e-07 s
29 - 2.9103830456733704e-07 s
30 - 2.500019036233425e-07 s
31 - 2.500019036233425e-07 s
32 - 2.0797597244381905e-07 s
33 - 2.500019036233425e-07 s
34 - 2.500019036233425e-07 s
35 - 2.9098009690642357e-07 s
36 - 2.0803418010473251e-07 s
37 - 2.0902371034026146e-07 s
38 - 2.500019036233425e-07 s
39 - 2.0797597244381905e-07 s


We can easily change the cache size to any specific value, including unlimited this way:

In [18]:
@lru_cache(maxsize=None)
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)  

In [19]:
for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")

25 - 5.62501372769475e-06 s
26 - 4.159519448876381e-07 s
27 - 2.500019036233425e-07 s
28 - 2.500019036233425e-07 s
29 - 2.500019036233425e-07 s
30 - 2.0902371034026146e-07 s
31 - 2.919696271419525e-07 s
32 - 2.500019036233425e-07 s
33 - 2.0797597244381905e-07 s
34 - 2.0797597244381905e-07 s
35 - 2.500019036233425e-07 s
36 - 2.08965502679348e-07 s
37 - 2.0902371034026146e-07 s
38 - 2.9098009690642357e-07 s
39 - 2.500019036233425e-07 s


And for a limited cache size, we can do this:

In [20]:
@lru_cache(maxsize=20)
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)  

In [21]:
for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")

25 - 6.417045369744301e-06 s
26 - 1.00000761449337e-06 s
27 - 4.1601015254855156e-07 s
28 - 5.00003807246685e-07 s
29 - 2.919696271419525e-07 s
30 - 2.92027834802866e-07 s
31 - 2.9098009690642357e-07 s
32 - 2.500019036233425e-07 s
33 - 3.3300602808594704e-07 s
34 - 2.500019036233425e-07 s
35 - 3.750319592654705e-07 s
36 - 2.9098009690642357e-07 s
37 - 2.500019036233425e-07 s
38 - 2.919696271419525e-07 s
39 - 2.500019036233425e-07 s


For our `fib()` example, if you look at the code closely, and the trace outputs we had earlier, you'll realize that in fact we only need to ever cache the last 2 calls - so we could limit our LRU cache to just two elements, without losing the speedup benefits:

In [22]:
@lru_cache(maxsize=2)
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)  

In [23]:
for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")

25 - 0.000801375019364059 s
26 - 0.00030200002947822213 s
27 - 0.0004160410026088357 s
28 - 0.0005925829755142331 s
29 - 0.00079158297739923 s
30 - 0.0010942919761873782 s
31 - 0.0016196249634958804 s
32 - 0.002049707982223481 s
33 - 0.0027969160000793636 s
34 - 0.00412495800992474 s
35 - 0.005773082957603037 s
36 - 0.007761000015307218 s
37 - 0.01030524994712323 s
38 - 0.013624083949252963 s
39 - 0.018959375040140003 s


Any lower than that however, and things won't work as well:

In [24]:
@lru_cache(maxsize=1)
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)  

In [25]:
for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")

25 - 0.020892457978334278 s
26 - 0.012763917038682848 s
27 - 0.02029274997767061 s
28 - 0.032628499960992485 s
29 - 0.05313491600099951 s
30 - 0.08700279204640538 s
31 - 0.13940141699276865 s
32 - 0.22385312500409782 s
33 - 0.3612438749987632 s
34 - 0.5826868750154972 s
35 - 0.9421670840238221 s
36 - 1.524265291984193 s
37 - 2.4616684170323424 s
38 - 3.987549334007781 s
39 - 6.533424041990656 s


#### Unbounded LRU Cache

We just saw that we can use `maxsize=None` to create an unbounded LRU cache.

Python also provides a simpler syntax to do the same thing:

In [26]:
from functools import cache

In [27]:
@lru_cache(maxsize=None)
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)  

Then can be equivalently defined this way:

In [28]:
@cache
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2) 

In [29]:
for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")

25 - 5.957961548119783e-06 s
26 - 3.7497375160455704e-07 s
27 - 2.500019036233425e-07 s
28 - 3.750319592654705e-07 s
29 - 2.0902371034026146e-07 s
30 - 2.0797597244381905e-07 s
31 - 2.0797597244381905e-07 s
32 - 2.0902371034026146e-07 s
33 - 1.66997779160738e-07 s
34 - 2.0797597244381905e-07 s
35 - 2.500019036233425e-07 s
36 - 2.0797597244381905e-07 s
37 - 2.500019036233425e-07 s
38 - 2.0902371034026146e-07 s
39 - 2.0797597244381905e-07 s


But of course, in this case, we would be better off setting up an LRU cache with a size of 2. However, for cases where you need an unlimited cache, then using `@cache` is pretty easy.

#### Some Finer Details and Caveats on LRU Cache

There are a few finer points that you may want to be aware of when using the LRU cache mechanism provided by Python.

Let's say we have this function:

In [30]:
def func(a, b, c):
    print(f"computing result for {a=}, {b=}, {c=}")
    return a + b + c

In [31]:
func(1, 2, 3)

computing result for a=1, b=2, c=3


6

In [32]:
func(1, 2, 3)

computing result for a=1, b=2, c=3


6

Now let's apply an LRU cache to this:

In [33]:
@cache
def func(a, b, c):
    print(f"computing result for {a=}, {b=}, {c=}")
    return a + b + c

In [34]:
func(1, 2, 3)

computing result for a=1, b=2, c=3


6

In [35]:
func(1, 2, 3)

6

As you can see, we got a cache "hit", and the computation was not performed.

However, as you know, we can choose to pass values using named arguments as well.

Let's see what happens when we do:

In [36]:
func(a=1, b=2, c=3)

computing result for a=1, b=2, c=3


6

As you can see, this was a cache "miss", and we incurred the re-computation.

Now the second time around, we'll get a cache hit:

In [37]:
func(a=1, b=2, c=3)

6

But, if we change the order of the arguments, we'll get a cache miss again:

In [38]:
func(b=2, a=1, c=3)

computing result for a=1, b=2, c=3


6

But now, both ways of calling the function (with the same values essentially) are cached:

In [39]:
func(a=1, b=2, c=3)

6

In [40]:
func(b=2, a=1, c=3)

6

#### Caching Class Properties

Sometimes we need to apply the same cache principle to properties (especially calculated properties) in our classes.

We can certainly do it from first principles, like this for example:

In [41]:
from math import pi

class Circle():
    def __init__(self, radius):
        self._radius = radius

    @property
    def area(self):
        return pi * (self._radius ** 2)

So, the calculated `area` property is completely uncached, and this is the performance if we would need to repeatedly call this property in our app:

In [42]:
c = Circle(3)

timeit("c.area", globals=globals(), number=1_000_000)

0.07789866701932624

Now, let's implement caching from first principles:

In [43]:
class Circle():
    def __init__(self, radius):
        self._radius = radius
        self._area = None

    @property
    def area(self):
        if self._area is None:
            self._area = pi * (self._radius ** 2)
        return self._area

And let's time it now:

In [44]:
c = Circle(3)

timeit("c.area", globals=globals(), number=1_000_000)

0.05554583400953561

Although this solution works, it suffers from the same problem we saw earlier - we need to "bake" the caching into the property itself. Not the end of the world, but it would be nice if we could just decorate our property just like we did with the LRU cache.

And in fact, Python provides this with the `@cached_property` decorator in the `functools` module also.

In [45]:
from functools import cached_property

class Circle():
    def __init__(self, radius):
        self._radius = radius
        self._area = None

    @cached_property
    def area(self):
        return pi * (self._radius ** 2)


In [46]:
c = Circle(3)

timeit("c.area", globals=globals(), number=1_000_000)

0.024079292023088783

#### Caveats of Cached Properties and Mutability

Now, this approach here works because our `Circle` class is immutable (by convention). We assume that since we have set our radius to be private (by convention, using that leading underscore), that the radius will never change over the lifetime of our `Circle(3)` instance.

If the radius does change for some reason, we'll run into issues (both with our own approach, and with the `cached_property` approach.

In [47]:
c = Circle(3)
c.area

28.274333882308138

Now let's change that radius (even though we should not):

In [48]:
c._radius = 1

In [49]:
c.area

28.274333882308138

As you can see, we get the wrong value for the area.

So can we deal with this, and how?

let's look at our first principle approach first.

In [50]:
class Circle():
    def __init__(self, radius):
        self._radius = radius
        self._area = None

    @property
    def area(self):
        if self._area is None:
            self._area = pi * (self._radius ** 2)
        return self._area

What we are going to do here is make the Circle class mutable by controlling how the radius value gets set, so we'll implement both a getter and a setter for our radius. 

In the setter, we'll detect whether the radius has changed, and if so, we'll "invalidate" the cache (i.e. clear the cache) for the `area` property.

In [51]:
class Circle():
    def __init__(self, radius):
        self._radius = radius
        self._area = None

    @property
    def radius(self):
        return self._radius

    @radius.setter
    def radius(self, value):
        if self._radius != value:
            self._area = None
        self._radius = value
    
    @property
    def area(self):
        if self._area is None:
            self._area = pi * (self._radius ** 2)
        return self._area

Now let's see what happens:

In [52]:
c = Circle(1)
c.area

3.141592653589793

In [53]:
c.radius = 2
c.area

12.566370614359172

As you can see, we now have a way to control when the cache for area gets cleared, so we get the correct result.

How about clearing the cache when using `cached_property`?

We do so, by deleting the property - kinf of weird, but this works because the `cached_property` decorator will re-created the property if needed.

So let's see this - much simpler to understand seeing an example:

This was our original implementation:

In [54]:
class Circle():
    def __init__(self, radius):
        self._radius = radius

    @cached_property
    def area(self):
        return pi * (self._radius ** 2)

Let's add the radius property and invalidate the cache as needed in the radius setter, just like we did in our first principle example:

In [55]:
class Circle():
    def __init__(self, radius):
        self._radius = radius
        self._area = None

    @property
    def radius(self):
        return self._radius

    @radius.setter
    def radius(self, value):
        if self._radius != value:
            del self.area
        self._radius = value
        
    @cached_property
    def area(self):
        return pi * (self._radius ** 2)

In [56]:
c = Circle(1)
c.area

3.141592653589793

In [57]:
c.radius = 2
c.area

12.566370614359172

As you can see, we get the same effect now.

#### Caching Class Methods

What about caching class methods?

This can certainly be done as well using either `cache` or `lru_cache`.

However, we need to understand that unlike `cached_property`, that establishes a cache at the instance level, caching a method using `cache` or `lru_cache` establishes a cache at the **class** level.

It's mostly transparent to us, as the end-users, but it does mean we have to handle cache invalidation a bit differently.

Instead of using a property for the area, let's just make it a method, and see how we can deal with it.

In [58]:
class Circle:
    def __init__(self, r):
        self.r = r

    @cache
    def area(self):
        print(f"calculating area for {self.r=}")
        return pi * (self.r ** 2)

In [59]:
c = Circle(1)
c.area()

calculating area for self.r=1


3.141592653589793

In [60]:
c.r = 3

In [61]:
c.area()

3.141592653589793

As you can see, we have the wrong result.

Observe that if we have two instances of the same radius circles:

In [62]:
c1 = Circle(1)
c2 = Circle(1)

In [63]:
c1 == c2

False

So, the two circles are not equal to each other.

And now let's see how the cache mechanism works:

In [64]:
c1.area()

calculating area for self.r=1


3.141592653589793

In [65]:
c2.area()

calculating area for self.r=1


3.141592653589793

As might be expected, we get two cache misses.

So, we can do something about not only this (recalculating the area when in fact the radiuses are the same), but also invalidting the cache when the radius has changed.

We can do this by implementing the `__eq__` and `__hash__` methods to define what consitutes equality, and to make sure that a radius change (in this case) results in objects no longer being equal - something the cache will pick up on.

In [66]:
class Circle:
    def __init__(self, r):
        self.r = r

    @cache
    def area(self):
        print(f"calculating area for {self.r=}")
        return pi * (self.r ** 2)

    def __eq__(self, other):
        return self.r == other.r

    def __hash__(self):
        return hash(self.r)

Now we can see that we have equality the way we probably want it:

In [67]:
c1 = Circle(1)
c2 = Circle(1)

In [68]:
c1 == c2

True

And let's see how the caching works:

In [69]:
c1.area()

calculating area for self.r=1


3.141592653589793

In [70]:
c2.area()

3.141592653589793

As you can see, we had a cache hit when calculating the area on the second instance. 

That is because the first instance is **equal** to the second instance - so, from the cache perspective, it already had the cached value,

Let's mutate the class now:

In [71]:
c1.r = 2

In [72]:
c1 == c2

False

In [73]:
c1.area()

calculating area for self.r=2


12.566370614359172

As you can see, we got a cache miss and the correct result.

And, if we call this on the second instance:

In [74]:
c2.area()

calculating area for self.r=1


3.141592653589793

We also get a cache miss.

When we invalidated the cache for `c1`, the cache entry that contained the area for a radius of `1`, we also happened to remove the entry that got used for `c2.area()` the first time around, hence the cache miss.

#### Other Types of Caches

If the LRU replacement policy is not what you are looking for, then you'll either have to implement something yourself from first principles, or you could make use of third party libraries which have already done the hard work for you.

For example this popular library:

[https://github.com/tkem/cachetools/](https://github.com/tkem/cachetools/)