Skip to content

Commit

Permalink
[Pallas][Mosaic] Expose semaphore read.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623593440
  • Loading branch information
bythew3i authored and rajasekharporeddy committed Apr 12, 2024
1 parent ac1c1da commit 6f88294
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 22 deletions.
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import roll
from jax._src.pallas.mosaic.primitives import run_scoped
from jax._src.pallas.mosaic.primitives import semaphore_read
from jax._src.pallas.mosaic.primitives import semaphore_signal
from jax._src.pallas.mosaic.primitives import semaphore_wait
from jax._src.pallas.mosaic.primitives import trace
Expand Down
14 changes: 14 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,6 +2095,20 @@ def _linearize_mesh_indices(*indices):
return device_id
raise NotImplementedError(f"Unsupported device id type: {device_id_type}")


def _semaphore_read_lowering_rule(
ctx: LoweringRuleContext,
*args,
args_tree,
):
sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
sem, indexers = tree_util.tree_unflatten(args_tree, args)
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
return tpu.SemaphoreReadOp(sem).result


lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule

def _semaphore_signal_lowering_rule(
ctx: LoweringRuleContext,
*args,
Expand Down
60 changes: 38 additions & 22 deletions jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,42 @@ class DeviceIdType(enum.Enum):
LOGICAL = "logical"


def check_sem_avals(sem_aval, sem_indexers_avals, name):
if not isinstance(sem_aval, state.AbstractRef):
raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}")
sem_shape = sem_aval.shape
if sem_indexers_avals:
sem_shape = sem_indexers_avals[-1].get_indexer_shape()
if sem_shape:
raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}")
sem_dtype = sem_aval.dtype
if not (
jnp.issubdtype(sem_dtype, tpu_core.semaphore)
or jnp.issubdtype(sem_dtype, tpu_core.barrier_semaphore)
):
raise ValueError(f"Must {name} a REGULAR or BARRIER semaphore: {sem_dtype}")


semaphore_read_p = jax_core.Primitive("semaphore_read")
semaphore_read_p.multiple_results = False


def semaphore_read(sem_or_view):
ref, indexers = _get_ref_and_indexers(sem_or_view)
args = [ref, indexers]
flat_args, args_tree = tree_util.tree_flatten(args)
return semaphore_read_p.bind(*flat_args, args_tree=args_tree)

@semaphore_read_p.def_abstract_eval
def _semaphore_read_abstract_eval(
*avals,
args_tree,
):
sem_aval, sem_indexers_avals = tree_util.tree_unflatten(args_tree, avals)
check_sem_avals(sem_aval, sem_indexers_avals, "read")
return jax_core.ShapedArray((), jnp.dtype("int32"))


semaphore_signal_p = jax_core.Primitive('semaphore_signal')
semaphore_signal_p.multiple_results = True

Expand Down Expand Up @@ -254,17 +290,7 @@ def _semaphore_signal_abstract_eval(
sem_aval, sem_indexers_avals, value_aval, device_id_avals = (
tree_util.tree_unflatten(args_tree, avals)
)
if not isinstance(sem_aval, state.AbstractRef):
raise ValueError(f"Cannot signal on a non-Ref: {sem_aval}")
sem_shape = sem_aval.shape
if sem_indexers_avals:
sem_shape = sem_indexers_avals[-1].get_indexer_shape()
if sem_shape:
raise ValueError(f"Cannot signal on a non-()-shaped semaphore: {sem_shape}")
sem_dtype = sem_aval.dtype
if not (jnp.issubdtype(sem_dtype, tpu_core.semaphore) or jnp.issubdtype(
sem_dtype, tpu_core.barrier_semaphore)):
raise ValueError(f"Must signal a REGULAR or BARRIER semaphore: {sem_dtype}")
check_sem_avals(sem_aval, sem_indexers_avals, "signal")
if value_aval.dtype != jnp.dtype("int32"):
raise ValueError("Must signal an int32 value.")
if device_id_avals is not None:
Expand Down Expand Up @@ -319,17 +345,7 @@ def semaphore_wait(sem_or_view, dec: int | jax.Array = 1):
@semaphore_wait_p.def_abstract_eval
def _semaphore_wait_abstract_eval(*avals, args_tree):
sem_aval, sem_indexers_avals, value_aval = tree_util.tree_unflatten(args_tree, avals)
if not isinstance(sem_aval, state.AbstractRef):
raise ValueError(f"Cannot wait on a non-semaphore Ref: {sem_aval}")
sem_shape = sem_aval.shape
if sem_indexers_avals:
sem_shape = sem_indexers_avals[-1].get_indexer_shape()
if sem_shape:
raise ValueError(f"Cannot wait on a non-()-shaped semaphore: {sem_shape}")
sem_dtype = sem_aval.dtype
if not (jnp.issubdtype(sem_dtype, tpu_core.semaphore) or jnp.issubdtype(
sem_dtype, tpu_core.barrier_semaphore)):
raise ValueError(f"Must wait a REGULAR or BARRIER semaphore: {sem_dtype}")
check_sem_avals(sem_aval, sem_indexers_avals, "wait")
if value_aval.dtype != jnp.dtype("int32"):
raise ValueError("Must wait an int32 value.")
return []
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/pallas/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from jax._src.pallas.mosaic import roll
from jax._src.pallas.mosaic import run_scoped
from jax._src.pallas.mosaic import semaphore
from jax._src.pallas.mosaic import semaphore_read
from jax._src.pallas.mosaic import semaphore_signal
from jax._src.pallas.mosaic import semaphore_wait
from jax._src.pallas.mosaic import trace
Expand Down
6 changes: 6 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,12 @@ def TPU_DeviceIdOp : TPU_Op<"device_id", [Pure]> {
let assemblyFormat = [{ attr-dict `:` type($result) }];
}

def TPU_SemaphoreReadOp : TPU_Op<"sem_read"> {
let arguments = (ins MemRefOf<[TPU_SemaphoreType]>:$semaphore);
let results = (outs I32:$result);
let assemblyFormat = [{ $semaphore attr-dict `:` type($semaphore) `->` type($result)}];
}

def TPU_SemaphoreWaitOp : TPU_Op<"sem_wait"> {
let arguments = (ins
MemRefOf<[TPU_SemaphoreType]>:$semaphore,
Expand Down
25 changes: 25 additions & 0 deletions tests/pallas/pallas_call_tpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,31 @@ def body(sems):
debug=True,
)())

def test_can_read_semaphore(self):
m, n = 2, 3

def kernel(y_ref):
def body(sems):
for r in range(m):
for c in range(n):
v = r * n + c
pltpu.semaphore_signal(sems.at[r, c],v)
y_ref[r, c] = pltpu.semaphore_read(sems.at[r, c])
pltpu.semaphore_wait(sems.at[r, c], v)

pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((m, n)))

y = jax.block_until_ready(
pl.pallas_call(
kernel,
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32),
)()
)
np.testing.assert_array_equal(
y, jnp.arange(m * n).astype(jnp.int32).reshape((m, n))
)

def test_hbm_hbm_dma(self):
def kernel(x_hbm_ref, y_hbm_ref):
def body(sem):
Expand Down

0 comments on commit 6f88294

Please sign in to comment.