In [None]:
### 装饰器
- 是什么: 对某个功能(函数)的额外包装(wrapper)
- 应用场景:日志、缓存
- 函数装饰器:
    - @:语法糖
    - 带有参数的装饰器: 在对应装饰器函数加上参数;多参数使用*args(任意数量的位置参数)和**kwargs(任意数量的关键字参数)
    - 带有自定义参数:定义在最外层
    - 被装饰以后,元信息会改变--在装饰器函数上适应内置@functools.wrap,保留原函数的元信息--作用是将原函数的元信息拷贝到装饰器函数里
- 类装饰器:
    - 主要依赖函数__call__(),调用类的实例时,__call__()会被执行一次
- 装饰器潜嵌套:
    - 执行顺序:从上到下

In [52]:
# 朴素的装饰器
def my_decorator(func):
    def wrapper():
        print('wrapper of decorator')
        func()
    return wrapper

def greet():
    print('hello world')

greet = my_decorator(greet)
greet()

wrapper of decorator
hello world


In [53]:
# 优雅的装饰器demo
def my_decorator(func):
    def wrapper():
        print('wrapper of decorator')
        func()
    return wrapper

@my_decorator
def greet():
    print('hello world')

greet()

wrapper of decorator
hello world


In [80]:
def my_decorator(func):
    def wrapper(message):
        print('wrapper of decorator')
        func(message)
    return wrapper

@my_decorator
def greet(message):
    print(message)

greet('Hello World')

wrapper of decorator
Hello World


In [99]:
def my_decorator(func):
    def wrapper(*args, **kwargs):
        print('wrapper of decorator {}'.format(kwargs['message']))
        func(*args, **kwargs)
    return wrapper

@my_decorator
def greet(message):
    print(message)

greet(message='Hello World')

wrapper of decorator Hello World
Hello World


In [173]:
import functools
def repeat(num):
    def my_decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            for i in range(num):
                print('wrapper of decorator')
                func(*args, **kwargs)
        return wrapper
    return my_decorator

@repeat(4)
def greet(message):
    print(message)

greet("Hello")
print(greet.__name__)

wrapper of decorator
Hello
wrapper of decorator
Hello
wrapper of decorator
Hello
wrapper of decorator
Hello
greet


In [259]:
#类装饰器demo
class Count:
    def __init__(self, func):
        self.func = func
        self.num_calls = 0
    def __call__(self, *args, **kwargs):
        self.num_calls += 1
        print('num of calls is: {}'.format(self.num_calls))
        return self.func(*args, **kwargs)

@Count
def example():
    print("example")

example()
example()

num of calls is: 1
num of calls is: 2


In [325]:
import functools

def my_decorator1(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        print('execute decorator1')
        func(*args, **kwargs)
    return wrapper

def my_decorator2(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        print('execute decorator2')
        func(*args, **kwargs)
    return wrapper

@my_decorator1
@my_decorator2
def greet(message):
    print(message)

greet("Hello")

execute decorator1
execute decorator2
Hello


### 典型应用demo

In [None]:
# 身份认证
import functools

def authenticate(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        request = args[0]
        if check_user_logged_in(request): # 如果用户处于登录状态
            return func(*args, **kwargs) # 执行函数post_comment() 
        else:
            raise Exception('Authentication failed')
    return wrapper
    
@authenticate
def post_comment(request, ...)
    ...
 

In [None]:
# 日志
import time
import functools

def log_execution_time(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        res = func(*args, **kwargs)
        end = time.perf_counter()
        print('{} took {} ms'.format(func.__name__, (end - start) * 1000))
        return res
    return wrapper
    
@log_execution_time
def calculate_similarity(items):
    ...

In [None]:
# 输入合理性检查
import functools

def validation_check(input):
    @functools.wraps(func)
    def wrapper(*args, **kwargs): 
        if check_params(*args, **kwargs):
            return func(*args, **kwargs)
        else:
            raise Exception('Parameter validation failed')
    
@validation_check
def neural_network_training(param1, param2, ...):
    ...

In [377]:
# 缓存
from functools import lru_cache 
# 自动缓存函数的返回值,避免重复计算

@lru_cache
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n-1) + fibonacci(n-2)

print(fibonacci(30))

832040
