Skip to content

Commit

Permalink
Add jax.default_device context manager
Browse files Browse the repository at this point in the history
This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.

Bumps the minimum jaxlib version in order to include
tensorflow/tensorflow#53656
  • Loading branch information
skye committed May 7, 2022
1 parent 212edd6 commit f26b866
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 36 deletions.
29 changes: 16 additions & 13 deletions docs/faq.rst
Expand Up @@ -122,14 +122,17 @@ Let's first look at the principles of data and computation placement in JAX.

In JAX, the computation follows data placement. JAX arrays
have two placement properties: 1) the device where the data resides;
and 2) whether it is **committed** to the device or not (the data is sometimes
and 2) whether it is **committed** to the device or not (the data is sometimes
referred to as being *sticky* to the device).

By default, JAX arrays are placed uncommitted on the default device
(``jax.devices()[0]``), which is the first GPU by default. If no GPU is
present, ``jax.devices()[0]`` is the first CPU. The default device can
be set to "cpu" or "gpu" manually by setting the environment variable
``JAX_PLATFORM_NAME`` or the absl flag ``--jax_platform_name``.
(``jax.devices()[0]``), which is the first GPU or TPU by default. If no GPU or
TPU is present, ``jax.devices()[0]`` is the CPU. The default device can
temporarily overridden with the :func:`jax.default_device` context manager, or
set for the whole process by setting the environment variable ``JAX_PLATFORMS``
or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu"
(``JAX_PLATFORMS`` can also be a list of platforms, which determines which
platforms are available in priority order).

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device_buffer.device()) # doctest: +SKIP
Expand All @@ -148,24 +151,24 @@ gpu:2

Computations involving some committed inputs will happen on the
committed device and the result will be committed on the
same device. Invoking an operation on arguments that are committed
same device. Invoking an operation on arguments that are committed
to more than one device will raise an error.

You can also use :func:`jax.device_put` without a ``device`` parameter. If the data
is already on a device (committed or not), it's left as-is. If the data isn't on any
device—that is, it's a regular Python or NumPy value—it's placed uncommitted on the default
You can also use :func:`jax.device_put` without a ``device`` parameter. If the data
is already on a device (committed or not), it's left as-is. If the data isn't on any
device—that is, it's a regular Python or NumPy value—it's placed uncommitted on the default
device.

Jitted functions behave like any other primitive operations—they will follow the
Jitted functions behave like any other primitive operations—they will follow the
data and will show errors if invoked on data committed on more than one device.

``jnp.device_put(jnp.zeros(...), jax.devices()[1])`` or similar will actually create the
array of zeros on ``jax.devices()[1]``, instead of creating the array on the default
device then moving it. This is thanks to some laziness in array creation, which holds
for all the constant creation operations (``ones``, ``full``, ``eye``, etc).

(As of April 2020, :func:`jax.jit` has a `device` parameter that affects the device
placement. That parameter is experimental, is likely to be removed or changed,
(As of April 2020, :func:`jax.jit` has a `device` parameter that affects the device
placement. That parameter is experimental, is likely to be removed or changed,
and its use is not recommended.)

For a worked-out example, we recommend reading through
Expand Down Expand Up @@ -259,7 +262,7 @@ Broadly speaking:
they are dispatched asynchronously (see :ref:`async-dispatch`); and they can
be executed on CPU, GPU, or TPU, each of which have vastly different and continuously
evolving performance characteristics.

These architectural differences make meaningful direct benchmark comparisons between
NumPy and JAX difficult.

Expand Down
1 change: 1 addition & 0 deletions docs/jax.config.rst
Expand Up @@ -11,6 +11,7 @@ JAX configuration
checking_leaks
debug_nans
debug_infs
default_device
default_matmul_precision
default_prng_impl
enable_checks
Expand Down
1 change: 1 addition & 0 deletions jax/__init__.py
Expand Up @@ -53,6 +53,7 @@
transfer_guard_host_to_device as transfer_guard_host_to_device,
transfer_guard_device_to_device as transfer_guard_device_to_device,
transfer_guard_device_to_host as transfer_guard_device_to_host,
default_device as default_device,
)
from .core import eval_context as ensure_compile_time_eval
from jax._src.api import (
Expand Down
9 changes: 8 additions & 1 deletion jax/_src/api.py
Expand Up @@ -456,6 +456,13 @@ def _cpp_jit(
raise ValueError("can't specify both a device and a backend for jit, "
f"got device={device} and backend={backend}.")

if device is not None:
jit_device = device
elif backend is not None:
jit_device = xb.get_backend(backend).get_default_device_assignment(1)[0]
else:
jit_device = None

@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
Expand Down Expand Up @@ -531,7 +538,7 @@ def get_device_info():
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
cache=_cpp_jit_cache)
cache=_cpp_jit_cache, jit_device=jit_device)
f_jitted = wraps(fun)(cpp_jitted_f)

f_jitted.lower = _jit_lower(fun, static_argnums, static_argnames, device,
Expand Down
90 changes: 78 additions & 12 deletions jax/_src/config.py
Expand Up @@ -27,6 +27,7 @@
from jax._src import lib
from jax._src.lib import jax_jit
from jax._src.lib import transfer_guard_lib
from jax._src.lib import xla_client

def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
Expand Down Expand Up @@ -321,6 +322,45 @@ def define_string_state(
updated value of the thread-local state when it is altered or set
initially.
Returns:
A contextmanager to control the thread-local state value.
"""
def validate(new_val):
if new_val is not None and not isinstance(new_val, str):
raise ValueError(f"new string config value must be None or of type str,"
f" got {new_val} of type {type(new_val)}.")

return self.define_string_or_object_state(
name, default, help, update_global_hook, update_thread_local_hook,
validate)

def define_string_or_object_state(
self, name: str, default: Any, help: str,
update_global_hook: Optional[Callable[[Any], None]] = None,
update_thread_local_hook: Optional[Callable[[Any], None]] = None,
validate_new_val_hook: Optional[Callable[[Any], None]] = None):
"""Set up thread-local state and return a contextmanager for managing it.
Similar to ``define_string_state``, except the context manager will accept
any object, not just a string. Any value passed via commandline flag or
environment variable will be treated as a string.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: string, a default value for the option.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
update_global_hook: an optional callback that is called with the updated
value of the global state when it is altered or set initially.
update_thread_local_hook: an optional callback that is called with the
updated value of the thread-local state when it is altered or set
initially.
validate_new_val_hook: an optional callback that is called with the new
value on any update, and should raise an error if the new value is
invalid.
Returns:
A contextmanager to control the thread-local state value.
"""
Expand All @@ -335,12 +375,8 @@ def get_state(self):
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))

def validate(new_val):
if new_val is not None and not isinstance(new_val, str):
raise ValueError(f"new string config value must be None or of type str,"
f" got {new_val} of type {type(new_val)}.")

return _StateContextManager(name, help, update_thread_local_hook, validate)
return _StateContextManager(name, help, update_thread_local_hook,
validate_new_val_hook)

def _trace_context(self):
"""Returns a tuple of configuration values that affect tracing.
Expand Down Expand Up @@ -435,19 +471,20 @@ def __setattr__(self, name, val):
# a global/thread-local state. These methods allow updates to part of the
# state when a configuration value changes.

class GlobalJitState(NamedTuple):
class GlobalExtraJitContext(NamedTuple):
numpy_rank_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
dynamic_shapes: bool = False


def update_global_jit_state(**kw):
gs = jax_jit.global_state()
context = gs.extra_jit_context or GlobalJitState()
gs.extra_jit_context = context._replace(**kw)
if gs.extra_jit_context is None:
gs.extra_jit_context = GlobalExtraJitContext()
gs.extra_jit_context = gs.extra_jit_context._replace(**kw)


class ThreadLocalJitState(NamedTuple):
class ThreadLocalExtraJitContext(NamedTuple):
dynamic_trace_state: Optional[Any] = None
numpy_rank_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
Expand All @@ -456,8 +493,9 @@ class ThreadLocalJitState(NamedTuple):

def update_thread_local_jit_state(**kw):
tls = jax_jit.thread_local_state()
context = tls.extra_jit_context or ThreadLocalJitState()
tls.extra_jit_context = context._replace(**kw)
if tls.extra_jit_context is None:
tls.extra_jit_context = ThreadLocalExtraJitContext()
tls.extra_jit_context = tls.extra_jit_context._replace(**kw)


# TODO(mattjj): remove all uses of this flag
Expand Down Expand Up @@ -639,6 +677,34 @@ def _update_x64_thread_local(val):

Config.x64_enabled = Config.jax_enable_x64 # type: ignore


def _update_default_device_global(val):
lib.jax_jit.global_state().default_device = val

def _update_default_device_thread_local(val):
lib.jax_jit.thread_local_state().default_device = val

def _validate_default_device(val):
if val is not None and not isinstance(val, xla_client.Device):
raise ValueError("jax.default_device must be passed a Device object (e.g. "
f"`jax.devices('cpu')[0]`), got: {repr(val)}")

# TODO(skye): default_device only accepts devices for now. Make it work with
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
default_device = config.define_string_or_object_state(
name='jax_default_device',
default=None,
help=(
'Configure the default device for JAX operations. '
'Set to a Device object (e.g. ``jax.devices("cpu")[0]``) to use that Device as the '
'default device for JAX operations and jit\'d function calls (there is '
'no effect on multi-device computations, e.g. pmapped function calls). '
'Set to None to use the system default device. See '
':ref:`faq-data-placement` for more information on device placement.'),
update_global_hook=_update_default_device_global,
update_thread_local_hook=_update_default_device_thread_local,
validate_new_val_hook=_validate_default_device)

def _update_disable_jit_global(val):
lib.jax_jit.global_state().disable_jit = val

Expand Down
20 changes: 13 additions & 7 deletions jax/_src/dispatch.py
Expand Up @@ -438,21 +438,27 @@ def initial_style_primitive_replicas(params):
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1)


def _xla_callable_device(nreps, backend, device, arg_devices):
def _xla_callable_device(nreps, backend, device, arg_devices
) -> Optional[Device]:
if nreps > 1:
if device is not None or backend is not None:
raise ValueError(f"can't specify device or backend for jit-of-pmap, "
f"got device={device} and backend={backend}")
return None
else:
if device is None and backend is None:
return _device_from_arg_devices(arg_devices)
elif device is not None and backend is None:
# TODO(skye): dedup with C++ jit logic for determining jit device?
if device is not None:
assert backend is None
return device
elif device is None and backend is not None:

if backend is not None:
return xb.get_backend(backend).get_default_device_assignment(1)[0]
else:
assert False # Unreachable given the error check in _xla_callable

arg_device = _device_from_arg_devices(arg_devices)
if arg_device is not None:
return arg_device

return config.jax_default_device


# Argument and result handlers
Expand Down
7 changes: 4 additions & 3 deletions jax/linear_util.py
Expand Up @@ -273,10 +273,11 @@ def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
config.x64_enabled, config._trace_context())
config.x64_enabled, config.jax_default_device,
config._trace_context())
else:
key = (fun.transforms, fun.params, fun.in_type, args,
config.x64_enabled, config._trace_context())
key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled,
config.jax_default_device, config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
Expand Down
49 changes: 49 additions & 0 deletions tests/api_test.py
Expand Up @@ -228,6 +228,55 @@ def test_jit_device(self):
self.assertIsInstance(x, jnp.DeviceArray)
self.assertEqual(x.device_buffer.device(), device)

@jtu.skip_on_devices("cpu")
def test_jit_default_device(self):
if jax.device_count() == 1:
raise unittest.SkipTest("Test requires multiple devices")

system_default_device = jax.numpy.add(1, 1).device()
test_device = jax.devices()[-1]
self.assertNotEqual(system_default_device, test_device)

f = jax.jit(lambda x: x + 1)
self.assertEqual(f(1).device(), system_default_device)

with jax.default_device(test_device):
self.assertEqual(jax.numpy.add(1, 1).device(), test_device)
self.assertEqual(f(1).device(), test_device)

self.assertEqual(jax.numpy.add(1, 1).device(), system_default_device)
self.assertEqual(f(1).device(), system_default_device)

with jax.default_device(test_device):
# Explicit `device` or `backend` argument to jit overrides default_device
self.assertEqual(jax.jit(f, device=system_default_device)(1).device(),
system_default_device)
self.assertEqual(jax.jit(f, backend="cpu")(1).platform(), "cpu")

# Sticky input device overrides default_device
sticky = jax.device_put(1, system_default_device)
self.assertEqual(jax.numpy.add(sticky, 1).device(), system_default_device)
self.assertEqual(f(sticky).device(), system_default_device)

# Test nested default_devices
with jax.default_device(system_default_device):
self.assertEqual(f(1).device(), system_default_device)
self.assertEqual(f(1).device(), test_device)

# Test a few more non-default_device calls for good luck
self.assertEqual(jax.numpy.add(1, 1).device(), system_default_device)
self.assertEqual(f(sticky).device(), system_default_device)
self.assertEqual(f(1).device(), system_default_device)

# TODO(skye): make this work!
def test_jit_default_platform(self):
with self.assertRaisesWithLiteralMatch(
ValueError,
"jax.default_device must be passed a Device object "
"(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"):
with jax.default_device("cpu"):
jax.jit(lambda x: x + 1)(1)

def test_complex_support(self):
self.assertEqual(self.jit(lambda x: x + 1)(1 + 1j), 2 + 1j)

Expand Down

0 comments on commit f26b866

Please sign in to comment.