From ae2d6e21481dd6a4849e7974bb3dbc0db9774756 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 13 Oct 2023 01:59:00 -0700 Subject: [PATCH] [pallas:gpu] Implement `get` and `swap` using `load` and `masked_swap` lowering rules. PiperOrigin-RevId: 573146382 --- jax/_src/pallas/triton/lowering.py | 49 +++++++++++++++--------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index f208f762bfbe..e1e27fcf64c8 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -719,30 +719,34 @@ def _pack_indices(non_slice_idx, indexed_dims): def _get_lowering_rule( ctx: TritonLoweringRuleContext, ptr, *non_slice_idx, indexed_dims ): - ref_block_info, *_ = ctx.block_infos - idx = _pack_indices(non_slice_idx, indexed_dims) - avals_in = ctx.avals_in - idx_avals = _pack_indices(avals_in[1:], indexed_dims) if not isinstance(ptr.type, tl.pointer_type): - assert len(avals_in) == 1 + assert not non_slice_idx return ptr + + ref_aval, *idx_avals = ctx.avals_in + idx_avals = _pack_indices(idx_avals, indexed_dims) if non_slice_idx: (int_indexer_shape,) = { i.shape for i in idx_avals if not isinstance(i, slice) } else: int_indexer_shape = () + + idx = _pack_indices(non_slice_idx, indexed_dims) idx = tuple( primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc - for s, slc in zip(avals_in[0].shape, idx) + for s, slc in zip(ref_aval.shape, idx) ) - idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape) - ptr = _compute_pointers_from_indices( - ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder + idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape) + args_flat, args_tree = tree_util.tree_flatten((ptr, idx, None, None)) + return _masked_load_lowering_rule( + ctx, + *args_flat, + args_tree=args_tree, + eviction_policy=None, + cache_modifier=None, + is_volatile=False, ) - val = tl.load(ptr, _builder=ctx.builder) - # `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type. - return val.to(ptr.dtype.element_ty, _builder=ctx.builder) triton_lowering_rules[sp.get_p] = _get_lowering_rule @@ -782,28 +786,25 @@ def _masked_load_lowering_rule( def _swap_lowering_rule( ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims ): - ref_block_info, *_ = ctx.block_infos - avals_in = ctx.avals_in - idx = _pack_indices(non_slice_idx, indexed_dims) - idx_avals = _pack_indices(avals_in[2:], indexed_dims) + ref_aval, _, *idx_avals = ctx.avals_in + idx_avals = _pack_indices(idx_avals, indexed_dims) if non_slice_idx: (int_indexer_shape,) = { i.shape for i in idx_avals if not isinstance(i, slice) } else: int_indexer_shape = () + + idx = _pack_indices(non_slice_idx, indexed_dims) idx = tuple( primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc - for s, slc in zip(avals_in[0].shape, idx) + for s, slc in zip(ref_aval.shape, idx) ) - idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape) - ptr = _compute_pointers_from_indices( - ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder + idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape) + args_flat, args_tree = tree_util.tree_flatten((ptr, idx, value, None)) + return _masked_swap_lowering_rule( + ctx, *args_flat, args_tree=args_tree, eviction_policy=None ) - mask = None - old_value = tl.load(ptr, mask=mask, _builder=ctx.builder) - tl.store(ptr, value, mask=mask, _builder=ctx.builder) - return old_value triton_lowering_rules[sp.swap_p] = _swap_lowering_rule