You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to get NetKet to work with multiple GPUs and I'm running into some issues. All my code runs fine with MPI for multiple CPUs. However, I cannot get multiple GPUs to work.
I tested MPI4JAX with XLA_PYTHON_CLIENT_PREALLOCATE=false MPI4JAX_DEBUG=1 mpirun -n 2 python test.py and this code
test.py
# importing automatically initializes MPI with importosfrommpi4pyimportMPIcomm=MPI.COMM_WORLDsize=comm.Get_size()
rank=comm.Get_rank()
# set only one visible deviceos.environ["CUDA_VISIBLE_DEVICES"] =f"{rank}"# force to use gpuos.environ["JAX_PLATFORM_NAME"] ="gpu"print('Hello from process %d of %d'% (rank, size))
importjaximportjax.numpyasjnpimportmpi4jax@jax.jitdeffoo(arr):
arr=arr+rankarr_sum, _=mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
returnarr_suma=jnp.zeros((3, 3))
result=foo(a)
ifrank==0:
print(result)
This works as intended:
Hello from process 0 of 2
Hello from process 1 of 2
r0 | ur5tvJRI | MPI_Allreduce with 9 items
r1 | g4gEXose | MPI_Allreduce with 9 items
r0 | ur5tvJRI | MPI_Allreduce done with code 0 (1.45e-01s)
r1 | g4gEXose | MPI_Allreduce done with code 0 (1.17e-03s)
[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]]
Note: without the XLA_PYTHON_CLIENT_PREALLOCATE flag, I get
r0 | VwUSW798 | MPI_Allreduce with 9 items
cudaStreamSynchronize failed with the following error:
Error 2 cudaErrorMemoryAllocation: out of memory--------------------------------------------------------------------------
MPI_ABORT was invoked on rank 1 in communicator MPI_COMM_WORLD
with errorcode 0.
NOTE: invoking MPI_ABORT causes Open MPI to kill all MPI processes.
You may or may not see output from other processes, depending on
exactly when Open MPI kills them.
--------------------------------------------------------------------------
[warn] Epoll MOD(1) on fd 23 failed. Old events were 6; read change was 0 (none); write change was 2 (del); close change was 0 (none): Bad file descriptor
Now, as soon as I bring NetKet into the mix, things break. I currently have the MWE
test_imports.py
importostry:
frommpi4pyimportMPIrank=MPI.COMM_WORLD.Get_rank()
# set only one visible deviceos.environ["CUDA_VISIBLE_DEVICES"] =f"{rank}"# force to use gpuos.environ["JAX_PLATFORM_NAME"] ="gpu"exceptModuleNotFoundError:
print("MPI disabled")
rank=0importjaximportnetketasnkprint(f"{rank} -> {jax.devices()}")
print(jax.devices())
L=4alpha=3bias=Truensamples=100### 1.) VARIATIONAL WAVE FUNCTION ###hilbert=nk.hilbert.Spin(0.5, L)
model=nk.models.RBM(param_dtype=complex, alpha=alpha, use_hidden_bias=bias)
sampler=nk.sampler.MetropolisLocal(hilbert, n_chains=10)
phi=nk.vqs.MCState(sampler=sampler, model=model, n_samples=nsamples)
Executing XLA_PYTHON_CLIENT_PREALLOCATE=false MPI4JAX_DEBUG=1 mpirun -n 2 python test_imports.py then gives the stacktrace:
Hey,
I'm trying to get NetKet to work with multiple GPUs and I'm running into some issues. All my code runs fine with MPI for multiple CPUs. However, I cannot get multiple GPUs to work.
Versions:
Python 3.9.6
CUDA=11.4
NetKet==3.8
jax==0.4.9
jaxlib==0.4.7+cuda11.cudnn82
I'm on a cluster, with 2 GPUs with 2 tasks. My MPI distribution is CUDA enabled:
ompi_info --parsable --all | grep mpi_built_with_cuda_support:value mca:mpi:base:param:mpi_built_with_cuda_support:value:true
Running
mpirun -n 2 python -m netket.tools.check_mpi
givesI tested MPI4JAX with
XLA_PYTHON_CLIENT_PREALLOCATE=false MPI4JAX_DEBUG=1 mpirun -n 2 python test.py
and this codetest.py
This works as intended:
Note: without the
XLA_PYTHON_CLIENT_PREALLOCATE flag
, I getNow, as soon as I bring NetKet into the mix, things break. I currently have the MWE
test_imports.py
Executing
XLA_PYTHON_CLIENT_PREALLOCATE=false MPI4JAX_DEBUG=1 mpirun -n 2 python test_imports.py
then gives the stacktrace:I've tried a bunch of things but I'm running out of ideas. Do you have any suggestions what to look for here?
Best,
Roeland
The text was updated successfully, but these errors were encountered: