From f26b866e08a14a92072e477fc8602b90628e0cc6 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 21 Dec 2021 20:55:03 +0000 Subject: [PATCH] Add `jax.default_device` context manager This currently only supports setting a specific Device object, not a platform like "cpu". That should be added in the future. Bumps the minimum jaxlib version in order to include https://github.com/tensorflow/tensorflow/pull/53656 --- docs/faq.rst | 29 +++++++------- docs/jax.config.rst | 1 + jax/__init__.py | 1 + jax/_src/api.py | 9 ++++- jax/_src/config.py | 90 ++++++++++++++++++++++++++++++++++++++------ jax/_src/dispatch.py | 20 ++++++---- jax/linear_util.py | 7 ++-- tests/api_test.py | 49 ++++++++++++++++++++++++ 8 files changed, 170 insertions(+), 36 deletions(-) diff --git a/docs/faq.rst b/docs/faq.rst index ff680106d1cf..813d83391817 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -122,14 +122,17 @@ Let's first look at the principles of data and computation placement in JAX. In JAX, the computation follows data placement. JAX arrays have two placement properties: 1) the device where the data resides; -and 2) whether it is **committed** to the device or not (the data is sometimes +and 2) whether it is **committed** to the device or not (the data is sometimes referred to as being *sticky* to the device). By default, JAX arrays are placed uncommitted on the default device -(``jax.devices()[0]``), which is the first GPU by default. If no GPU is -present, ``jax.devices()[0]`` is the first CPU. The default device can -be set to "cpu" or "gpu" manually by setting the environment variable -``JAX_PLATFORM_NAME`` or the absl flag ``--jax_platform_name``. +(``jax.devices()[0]``), which is the first GPU or TPU by default. If no GPU or +TPU is present, ``jax.devices()[0]`` is the CPU. The default device can +temporarily overridden with the :func:`jax.default_device` context manager, or +set for the whole process by setting the environment variable ``JAX_PLATFORMS`` +or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu" +(``JAX_PLATFORMS`` can also be a list of platforms, which determines which +platforms are available in priority order). >>> from jax import numpy as jnp >>> print(jnp.ones(3).device_buffer.device()) # doctest: +SKIP @@ -148,15 +151,15 @@ gpu:2 Computations involving some committed inputs will happen on the committed device and the result will be committed on the -same device. Invoking an operation on arguments that are committed +same device. Invoking an operation on arguments that are committed to more than one device will raise an error. -You can also use :func:`jax.device_put` without a ``device`` parameter. If the data -is already on a device (committed or not), it's left as-is. If the data isn't on any -device—that is, it's a regular Python or NumPy value—it's placed uncommitted on the default +You can also use :func:`jax.device_put` without a ``device`` parameter. If the data +is already on a device (committed or not), it's left as-is. If the data isn't on any +device—that is, it's a regular Python or NumPy value—it's placed uncommitted on the default device. -Jitted functions behave like any other primitive operations—they will follow the +Jitted functions behave like any other primitive operations—they will follow the data and will show errors if invoked on data committed on more than one device. ``jnp.device_put(jnp.zeros(...), jax.devices()[1])`` or similar will actually create the @@ -164,8 +167,8 @@ array of zeros on ``jax.devices()[1]``, instead of creating the array on the def device then moving it. This is thanks to some laziness in array creation, which holds for all the constant creation operations (``ones``, ``full``, ``eye``, etc). -(As of April 2020, :func:`jax.jit` has a `device` parameter that affects the device -placement. That parameter is experimental, is likely to be removed or changed, +(As of April 2020, :func:`jax.jit` has a `device` parameter that affects the device +placement. That parameter is experimental, is likely to be removed or changed, and its use is not recommended.) For a worked-out example, we recommend reading through @@ -259,7 +262,7 @@ Broadly speaking: they are dispatched asynchronously (see :ref:`async-dispatch`); and they can be executed on CPU, GPU, or TPU, each of which have vastly different and continuously evolving performance characteristics. - + These architectural differences make meaningful direct benchmark comparisons between NumPy and JAX difficult. diff --git a/docs/jax.config.rst b/docs/jax.config.rst index b5221da08ab6..198086a8fa7a 100644 --- a/docs/jax.config.rst +++ b/docs/jax.config.rst @@ -11,6 +11,7 @@ JAX configuration checking_leaks debug_nans debug_infs + default_device default_matmul_precision default_prng_impl enable_checks diff --git a/jax/__init__.py b/jax/__init__.py index c009f4756013..88a2333bc7d2 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -53,6 +53,7 @@ transfer_guard_host_to_device as transfer_guard_host_to_device, transfer_guard_device_to_device as transfer_guard_device_to_device, transfer_guard_device_to_host as transfer_guard_device_to_host, + default_device as default_device, ) from .core import eval_context as ensure_compile_time_eval from jax._src.api import ( diff --git a/jax/_src/api.py b/jax/_src/api.py index c3909d77564d..3a9aaf38a671 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -456,6 +456,13 @@ def _cpp_jit( raise ValueError("can't specify both a device and a backend for jit, " f"got device={device} and backend={backend}.") + if device is not None: + jit_device = device + elif backend is not None: + jit_device = xb.get_backend(backend).get_default_device_assignment(1)[0] + else: + jit_device = None + @api_boundary def cache_miss(*args, **kwargs): ### This first part is basically the same code as in _python_jit. @@ -531,7 +538,7 @@ def get_device_info(): static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, - cache=_cpp_jit_cache) + cache=_cpp_jit_cache, jit_device=jit_device) f_jitted = wraps(fun)(cpp_jitted_f) f_jitted.lower = _jit_lower(fun, static_argnums, static_argnames, device, diff --git a/jax/_src/config.py b/jax/_src/config.py index c8ac7c0d8a35..5cda4b436f33 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -27,6 +27,7 @@ from jax._src import lib from jax._src.lib import jax_jit from jax._src.lib import transfer_guard_lib +from jax._src.lib import xla_client def bool_env(varname: str, default: bool) -> bool: """Read an environment variable and interpret it as a boolean. @@ -321,6 +322,45 @@ def define_string_state( updated value of the thread-local state when it is altered or set initially. + Returns: + A contextmanager to control the thread-local state value. + """ + def validate(new_val): + if new_val is not None and not isinstance(new_val, str): + raise ValueError(f"new string config value must be None or of type str," + f" got {new_val} of type {type(new_val)}.") + + return self.define_string_or_object_state( + name, default, help, update_global_hook, update_thread_local_hook, + validate) + + def define_string_or_object_state( + self, name: str, default: Any, help: str, + update_global_hook: Optional[Callable[[Any], None]] = None, + update_thread_local_hook: Optional[Callable[[Any], None]] = None, + validate_new_val_hook: Optional[Callable[[Any], None]] = None): + """Set up thread-local state and return a contextmanager for managing it. + + Similar to ``define_string_state``, except the context manager will accept + any object, not just a string. Any value passed via commandline flag or + environment variable will be treated as a string. + + Args: + name: string, converted to lowercase to define the name of the config + option (and absl flag). It is converted to uppercase to define the + corresponding shell environment variable. + default: string, a default value for the option. + help: string, used to populate the flag help information as well as the + docstring of the returned context manager. + update_global_hook: an optional callback that is called with the updated + value of the global state when it is altered or set initially. + update_thread_local_hook: an optional callback that is called with the + updated value of the thread-local state when it is altered or set + initially. + validate_new_val_hook: an optional callback that is called with the new + value on any update, and should raise an error if the new value is + invalid. + Returns: A contextmanager to control the thread-local state value. """ @@ -335,12 +375,8 @@ def get_state(self): return val if val is not unset else self._read(name) setattr(Config, name, property(get_state)) - def validate(new_val): - if new_val is not None and not isinstance(new_val, str): - raise ValueError(f"new string config value must be None or of type str," - f" got {new_val} of type {type(new_val)}.") - - return _StateContextManager(name, help, update_thread_local_hook, validate) + return _StateContextManager(name, help, update_thread_local_hook, + validate_new_val_hook) def _trace_context(self): """Returns a tuple of configuration values that affect tracing. @@ -435,7 +471,7 @@ def __setattr__(self, name, val): # a global/thread-local state. These methods allow updates to part of the # state when a configuration value changes. -class GlobalJitState(NamedTuple): +class GlobalExtraJitContext(NamedTuple): numpy_rank_promotion: Optional[str] = None default_matmul_precision: Optional[Any] = None dynamic_shapes: bool = False @@ -443,11 +479,12 @@ class GlobalJitState(NamedTuple): def update_global_jit_state(**kw): gs = jax_jit.global_state() - context = gs.extra_jit_context or GlobalJitState() - gs.extra_jit_context = context._replace(**kw) + if gs.extra_jit_context is None: + gs.extra_jit_context = GlobalExtraJitContext() + gs.extra_jit_context = gs.extra_jit_context._replace(**kw) -class ThreadLocalJitState(NamedTuple): +class ThreadLocalExtraJitContext(NamedTuple): dynamic_trace_state: Optional[Any] = None numpy_rank_promotion: Optional[str] = None default_matmul_precision: Optional[Any] = None @@ -456,8 +493,9 @@ class ThreadLocalJitState(NamedTuple): def update_thread_local_jit_state(**kw): tls = jax_jit.thread_local_state() - context = tls.extra_jit_context or ThreadLocalJitState() - tls.extra_jit_context = context._replace(**kw) + if tls.extra_jit_context is None: + tls.extra_jit_context = ThreadLocalExtraJitContext() + tls.extra_jit_context = tls.extra_jit_context._replace(**kw) # TODO(mattjj): remove all uses of this flag @@ -639,6 +677,34 @@ def _update_x64_thread_local(val): Config.x64_enabled = Config.jax_enable_x64 # type: ignore + +def _update_default_device_global(val): + lib.jax_jit.global_state().default_device = val + +def _update_default_device_thread_local(val): + lib.jax_jit.thread_local_state().default_device = val + +def _validate_default_device(val): + if val is not None and not isinstance(val, xla_client.Device): + raise ValueError("jax.default_device must be passed a Device object (e.g. " + f"`jax.devices('cpu')[0]`), got: {repr(val)}") + +# TODO(skye): default_device only accepts devices for now. Make it work with +# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). +default_device = config.define_string_or_object_state( + name='jax_default_device', + default=None, + help=( + 'Configure the default device for JAX operations. ' + 'Set to a Device object (e.g. ``jax.devices("cpu")[0]``) to use that Device as the ' + 'default device for JAX operations and jit\'d function calls (there is ' + 'no effect on multi-device computations, e.g. pmapped function calls). ' + 'Set to None to use the system default device. See ' + ':ref:`faq-data-placement` for more information on device placement.'), + update_global_hook=_update_default_device_global, + update_thread_local_hook=_update_default_device_thread_local, + validate_new_val_hook=_validate_default_device) + def _update_disable_jit_global(val): lib.jax_jit.global_state().disable_jit = val diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 9dddb08f8d94..5d2f6e2ae95b 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -438,21 +438,27 @@ def initial_style_primitive_replicas(params): return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1) -def _xla_callable_device(nreps, backend, device, arg_devices): +def _xla_callable_device(nreps, backend, device, arg_devices + ) -> Optional[Device]: if nreps > 1: if device is not None or backend is not None: raise ValueError(f"can't specify device or backend for jit-of-pmap, " f"got device={device} and backend={backend}") return None else: - if device is None and backend is None: - return _device_from_arg_devices(arg_devices) - elif device is not None and backend is None: + # TODO(skye): dedup with C++ jit logic for determining jit device? + if device is not None: + assert backend is None return device - elif device is None and backend is not None: + + if backend is not None: return xb.get_backend(backend).get_default_device_assignment(1)[0] - else: - assert False # Unreachable given the error check in _xla_callable + + arg_device = _device_from_arg_devices(arg_devices) + if arg_device is not None: + return arg_device + + return config.jax_default_device # Argument and result handlers diff --git a/jax/linear_util.py b/jax/linear_util.py index 266334864112..c0ad791490e8 100644 --- a/jax/linear_util.py +++ b/jax/linear_util.py @@ -273,10 +273,11 @@ def memoized_fun(fun: WrappedFun, *args): cache = fun_caches.setdefault(fun.f, {}) if config.jax_check_tracer_leaks: key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args, - config.x64_enabled, config._trace_context()) + config.x64_enabled, config.jax_default_device, + config._trace_context()) else: - key = (fun.transforms, fun.params, fun.in_type, args, - config.x64_enabled, config._trace_context()) + key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled, + config.jax_default_device, config._trace_context()) result = cache.get(key, None) if result is not None: ans, stores = result diff --git a/tests/api_test.py b/tests/api_test.py index 1b691e666c7d..29a00c46498e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -228,6 +228,55 @@ def test_jit_device(self): self.assertIsInstance(x, jnp.DeviceArray) self.assertEqual(x.device_buffer.device(), device) + @jtu.skip_on_devices("cpu") + def test_jit_default_device(self): + if jax.device_count() == 1: + raise unittest.SkipTest("Test requires multiple devices") + + system_default_device = jax.numpy.add(1, 1).device() + test_device = jax.devices()[-1] + self.assertNotEqual(system_default_device, test_device) + + f = jax.jit(lambda x: x + 1) + self.assertEqual(f(1).device(), system_default_device) + + with jax.default_device(test_device): + self.assertEqual(jax.numpy.add(1, 1).device(), test_device) + self.assertEqual(f(1).device(), test_device) + + self.assertEqual(jax.numpy.add(1, 1).device(), system_default_device) + self.assertEqual(f(1).device(), system_default_device) + + with jax.default_device(test_device): + # Explicit `device` or `backend` argument to jit overrides default_device + self.assertEqual(jax.jit(f, device=system_default_device)(1).device(), + system_default_device) + self.assertEqual(jax.jit(f, backend="cpu")(1).platform(), "cpu") + + # Sticky input device overrides default_device + sticky = jax.device_put(1, system_default_device) + self.assertEqual(jax.numpy.add(sticky, 1).device(), system_default_device) + self.assertEqual(f(sticky).device(), system_default_device) + + # Test nested default_devices + with jax.default_device(system_default_device): + self.assertEqual(f(1).device(), system_default_device) + self.assertEqual(f(1).device(), test_device) + + # Test a few more non-default_device calls for good luck + self.assertEqual(jax.numpy.add(1, 1).device(), system_default_device) + self.assertEqual(f(sticky).device(), system_default_device) + self.assertEqual(f(1).device(), system_default_device) + + # TODO(skye): make this work! + def test_jit_default_platform(self): + with self.assertRaisesWithLiteralMatch( + ValueError, + "jax.default_device must be passed a Device object " + "(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"): + with jax.default_device("cpu"): + jax.jit(lambda x: x + 1)(1) + def test_complex_support(self): self.assertEqual(self.jit(lambda x: x + 1)(1 + 1j), 2 + 1j)