## Context manager

The most basic **conect manager**:

In [13]:
with open('context_managers.ipynb') as f:
    pass

### How to write our own context manager

In [29]:
# simple contect manager:

# with ctx() as x:
#     pass

# under the hood this CM would look like this:

# x = ctx().__enter__()
# try:
#     pass
# finally:
#     x.__exit__()

In [32]:
from sqlite3 import connect

with connect('test_db') as conn:
    cur = conn.cursor()
    cur.execute('create table points(x int, y int)')
    cur.execute('insert into points (x, y) values(3, 3)')
    cur.execute('insert into points (x, y) values(2, 2)')
    
    for row in cur.execute('select x, y from points'):
        print(row)
    cur.execute('drop table points')

(3, 3)
(2, 2)


In [42]:
# creating a context manager
class Temptable:
    def __init__(self, cur):
        self.cur = cur
        
    def __enter__(self):
        print('__enter__')
        cur.execute('create table points(x int, y int)')
        
    def __exit__(self, *args):
        self.cur.execute('drop table points')
        print('__exit__')

In [43]:
with connect('test_db') as conn:
    cur = conn.cursor()
    with Temptable(cur):
        cur.execute('insert into points (x, y) values(3, 3)')
        cur.execute('insert into points (x, y) values(2, 2)')

        for row in cur.execute('select x, y from points'):
            print(row)

__enter__
(3, 3)
(2, 2)
__exit__


### Should __exit__ be called before __enter__?  -- Nope, thus we need generator for sequencing

In [45]:
# create generator
def temptable2(cur):
    cur.execute('create table points(x int, y int)')
    print('table created')
    
    yield
    
    cur.execute('drop table points')
    print('table removed')

# create context manager
class Contextmanager:
    def __init__(self, cur):
        self.cur = cur
    def __enter__(self):
        self.gen = temptable2(self.cur)
        next(self.gen)
    def __exit__(self, *args):
        next(self.gen, None)

#using nested context managers
with connect('test_db') as conn:
    cur = conn.cursor()
    with Contextmanager(cur):
        cur.execute('insert into points (x, y) values(3, 3)')
        cur.execute('insert into points (x, y) values(2, 2)')

        for row in cur.execute('select x, y from points'):
            print(row)

table created
(3, 3)
(2, 2)
table removed


### We could generalize it

In [51]:
# create generator
def temptable2(cur):
    cur.execute('create table points(x int, y int)')
    print('table created')
    
    yield
    
    cur.execute('drop table points')
    print('table removed')

# create context manager
class Contextmanager:
    def __init__(self, gen):
        self.gen = gen
    def __call__(self, *a, **kw):
        self.a, self.kw = a, kw
        return self
    def __enter__(self):
        self.gen.instance = self.gen(*self.a, *self.kw)
        next(self.gen.instance)
    def __exit__(self, *args):
        next(self.gen.instance, None)

#using nested context managers
with connect('test_db') as conn:
    cur = conn.cursor()
    with Contextmanager(temptable2)(cur): # this line is ugly
        cur.execute('insert into points (x, y) values(3, 3)')
        cur.execute('insert into points (x, y) values(2, 2)')

        for row in cur.execute('select x, y from points'):
            print(row)

table created
(3, 3)
(2, 2)
table removed


What we could do:

In [53]:
# create context manager
class Contextmanager:
    def __init__(self, gen):
        self.gen = gen
    def __call__(self, *a, **kw):
        self.a, self.kw = a, kw
        return self
    def __enter__(self):
        self.gen.instance = self.gen(*self.a, *self.kw)
        next(self.gen.instance)
    def __exit__(self, *args):
        next(self.gen.instance, None)

# create generator
def temptable2(cur):
    cur.execute('create table points(x int, y int)')
    print('table created')
    yield
    cur.execute('drop table points')
    print('table removed')

# this is actually a decorator
temptable2 = Contextmanager(temptable2) # added this

#using nested context managers
with connect('test_db') as conn:
    cur = conn.cursor()
    with temptable2(cur): # removed "contextmanager"
        cur.execute('insert into points (x, y) values(3, 3)')
        cur.execute('insert into points (x, y) values(2, 2)')

        for row in cur.execute('select x, y from points'):
            print(row)

table created
(3, 3)
(2, 2)
table removed


**So we could write it as a decorator**

In [54]:
# create context manager
class Contextmanager:
    def __init__(self, gen):
        self.gen = gen
    def __call__(self, *a, **kw):
        self.a, self.kw = a, kw
        return self
    def __enter__(self):
        self.gen.instance = self.gen(*self.a, *self.kw)
        next(self.gen.instance)
    def __exit__(self, *args):
        next(self.gen.instance, None)

# create generator
@Contextmanager
def temptable2(cur):
    cur.execute('create table points(x int, y int)')
    print('table created')
    yield
    cur.execute('drop table points')
    print('table removed')

#using nested context managers
with connect('test_db') as conn:
    cur = conn.cursor()
    with temptable2(cur): # removed "contextmanager"
        cur.execute('insert into points (x, y) values(3, 3)')
        cur.execute('insert into points (x, y) values(2, 2)')

        for row in cur.execute('select x, y from points'):
            print(row)

table created
(3, 3)
(2, 2)
table removed


**It turns out that we do not need to define contextmanager ourselves, its already done in libraries:**

In [56]:
from contextlib import contextmanager

# create generator, for completeness adding a try-finally statement
@contextmanager
def temptable2(cur):
    cur.execute('create table points(x int, y int)')
    print('table created')
    try:
        yield
    finally:
        cur.execute('drop table points')
        print('table removed')

#using nested context managers
with connect('test_db') as conn:
    cur = conn.cursor()
    with temptable2(cur): # removed "contextmanager"
        cur.execute('insert into points (x, y) values(3, 3)')
        cur.execute('insert into points (x, y) values(2, 2)')

        for row in cur.execute('select x, y from points'):
            print(row)

table created
(3, 3)
(2, 2)
table removed
