Skip to content

Commit

Permalink
jax_array_test: set config once & fix X64 failure
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 13, 2022
1 parent ad326b9 commit 8eb44fd
Showing 1 changed file with 81 additions and 106 deletions.
187 changes: 81 additions & 106 deletions tests/array_test.py
Expand Up @@ -21,7 +21,6 @@
import jax
import jax.numpy as jnp
from jax._src import dispatch
from jax._src import config as jax_config
from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
Expand Down Expand Up @@ -66,6 +65,7 @@ def create_array(shape, sharding, global_data=None):
shape, sharding, lambda idx: global_data[idx]), global_data


@jtu.with_config(jax_array=True)
class JaxArrayTest(jtu.JaxTestCase):

@parameterized.named_parameters(
Expand All @@ -77,16 +77,15 @@ class JaxArrayTest(jtu.JaxTestCase):
("mesh_fully_replicated", P()),
)
def test_jax_array_value(self, mesh_axes):
with jax_config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, global_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, mesh_axes))
for s in arr.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], global_data[s.index])
self.assertArraysEqual(arr._value, global_data)
self.assertArraysEqual(arr._npy_value, global_data)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, global_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, mesh_axes))
for s in arr.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], global_data[s.index])
self.assertArraysEqual(arr._value, global_data)
self.assertArraysEqual(arr._npy_value, global_data)

@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
Expand Down Expand Up @@ -137,108 +136,96 @@ def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
self.assertArraysEqual(s.data, global_input_data[s.index])

def test_array_delete(self):
with jax_config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)

def test_device_put(self):
with jax_config.jax_array(True):
numpy_array = np.array([1, 2, 3])
arr = jax.device_put(numpy_array, jax.devices()[0])
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
self.assertArraysEqual(arr, numpy_array)
self.assertEqual(arr._committed, True)
for i in arr.addressable_shards:
self.assertArraysEqual(i.data, numpy_array)
self.assertEqual(i.device, jax.devices()[0])
self.assertEqual(i.index, (slice(None),))
self.assertEqual(i.replica_id, 0)
numpy_array = np.array([1, 2, 3])
arr = jax.device_put(numpy_array, jax.devices()[0])
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
self.assertArraysEqual(arr, numpy_array)
self.assertEqual(arr._committed, True)
for i in arr.addressable_shards:
self.assertArraysEqual(i.data, numpy_array)
self.assertEqual(i.device, jax.devices()[0])
self.assertEqual(i.index, (slice(None),))
self.assertEqual(i.replica_id, 0)

def test_device_put_array_delete(self):
with jax_config.jax_array(True):
arr = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)
arr = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)

def test_array_device_get(self):
with jax_config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
self.assertArraysEqual(jax.device_get(arr), input_data)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
self.assertArraysEqual(jax.device_get(arr), input_data)

def test_repr(self):
with jax_config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
repr(arr) # doesn't crash
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
repr(arr) # doesn't crash

def test_jnp_array(self):
with jax_config.jax_array(True):
arr = jnp.array([1, 2, 3])
self.assertIsInstance(arr, array.Array)
self.assertTrue(dispatch.is_single_device_sharding(arr.sharding))
self.assertEqual(arr._committed, False)
arr = jnp.array([1, 2, 3])
self.assertIsInstance(arr, array.Array)
self.assertTrue(dispatch.is_single_device_sharding(arr.sharding))
self.assertEqual(arr._committed, False)

def test_jnp_array_jit_add(self):
with jax_config.jax_array(True):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = jax.jit(lambda x, y: x + y)(a, b)
self.assertIsInstance(arr, array.Array)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = jax.jit(lambda x, y: x + y)(a, b)
self.assertIsInstance(arr, array.Array)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)

def test_jnp_array_jnp_add(self):
with jax_config.jax_array(True):
arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
self.assertIsInstance(arr, array.Array)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
self.assertIsInstance(arr, array.Array)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)

def test_jnp_array_normal_add(self):
with jax_config.jax_array(True):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = a + b
self.assertIsInstance(arr, array.Array)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = a + b
self.assertIsInstance(arr, array.Array)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)

def test_array_sharded_astype(self):
with jax_config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
self.assertLen(arr_float32.addressable_shards, 8)
for i in arr_float32.addressable_shards:
self.assertArraysEqual(i.data, input_data[i.index].astype(np.float32))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
self.assertLen(arr_float32.addressable_shards, 8)
for i in arr_float32.addressable_shards:
self.assertArraysEqual(i.data, input_data[i.index].astype(np.float32))

def test_jnp_array_astype(self):
with jax_config.jax_array(True):
arr = jnp.array([1, 2, 3])
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, arr.astype(np.float32))
arr = jnp.array([1, 2, 3])
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, arr.astype(np.float32))

@jax_config.jax_array(True)
def test_sharded_add(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
Expand All @@ -253,28 +240,25 @@ def test_sharded_add(self):
for i in out.addressable_shards:
self.assertArraysEqual(i.data, expected[i.index])

@jax_config.jax_array(True)
def test_sharded_zeros_like(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
a, input_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
out = jnp.zeros_like(a)
expected = jnp.zeros(input_data.shape, dtype=np.int64)
expected = jnp.zeros(input_data.shape, dtype=int)
self.assertArraysEqual(out, expected)
self.assertLen(out.addressable_shards, 8)
for i in out.addressable_shards:
self.assertArraysEqual(i.data, expected[i.index])

@jax_config.jax_array(True)
def test_zeros_like(self):
a = jnp.array([1, 2, 3], dtype=np.int32)
out = jnp.zeros_like(a)
expected = np.zeros(a.shape, dtype=np.int32)
self.assertArraysEqual(out, expected)
self.assertTrue(dispatch.is_single_device_sharding(out.sharding))

@jax_config.jax_array(True)
def test_wrong_num_arrays(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
Expand All @@ -294,7 +278,6 @@ def test_wrong_num_arrays(self):
r'by the sharding\), but got 16'):
array.Array(jax.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)

@jax_config.jax_array(True)
def test_arrays_not_in_device_assignment(self):
if jax.device_count() < 4:
self.skipTest('Requires more than 4 devices')
Expand Down Expand Up @@ -331,7 +314,6 @@ def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shap
"buffer shape"):
array.make_array_from_callback(shape, mps, lambda idx: inp_data)

@jax_config.jax_array(True)
def test_mismatch_dtype(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
Expand All @@ -345,7 +327,6 @@ def test_mismatch_dtype(self):
"Got int32, expected float32"):
array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)

@jax_config.jax_array(True)
def test_array_iter_pmap_sharding(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
Expand All @@ -360,7 +341,6 @@ def test_array_iter_pmap_sharding(self):
self.assertIsInstance(i, array.Array)
self.assertArraysAllClose(i, j)

@jax_config.jax_array(True)
def test_array_iter_pmap_sharding_last_dim_sharded(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
Expand All @@ -371,7 +351,6 @@ def test_array_iter_pmap_sharding_last_dim_sharded(self):
for i, j in zip(iter(y), iter(np.sin(x).T)):
self.assertArraysAllClose(i, j)

@jax_config.jax_array(True)
def test_array_iter_mesh_pspec_sharding_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
Expand All @@ -381,7 +360,6 @@ def test_array_iter_mesh_pspec_sharding_multi_device(self):
for i, j in zip(iter(arr), iter(input_data)):
self.assertArraysEqual(i, j)

@jax_config.jax_array(True)
def test_array_getitem_mesh_pspec_sharding_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
Expand All @@ -394,7 +372,6 @@ def test_array_getitem_mesh_pspec_sharding_multi_device(self):
self.assertArraysEqual(s, np.array([[4], [6]]))
self.assertArraysEqual(arr[:2], input_data[:2])

@jax_config.jax_array(True)
def test_array_iter_mesh_pspec_sharding_single_device(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
Expand All @@ -409,7 +386,6 @@ def test_array_iter_mesh_pspec_sharding_single_device(self):
self.assertArraysEqual(i, j)
self.assertEqual(i.device(), single_dev[0])

@jax_config.jax_array(True)
def test_array_shards_committed(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
Expand All @@ -424,7 +400,6 @@ def test_array_shards_committed(self):
self.assertEqual(s.data._committed, y._committed)
self.assertTrue(s.data._committed)

@jax_config.jax_array(True)
def test_array_jnp_array_copy_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
Expand All @@ -443,6 +418,7 @@ def test_array_jnp_array_copy_multi_device(self):
c.data.unsafe_buffer_pointer())


@jtu.with_config(jax_array=True)
class ShardingTest(jtu.JaxTestCase):

def test_mesh_pspec_sharding_interface(self):
Expand Down Expand Up @@ -504,7 +480,6 @@ def test_uneven_shard_error(self):
r"factors: \[4, 2\] should evenly divide the shape\)"):
mps.shard_shape((8, 3))

@jax_config.jax_array(True)
def test_pmap_sharding_hash_eq(self):
if jax.device_count() < 2:
self.skipTest('Test needs >= 2 devices.')
Expand Down

0 comments on commit 8eb44fd

Please sign in to comment.