# Decorators

A decorator, sometimes called a wrapper, is a function that takes another function as its input and outputs a modified function. Here is an example of a decorator that modifies a numeric function to return the function value plus one:

In [1]:
def plus_one_decorator(func):
    def new_func(n):
        return func(n) + 1
    return new_func

To see it in action let's use it on a simple function:

In [2]:
def square(n):
    return n * n

You can use the decorator like a normal function,

In [3]:
square_output = square(3)
print('Original function output: {}'.format(square_output))

square_plus_one = plus_one_decorator(square)
square_plus_one_output = square_plus_one(3)
print('Modified function output: {}'.format(square_plus_one_output))

Original function output: 9
Modified function output: 10


but the proper way to use a decorator is like this:

In [4]:
@plus_one_decorator
def square_plus_one_decorated(n):
    return n * n

square_plus_one_decorated_output = square_plus_one_decorated(3)
print('Decorated function output: {}'.format(square_plus_one_decorated_output))

Decorated function output: 10


This way, you don't need an intermediate function taking up identifier space.

Here is an example of a decorator that is actually useful:

In [5]:
# Create a caching decorator
def cache(func):
    def new_func(*args):
        if args in new_func.cache: # If the arguments passed in are already in the cache, don't compute it, just return what's already in the cache
            return new_func.cache[args]
        out = func(*args)
        new_func.cache[args] = out # If the arguments are not already in the cache, add them to the cache
        return out
        
    new_func.cache = {} # Create the cache for the modified function
    return new_func

And a recursive function to compute the Stirling numbers of the first kind

In [6]:
def stirling(n, k):
    if k == n:
        return 1
    if k == 0:
        return 0
    return (1 - n) * stirling(n - 1, k) + stirling(n - 1, k - 1)

print('s(15, 7) = {}'.format(stirling(15, 7)))

s(15, 7) = 14409322928


Notice how long it takes if we call it many times

In [7]:
%timeit -n 1000 stirling(15, 7)

1.75 ms ± 4.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


However, if you use the cache decorator with the same function body, the first call will take just as long, but any subsequent calls with the same arguments (including recursive calls) will be instantaneous

In [8]:
@cache
def stirling_cache(n, k):
    if k == n:
        return 1
    if k == 0:
        return 0
    return (1 - n) * stirling_cache(n - 1, k) + stirling_cache(n - 1, k - 1)

print('s(15, 7) = {}'.format(stirling_cache(15, 7)))
%timeit -n 1000 stirling_cache(15, 7)

s(15, 7) = 14409322928
246 ns ± 21.8 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In fact a similar decorator is already in the standard _functools_ library

In [9]:
import functools

@functools.lru_cache
def stirling_fcache(n, k):
    if k == n:
        return 1
    if k == 0:
        return 0
    return (1 - n) * stirling_fcache(n - 1, k) + stirling_fcache(n - 1, k - 1)

print('s(15, 7) = {}'.format(stirling_fcache(15, 7)))
%timeit -n 1000 stirling_fcache(15, 7)

s(15, 7) = 14409322928
89.4 ns ± 3.24 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Decorator factories

A decorator factory is a function that returns a decorator. As an example let's generalize the plus one decorator in the previous section to add any number to the function output:

In [10]:
def plus_k_decorator_factory(k):
    def plus_k_decorator(func):
        def new_func(n):
            return func(n) + k
        return new_func
    return plus_k_decorator

You can use this decorator factory like this:

In [11]:
@plus_k_decorator_factory(10)
def square_plus_ten(n):
    return n * n

square_plus_ten_out = square_plus_ten(3)
print('Decorated function output: {}'.format(square_plus_ten_out))

Decorated function output: 19


The _functools.lru_cache_ decorator that I mentioned in the previous section can also be used as a decorator factory. One potential issue with caching function calls is that the cache can take up a lot of space in memory. We can mitigate this problem by deleting entries in the cache that haven't been called in a while. This is known as a Least Recently Used (LRU) cache. By inputting an argument into _functools.lru_cache_, we can set the maximum size of the cache before it starts to delete the least recently used entries.

In [12]:
@functools.lru_cache(1000) # This will allow up to 1000 entries in the cache before it starts to delete them
def stirling_lru_cache(n, k):
    if k == n:
        return 1
    if k == 0:
        return 0
    return (1 - n) * stirling_lru_cache(n - 1, k) + stirling_lru_cache(n - 1, k - 1)

print('s(15, 7) = {}'.format(stirling_lru_cache(15, 7)))

s(15, 7) = 14409322928
