Skip to content

Commit

Permalink
Return jax.Array from GDA's callback APIs if jax.Array is True.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 510268071
  • Loading branch information
yashk2810 authored and jax authors committed Feb 17, 2023
1 parent 2b9ad0d commit eea1fef
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
9 changes: 9 additions & 0 deletions jax/_src/global_device_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
jax.device_put(data_callback(global_indices_rid[device][0]), device)
for device in local_devices
]
if config.jax_array:
return jax.make_array_from_single_device_arrays(
global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs)
return cls(global_shape, global_mesh, mesh_axes, dbs,
_gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))

Expand Down Expand Up @@ -570,6 +573,9 @@ def from_batched_callback(cls, global_shape: Shape,
local_indices = [global_indices_rid[d][0] for d in local_devices]
local_arrays = data_callback(local_indices)
dbs = pxla.device_put(local_arrays, local_devices)
if config.jax_array:
return jax.make_array_from_single_device_arrays(
global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs) # type: ignore
return cls(global_shape, global_mesh, mesh_axes, dbs,
_gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))

Expand Down Expand Up @@ -633,6 +639,9 @@ def from_batched_callback_with_devices(
(index, tuple(devices)) for index, devices in index_to_device.values()
]
dbs = data_callback(cb_inp)
if config.jax_array:
return jax.make_array_from_single_device_arrays(
global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs) # type: ignore
return cls(global_shape, global_mesh, mesh_axes, dbs,
_gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))

Expand Down
11 changes: 9 additions & 2 deletions tests/global_device_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def cb(index):
mesh_axes, cb)
self.assertEqual(gda.ndim, 2)
self.assertEqual(gda.size, 16)
self.assertEqual(gda.mesh_axes, mesh_axes)
self.assertEqual(gda.addressable_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.addressable_data(0),
global_input_data[expected_index[0]])
Expand Down Expand Up @@ -283,6 +282,8 @@ def cb(cb_inp):
expected_second_shard_value)

def test_gda_str_repr(self):
if jax.config.jax_array:
self.skipTest('jax.Array repr already has a test')
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
Expand All @@ -300,6 +301,8 @@ def cb(index):
"mesh_axes=PartitionSpec(('x', 'y'),))"))

def test_gda_equality_raises_not_implemented(self):
if jax.config.jax_array:
self.skipTest('jax.Array has __eq__.')
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(None,)
Expand Down Expand Up @@ -385,8 +388,12 @@ def test_gda_delete(self):
gda, _ = create_gda(input_shape, global_mesh, P("x", "y"))
gda._check_if_deleted()
gda.delete()
if jax.config.jax_array:
arr_type = 'Array'
else:
arr_type = 'GlobalDeviceArray'
with self.assertRaisesRegex(RuntimeError,
"GlobalDeviceArray has been deleted."):
f"{arr_type} has been deleted."):
gda._check_if_deleted()


Expand Down
10 changes: 7 additions & 3 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,10 +1615,10 @@ def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape,


@parameterized.named_parameters(
('gda', parallel_functions_output_gda, create_gda, 'GDA'),
('array', jax_array, create_array, 'Array'),
('gda', parallel_functions_output_gda, create_gda),
('array', jax_array, create_array),
)
def test_xla_arr_sharding_mismatch(self, ctx, create_fun, arr_type):
def test_xla_arr_sharding_mismatch(self, ctx, create_fun):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
Expand All @@ -1637,6 +1637,10 @@ def test_xla_arr_sharding_mismatch(self, ctx, create_fun, arr_type):
else P('x', 'y'))
arr, _ = create_fun(global_input_shape, global_mesh, different_pspec,
input_data)
if jax.config.jax_array:
arr_type = 'Array'
else:
arr_type = 'GDA'
with self.assertRaisesRegex(
ValueError,
f"{arr_type} sharding does not match the input sharding."):
Expand Down

0 comments on commit eea1fef

Please sign in to comment.