### Python Parallel Processing Experiments

In [1]:
import jax 
import jax.numpy as jnp
import multiprocessing
multiprocessing.set_start_method('fork')   
import concurrent.futures
import os
import time
import asyncio

jax.config.update("jax_platform_name", "cpu") 

num_cpus = os.cpu_count()
print(f"Number of CPUs: {num_cpus}")

def f(n):
    rng = jax.random.key(n)
    for i in range(10):
        rng, sub_rng = jax.random.split(rng)
        A = jax.random.normal(key=sub_rng, shape=(6000,6000))
        A@A
    return True



Number of CPUs: 12


### Spawn multiple processes


In [2]:
start = time.time()
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus) as executor:
    results = list(executor.map(f, range(num_cpus)))
print(f"parallel process time: {time.time()-start:0.4}")

parallel process time: 21.07


### Spawn multiple threads

In [2]:
start = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpus) as executor:
    results = list(executor.map(f, range(num_cpus)))
print(f"parallel thread time: {time.time()-start:0.4}")

parallel thread time: 82.89


### Asyncio

In [2]:
start = time.time()
tasks = [f(n) for n in range(num_cpus)]
result = await asyncio.gather(*tasks)
print(f"asyncio time: {time.time()-start:0.4}")

asyncio time: 74.26
