## 임계 영역 락

In [3]:
import threading

class SharedCounter:
    def __init__(self, initial_value = 0):
        self._value = initial_value
        self._value_lock = threading.Lock()

    def incr(self,delta=1):
        with self._value_lock:
            self._value += delta

    def decr(self,delta=1):
        with self._value_lock:
            self._value -= delta

def test(c):
    for n in range(1000000):
        c.incr()
    for n in range(1000000):
        c.decr()

if __name__ == '__main__':
    c = SharedCounter()
    t1 = threading.Thread(target=test, args=(c,))
    t2 = threading.Thread(target=test, args=(c,))
    t3 = threading.Thread(target=test, args=(c,))
    t1.start()
    t2.start()
    t3.start()
    print('Running test')
    t1.join()
    t2.join()
    t3.join()
    
    assert c._value == 0
    print('Looks good!')

Running test
Looks good!


## 락킹으로 데드락 피하기

In [4]:
import threading
from contextlib import contextmanager

_local = threading.local()

@contextmanager
def acquire(*locks):
    locks = sorted(locks, key=lambda x: id(x))   

    acquired = getattr(_local, 'acquired',[])
    if acquired and max(id(lock) for lock in acquired) >= id(locks[0]):
        raise RuntimeError('Lock Order Violation')

    acquired.extend(locks)
    _local.acquired = acquired
    try:
        for lock in locks:
            lock.acquire()
        yield
    finally:
        for lock in reversed(locks):
            lock.release()
        del acquired[-len(locks):]

In [None]:
x_lock = threading.Lock()
y_lock = threading.Lock()

def thread_1():
    while True:
        with acquire(x_lock, y_lock):
            print("Thread-1")

def thread_2():
    while True:
        with acquire(y_lock, x_lock):
            print("Thread-2")

input('This program runs forever. Press [return] to start, Ctrl-C to exit')

t1 = threading.Thread(target=thread_1)
t1.daemon = True
t1.start()

t2 = threading.Thread(target=thread_2)
t2.daemon = True
t2.start()

import time
while True:
    time.sleep(1)

In [None]:
import threading
import time

x_lock = threading.Lock()
y_lock = threading.Lock()

def thread_1():
    while True:
        with acquire(x_lock):
            with acquire(y_lock):
                print("Thread-1")
                time.sleep(1)

def thread_2():
    while True:
        with acquire(y_lock):
            with acquire(x_lock):
                print("Thread-2")
                time.sleep(1)

input('This program crashes with an exception. Press [return] to start')

t1 = threading.Thread(target=thread_1)
t1.daemon = True
t1.start()

t2 = threading.Thread(target=thread_2)
t2.daemon = True
t2.start()

time.sleep(5)

In [None]:
import threading

def philosopher(left, right):
    while True:
        with acquire(left,right):
             print(threading.currentThread(), 'eating')

NSTICKS = 5
chopsticks = [threading.Lock() for n in range(NSTICKS)]

for n in range(NSTICKS):
    t = threading.Thread(target=philosopher,
                         args=(chopsticks[n],chopsticks[(n+1) % NSTICKS]))
    t.daemon = True
    t.start()

import time
while True:
    time.sleep(1)

## 특정 스레드 용 상태 저장

In [None]:
from socket import socket, AF_INET, SOCK_STREAM
import threading

class LazyConnection:
    def __init__(self, address, family=AF_INET, type=SOCK_STREAM):
        self.address = address
        self.family = AF_INET
        self.type = SOCK_STREAM
        self.local = threading.local()

    def __enter__(self):
        if hasattr(self.local, 'sock'):
            raise RuntimeError('Already connected')
        self.local.sock = socket(self.family, self.type)
        self.local.sock.connect(self.address)
        return self.local.sock

    def __exit__(self, exc_ty, exc_val, tb):
        self.local.sock.close()
        del self.local.sock

def test(conn):
    from functools import partial

    with conn as s:
        s.send(b'GET /index.html HTTP/1.0\r\n')
        s.send(b'Host: www.python.org\r\n')
        s.send(b'\r\n')
        resp = b''.join(iter(partial(s.recv, 8192), b''))

    print('Got {} bytes'.format(len(resp)))

if __name__ == '__main__':
    conn = LazyConnection(('www.python.org', 80))

    t1 = threading.Thread(target=test, args=(conn,))
    t2 = threading.Thread(target=test, args=(conn,))
    t1.start()
    t2.start()
    t1.join()
    t2.join()

In [None]:
from socket import socket, AF_INET, SOCK_STREAM
import threading

class LazyConnection:
    def __init__(self, address, family=AF_INET, type=SOCK_STREAM):
        self.address = address
        self.family = AF_INET
        self.type = SOCK_STREAM
        self.local = threading.local()

    def __enter__(self):
        sock = socket(self.family, self.type)
        sock.connect(self.address)
        if not hasattr(self.local, 'connections'):
            self.local.connections = []
        self.local.connections.append(sock)
        return sock
                
    def __exit__(self, exc_ty, exc_val, tb):
        self.local.connections.pop().close()

def test(conn):
    from functools import partial

    with conn as s:
        s.send(b'GET /index.html HTTP/1.0\r\n')
        s.send(b'Host: www.python.org\r\n')
        s.send(b'\r\n')
        resp = b''.join(iter(partial(s.recv, 8192), b''))

    print('Got {} bytes'.format(len(resp)))

    with conn as s1, conn as s2:
        s1.send(b'GET /downloads HTTP/1.0\r\n')
        s2.send(b'GET /index.html HTTP/1.0\r\n')
        s1.send(b'Host: www.python.org\r\n')
        s2.send(b'Host: www.python.org\r\n')
        s1.send(b'\r\n')
        s2.send(b'\r\n')
        resp1 = b''.join(iter(partial(s1.recv, 8192), b''))
        resp2 = b''.join(iter(partial(s2.recv, 8192), b''))

    print('resp1 got {} bytes'.format(len(resp1)))
    print('resp2 got {} bytes'.format(len(resp2)))

if __name__ == '__main__':

    conn = LazyConnection(('www.python.org', 80))
    t1 = threading.Thread(target=test, args=(conn,))
    t2 = threading.Thread(target=test, args=(conn,))
    t3 = threading.Thread(target=test, args=(conn,))
    t1.start()
    t2.start()
    t3.start()
    t1.join()
    t2.join()
    t3.join()
