# Jax pmap issue

References:

* [Pmap Cookbook via Google Colab](https://colab.research.google.com/github/google/jax/blob/master/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=4DYY4Yyhq8vG)

In [1]:
# dynamically add $NCCL_HOME/lib into current LD_LIBRARY_PATH
! echo "add $NCCL_HOME/lib to ${LD_LIBRARY_PATH}"
! export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$NCCL_HOME/lib

add /usr/local/nccl/lib to /usr/local/cuda/lib:/Developer/NVIDIA/CUDA-10.1/lib:/usr/local/cuda/extras/CUPTI/lib:/usr/local/opt/boost-python3/lib:/usr/local/opt/open-mpi/lib:/usr/local/Cellar/libomp/10.0.0/lib:/usr/local/Cellar/rdkit20210304/lib:/Users/llv23/opt/miniconda3/lib:/usr/local/lib


In [2]:
from jax import grad
import jax.numpy as jnp
from jax import random, pmap

@pmap
def f(x):
  y = jnp.sin(x)
  @pmap
  def g(z):
    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
  return grad(lambda w: jnp.sum(g(w)))(x)

x = jnp.ones((5000, 5000))
print(f(x))
# [[ 0.        , -0.7170853 ],
#  [-3.1085174 , -0.4824318 ],
#  [10.366636  , 13.135289  ],
#  [ 0.22163185, -0.52112055]]

print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726,  -1.6356447],
#  [  4.7572474,  11.606951 ],
#  [-98.524414 ,  42.76499  ],
#  [ -1.6007166,  -1.2568436]]

ValueError: compiling computation that requires 25000000 logical devices, but only 2 XLA devices are available (num_replicas=25000000, num_partitions=1)