Skip to content

Commit

Permalink
[Pallas/Mosaic] Add support for barrier semaphores
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586289340
  • Loading branch information
sharadmv authored and jax authors committed Nov 29, 2023
1 parent 5bcf231 commit 8dfbf90
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 6 deletions.
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/__init__.py
Expand Up @@ -26,6 +26,7 @@
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import device_id
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
from jax._src.pallas.mosaic.primitives import make_async_copy
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
from jax._src.pallas.mosaic.primitives import repeat
Expand Down
9 changes: 8 additions & 1 deletion jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -1727,7 +1727,10 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, semaphore,
value):
sem_aval = ctx.avals_in[0]
assert isinstance(sem_aval, tpu_core.AbstractSemaphore)
assert sem_aval.sem_type is tpu_core.SemaphoreType.REGULAR
assert sem_aval.sem_type in {
tpu_core.SemaphoreType.REGULAR,
tpu_core.SemaphoreType.BARRIER,
}
assert ctx.avals_in[1].dtype == jnp.dtype('int32')
return tpu.SemaphoreWaitOp(semaphore, value).results
lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule
Expand Down Expand Up @@ -1803,3 +1806,7 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: str):
col = _make_index(axis_names.index(axis_name))
return memref.LoadOp(l_to_m, [device_id, col]).result
lowering_rules[lax.axis_index_p] = _axis_index_rule

def _get_barrier_semaphore_rule(ctx: LoweringRuleContext):
return tpu.GetBarrierSemaphoreOp().result
lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule
6 changes: 5 additions & 1 deletion jax/_src/pallas/mosaic/pallas_call_registration.py
Expand Up @@ -84,7 +84,11 @@ def _lower_fun(*args):
kernel_regeneration_metadata=kernel_regeneration_metadata,
cost_estimate=mosaic_params.get("cost_estimate", None),
flags=mosaic_params.get("flags", None),
)(*extra_args, *args)
)(
*extra_args,
*args,
collective_id=mosaic_params.get("collective_id", None),
)
return mlir.lower_fun(_lower_fun, multiple_results=True)(
ctx, *in_nodes)
mlir.register_lowering(pallas_call_p, pallas_call_tpu_lowering_rule,
Expand Down
23 changes: 19 additions & 4 deletions jax/_src/pallas/mosaic/primitives.py
Expand Up @@ -149,8 +149,11 @@ def _semaphore_signal_abstract_eval(
del device_id_tree
if not isinstance(sem_aval, tpu_core.AbstractSemaphore):
raise ValueError(f"Cannot signal on a non-semaphore value: {sem_aval}")
if sem_aval.sem_type is not tpu_core.SemaphoreType.REGULAR:
raise ValueError("Must signal a REGULAR semaphore.")
if sem_aval.sem_type not in {
tpu_core.SemaphoreType.REGULAR,
tpu_core.SemaphoreType.BARRIER,
}:
raise ValueError("Must signal a REGULAR or BARRIER semaphore.")
if value.dtype != jnp.dtype("int32"):
raise ValueError("Must signal an int32 value.")
if has_device_id:
Expand All @@ -171,8 +174,11 @@ def semaphore_wait(sem, dec: int | jax.Array = 1):
def _semaphore_wait_abstract_eval(sem_aval: tpu_core.AbstractSemaphore, value):
if not isinstance(sem_aval, tpu_core.AbstractSemaphore):
raise ValueError(f"Cannot wait on a non-semaphore value: {sem_aval}")
if sem_aval.sem_type is not tpu_core.SemaphoreType.REGULAR:
raise ValueError("Must wait a REGULAR semaphore.")
if sem_aval.sem_type not in {
tpu_core.SemaphoreType.REGULAR,
tpu_core.SemaphoreType.BARRIER,
}:
raise ValueError("Must wait on a REGULAR or BARRIER semaphore.")
if value.dtype != jnp.dtype("int32"):
raise ValueError("Must signal an int32 value.")
return []
Expand Down Expand Up @@ -389,3 +395,12 @@ def _device_id_abstract_eval():
return jax_core.ShapedArray((), jnp.dtype("int32"))

device_id = device_id_p.bind

get_barrier_semaphore_p = jax_core.Primitive('get_barrier_semaphore')

@get_barrier_semaphore_p.def_abstract_eval
def _get_barrier_semaphore_abstract_eval():
return tpu_core.AbstractSemaphore(tpu_core.SemaphoreType.BARRIER)

def get_barrier_semaphore():
return get_barrier_semaphore_p.bind()
1 change: 1 addition & 0 deletions jax/experimental/pallas/tpu.py
Expand Up @@ -26,6 +26,7 @@
from jax._src.pallas.mosaic import device_id
from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic import get_barrier_semaphore
from jax._src.pallas.mosaic import make_async_copy
from jax._src.pallas.mosaic import make_async_remote_copy
from jax._src.pallas.mosaic import repeat
Expand Down

0 comments on commit 8dfbf90

Please sign in to comment.