# Decorators and context managers

ASPP Reading 2016

Zbigniew Jędrzejewski-Szmek

https://beta.etherpad.org/p/aspp

In [85]:
import numpy as np

In [18]:
def logger(f):
    def wrapper(*args, **kwargs):
        print("function arguments", *args, **kwargs)
        return f(*args, **kwargs)
    return wrapper

In [19]:
@logger
def f1(x, y): return x, y

In [20]:
f1(1,2)

function arguments 1 2


(1, 2)


#### Exercise — wrapping a function to do something on every invocation

Write a decorator which prints the arguments
and the return value of the wrapped function.

>>> @logger
>>> def f(x, y):
...     return g(x) + g(y)

>>> @logger
>>> def g(x):
...     return 2 * x

>>> f(5, 6)
f is called with args [5, 6] kwargs {}
g is called with args [5] kwargs {}
g returns 10
g is called with args [6] kwargs {}
g returns 12
f returns 22


In [29]:
import time

def printtime(f):
    def wrapper(*args, **kwargs):
        begin = time.time()
        x = f(*args, **kwargs)
        tf = time.time()
        print(tf - begin)
        return x

    return wrapper


In [30]:
@printtime
def f1(x, y): return x, y

In [31]:
f1(10,2)

9.5367431640625e-07


(10, 2)

#### Exercise — keeping state in decorators

Write a decorator which prints a warning the first
time a given function is executed. This is a modification
of deprecate() from previous exercise.

@deprecate('do not use')
def f(): pass

>>> f()
f is deprecated, do not use
>>> f()
>>> f()

The trick is how to store the state!

In [None]:
def deprecated(msg):
    return logger

In [74]:
import warnings

def deprecate(msg):
    def deprecated(func1):
        '''
        This is a decorator which can be used to mark functions
        as deprecated. '''

        def wrapper(*args, **kwargs):
            print (msg)
            warnings.warn("Call to deprecated function {}.".format(func1.__name__),
                          category=DeprecationWarning)
            return func1(*args, **kwargs)
        return wrapper
    return deprecated

In [72]:
@deprecate('DO NOT USE')
def f(x): return x

In [73]:
print(f(1))

DO NOT USE
1




## Generators

"""
Exercise: listize decorator

When a function returns a list of results, we might need
to gather those results in a list:

def lucky_numbers(n):
    ans = []
    for i in range(n):
        if i % 7 != 0:
            continue
        if sum(int(digit) for digit in str(i)) % 3 != 0:
            continue
        ans.append(i)
    return ans

This looks much nicer when written as a generator.

① Convert lucky_numbers to be a generator.

② Write a 'listize' decorator which gathers th results from a
generator and returns a list and use it to wrap the new lucky_numbers().

Subexercise: ③ Write an 'arrayize' decorator which returns the results
in a numpy array instead of a list.

>>> @listize
... def f():
...     yield 1
...     yield 2
>>> f()
[1, 2]
"""

def listize():
    ...

In [64]:
def lucky_numbers_li(n):
    '''generate lucky numbers as a list'''
    ans = []
    for i in range(n):
        if i % 7 != 0:
            continue
        if sum(int(digit) for digit in str(i)) % 3 != 0:
            continue
        ans.append(i)
    return ans

def lucky_numbers(n):
    '''generate lucky numbers generator'''

    ans = []
    for i in range(n):
        if i % 7 != 0: # remove all not exactly divisible by 7
            continue
        if sum(int(digit) for digit in str(i)) % 3 != 0: # rm all elements whose sum of the digits is not div by 3
            continue
        yield i


In [70]:
x = lucky_numbers(390)
for item in x:
    print(item)

In [82]:
def listize(func1):
    '''transforms a generator into a list'''
    def wrapper(*args, **kwargs):
        return list(func1(*args, **kwargs))
    return wrapper

def arrayize(func1):
    '''transforms a list into a numpy array'''
    def wrapper(*args, **kwargs):
        return np.array((func1(*args, **kwargs)))
    return wrapper

In [88]:
@arrayize
@listize
def lucky_numbers(n):
    ans = []
    for i in range(n):
        if i % 7 != 0: # remove all not exactly divisible by 7
            continue
        if sum(int(digit) for digit in str(i)) % 3 != 0: # rm all elements whose sum of the digits is not div by 3
            continue
        yield i

In [89]:

lucky_numbers(100)

array([ 0, 21, 42, 63, 84])

## Context managers


#### Exercise — printtime context manager

We had this:
>>> t = time.time()
>>> ans = do_calculations()
>>> t2 = time.time()
>>> print('calculations took {} s'.format(t2-t))
calculations took 3.4 s

Implement this:
>>> with printtime_cm():
...     time.sleep(3)
calculations took 3.40001 s


In [90]:
import contextlib

In [96]:
@contextlib.contextmanager
def printtime():
    t1 = time.time()
    yield
    t2 = time.time()
    print(t2-t1, 's')
    

In [97]:
with printtime():
    [_ for _ in range(100000)]

0.005064725875854492 s


#### Exercise: matplotlib!

Write a context manager which gives you a matplotlib figure object,
and either saves the plot to a file or pops it up on screen,
depending on a *global* parameter SAVEFIGS (in a real program
this parameter would be settable by a commandline options).



In [3]:
import matplotlib.pyplot as plt
import contextlib

In [12]:
@contextlib.contextmanager
def save_or_plot():
    fig = plt.figure()
    yield fig
    if SAVEFIGS:
        fig.savefig(SAVEFIGS)
    else:
        plt.show()

In [13]:
SAVEFIGS = None#'/home/student/Desktop/testplt.png'

In [14]:
with save_or_plot() as f:
    f.add_subplot(111).plot([0, 3, 2, 5])