Skip to content

Commit

Permalink
Use XLA extension tokens instead of output tokens
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 465389589
  • Loading branch information
sharadmv authored and jax authors committed Aug 4, 2022
1 parent 9185509 commit c5d4eb5
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 26 deletions.
56 changes: 40 additions & 16 deletions jax/_src/dispatch.py
Expand Up @@ -50,6 +50,7 @@
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
import jax._src.util as util
from jax._src.util import flatten, unflatten
from etils import epath
Expand Down Expand Up @@ -107,10 +108,14 @@ def apply_primitive(prim, *args, **params):
class RuntimeTokenSet(threading.local):
tokens: Dict[core.Effect, Tuple[RuntimeToken, Device]]
output_tokens: Dict[Device, RuntimeToken]
output_runtime_tokens: Dict[Device, RuntimeToken]

def __init__(self):
self.tokens = {}
# TODO(sharadmv): remove redundant output token dictionary when minimum
# jaxlib version is bumped to 0.3.16.
self.output_tokens = {}
self.output_runtime_tokens = {}

def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken:
if eff not in self.tokens:
Expand All @@ -131,17 +136,22 @@ def set_output_token(self, device: Device, token: RuntimeToken):
# we'd need to store a set of output tokens.
self.output_tokens[device] = token

def set_output_runtime_token(self, device: Device, token: RuntimeToken):
# TODO(sharadmv): remove this method when minimum jaxlib version is bumped
self.output_runtime_tokens[device] = token

def clear(self):
self.tokens = {}
self.output_tokens = {}
self.output_runtime_tokens = {}

def block_until_ready(self):
for t, _ in self.tokens.values():
t[0].block_until_ready()
# TODO(sharadmv): use a runtime mechanism to block on computations instead
# of using output tokens.
for t in self.output_tokens.values():
t[0].block_until_ready()
for token, _ in self.tokens.values():
token[0].block_until_ready()
for token in self.output_tokens.values():
token[0].block_until_ready()
for token in self.output_runtime_tokens.values():
token.block_until_ready()

runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()

Expand Down Expand Up @@ -703,12 +713,17 @@ def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects]
tokens_flat = flatten(tokens)
input_bufs = [*tokens_flat, *input_bufs]
def _remove_tokens(output_bufs):
token_bufs, output_bufs = util.split_list(
output_bufs, [has_unordered_effects + len(ordered_effects)])
def _remove_tokens(output_bufs, runtime_token):
# TODO(sharadmv): simplify when minimum jaxlib version is bumped
num_output_tokens = len(ordered_effects) + (xla_extension_version < 81 and
has_unordered_effects)
token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens])
if has_unordered_effects:
output_token_buf, *token_bufs = token_bufs
runtime_tokens.set_output_token(device, output_token_buf)
if xla_extension_version >= 81:
runtime_tokens.set_output_runtime_token(device, runtime_token)
else:
output_token_buf, *token_bufs = token_bufs
runtime_tokens.set_output_token(device, output_token_buf)
for eff, token_buf in zip(ordered_effects, token_bufs):
runtime_tokens.update_token(eff, token_buf)
return output_bufs
Expand All @@ -727,13 +742,19 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
in_flat = flatten(device_put(x, device) for i, x in enumerate(args)
if i in kept_var_idx)
if has_unordered_effects or ordered_effects:
in_flat, token_handler = _add_tokens(has_unordered_effects, ordered_effects,
device, in_flat)
out_flat = compiled.execute(in_flat)
in_flat, token_handler = _add_tokens(
has_unordered_effects, ordered_effects, device, in_flat)
if xla_extension_version >= 81:
out_flat, runtime_token = compiled.execute_with_token(in_flat)
else:
out_flat = compiled.execute(in_flat)
runtime_token = None
else:
out_flat = compiled.execute(in_flat)
check_special(name, out_flat)
out_bufs = unflatten(out_flat, output_buffer_counts)
if ordered_effects or has_unordered_effects:
out_bufs = token_handler(out_bufs)
out_bufs = token_handler(out_bufs, runtime_token)
return result_handler(env, out_bufs)


Expand Down Expand Up @@ -934,7 +955,10 @@ def from_xla_computation(name: str, xla_computation: Optional[ir.Module],
host_callbacks)
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
if ordered_effects or has_unordered_effects:
num_output_tokens = len(ordered_effects) + has_unordered_effects
num_output_tokens = len(ordered_effects)
# TODO(sharadmv): remove check when minimum jaxlib version is bumped
if xla_extension_version < 81:
num_output_tokens += has_unordered_effects
buffer_counts = ([1] * num_output_tokens) + buffer_counts
execute = _execute_compiled if nreps == 1 else _execute_replicated
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811
Expand Down
4 changes: 3 additions & 1 deletion jax/interpreters/mlir.py
Expand Up @@ -41,6 +41,7 @@
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import source_info_util
import jax._src.util as util
from jax.config import config
Expand Down Expand Up @@ -615,7 +616,8 @@ def lower_jaxpr_to_module(
lower_jaxpr_to_fun(
ctx, "main", jaxpr, ordered_effects, public=True, create_tokens=True,
replace_tokens_with_dummy=True,
num_output_tokens=1 if unordered_effects else 0,
num_output_tokens=(
1 if (unordered_effects and xla_extension_version < 81) else 0),
replicated_args=replicated_args,
arg_shardings=arg_shardings, result_shardings=result_shardings,
input_output_aliases=input_output_aliases)
Expand Down
24 changes: 19 additions & 5 deletions jax/interpreters/pxla.py
Expand Up @@ -1707,12 +1707,26 @@ def __init__(self, xla_executable, backend, in_handler: InputsHandler,
@profiler.annotate_function
def __call__(self, *args):
input_bufs = self.in_handler(args)
out_bufs = self.xla_executable.execute_sharded_on_local_devices(input_bufs)
if self.has_unordered_effects:
token_bufs, *out_bufs = out_bufs
for i, device in enumerate(self.xla_executable.local_devices()):
token = (token_bufs[i],)
dispatch.runtime_tokens.set_output_token(device, token)
# TODO(sharadmv): simplify this logic when minimum jaxlib version is
# bumped
if xla_extension_version >= 81:
out_bufs, runtime_tokens = (
self.xla_executable.execute_sharded_on_local_devices_with_tokens(
input_bufs))
for device, token in zip(
self.xla_executable.local_devices(), runtime_tokens):
dispatch.runtime_tokens.set_output_runtime_token(device, token)
else:
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
input_bufs)
token_bufs, *out_bufs = out_bufs
for i, device in enumerate(self.xla_executable.local_devices()):
token = (token_bufs[i],)
dispatch.runtime_tokens.set_output_token(device, token)
else:
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
input_bufs)
if dispatch.needs_check_special():
for bufs in out_bufs:
dispatch.check_special("parallel computation", bufs)
Expand Down
5 changes: 5 additions & 0 deletions tests/debugging_primitives_test.py
Expand Up @@ -29,6 +29,7 @@
from jax.experimental import maps
from jax.experimental import pjit
from jax._src import debugging
from jax._src import dispatch
from jax._src import lib as jaxlib
from jax._src import test_util as jtu
import jax.numpy as jnp
Expand Down Expand Up @@ -67,6 +68,10 @@ def tearDownModule():

class DebugPrintTest(jtu.JaxTestCase):

def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()

@jtu.skip_on_devices(*disabled_backends)
def test_simple_debug_print_works_in_eager_mode(self):
def f(x):
Expand Down
12 changes: 8 additions & 4 deletions tests/jaxpr_effects_test.py
Expand Up @@ -456,10 +456,14 @@ def f(x):

# First output should be output token
result_types = mhlo.body.operations[0].type.results
self.assertLen(list(result_types), 2)
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
self.assertLen(list(result_types), 2)
self.assertEqual(str(result_types[1]), 'tensor<f32>')
if jaxlib.version < (0, 3, 16):
self.assertLen(list(result_types), 2)
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
self.assertLen(list(result_types), 2)
self.assertEqual(str(result_types[1]), 'tensor<f32>')
else:
self.assertLen(list(result_types), 1)
self.assertEqual(str(result_types[0]), 'tensor<f32>')

def test_lowered_jaxpr_with_ordered_effects_takes_in_dummy_inputs(self):
@jax.jit
Expand Down

0 comments on commit c5d4eb5

Please sign in to comment.