Skip to content

Commit

Permalink
Introduce class PyArray that contains the data members of python Array.
Browse files Browse the repository at this point in the history
A few key methods is implemented in C++ while the rest are still implmemented in python and added to the class later. A class decorator, @use_cpp_array, is added to add python methods to xc.Array.

PiperOrigin-RevId: 473075244
  • Loading branch information
cky9301 authored and jax authors committed Sep 8, 2022
1 parent a2e05b0 commit 0400db9
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 74 deletions.
112 changes: 77 additions & 35 deletions jax/_src/api.py
Expand Up @@ -477,10 +477,12 @@ class _BackendAndDeviceInfo(NamedTuple):
class _FastpathData(NamedTuple):
xla_executable: xla.XlaExecutable
out_pytree_def: Any
sticky_device: xc.Device
sticky_device: Optional[xc.Device]
avals: Iterable[Any]
lazy_exprs: Iterable[Any]
kept_var_bitvec: Iterable[bool]
shardings: Iterable[Any]
committed: Iterable[bool]

_cpp_jit_cache = jax_jit.CompiledFunctionCache()

Expand All @@ -489,6 +491,74 @@ def _cpp_jit_clear_cache(self):
self._clear_cache()
dispatch.xla_callable.evict_function(self._fun)

def _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
use_fastpath = (
xc._version >= 92 and
# This is if we have already executed this code-path (most-recent entry
# has been reset to None). Thus, we do not support the fast-path.
execute is not None and
type(execute) is pxla.ExecuteReplicated and
# No effects in computation
not execute.ordered_effects and
not execute.has_unordered_effects and
not execute.has_host_callbacks and
all(isinstance(x, xc.Array) for x in out_flat) and
# Not supported: dynamic shapes
not jax.config.jax_dynamic_shapes
# TODO(chky): Check sharding is SingleDeviceSharding
)

if use_fastpath:
sticky_device = None
lazy_exprs = [None] * len(out_flat)
kept_var_bitvec = [i in execute.kept_var_idx for i in range(len(args_flat))]
avals = [out.aval for out in out_flat]
shardings = [out.sharding for out in out_flat]
committed = [out._committed for out in out_flat]

return _FastpathData(execute.xla_executable, out_pytree_def, sticky_device,
avals, lazy_exprs, kept_var_bitvec, shardings,
committed)

return None

def _device_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
# TODO(sharadmv): Clean up usage of `execute.args`
use_fastpath = (
# This is if we have already executed this code-path (most-recent entry
# has been reset to None). Thus, we do not support the fast-path.
execute is not None and
execute.func is dispatch._execute_compiled and # not trivial, not pmap
# No effects in computation
not execute.args[5] and not execute.args[6] and
# Has no host callbacks
not execute.args[8] and
# Not supported: ShardedDeviceArray
all(device_array.type_is_device_array(x) for x in out_flat) and
# Not supported: dynamic shapes
not jax.config.jax_dynamic_shapes
and type(execute.args[4]) is dispatch.SimpleResultHandler)

### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
(_, xla_executable, _, _, result_handlers, _, _, kept_var_idx,
_) = execute.args # pytype: disable=attribute-error
sticky_device = None
avals = []
lazy_exprs = [None] * len(result_handlers)
for result_handler in result_handlers:
aval, sticky_device = result_handler.args
avals.append(aval)
assert len(avals) == len(out_flat)
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
shardings = []
committed = []

return _FastpathData(xla_executable, out_pytree_def, sticky_device, avals,
lazy_exprs, kept_var_bitvec, shardings, committed)

return None

def _cpp_jit(
fun: Callable,
*,
Expand Down Expand Up @@ -539,42 +609,14 @@ def cache_miss(*args, **kwargs):
# outputs that could be tracers (if f is capturing `Tracer` by closure).
execute: Optional[functools.partial] = (
dispatch.xla_callable.most_recent_entry())

fastpath_data = None

# TODO(sharadmv): Enable fast path for effectful jaxprs
# TODO(sharadmv): Clean up usage of `execute.args`
use_fastpath = (
not jax.config.jax_array and
# This is if we have already executed this code-path (most-recent entry
# has been reset to None). Thus, we do not support the fast-path.
execute is not None and
execute.func is dispatch._execute_compiled and # not trivial, not pmap
# No effects in computation
not execute.args[5] and
not execute.args[6] and
# Has no host callbacks
not execute.args[8] and
# Not supported: ShardedDeviceArray
all(device_array.type_is_device_array(x) for x in out_flat) and
# Not supported: dynamic shapes
not jax.config.jax_dynamic_shapes and
type(execute.args[4]) is dispatch.SimpleResultHandler
)
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
(_, xla_executable,
_, _, result_handlers, _, _, kept_var_idx, _) = execute.args # pytype: disable=attribute-error
sticky_device = None
avals = []
lazy_exprs = [None] * len(result_handlers)
for result_handler in result_handlers:
aval, sticky_device = result_handler.args
avals.append(aval)
assert len(avals) == len(out_flat)
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
fastpath_data = _FastpathData(xla_executable, out_pytree_def,
sticky_device, avals, lazy_exprs,
kept_var_bitvec)
if jax.config.jax_array:
fastpath_data = _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat)
else:
fastpath_data = None
fastpath_data = _device_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat)

return out, fastpath_data

Expand Down
10 changes: 10 additions & 0 deletions jax/_src/config.py
Expand Up @@ -652,10 +652,20 @@ def update_thread_local_jit_state(**kw):
default=False,
help='If True, pjit will output GDAs.')

def _update_jax_array_global(val):
if lib.xla_extension_version >= 92:
lib.jax_jit.global_state().jax_array = val

def _update_jax_array_thread_local(val):
if lib.xla_extension_version >= 92:
lib.jax_jit.thread_local_state().jax_array = val

jax_array = config.define_bool_state(
name='jax_array',
default=False,
upgrade=True,
update_global_hook = _update_jax_array_global,
update_thread_local_hook = _update_jax_array_thread_local,
help=('If True, new pjit behavior will be enabled and `jax.Array` will be '
'used.'))

Expand Down

0 comments on commit 0400db9

Please sign in to comment.