# Jax pmap issue

References:

* [Pmap Cookbook via Google Colab](https://colab.research.google.com/github/google/jax/blob/master/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=4DYY4Yyhq8vG)

Prepare:

* add /usr/local/nccl/lib into LD_LIBRARY_PATH before executing notebook by

```bash
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$NCCL_HOME/lib
```

In [1]:
import os
print(os.environ['LD_LIBRARY_PATH'])
# os.environ['LD_LIBRARY_PATH'] += f":{os.environ['NCCL_HOME']}/lib"

/usr/local/cuda/lib:/Developer/NVIDIA/CUDA-10.1/lib:/usr/local/cuda/extras/CUPTI/lib:/usr/local/opt/boost-python3/lib:/usr/local/opt/open-mpi/lib:/usr/local/Cellar/libomp/10.0.0/lib:/usr/local/Cellar/rdkit20210304/lib:/Users/llv23/opt/miniconda3/lib:/usr/local/lib:/usr/local/nccl/lib


In [2]:
import jax
n_devices = jax.local_device_count() 

In [3]:
jax.devices("gpu")

[GpuDevice(id=0), GpuDevice(id=1)]

In [4]:
jax.devices("cpu")

[CpuDevice(id=0)]

In [5]:
from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), n_devices)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1608742 1.230151 ]

[1.1608742 1.230151 ]


In [6]:
from jax import random, pmap
import jax.numpy as jnp

from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')

print(normalize(jnp.arange(float(n_devices))))
# prints [0. 1.]

[0. 1.]


In [7]:
from functools import partial
from jax import grad
import jax.numpy as jnp
from jax import random, pmap

@partial(pmap, axis_size=2, devices=jax.devices()[:n_devices])
def f(x):
    y = jnp.sin(x)
    def g(z):
        return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
    return grad(lambda w: jnp.sum(g(w)))(x)

x = jnp.ones((n_devices, 1))
print(f(x))
# [[-0.7170831]
#  [-0.7170831]]

print(grad(lambda x: jnp.sum(f(x)))(x))
# [[-1.6356444]
#  [-1.6356444]]

[[-0.7170831]
 [-0.7170831]]
[[-1.6356444]
 [-1.6356444]]


In [8]:
import jax
n_devices = jax.local_device_count() 

from functools import partial
from jax import grad
import jax.numpy as jnp
from jax import random, pmap

@pmap
def f(x):
    y = jnp.sin(x)
    @pmap
    def g(z):
        return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
    return grad(lambda w: jnp.sum(g(w)))(x)

x = jnp.ones((n_devices, 1))
print(f(x))
# [[-0.7170831]
#  [-0.7170831]]

print(grad(lambda x: jnp.sum(f(x)))(x))
# [[-1.6356444]
#  [-1.6356444]]

[[-0.7170831]
 [-0.7170831]]
[[-1.6356444]
 [-1.6356444]]
