Skip to content

Commit

Permalink
[XLA:Python] Split DevicePut out of jax_jit and refactor it.
Browse files Browse the repository at this point in the history
* Creates a new py_values.cc/h file to contain device_put.
* Moves some of the type helpers into the existing types module.
* Change `PyClient::BufferFromPyval` to call DevicePut. There's no reason to have two similar but subtly different methods for copying a buffer-like object to a device.
* Refactor and optimize some of the handler functions. In particular, avoid creating a number of unnecessary intermediate objects.

PiperOrigin-RevId: 361430648
  • Loading branch information
hawkinsp authored and jax authors committed Mar 7, 2021
1 parent 12c2d0d commit 2722837
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tests/jax_jit_test.py
Expand Up @@ -27,10 +27,13 @@


# It covers all JAX numpy types types except bfloat16 and numpy array.
# TODO(jblespiau): Add support for float0 and bfloat16 in the C++ path.
# TODO(jblespiau): Add support for float0 in the C++ path.
_EXCLUDED_TYPES = [np.ndarray]
if jax.lib._xla_extension_version < 6:
_EXCLUDED_TYPES.append(jax.dtypes.bfloat16)

_SCALAR_NUMPY_TYPES = [
x for x in jax.abstract_arrays.array_types
if x not in [np.ndarray, jax.dtypes.bfloat16]
x for x in jax.abstract_arrays.array_types if x not in _EXCLUDED_TYPES
]


Expand All @@ -45,9 +48,6 @@ def test_is_float_0(self):
jaxlib.jax_jit._is_float0(np.zeros((5, 5), dtype=jax.float0)))
self.assertFalse(jaxlib.jax_jit._is_float0(np.zeros((5, 5))))

def test_DtypeTo32BitDtype(self):
self.assertEqual(np.float32, jaxlib.jax_jit._DtypeTo32BitDtype(np.float64))

@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_numpy_scalars(self, device_put_function):

Expand Down Expand Up @@ -140,7 +140,9 @@ def test_device_put_on_python_scalars(self):

@unittest.skipIf(jax.lib._xla_extension_version < 3, "jaxlib too old")
def test_convert_int_overflow(self):
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
with self.assertRaisesRegex(
RuntimeError if jax.lib._xla_extension_version >= 6 else OverflowError,
"(Python int too large|Unable to convert Python scalar).*"):
jaxlib.jax_jit.device_put(int(1e100), True, jax.devices()[0])

def test_arg_signature_of_value(self):
Expand Down

0 comments on commit 2722837

Please sign in to comment.