Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems with multi-GPU MPI setup #1607

Closed
therooler opened this issue Oct 12, 2023 · 0 comments
Closed

Problems with multi-GPU MPI setup #1607

therooler opened this issue Oct 12, 2023 · 0 comments

Comments

@therooler
Copy link

Hey,

I'm trying to get NetKet to work with multiple GPUs and I'm running into some issues. All my code runs fine with MPI for multiple CPUs. However, I cannot get multiple GPUs to work.

Versions:

Python 3.9.6
CUDA=11.4
NetKet==3.8
jax==0.4.9
jaxlib==0.4.7+cuda11.cudnn82

I'm on a cluster, with 2 GPUs with 2 tasks. My MPI distribution is CUDA enabled:

ompi_info --parsable --all | grep mpi_built_with_cuda_support:value
mca:mpi:base:param:mpi_built_with_cuda_support:value:true

Running mpirun -n 2 python -m netket.tools.check_mpi gives

mpi4py_available             : True
mpi4jax_available            : True
available_cpus (rank 0)      : 1
n_nodes                      : 2
mpi4py | MPI version         : (3, 1)
mpi4py | MPI library_version : Open MPI v4.0.3, package: Open MPI ebuser@build-node.computecanada.ca Distribution, ident: 4.0.3, repo rev: v4.0.3, Mar 03, 2020

I tested MPI4JAX with XLA_PYTHON_CLIENT_PREALLOCATE=false MPI4JAX_DEBUG=1 mpirun -n 2 python test.py and this code

test.py

# importing automatically initializes MPI with 
import os
from mpi4py import MPI

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
# set only one visible device
os.environ["CUDA_VISIBLE_DEVICES"] = f"{rank}"

# force to use gpu
os.environ["JAX_PLATFORM_NAME"] = "gpu"

print('Hello from process %d of %d' % (rank, size))
import jax
import jax.numpy as jnp
import mpi4jax


@jax.jit
def foo(arr):
    arr = arr + rank
    arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
    return arr_sum


a = jnp.zeros((3, 3))
result = foo(a)

if rank == 0:
    print(result)

This works as intended:

Hello from process 0 of 2
Hello from process 1 of 2
r0 | ur5tvJRI | MPI_Allreduce with 9 items
r1 | g4gEXose | MPI_Allreduce with 9 items
r0 | ur5tvJRI | MPI_Allreduce done with code 0 (1.45e-01s)
r1 | g4gEXose | MPI_Allreduce done with code 0 (1.17e-03s)
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]

Note: without the XLA_PYTHON_CLIENT_PREALLOCATE flag, I get

r0 | VwUSW798 | MPI_Allreduce with 9 items
cudaStreamSynchronize failed with the following error:
        Error 2 cudaErrorMemoryAllocation: out of memory--------------------------------------------------------------------------
MPI_ABORT was invoked on rank 1 in communicator MPI_COMM_WORLD
with errorcode 0.

NOTE: invoking MPI_ABORT causes Open MPI to kill all MPI processes.
You may or may not see output from other processes, depending on
exactly when Open MPI kills them.
--------------------------------------------------------------------------
[warn] Epoll MOD(1) on fd 23 failed. Old events were 6; read change was 0 (none); write change was 2 (del); close change was 0 (none): Bad file descriptor

Now, as soon as I bring NetKet into the mix, things break. I currently have the MWE

test_imports.py

import os

try:
    from mpi4py import MPI

    rank = MPI.COMM_WORLD.Get_rank()

    # set only one visible device
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{rank}"
    # force to use gpu
    os.environ["JAX_PLATFORM_NAME"] = "gpu"

except ModuleNotFoundError:
    print("MPI disabled")
    rank = 0
import jax
import netket as nk

print(f"{rank} -> {jax.devices()}")
print(jax.devices())

L = 4
alpha = 3
bias = True
nsamples = 100
### 1.) VARIATIONAL WAVE FUNCTION ###
hilbert = nk.hilbert.Spin(0.5, L)
model = nk.models.RBM(param_dtype=complex, alpha=alpha, use_hidden_bias=bias)
sampler = nk.sampler.MetropolisLocal(hilbert, n_chains=10)
phi = nk.vqs.MCState(sampler=sampler, model=model, n_samples=nsamples)

Executing XLA_PYTHON_CLIENT_PREALLOCATE=false MPI4JAX_DEBUG=1 mpirun -n 2 python test_imports.py then gives the stacktrace:

0 -> [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]
1 -> [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]
r1 | IVZR2vbQ | MPI_Bcast -> 0 with 2 items
r0 | IZXUL4f9 | MPI_Bcast -> 0 with 2 items
[gra985:9629 :0:9629] Caught signal 11 (Segmentation fault: invalid permissions for mapped object at address 0x2b771ac00100)
==== backtrace (tid:   9629) ====
 0 0x0000000000020243 ucs_debug_print_backtrace()  /tmp/ebuser/avx2/UCX/1.8.0/gcccorecuda-2020.1.114/ucx-1.8.0/src/ucs/debug/debug.c:653
 1 0x00000000000130f0 __funlockfile()  :0
 2 0x0000000000159433 __nss_database_lookup()  /cvmfs/soft.computecanada.ca/gentoo/2020/usr/src/debug/sys-libs/glibc-2.30-r8/glibc-2.30/string/../sysdeps/x86_64/multiarch/memmove-vec-unaligned-erms.S:308
 3 0x0000000000015de9 uct_am_short_fill_data()  /tmp/ebuser/avx2/UCX/1.8.0/gcccorecuda-2020.1.114/ucx-1.8.0/src/uct/base/uct_iface.h:725
 4 0x0000000000015de9 uct_mm_ep_am_short()  /tmp/ebuser/avx2/UCX/1.8.0/gcccorecuda-2020.1.114/ucx-1.8.0/src/uct/sm/mm/base/mm_ep.c:333
 5 0x000000000003932b uct_ep_am_short()  /tmp/ebuser/avx2/UCX/1.8.0/gcccorecuda-2020.1.114/ucx-1.8.0/src/uct/api/uct.h:2424
 6 0x000000000003932b ucp_tag_eager_contig_short()  /tmp/ebuser/avx2/UCX/1.8.0/gcccorecuda-2020.1.114/ucx-1.8.0/src/ucp/tag/eager_snd.c:125
 7 0x000000000004b1ca ucp_request_try_send()  /tmp/ebuser/avx2/UCX/1.8.0/gcccorecuda-2020.1.114/ucx-1.8.0/src/ucp/core/ucp_request.inl:171
 8 0x000000000004b1ca ucp_request_send()  /tmp/ebuser/avx2/UCX/1.8.0/gcccorecuda-2020.1.114/ucx-1.8.0/src/ucp/core/ucp_request.inl:206
 9 0x000000000004b1ca ucp_tag_send_req()  /tmp/ebuser/avx2/UCX/1.8.0/gcccorecuda-2020.1.114/ucx-1.8.0/src/ucp/tag/tag_send.c:109
10 0x000000000004b1ca ucp_tag_send_nb()  /tmp/ebuser/avx2/UCX/1.8.0/gcccorecuda-2020.1.114/ucx-1.8.0/src/ucp/tag/tag_send.c:224
11 0x000000000000725f mca_pml_ucx_isend()  ???:0
12 0x00000000000ed8f1 ompi_coll_base_bcast_intra_generic()  ???:0
13 0x00000000000eea7f ompi_coll_base_bcast_intra_binomial()  ???:0
14 0x0000000000108d03 ompi_coll_tuned_bcast_intra_dec_fixed()  ???:0
15 0x000000000009f1fa MPI_Bcast()  ???:0
16 0x000000000000b223 __pyx_f_7mpi4jax_4_src_10xla_bridge_14mpi_xla_bridge_mpi_bcast()  /project/6019671/roeland/dwave/experiments/graham/mpi4jax/mpi4jax/_src/xla_bridge/mpi_xla_bridge.c:5260
17 0x000000000000cf19 __pyx_f_7mpi4jax_4_src_10xla_bridge_18mpi_xla_bridge_gpu_mpi_bcast_gpu()  /project/6019671/roeland/dwave/experiments/graham/mpi4jax/mpi4jax/_src/xla_bridge/mpi_xla_bridge_gpu.c:7245
18 0x000000000000cf19 __pyx_f_7mpi4jax_4_src_10xla_bridge_18mpi_xla_bridge_gpu_mpi_bcast_gpu()  /project/6019671/roeland/dwave/experiments/graham/mpi4jax/mpi4jax/_src/xla_bridge/mpi_xla_bridge_gpu.c:7245
19 0x00000000029cab9a xla::runtime::CustomCallHandler<(xla::runtime::CustomCall::RuntimeChecks)1, xla::runtime::CustomCall::FunctionWrapper<&xla::gpu::XlaCustomCallImpl>, xla::runtime::internal::UserData<xla::ServiceExecutableRunOptions const*>, xla::runtime::internal::UserData<xla::DebugOptions const*>, xla::runtime::CustomCall::RemainingArgs, xla::runtime::internal::Attr<std::basic_string_view<char, std::char_traits<char> > >, xla::runtime::internal::Attr<int>, xla::runtime::internal::Attr<std::basic_string_view<char, std::char_traits<char> > > >::call()  custom_call.cc:0
20 0x00000000029cb872 xla::gpu::XlaCustomCall()  custom_call.cc:0
=================================
[gra985:09629] *** Process received signal ***
[gra985:09629] Signal: Segmentation fault (11)
[gra985:09629] Signal code:  (-6)
[gra985:09629] Failing at address: 0x2f425f0000259d
[gra985:09629] [ 0] /cvmfs/soft.computecanada.ca/gentoo/2020/lib64/libpthread.so.0(+0x130f0)[0x2b763b88d0f0]
[gra985:09629] [ 1] /cvmfs/soft.computecanada.ca/gentoo/2020/lib64/libc.so.6(+0x159433)[0x2b763b9fa433]
[gra985:09629] [ 2] /cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/CUDA/cuda11.4/ucx/1.8.0/lib/libuct.so.0(uct_mm_ep_am_short+0xc9)[0x2b76434e1de9]
[gra985:09629] [ 3] /cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/CUDA/cuda11.4/ucx/1.8.0/lib/libucp.so.0(+0x3932b)[0x2b764347832b]
[gra985:09629] [ 4] /cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/CUDA/cuda11.4/ucx/1.8.0/lib/libucp.so.0(ucp_tag_send_nb+0x5da)[0x2b764348a1ca]
[gra985:09629] [ 5] /cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/CUDA/intel2020/cuda11.4/openmpi/4.0.3/lib/openmpi/mca_pml_ucx.so(mca_pml_ucx_isend+0xcf)[0x2b76425a325f]
[gra985:09629] [ 6] /cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/CUDA/intel2020/cuda11.4/openmpi/4.0.3/lib/libmpi.so.40(ompi_coll_base_bcast_intra_generic+0x181)[0x2b763c9ba8f1]
[gra985:09629] [ 7] /cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/CUDA/intel2020/cuda11.4/openmpi/4.0.3/lib/libmpi.so.40(ompi_coll_base_bcast_intra_binomial+0xaf)[0x2b763c9bba7f]
[gra985:09629] [ 8] /cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/CUDA/intel2020/cuda11.4/openmpi/4.0.3/lib/libmpi.so.40(ompi_coll_tuned_bcast_intra_dec_fixed+0xf3)[0x2b763c9d5d03]
[gra985:09629] [ 9] /cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/CUDA/intel2020/cuda11.4/openmpi/4.0.3/lib/libmpi.so.40(MPI_Bcast+0x6a)[0x2b763c96c1fa]
[gra985:09629] [10] /project/6019671/roeland/virtenvs/dwave/lib/python3.9/site-packages/mpi4jax-0.3.15.post2-py3.9-linux-x86_64.egg/mpi4jax/_src/xla_bridge/mpi_xla_bridge.cpython-39-x86_64-linux-gnu.so(+0xb223)[0x2b76917c1223]
[gra985:09629] [11] /project/6019671/roeland/virtenvs/dwave/lib/python3.9/site-packages/mpi4jax-0.3.15.post2-py3.9-linux-x86_64.egg/mpi4jax/_src/xla_bridge/mpi_xla_bridge_gpu.cpython-39-x86_64-linux-gnu.so(+0xcf19)[0x2b7691adff19]
[gra985:09629] [12] /project/6019671/roeland/virtenvs/dwave/lib/python3.9/site-packages/jaxlib/xla_extension.so(+0x29cab9a)[0x2b76769cab9a]
[gra985:09629] [13] /project/6019671/roeland/virtenvs/dwave/lib/python3.9/site-packages/jaxlib/xla_extension.so(+0x29cb872)[0x2b76769cb872]
[gra985:09629] [14] [0x2b765fffe63c]
[gra985:09629] *** End of error message ***
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpirun noticed that process rank 0 with PID 9629 on node gra985 exited on signal 11 (Segmentation fault).
--------------------------------------------------------------------------

I've tried a bunch of things but I'm running out of ideas. Do you have any suggestions what to look for here?

Best,
Roeland

@netket netket locked and limited conversation to collaborators Oct 12, 2023
@PhilipVinc PhilipVinc converted this issue into discussion #1608 Oct 12, 2023

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant