From 8dfbf90602f92be4f936a3300ae1fd6dbb5af2e2 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 29 Nov 2023 04:02:30 -0800 Subject: [PATCH] [Pallas/Mosaic] Add support for barrier semaphores PiperOrigin-RevId: 586289340 --- jax/_src/pallas/mosaic/__init__.py | 1 + jax/_src/pallas/mosaic/lowering.py | 9 +++++++- .../pallas/mosaic/pallas_call_registration.py | 6 ++++- jax/_src/pallas/mosaic/primitives.py | 23 +++++++++++++++---- jax/experimental/pallas/tpu.py | 1 + 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic/__init__.py b/jax/_src/pallas/mosaic/__init__.py index 18d724f07424..00003a35036a 100644 --- a/jax/_src/pallas/mosaic/__init__.py +++ b/jax/_src/pallas/mosaic/__init__.py @@ -26,6 +26,7 @@ from jax._src.pallas.mosaic.primitives import async_copy from jax._src.pallas.mosaic.primitives import async_remote_copy from jax._src.pallas.mosaic.primitives import device_id +from jax._src.pallas.mosaic.primitives import get_barrier_semaphore from jax._src.pallas.mosaic.primitives import make_async_copy from jax._src.pallas.mosaic.primitives import make_async_remote_copy from jax._src.pallas.mosaic.primitives import repeat diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 37d17464d46e..13865d848aca 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1727,7 +1727,10 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, semaphore, value): sem_aval = ctx.avals_in[0] assert isinstance(sem_aval, tpu_core.AbstractSemaphore) - assert sem_aval.sem_type is tpu_core.SemaphoreType.REGULAR + assert sem_aval.sem_type in { + tpu_core.SemaphoreType.REGULAR, + tpu_core.SemaphoreType.BARRIER, + } assert ctx.avals_in[1].dtype == jnp.dtype('int32') return tpu.SemaphoreWaitOp(semaphore, value).results lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule @@ -1803,3 +1806,7 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: str): col = _make_index(axis_names.index(axis_name)) return memref.LoadOp(l_to_m, [device_id, col]).result lowering_rules[lax.axis_index_p] = _axis_index_rule + +def _get_barrier_semaphore_rule(ctx: LoweringRuleContext): + return tpu.GetBarrierSemaphoreOp().result +lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index f6ce80fa5b0f..1d7dcde979cf 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -84,7 +84,11 @@ def _lower_fun(*args): kernel_regeneration_metadata=kernel_regeneration_metadata, cost_estimate=mosaic_params.get("cost_estimate", None), flags=mosaic_params.get("flags", None), - )(*extra_args, *args) + )( + *extra_args, + *args, + collective_id=mosaic_params.get("collective_id", None), + ) return mlir.lower_fun(_lower_fun, multiple_results=True)( ctx, *in_nodes) mlir.register_lowering(pallas_call_p, pallas_call_tpu_lowering_rule, diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index b09e32f9f4bb..f7fb5108470d 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -149,8 +149,11 @@ def _semaphore_signal_abstract_eval( del device_id_tree if not isinstance(sem_aval, tpu_core.AbstractSemaphore): raise ValueError(f"Cannot signal on a non-semaphore value: {sem_aval}") - if sem_aval.sem_type is not tpu_core.SemaphoreType.REGULAR: - raise ValueError("Must signal a REGULAR semaphore.") + if sem_aval.sem_type not in { + tpu_core.SemaphoreType.REGULAR, + tpu_core.SemaphoreType.BARRIER, + }: + raise ValueError("Must signal a REGULAR or BARRIER semaphore.") if value.dtype != jnp.dtype("int32"): raise ValueError("Must signal an int32 value.") if has_device_id: @@ -171,8 +174,11 @@ def semaphore_wait(sem, dec: int | jax.Array = 1): def _semaphore_wait_abstract_eval(sem_aval: tpu_core.AbstractSemaphore, value): if not isinstance(sem_aval, tpu_core.AbstractSemaphore): raise ValueError(f"Cannot wait on a non-semaphore value: {sem_aval}") - if sem_aval.sem_type is not tpu_core.SemaphoreType.REGULAR: - raise ValueError("Must wait a REGULAR semaphore.") + if sem_aval.sem_type not in { + tpu_core.SemaphoreType.REGULAR, + tpu_core.SemaphoreType.BARRIER, + }: + raise ValueError("Must wait on a REGULAR or BARRIER semaphore.") if value.dtype != jnp.dtype("int32"): raise ValueError("Must signal an int32 value.") return [] @@ -389,3 +395,12 @@ def _device_id_abstract_eval(): return jax_core.ShapedArray((), jnp.dtype("int32")) device_id = device_id_p.bind + +get_barrier_semaphore_p = jax_core.Primitive('get_barrier_semaphore') + +@get_barrier_semaphore_p.def_abstract_eval +def _get_barrier_semaphore_abstract_eval(): + return tpu_core.AbstractSemaphore(tpu_core.SemaphoreType.BARRIER) + +def get_barrier_semaphore(): + return get_barrier_semaphore_p.bind() diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 37b3dd6700e2..9fd0e0e2ec33 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -26,6 +26,7 @@ from jax._src.pallas.mosaic import device_id from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata +from jax._src.pallas.mosaic import get_barrier_semaphore from jax._src.pallas.mosaic import make_async_copy from jax._src.pallas.mosaic import make_async_remote_copy from jax._src.pallas.mosaic import repeat