In [12]:
import threading
import torch
import time

Incorrect Result; Not protecting a critical section!

In [11]:
total = torch.tensor(0, dtype=torch.int32)

def inc(count):
    global total
    for i in range(count):
        total += 1

t1 = threading.Thread(target=inc, args=[1_000_000])
t2 = threading.Thread(target=inc, args=[1_000_000])
t1.start()
t2.start()
t1.join()
t2.join()
print(f"Incorrect total due to race condition on total between t1 and t2: {total}") # should be 2_000_000, but we get less

Incorrect total due to race condition on total between t1 and t2: 1999695


Examine Bytecode to see how we get race condition

In [8]:
import dis
dis.dis("total += i") # <-- looking at the bytecode, we can see there's potential for interleaving even in this 1 line

  0           0 RESUME                   0

  1           2 LOAD_NAME                0 (total)
              4 LOAD_NAME                1 (i)
              6 BINARY_OP               13 (+=)
             10 STORE_NAME               0 (total)
             12 LOAD_CONST               0 (None)
             14 RETURN_VALUE


Let's use Locks to protect our critical section, so only one thread can add to the total at any time

In [14]:
total = torch.tensor(0, dtype=torch.int32)
lock = threading.Lock()

def inc(count):
    global total
    for i in range(count):
        lock.acquire()   # is this locking too fine grained? Hurts performance...
        total += 1
        lock.release()

start = time.time()
t1 = threading.Thread(target=inc, args=[1_000_000])
t2 = threading.Thread(target=inc, args=[1_000_000])
t1.start()
t2.start()
t1.join()
t2.join()
end = time.time()
print(f"Correct total after protecting critical section with locks: {total}, took {end - start} seconds.") # 1_000_000 + 1_000_000 = 2_000_000

Correct total after protecting critical section with locks: 2000000, took 36.96721053123474 seconds.


We can make the locking more coarse grained

In [15]:
total = torch.tensor(0, dtype=torch.int32)
lock = threading.Lock()

def inc(count):
    global total
    lock.acquire()   # Do the whole loop, runs faster now
    for i in range(count):
        total += 1
    lock.release()

start = time.time()
t1 = threading.Thread(target=inc, args=[1_000_000])
t2 = threading.Thread(target=inc, args=[1_000_000])
t1.start()
t2.start()
t1.join()
t2.join()
end = time.time()
print(f"Correct total after protecting critical section with locks: {total}, took {end - start} seconds.") # 1_000_000 + 1_000_000 = 2_000_000

Correct total after protecting critical section with locks: 2000000, took 8.609662532806396 seconds.


Bank Account Example (Critical Sections; Deadlock)

In [22]:
bank_accounts = {"x": 25, "y": 100, "z": 200} # in dollars
lock = threading.Lock() # protects bank_accounts

def transfer(src, dst, amount):
    lock.acquire()
    success = False
    if bank_accounts[src] >= amount:
        bank_accounts[src] -= amount
        bank_accounts[dst] += amount
        success = True
    print("transferred" if success else "denied")
    lock.release()

#transfer("w", "x", 10) # gives exception because key doesn't exist in dict!
    
# the way the code is written, we acquire lock first and then run into exception so it never gets released
# therefore, no other threads can acquire the lock --> DEADLOCK

Using Python Lock Auto-Handling

In [None]:
bank_accounts = {"x": 25, "y": 100, "z": 200} # in dollars
lock = threading.Lock() # protects bank_accounts

def transfer(src, dst, amount):
    with lock: # acquires/locks automatically, and handles exceptions for us
        # lock.acquire() <-- not needed anymore
        success = False
        if bank_accounts[src] >= amount:
            bank_accounts[src] -= amount
            bank_accounts[dst] += amount
            success = True
        print("transferred" if success else "denied")
        # lock.release() <-- not needed anymore