Skip to content

Commit

Permalink
add multihost pjit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sudhakarsingh27 committed Sep 23, 2022
1 parent a6b24b3 commit 4dd0d85
Showing 1 changed file with 289 additions and 1 deletion.
290 changes: 289 additions & 1 deletion tests/multiprocess_gpu_test.py
Expand Up @@ -17,16 +17,23 @@
import sys
import threading
import unittest
import functools

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

import jax
from jax import experimental
from jax.config import config
from jax._src import distributed
import jax.numpy as jnp
from jax._src.lib import xla_extension_version
from jax._src import test_util as jtu
from jax._src import util
from jax.experimental import global_device_array
from jax.experimental import maps
from jax.experimental import pjit

try:
import portpicker
Expand All @@ -40,7 +47,6 @@

config.parse_flags_with_absl()


@unittest.skipIf(not portpicker, "Test requires portpicker")
class DistributedTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -170,6 +176,49 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
if pytest is not None:
pytestmark = pytest.mark.SlurmMultiNodeGpuTest

def sorted_devices(self):
devices = sorted(jax.devices(), key=lambda d: (d.id, d.host_id))
if len(devices) != 16:
raise unittest.SkipTest(
"Test assumes that it runs on 16 devices (2 nodes)")
return devices

def create_2d_non_contiguous_mesh(self):
devices = self.sorted_devices()
device_mesh = np.array([[devices[0], devices[2]],
[devices[4], devices[6]],
[devices[1], devices[3]],
[devices[5], devices[7]],
[devices[8], devices[10]],
[devices[12], devices[14]],
[devices[9], devices[11]],
[devices[13], devices[15]]])
# The mesh looks like this (the integers are process index):
# 0 2
# 4 6
# 1 3
# 5 7
# 8 10
# 12 14
# 9 11
# 13 15
assert [d.id for d in device_mesh.flat
] == [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]
return maps.Mesh(device_mesh, ("x", "y"))

def setUp(self):
super().setUp()
self.xmap_spmd_lowering_enabled = jax.config.experimental_xmap_spmd_lowering
jax.config.update("experimental_xmap_spmd_lowering", True)
self.gda_enabled = jax.config.jax_parallel_functions_output_gda
jax.config.update('jax_parallel_functions_output_gda', True)

def tearDown(self):
jax.config.update("experimental_xmap_spmd_lowering",
self.xmap_spmd_lowering_enabled)
jax.config.update('jax_parallel_functions_output_gda', self.gda_enabled)
super().tearDown()

def test_gpu_multi_node_initialize_and_psum(self):

# Hookup the ENV vars expected to be set already in the SLURM environment
Expand Down Expand Up @@ -224,5 +273,244 @@ def test_gpu_multi_node_transparent_initialize_and_psum(self):
self.assertEqual(y[0], jax.device_count())
print(y)

# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_multi_input_multi_output(self):
jax.distributed.initialize()
global_mesh = jtu.create_global_mesh((8, 2), ("x", "y"))
global_input_shape = (16, 2)
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)

def cb(index):
return global_input_data[index]

mesh_axes1 = experimental.PartitionSpec("x", "y")
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes1, cb)
mesh_axes2 = experimental.PartitionSpec("x")
gda2 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes2, cb)
mesh_axes3 = experimental.PartitionSpec(("x", "y"))
gda3 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes3, cb)

with maps.Mesh(global_mesh.devices, global_mesh.axis_names):

@functools.partial(
pjit.pjit,
# `FROM_GDA` will be replicated for all the inputs.
in_axis_resources=pjit.FROM_GDA,
out_axis_resources=(mesh_axes1, None, mesh_axes2))
def f(x, y, z):
return x @ x.T, y, z

out1, out2, out3 = f(gda1, gda2, gda3)

self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (16, 16))
self.assertEqual(out1.local_shards[0].data.shape, (2, 8))
self.assertDictEqual(out1.mesh.shape, {"x": 8, "y": 2})
expected_matrix_mul = global_input_data @ global_input_data.T
for s in out1.local_shards:
np.testing.assert_array_equal(np.asarray(s.data),
expected_matrix_mul[s.index])

self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (16, 2))
self.assertEqual(out2.local_shards[0].data.shape, (16, 2))
for s in out2.local_shards:
np.testing.assert_array_equal(np.asarray(s.data), global_input_data)

self.assertIsInstance(out3, global_device_array.GlobalDeviceArray)
self.assertEqual(out3.shape, (16, 2))
self.assertEqual(out3.local_shards[0].data.shape, (2, 2))
for s in out3.local_shards:
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[s.index])

# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_non_contiguous_mesh(self):
jax.distributed.initialize()
devices = self.sorted_devices()
mesh_devices = np.array(devices[0:8:2] + devices[1:8:2] + devices[8:16:2] +
devices[9:16:2])
# The device order in the below mesh is:
# [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]
# each having the following process index:
# The process-gpu mapping is random: @sudhakarsingh27 to figure out why so
# and the data is:
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
global_mesh = maps.Mesh(mesh_devices, ("x",))
global_input_shape = (16,)
mesh_axes = experimental.PartitionSpec("x")
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)

def cb(index):
return global_input_data[index]

gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)

# device_id -> (index, replica_id)
expected_idx_rid = {
0: ((slice(0, 1),), 0),
1: ((slice(4, 5),), 0),
2: ((slice(1, 2),), 0),
3: ((slice(5, 6),), 0),
4: ((slice(2, 3),), 0),
5: ((slice(6, 7),), 0),
6: ((slice(3, 4),), 0),
7: ((slice(7, 8),), 0),
8: ((slice(8, 9),), 0),
9: ((slice(12, 13),), 0),
10: ((slice(9, 10),), 0),
11: ((slice(13, 14),), 0),
12: ((slice(10, 11),), 0),
13: ((slice(14, 15),), 0),
14: ((slice(11, 12),), 0),
15: ((slice(15, 16),), 0),
}

with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
f = pjit.pjit(lambda x: x,
in_axis_resources=pjit.FROM_GDA,
out_axis_resources=mesh_axes)
out = f(gda1)
for s in out.local_shards:
device_id = s.device.id
expected_index = expected_idx_rid[device_id][0]
expected_replica_id = expected_idx_rid[device_id][1]
self.assertEqual(s.index, expected_index)
self.assertEqual(s.replica_id, expected_replica_id)
self.assertEqual(s.data.shape, (1,))
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[expected_index])

# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_non_contiguous_mesh_2d(self):
jax.distributed.initialize()
global_mesh = self.create_2d_non_contiguous_mesh()
global_input_shape = (16, 2)
mesh_axes = experimental.PartitionSpec("x", "y")
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)

def cb(index):
return global_input_data[index]

gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)

# device_id -> (index, replica_id)
expected_idx_rid = {
0: ((slice(0, 2), slice(0, 1)), 0),
1: ((slice(4, 6), slice(0, 1)), 0),
2: ((slice(0, 2), slice(1, 2)), 0),
3: ((slice(4, 6), slice(1, 2)), 0),
4: ((slice(2, 4), slice(0, 1)), 0),
5: ((slice(6, 8), slice(0, 1)), 0),
6: ((slice(2, 4), slice(1, 2)), 0),
7: ((slice(6, 8), slice(1, 2)), 0),
8: ((slice(8, 10), slice(0, 1)), 0),
9: ((slice(12, 14), slice(0, 1)), 0),
10: ((slice(8, 10), slice(1, 2)), 0),
11: ((slice(12, 14), slice(1, 2)), 0),
12: ((slice(10, 12), slice(0, 1)), 0),
13: ((slice(14, 16), slice(0, 1)), 0),
14: ((slice(10, 12), slice(1, 2)), 0),
15: ((slice(14, 16), slice(1, 2)), 0),
}

with global_mesh:
f = pjit.pjit(lambda x: x,
in_axis_resources=pjit.FROM_GDA,
out_axis_resources=mesh_axes)
out = f(gda1)

for s in out.local_shards:
device_id = s.device.id
expected_index = expected_idx_rid[device_id][0]
expected_replica_id = expected_idx_rid[device_id][1]
self.assertEqual(s.index, expected_index)
self.assertEqual(s.replica_id, expected_replica_id)
self.assertEqual(s.data.shape, (2, 1))
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[expected_index])

with global_mesh:
f = pjit.pjit(lambda x: x,
in_axis_resources=experimental.PartitionSpec(None),
out_axis_resources=mesh_axes)
# Fully replicated values allows a non-contiguous mesh.
out = f(global_input_data)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)

with global_mesh:
f = pjit.pjit(lambda x: x,
in_axis_resources=None,
out_axis_resources=mesh_axes)
# Fully replicated values allows a non-contiguous mesh.
out = f(global_input_data)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)

gda2 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, experimental.PartitionSpec(None), cb)

with global_mesh:
f = pjit.pjit(lambda x, y: (x, y),
in_axis_resources=(None, None),
out_axis_resources=(mesh_axes, mesh_axes))
# Fully replicated values + GDA allows a non-contiguous mesh.
out1, out2 = f(global_input_data, gda2)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)

# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_non_contiguous_mesh_2d_aot(self):
jax.distributed.initialize()
global_mesh = self.create_2d_non_contiguous_mesh()
global_input_shape = (8, 2)
mesh_axes = experimental.PartitionSpec("x", "y")
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes,
lambda idx: global_input_data[idx])

with global_mesh:
f = pjit.pjit(lambda x, y: (x, y),
in_axis_resources=experimental.PartitionSpec("x", "y"),
out_axis_resources=experimental.PartitionSpec("x", "y"))
inp_aval = jax.ShapedArray((8, 2), jnp.int32)
# `ShapedArray` is considered global when lowered and compiled.
# Hence it can bypass the contiguous mesh restriction.
compiled = f.lower(inp_aval, gda1, _global_avals=True).compile()
out1, out2 = compiled(gda1, gda1)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8, 2))
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8, 2))

# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_eval_shape(self):
jax.distributed.initialize()

with jtu.create_global_mesh((16,), ("x")):

@functools.partial(pjit.pjit,
in_axis_resources=experimental.PartitionSpec(None),
out_axis_resources=experimental.PartitionSpec("x"))
def f():
return jnp.zeros([32, 10])

self.assertEqual(f().shape, (32, 10))
self.assertEqual(jax.eval_shape(f).shape, (32, 10))

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 4dd0d85

Please sign in to comment.