Skip to content

Commit

Permalink
Add some 'wrapped' MPI operations that are needed to implement MinSR …
Browse files Browse the repository at this point in the history
…efficiently. (#1621)
  • Loading branch information
PhilipVinc committed Oct 23, 2023
1 parent fac0635 commit 5d8bad1
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 3 deletions.
4 changes: 4 additions & 0 deletions netket/utils/mpi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,8 @@
mpi_max_jax,
mpi_mean_jax,
mpi_sum_jax,
mpi_gather_jax,
mpi_alltoall_jax,
mpi_reduce_sum_jax,
mpi_scatter_jax,
)
45 changes: 42 additions & 3 deletions netket/utils/mpi/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import jax.numpy as jnp

from .mpi import n_nodes, MPI, MPI_py_comm, MPI_jax_comm

Expand Down Expand Up @@ -270,17 +271,55 @@ def mpi_bcast_jax(x, *, token=None, root, comm=MPI_jax_comm):
return mpi4jax.bcast(x, token=token, root=root, comm=comm)


def mpi_allgather(x, *, comm=MPI_py_comm):
def mpi_allgather(x, *, token=None, comm=MPI_py_comm):
if n_nodes == 1:
return x
return x, token
else:
return comm.allgather(x)


def mpi_gather_jax(x, *, token=None, root: int = 0, comm=MPI_jax_comm):
if n_nodes == 1:
return jnp.expand_dims(x, 0), token
else:
import mpi4jax

return mpi4jax.gather(x, token=token, root=root)


def mpi_allgather_jax(x, *, token=None, comm=MPI_jax_comm):
if n_nodes == 1:
return x, token
return jnp.expand_dims(x, 0), token
else:
import mpi4jax

return mpi4jax.allgather(x, token=token, comm=comm)


def mpi_scatter_jax(x, *, token=None, root: int = 0, comm=MPI_jax_comm):
if n_nodes == 1:
if x.shape[0] != 1:
raise ValueError("Scatter input must have shape (nproc, ...)")
return x[0], token
else:
import mpi4jax

return mpi4jax.scatter(x, root=root, token=token)


def mpi_alltoall_jax(x, *, token=None, comm=MPI_jax_comm):
if n_nodes == 1:
return x, token
else:
import mpi4jax

return mpi4jax.alltoall(x, token=token)


def mpi_reduce_sum_jax(x, *, token=None, root: int = 0, comm=MPI_jax_comm):
if n_nodes == 1:
return x, token
else:
import mpi4jax

return mpi4jax.reduce(x, op=MPI.SUM, root=root, token=token)
74 changes: 74 additions & 0 deletions test/mpi/test_mpi_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import pytest

import jax
import jax.numpy as jnp

from netket.utils import mpi


def approx(data):
return pytest.approx(data, abs=1.0e-6, rel=1.0e-5)


def test_gather_jax():
rank = mpi.rank
size = mpi.n_nodes

arr = jnp.ones((3, 2)) * rank

res = jax.jit(lambda x: mpi.mpi_gather_jax(x, root=0)[0])(arr)
if rank == 0:
for p in range(size):
np.testing.assert_allclose(res[p], jnp.ones((3, 2)) * p)
else:
np.testing.assert_allclose(res, arr)


def test_allgather_jax():
rank = mpi.rank
size = mpi.n_nodes

arr = jnp.ones((3, 2)) * rank

res = jax.jit(lambda x: mpi.mpi_allgather_jax(x)[0])(arr)
for p in range(size):
np.testing.assert_allclose(res[p], jnp.ones((3, 2)) * p)


def test_scatter_jax():
rank = mpi.rank
size = mpi.n_nodes

if rank == 0:
arr = jnp.stack([jnp.ones((3, 2)) * r for r in range(size)], axis=0)
else:
arr = jnp.ones((3, 2)) * rank

res = jax.jit(lambda x: mpi.mpi_scatter_jax(x, root=0)[0])(arr)

np.testing.assert_allclose(res, jnp.ones((3, 2)) * rank)


def test_alltoall_jax():
rank = mpi.rank
size = mpi.n_nodes

arr = jnp.ones((size, 3, 2)) * rank

res = jax.jit(lambda x: mpi.mpi_alltoall_jax(x)[0])(arr)
for p in range(size):
np.testing.assert_allclose(res[p], jnp.ones((3, 2)) * p)


def test_reduce_jax():
rank = mpi.rank
size = mpi.n_nodes

arr = jnp.ones((3, 2)) * rank

res = jax.jit(lambda x: mpi.mpi_reduce_sum_jax(x, root=0)[0])(arr)
if rank == 0:
np.testing.assert_allclose(res, jnp.ones((3, 2)) * sum(range(size)))
else:
np.testing.assert_allclose(res, arr)

0 comments on commit 5d8bad1

Please sign in to comment.