# 装饰器

本节参考了博文: [《Finally understanding decorators in Python》](https://pouannes.github.io/blog/decorators/)

本文目录：

1. 装饰器基础
2. 一个小练习
3. 使用 `*args` 和 `**kwargs`
4. 高阶装饰器
5. `@functools.lru_cache()`


## 1. 装饰器基础

一个带装饰器的函数看起来是这样的：

In [1]:
def decorator_name(func):
    def f():
        rv = func()
        return rv
    return f

@decorator_name
def func_name():
    pass

在上面这段代码中，`@decorator_name` 就是装饰器。

顾名思义，装饰器是用来装饰东西的。用来装饰什么东西呢？

它是用来装饰**函数**的。从作用上来说，就是把函数捉出来“装饰”一番，然后放回去。

下面通过一些例子说明装饰器的用法和意义。

首先我们写一个 `add` 函数，其中 `y` 的默认值是 10。

In [2]:
def add(x, y = 10):
    return x + y

测试一下这个函数： 

In [3]:
def add(x, y = 10):
    return x + y

print('add(10)', add(10))
print('add(20, 30)', add(20, 30))
print('add("a", "b")', add("a", "b"))

add(10) 20
add(20, 30) 50
add("a", "b") ab


顺便把 `add` 函数的属性打印出来看看：

In [4]:
add.__name__

'add'

In [5]:
add.__module__

'__main__'

In [6]:
add.__defaults__

(10,)

In [7]:
add.__code__.co_varnames  # the variable names of the 'add' fucntion

('x', 'y')

In [8]:
from inspect import getsource
print(getsource(add))

def add(x, y = 10):
    return x + y



接下来，我们用 `time` 模块打印 `add` 函数的运行时间：

In [9]:
from time import time

def add(x, y = 10):
    return x + y

before = time()
print('add(10)', add(10))
after = time()
print('time taken: ', after - before)
defore = time()
print('add(20, 30)', add(20, 30))
after = time()
print('time taken: ', after - before)
before = time()
print('add("a", "b")', add("a", "b"))
after = time()
print('time taken: ', after - before)

add(10) 20
time taken:  0.00033783912658691406
add(20, 30) 50
time taken:  0.0005986690521240234
add("a", "b") ab
time taken:  0.00010180473327636719


通过阅读上面的代码，我们发现：每次运行，都要先记录一下开始和结束时间，然后 `print` 一下。

为了做这一件事，写了一堆代码。

显然，我的 Python 不可能这么麻烦。重复的代码可以用一个函数装起来，多次利用：

In [10]:
from time import time

def add(x, y = 10):
    before = time()
    rv = x + y
    after = time()
    print('time taken: ', after - before)
    return rv

print('add(10)', add(10))
print('add(20, 30)', add(20, 30))
print('add("a", "b")', add("a", "b"))

time taken:  9.5367431640625e-07
add(10) 20
time taken:  0.0
add(20, 30) 50
time taken:  0.0
add("a", "b") ab


通过上面这个函数，每次调用 `add` 函数的时候，都会 `print` 函数的运行时间。

到目前为止，这一切都运行的很完美。但如果我们现在有一个 `sub` 函数，然后同样需要打印函数的运行时间呢？

下面是一个 `sub` 函数：

In [11]:
def sub(x, y = 10):
    return x - y

最直接的方法当然就是像上面一样，对 `add` 函数进行改造。

但是如果需要打印时间的不止是 `add` 和 `sub`，还有其他很多函数呢？一个一个地改造函数，就十分麻烦了。

我的 Python 不可能这么麻烦！一定有更简单的方法来实现这个功能。

In [12]:
def add(x, y = 10):
    return x + y

def sub(x, y = 10):
    return x - y

def timer(func, x, y = 10):
    before = time()
    rv = func(x, y)
    after = time()
    print('time taken: ', after - before)
    return rv

print('add(10)', timer(add, 10))
print('sub(10)', timer(sub, 20))

time taken:  0.0
add(10) 20
time taken:  0.0
sub(10) 10


上面构造了一个函数 `timer`。

它将 `add`, `sub` 及 `add`, `sub` 的输入作为它的输入，把 `add`, `sub` 的输出作为它的输出。

这就相当于，函数 `timer` 用它自己替代了 `add` 和 `sub` 的功能。`add` 和 `sub` 在这里退化成了函数 `timer` 的参数。

这也不失为一种实现方法，但是这偏离了我们最初**改造函数**的目的。

我们希望经过改造后的函数仍然是一个函数，而非退化成其他函数的参数。

In [13]:
from time import time


def timer(func):
    def f(x, y = 10):
        before = time()
        rv = func(x, y)
        after = time()
        print('time taken: ', after - before)
        return rv
    return f

def add(x, y = 10):
    return x + y

def sub(x, y = 10):
    return x - y

add = timer(add)
sub = timer(sub)

print('add(10)', add(10))
print('add(10, 20)', add(10, 20))
print('sub(20)', sub(20))

time taken:  0.0
add(10) 20
time taken:  9.5367431640625e-07
add(10, 20) 30
time taken:  7.152557373046875e-07
sub(20) 10


上面这段代码的本质是利用函数 `timer` 对函数 `add` 进行改造。

函数 `timer` 返回的结果是一个新的经过改造的 `add` 函数。

上面这段代码已经完全实现了装饰器的功能。在文章开头的，带 `@` 形式的装饰器，无非是以上过程的简写。

下面尝试用带 `@` 形式写出装饰器 `timer`。为了避免污染变量，我们命名一个新装饰器 `ntimer`。

In [14]:
from time import time


def ntimer(func):
    def f(x, y = 10):
        before = time()
        rv = func(x, y)
        after = time()
        print('time taken: ', after - before)
        return rv
    return f

@ntimer
def add(x, y = 10):
    return x + y

print('add(10)', add(10))

time taken:  0.0
add(10) 20


## 2. 一个小练习

小尝试：

既然学会了写装饰器，我们就来 freestyle 一下。

尝试改造 `new_print` 函数，使每次使用它打印内容时，自动在内容前加上"You just typed: "。

In [15]:
def change_print(func):
    def f(x):
        rv = func('You just typed: ' + x)
        return rv
    return f

@change_print
def new_print(x):
    print(x)

In [16]:
new_print('a')

You just typed: a


## 3. 使用 `*args` 和 `**kwargs`

使用 `*args` 和 `**kwargs` 作为函数输入的参数和关键字参数。

In [17]:
# *args 和 **kwargs
from time import time

def ntimer(func):
    def f(*args, **kwargs):
        before = time()
        rv = func(*args, **kwargs)
        after = time()
        print('time taken: ', after - before)
        return rv
    return f

@ntimer
def add(x, y=10):
    return x + y

@ntimer
def sub(x, y=10):
    return x - y

print('add(10)', add(10))
print('add(20, 30)', add(20, 30))
print('sub(20, 5)', sub(20, 5))

time taken:  9.5367431640625e-07
add(10) 20
time taken:  9.5367431640625e-07
add(20, 30) 50
time taken:  1.1920928955078125e-06
sub(20, 5) 15


## 4. 高阶装饰器

还是原来的风格，还是原来的味道。本质是没有变的。只不过在外面多包了一层函数，多给了一个参数罢了。

称之为高阶装饰器实在是太给面子了。


In [18]:
# 高阶装饰器

def ntimes(n):
    def inner(f):
        def wrapper(*args, **kwargs):
            for _ in range(n):
                rv = f(*args, **kwargs)
            return rv
        return wrapper
    return inner

In [19]:
@ntimes(3)
def add(x, y):
    print(x + y)
    return x + y

rv = add(1, 2)

3
3
3



把外层函数去掉，取中间部分代码，就是上文中的低阶装饰器。

```python
def inner(f):
    def wrapper(*args, **kwargs):
        for _ in range(n):
            rv = f(*args, **kwargs)
        return rv
    return wrapper
```

## 5. `@functools.lru_cache()`

`@functools.lru_cache(maxsize=128, typed=False)` 该装饰器包装一个函数，使之能记忆最近使用的输入和对应的输出，从而达到节省时间的目的。

其 LRU 缓存最多记忆 maxsize 个记录。如果 maxsize 设为 None，则缓存大小无限制。如果 type 设为 True，则函数会区别对待不同数据类型的输入。例如，f(3) 和 f(3.0) 会被当作两个输入，并对应其各自的输出。

注意，此装饰器仅用于有稳定输入输出的函数。如果一个输入会对应不同的输出，则其输出的结果无法重用，也就无法使用此装饰器。比如 random 函数。

更多信息参见 [functools docs](https://docs.python.org/3.7/library/functools.html).

示例一：使用 LRU 缓存加速斐波拉契数列的计算。

In [20]:
from time import time
from functools import lru_cache


def timer(func):
    """用于统计运行时间的装饰器"""
    def f(*args, **kwargs):
        before = time()
        rv = func(*args, **kwargs)
        after = time()
        print('time taken: ', after - before)
        return rv
    return f

@timer
def main(num, maxsize):
    @lru_cache(maxsize)
    def fib(n):
        if n < 2:
            return n
        return fib(n-1) + fib(n-2)
    
    return fib(num), fib.cache_info()

print(main(30, None))
print(main(30, 0))

time taken:  6.723403930664062e-05
(832040, CacheInfo(hits=28, misses=31, maxsize=None, currsize=31))
time taken:  0.35408902168273926
(832040, CacheInfo(hits=0, misses=2692537, maxsize=0, currsize=0))


示例二：用于稳定网络内容的 LRU 缓存。

In [21]:
from functools import lru_cache
import urllib


@lru_cache(maxsize=32)
def get_pep(num):
    'Retrieve text of a Python Enhancement Proposal'
    resource = 'http://www.python.org/dev/peps/pep-%04d/' % num
    try:
        with urllib.request.urlopen(resource) as s:
            return s.read()
    except urllib.error.HTTPError:
        return 'Not Found'
    
for n in 8, 290, 308, 320, 8, 218, 320, 279, 289, 320, 9991:
    pep = get_pep(n)
    print(n, len(pep))

8 106914
290 59806
308 57012
320 49591
8 106914
218 46835
320 49591
279 48593
289 50922
320 49591
9991 9


In [22]:
get_pep.cache_info()

CacheInfo(hits=3, misses=8, maxsize=32, currsize=8)

作为拓展，下面摘录了 `@lru_cache()` 的部分源码。

```python
def update_wrapper(wrapper,
                   wrapped,
                   assigned = WRAPPER_ASSIGNMENTS,
                   updated = WRAPPER_UPDATES):
    """Update a wrapper function to look like the wrapped function

       wrapper is the function to be updated
       wrapped is the original function
       assigned is a tuple naming the attributes assigned directly
       from the wrapped function to the wrapper function (defaults to
       functools.WRAPPER_ASSIGNMENTS)
       updated is a tuple naming the attributes of the wrapper that
       are updated with the corresponding attribute from the wrapped
       function (defaults to functools.WRAPPER_UPDATES)
    """
    for attr in assigned:
        try:
            value = getattr(wrapped, attr)
        except AttributeError:
            pass
        else:
            setattr(wrapper, attr, value)
    for attr in updated:
        getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
    # Issue #17482: set __wrapped__ last so we don't inadvertently copy it
    # from the wrapped function when updating __dict__
    wrapper.__wrapped__ = wrapped
    # Return the wrapper so this can be used as a decorator via partial()
    return wrapper

def _make_key(args, kwds, typed,
             kwd_mark = (object(),),
             fasttypes = {int, str},
             tuple=tuple, type=type, len=len):
    """Make a cache key from optionally typed positional and keyword arguments

    The key is constructed in a way that is flat as possible rather than
    as a nested structure that would take more memory.

    If there is only a single argument and its data type is known to cache
    its hash value, then that argument is returned without a wrapper.  This
    saves space and improves lookup speed.

    """
    # All of code below relies on kwds preserving the order input by the user.
    # Formerly, we sorted() the kwds before looping.  The new way is *much*
    # faster; however, it means that f(x=1, y=2) will now be treated as a
    # distinct call from f(y=2, x=1) which will be cached separately.
    key = args
    if kwds:
        key += kwd_mark
        for item in kwds.items():
            key += item
    if typed:
        key += tuple(type(v) for v in args)
        if kwds:
            key += tuple(type(v) for v in kwds.values())
    elif len(key) == 1 and type(key[0]) in fasttypes:
        return key[0]
    return _HashedSeq(key)

def lru_cache(maxsize=128, typed=False):
    """Least-recently-used cache decorator.

    If *maxsize* is set to None, the LRU features are disabled and the cache
    can grow without bound.

    If *typed* is True, arguments of different types will be cached separately.
    For example, f(3.0) and f(3) will be treated as distinct calls with
    distinct results.

    Arguments to the cached function must be hashable.

    View the cache statistics named tuple (hits, misses, maxsize, currsize)
    with f.cache_info().  Clear the cache and statistics with f.cache_clear().
    Access the underlying function with f.__wrapped__.

    See:  http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used

    """

    # Users should only access the lru_cache through its public API:
    #       cache_info, cache_clear, and f.__wrapped__
    # The internals of the lru_cache are encapsulated for thread safety and
    # to allow the implementation to change (including a possible C version).

    # Early detection of an erroneous call to @lru_cache without any arguments
    # resulting in the inner function being passed to maxsize instead of an
    # integer or None.  Negative maxsize is treated as 0.
    if isinstance(maxsize, int):
        if maxsize < 0:
            maxsize = 0
    elif maxsize is not None:
        raise TypeError('Expected maxsize to be an integer or None')

    def decorating_function(user_function):
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        return update_wrapper(wrapper, user_function)

    return decorating_function

def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
    # Constants shared by all lru cache instances:
    sentinel = object()          # unique object used to signal cache misses
    make_key = _make_key         # build a key from the function arguments
    PREV, NEXT, KEY, RESULT = 0, 1, 2, 3   # names for the link fields

    cache = {}
    hits = misses = 0
    full = False
    cache_get = cache.get    # bound method to lookup a key or return None
    cache_len = cache.__len__  # get cache size without calling len()
    lock = RLock()           # because linkedlist updates aren't threadsafe
    root = []                # root of the circular doubly linked list
    root[:] = [root, root, None, None]     # initialize by pointing to self

    if maxsize == 0:

        def wrapper(*args, **kwds):
            # No caching -- just a statistics update
            nonlocal misses
            misses += 1
            result = user_function(*args, **kwds)
            return result

    elif maxsize is None:

        def wrapper(*args, **kwds):
            # Simple caching without ordering or size limit
            nonlocal hits, misses
            key = make_key(args, kwds, typed)
            result = cache_get(key, sentinel)
            if result is not sentinel:
                hits += 1
                return result
            misses += 1
            result = user_function(*args, **kwds)
            cache[key] = result
            return result

    else:

        def wrapper(*args, **kwds):
            # Size limited caching that tracks accesses by recency
            nonlocal root, hits, misses, full
            key = make_key(args, kwds, typed)
            with lock:
                link = cache_get(key)
                if link is not None:
                    # Move the link to the front of the circular queue
                    link_prev, link_next, _key, result = link
                    link_prev[NEXT] = link_next
                    link_next[PREV] = link_prev
                    last = root[PREV]
                    last[NEXT] = root[PREV] = link
                    link[PREV] = last
                    link[NEXT] = root
                    hits += 1
                    return result
                misses += 1
            result = user_function(*args, **kwds)
            with lock:
                if key in cache:
                    # Getting here means that this same key was added to the
                    # cache while the lock was released.  Since the link
                    # update is already done, we need only return the
                    # computed result and update the count of misses.
                    pass
                elif full:
                    # Use the old root to store the new key and result.
                    oldroot = root
                    oldroot[KEY] = key
                    oldroot[RESULT] = result
                    # Empty the oldest link and make it the new root.
                    # Keep a reference to the old key and old result to
                    # prevent their ref counts from going to zero during the
                    # update. That will prevent potentially arbitrary object
                    # clean-up code (i.e. __del__) from running while we're
                    # still adjusting the links.
                    root = oldroot[NEXT]
                    oldkey = root[KEY]
                    oldresult = root[RESULT]
                    root[KEY] = root[RESULT] = None
                    # Now update the cache dictionary.
                    del cache[oldkey]
                    # Save the potentially reentrant cache[key] assignment
                    # for last, after the root and links have been put in
                    # a consistent state.
                    cache[key] = oldroot
                else:
                    # Put result in a new link at the front of the queue.
                    last = root[PREV]
                    link = [last, root, key, result]
                    last[NEXT] = root[PREV] = cache[key] = link
                    # Use the cache_len bound method instead of the len() function
                    # which could potentially be wrapped in an lru_cache itself.
                    full = (cache_len() >= maxsize)
            return result

    def cache_info():
        """Report cache statistics"""
        with lock:
            return _CacheInfo(hits, misses, maxsize, cache_len())

    def cache_clear():
        """Clear the cache and cache statistics"""
        nonlocal hits, misses, full
        with lock:
            cache.clear()
            root[:] = [root, root, None, None]
            hits = misses = 0
            full = False

    wrapper.cache_info = cache_info
    wrapper.cache_clear = cache_clear
    return wrapper
```