Skip to content

Commit

Permalink
[pallas:gpu] Implement get and swap using load and `masked_swap…
Browse files Browse the repository at this point in the history
…` lowering rules.

PiperOrigin-RevId: 573146382
  • Loading branch information
chr1sj0nes authored and jax authors committed Oct 13, 2023
1 parent 2bc2e17 commit ae2d6e2
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions jax/_src/pallas/triton/lowering.py
Expand Up @@ -719,30 +719,34 @@ def _pack_indices(non_slice_idx, indexed_dims):
def _get_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, *non_slice_idx, indexed_dims
):
ref_block_info, *_ = ctx.block_infos
idx = _pack_indices(non_slice_idx, indexed_dims)
avals_in = ctx.avals_in
idx_avals = _pack_indices(avals_in[1:], indexed_dims)
if not isinstance(ptr.type, tl.pointer_type):
assert len(avals_in) == 1
assert not non_slice_idx
return ptr

ref_aval, *idx_avals = ctx.avals_in
idx_avals = _pack_indices(idx_avals, indexed_dims)
if non_slice_idx:
(int_indexer_shape,) = {
i.shape for i in idx_avals if not isinstance(i, slice)
}
else:
int_indexer_shape = ()

idx = _pack_indices(non_slice_idx, indexed_dims)
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(avals_in[0].shape, idx)
for s, slc in zip(ref_aval.shape, idx)
)
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
ptr = _compute_pointers_from_indices(
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, None, None))
return _masked_load_lowering_rule(
ctx,
*args_flat,
args_tree=args_tree,
eviction_policy=None,
cache_modifier=None,
is_volatile=False,
)
val = tl.load(ptr, _builder=ctx.builder)
# `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type.
return val.to(ptr.dtype.element_ty, _builder=ctx.builder)


triton_lowering_rules[sp.get_p] = _get_lowering_rule
Expand Down Expand Up @@ -782,28 +786,25 @@ def _masked_load_lowering_rule(
def _swap_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
):
ref_block_info, *_ = ctx.block_infos
avals_in = ctx.avals_in
idx = _pack_indices(non_slice_idx, indexed_dims)
idx_avals = _pack_indices(avals_in[2:], indexed_dims)
ref_aval, _, *idx_avals = ctx.avals_in
idx_avals = _pack_indices(idx_avals, indexed_dims)
if non_slice_idx:
(int_indexer_shape,) = {
i.shape for i in idx_avals if not isinstance(i, slice)
}
else:
int_indexer_shape = ()

idx = _pack_indices(non_slice_idx, indexed_dims)
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(avals_in[0].shape, idx)
for s, slc in zip(ref_aval.shape, idx)
)
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
ptr = _compute_pointers_from_indices(
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, value, None))
return _masked_swap_lowering_rule(
ctx, *args_flat, args_tree=args_tree, eviction_policy=None
)
mask = None
old_value = tl.load(ptr, mask=mask, _builder=ctx.builder)
tl.store(ptr, value, mask=mask, _builder=ctx.builder)
return old_value


triton_lowering_rules[sp.swap_p] = _swap_lowering_rule
Expand Down

0 comments on commit ae2d6e2

Please sign in to comment.