Skip to content

Commit

Permalink
Implement the JAX transfer guard API
Browse files Browse the repository at this point in the history
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 428590081
  • Loading branch information
hyeontaek authored and jax authors committed Feb 14, 2022
1 parent f229a70 commit beaa00c
Show file tree
Hide file tree
Showing 6 changed files with 429 additions and 18 deletions.
12 changes: 10 additions & 2 deletions build/build_wheel.py
Expand Up @@ -96,6 +96,10 @@ def copy_file(src_file, dst_dir, dst_filename=None):
"pmap_lib.pyi",
"profiler.pyi",
"pytree.pyi",
"transfer_guard_lib.pyi",
]
_OPTIONAL_XLA_EXTENSION_STUBS = [
"transfer_guard_lib.pyi", # Will be required on xla_extension_version >= 58.
]


Expand All @@ -107,8 +111,12 @@ def patch_copy_xla_extension_stubs(dst_dir):
xla_extension_dir = os.path.join(dst_dir, "xla_extension")
os.makedirs(xla_extension_dir)
for stub_name in _XLA_EXTENSION_STUBS:
with open(r.Rlocation(
"org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name)) as f:
stub_path = r.Rlocation(
"org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name)
stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path).
if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):
continue
with open(stub_path) as f:
src = f.read()
src = src.replace(
"from tensorflow.compiler.xla.python import xla_extension",
Expand Down
6 changes: 5 additions & 1 deletion jax/__init__.py
Expand Up @@ -49,7 +49,11 @@
default_matmul_precision as default_matmul_precision,
default_prng_impl as default_prng_impl,
numpy_rank_promotion as numpy_rank_promotion,
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
transfer_guard as transfer_guard,
transfer_guard_host_to_device as transfer_guard_host_to_device,
transfer_guard_device_to_device as transfer_guard_device_to_device,
transfer_guard_device_to_host as transfer_guard_device_to_host,
)
from .core import eval_context as ensure_compile_time_eval
from jax._src.api import (
Expand Down
36 changes: 22 additions & 14 deletions jax/_src/api.py
Expand Up @@ -84,11 +84,14 @@
from jax.custom_transpose import custom_transpose
from jax.ad_checkpoint import checkpoint_policies

from jax._src.config import (flags, config, bool_env,
disable_jit as _disable_jit,
debug_nans as config_debug_nans,
debug_infs as config_debug_infs,
_thread_local_state as config_thread_local_state)
from jax._src.config import (
flags, config, bool_env,
disable_jit as _disable_jit,
debug_nans as config_debug_nans,
debug_infs as config_debug_infs,
_thread_local_state as config_thread_local_state,
explicit_device_put_scope as config_explicit_device_put_scope,
explicit_device_get_scope as config_explicit_device_get_scope)


traceback_util.register_exclusion(__file__)
Expand Down Expand Up @@ -2750,7 +2753,8 @@ def device_put(x, device: Optional[xc.Device] = None):
Returns:
A copy of ``x`` that resides on ``device``.
"""
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
with config_explicit_device_put_scope():
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)


def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
Expand Down Expand Up @@ -2819,7 +2823,8 @@ def _device_put_sharded(*xs):
for buf in dispatch.device_put(x, d)]
return pxla.make_sharded_device_array(stacked_aval, None, buffers)

return tree_multimap(_device_put_sharded, *shards)
with config_explicit_device_put_scope():
return tree_multimap(_device_put_sharded, *shards)


def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
Expand Down Expand Up @@ -2862,7 +2867,9 @@ def _device_put_replicated(x):
buf, = dispatch.device_put(x, devices[0])
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
return tree_map(_device_put_replicated, x)

with config_explicit_device_put_scope():
return tree_map(_device_put_replicated, x)


# TODO(mattjj): consider revising
Expand Down Expand Up @@ -2907,12 +2914,13 @@ def device_get(x: Any):
- device_put_sharded
- device_put_replicated
"""
for y in tree_leaves(x):
try:
y.copy_to_host_async()
except AttributeError:
pass
return tree_map(_device_get, x)
with config_explicit_device_get_scope():
for y in tree_leaves(x):
try:
y.copy_to_host_async()
except AttributeError:
pass
return tree_map(_device_get, x)

def _check_arg(arg):
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
Expand Down
141 changes: 140 additions & 1 deletion jax/_src/config.py
Expand Up @@ -21,11 +21,13 @@
import os
import sys
import threading
from typing import Any, List, Callable, NamedTuple, Optional
from typing import Any, List, Callable, NamedTuple, Iterator, Optional
import warnings

from jax._src import lib
from jax._src.lib import jax_jit
if lib.xla_extension_version >= 58:
from jax._src.lib import transfer_guard_lib

def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
Expand Down Expand Up @@ -685,3 +687,140 @@ def _update_disable_jit_thread_local(val):
default=False,
help=('Enables experimental features for staging out computations with '
'dynamic shapes.'))

if lib.xla_extension_version < 58:
@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""
yield

@contextlib.contextmanager
def explicit_device_get_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_get() call."""
yield

@contextlib.contextmanager
def _transfer_guard(new_val: str) -> Iterator[None]:
raise NotImplementedError("jaxlib version is too low for transfer guards")

transfer_guard_host_to_device = _transfer_guard
transfer_guard_device_to_device = _transfer_guard
transfer_guard_device_to_host = _transfer_guard
transfer_guard = _transfer_guard

else:
@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""
state = transfer_guard_lib.thread_local_state()
prev = state.explicit_device_put
state.explicit_device_put = True
try:
yield
finally:
state.explicit_device_put = prev

@contextlib.contextmanager
def explicit_device_get_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_get() call."""
state = transfer_guard_lib.thread_local_state()
prev = state.explicit_device_get
state.explicit_device_get = True
try:
yield
finally:
state.explicit_device_get = prev

def _update_transfer_guard(state, key, val):
"""Applies the transfer guard level within transfer_guard_lib."""
if val is None:
setattr(state, key, None)
elif val == 'allow':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW)
elif val == 'log':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG)
elif val == 'disallow':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW)
elif val == 'log_explicit':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT)
elif val == 'disallow_explicit':
setattr(state, key,
transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
else:
assert False, f"Invalid transfer guard level {val}"

transfer_guard_host_to_device = config.define_enum_state(
name='jax_transfer_guard_host_to_device',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for host-to-device transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'host_to_device', val),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'host_to_device', val))

transfer_guard_device_to_device = config.define_enum_state(
name='jax_transfer_guard_device_to_device',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for device-to-device transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'device_to_device', val),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'device_to_device', val))

transfer_guard_device_to_host = config.define_enum_state(
name='jax_transfer_guard_device_to_host',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for device-to-host transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'device_to_host', val),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'device_to_host', val))

def _update_all_transfer_guard_global(val):
for name in ('jax_transfer_guard_host_to_device',
'jax_transfer_guard_device_to_device',
'jax_transfer_guard_device_to_host'):
config.update(name, val)

_transfer_guard = config.define_enum_state(
name='jax_transfer_guard',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard_*.
default=None,
help=(
'Select the transfer guard level for all transfers. This option is '
'set-only; the transfer guard level for a specific direction should '
'be read using the per-transfer direction option. '
'Default is "allow".'),
update_global_hook=_update_all_transfer_guard_global)

@contextlib.contextmanager
def transfer_guard(new_val: str) -> Iterator[None]:
"""Set up thread-local state and return a contextmanager for managing it."""
with contextlib.ExitStack() as stack:
stack.enter_context(transfer_guard_host_to_device(new_val))
stack.enter_context(transfer_guard_device_to_device(new_val))
stack.enter_context(transfer_guard_device_to_host(new_val))
stack.enter_context(_transfer_guard(new_val))
yield
3 changes: 3 additions & 0 deletions jax/_src/lib/__init__.py
Expand Up @@ -152,3 +152,6 @@ def _parse_version(v: str) -> Tuple[int, ...]:
cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")
if not os.path.isdir(cuda_path):
cuda_path = None

if xla_extension_version >= 58:
transfer_guard_lib = xla_client._xla.transfer_guard_lib

0 comments on commit beaa00c

Please sign in to comment.