Skip to content

Commit

Permalink
Inlined some Triton-specific abstractions
Browse files Browse the repository at this point in the history
* tensor is now just a container for dtype/shape with no extra methods;
* constexpr is not used;
* all APIs assume that the arguments are tensors and do no ->tensor conversion.

PiperOrigin-RevId: 603332629
  • Loading branch information
superbobry authored and jax authors committed Feb 2, 2024
1 parent 16636f9 commit 28ed607
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 288 deletions.
8 changes: 4 additions & 4 deletions jax/_src/pallas/primitives.py
Expand Up @@ -417,8 +417,8 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule)


def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="",
eviction_policy="", volatile=False) -> jax.Array:
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
eviction_policy=None, volatile=False) -> jax.Array:
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load")
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other))
return load_p.bind(
Expand All @@ -429,15 +429,15 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="",
is_volatile=volatile,
)

def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy="",
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
_function_name="swap") -> Any:
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, _function_name)
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
return swap_p.bind(
*args_flat, args_tree=args_tree, eviction_policy=eviction_policy
)

def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy="") -> None:
def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None) -> None:
_ = swap(x_ref_or_view, idx, val, mask=mask, eviction_policy=eviction_policy,
_function_name="store")

Expand Down
88 changes: 50 additions & 38 deletions jax/_src/pallas/triton/lowering.py
Expand Up @@ -148,7 +148,7 @@ def _eval_index_map(
return tuple(
i
if b is pallas_core.mapped
else i * tc._to_tensor(b, i.dtype)
else tc.semantic.mul(i, tc._to_tensor(b, i.dtype))
for i, b in zip(block_indices, block_mapping.block_shape)
)

Expand Down Expand Up @@ -228,7 +228,9 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping):
grid0 = tc.program_id(0)
for i, s in enumerate(collapse_dims):
out_idx = launch_grid_to_pallas_grid[i]
grid0, out_indices[out_idx] = grid0 // s, grid0 % s
s = tc._to_tensor(s, grid0.dtype)
out_indices[out_idx] = tc.semantic.mod(grid0, s)
grid0 = tc.semantic.floordiv(grid0, s)

for i in range(len(prog_id_dims)):
out_idx = launch_grid_to_pallas_grid[num_collapse + i]
Expand Down Expand Up @@ -567,12 +569,9 @@ def _pow_lowering_rule(ctx: TritonLoweringRuleContext, x: tc.tensor, y: tc.tenso
}

for prim, fn in _JAX_TO_TRITON_OTHER.items():

def rule(ctx: TritonLoweringRuleContext, *args, fn=fn, **kwargs):
kwargs = tree_util.tree_map(tc.constexpr, kwargs)
return fn(*args, **kwargs)

triton_lowering_rules[prim] = rule
triton_lowering_rules[prim] = lambda ctx, *args, fn=fn, **kwargs: fn(
*args, **kwargs
)


def _integer_pow(a, *, y):
Expand Down Expand Up @@ -670,7 +669,6 @@ def select_n_lowering_rule(ctx: TritonLoweringRuleContext, pred, a, b):
def _broadcast_in_dim_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, broadcast_dimensions, shape
):
shape = map(tc.constexpr, shape)
if not a.type.is_block():
return tc.broadcast_to(a, shape)
expand_dims = [i for i in range(len(shape)) if i not in broadcast_dimensions]
Expand Down Expand Up @@ -699,25 +697,25 @@ def _reshape_lowering_rule(
if dimensions is not None:
return ValueError("`dimensions` is not supported.")

dst_shape = map(tc.constexpr, ctx.avals_out[0].shape)
dst_shape = ctx.avals_out[0].shape
if not a.type.is_block():
assert all(dim_size.value == 1 for dim_size in dst_shape)
assert all(dim_size == 1 for dim_size in dst_shape)
return tc.broadcast_to(a, dst_shape)

# Expand-dims or reduce-sum to handle singleton dims as `tl.reshape` is not
# currently implemented.
i = 0
while a.shape != dst_shape:
dim_size = a.shape[i].value if i < len(a.shape) else None
dst_dim_size = dst_shape[i].value if i < len(dst_shape) else None
dim_size = a.shape[i] if i < len(a.shape) else None
dst_dim_size = dst_shape[i] if i < len(dst_shape) else None
if dim_size == dst_dim_size:
i += 1
elif dst_dim_size == 1:
a = tc.expand_dims(a, axis=i)
i += 1
elif dim_size == 1:
in_shape = tuple(d.value for d in a.shape)
out_shape = tuple(d.value for di, d in enumerate(a.shape) if di != i)
in_shape = a.shape
out_shape = tuple(d for di, d in enumerate(a.shape) if di != i)
reduce_ctx = ctx.replace(
avals_in=[ctx.avals_in[0].update(shape=in_shape)],
avals_out=[ctx.avals_in[0].update(shape=out_shape)],
Expand Down Expand Up @@ -773,8 +771,9 @@ def _compute_pointers_from_indices(
if isinstance(index.start, int):
ptr_dim_offset = tc.arange(index.start, index.start + index.size)
else:
ptr_dim_offset = tc.broadcast_to(index.start, [index.size])
ptr_dim_offset += tc.arange(0, index.size)
ptr_dim_offset = tc.semantic.add(
tc.broadcast_to(index.start, [index.size]), tc.arange(0, index.size)
)
# We need to add broadcastable dimensions for the advanced int indexing
# and for previous slices
num_left_expand_dims = len(int_indexer_shape) + other_shape_idx
Expand Down Expand Up @@ -808,15 +807,15 @@ def _compute_pointers_from_indices(
ndim = len(ptr_dim_offset.shape)
ptr_dim_offset = tc.expand_dims(ptr_dim_offset, ndim)
if start_offset is not None:
ptr_dim_offset += tc.broadcast_to(
tc.semantic.cast(start_offset, ptr_dim_offset.dtype),
ptr_dim_offset.shape,
start_offset = tc.semantic.cast(start_offset, ptr_dim_offset.dtype)
ptr_dim_offset = tc.semantic.add(
ptr_dim_offset, tc.broadcast_to(start_offset, ptr_dim_offset.shape)
)

stride_size = tc.broadcast_to(
tc._to_tensor(dim_stride, ptr_dim_offset.dtype), ptr_dim_offset.shape
)
bcast_indices.append(ptr_dim_offset * stride_size)
bcast_indices.append(tc.semantic.mul(ptr_dim_offset, stride_size))
block_shapes = [
() if not index.type.is_block() else tuple(index.type.get_block_shapes())
for index in bcast_indices
Expand All @@ -827,7 +826,9 @@ def _compute_pointers_from_indices(
else index
for index, block_shape in zip(bcast_indices, block_shapes)
]
return sum(bcast_indices, tc.broadcast_to(root_ptr, indexer_shape))
return functools.reduce(
tc.semantic.add, bcast_indices, tc.broadcast_to(root_ptr, indexer_shape)
)


def _pack_indices(non_slice_idx, indexed_dims):
Expand Down Expand Up @@ -880,16 +881,18 @@ def _masked_load_lowering_rule(
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape
)
if other is not None and mask is not None:
other = tc.broadcast_to(other, mask.shape)
val = tc.load(
ptr,
mask=mask,
other=other,
cache_modifier=cache_modifier,
volatile=is_volatile,
is_volatile=is_volatile,
eviction_policy=eviction_policy,
)
# `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type.
return val.to(ptr.dtype.element_ty)
return tc.semantic.cast(val, ptr.dtype.element_ty)


triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule
Expand Down Expand Up @@ -931,7 +934,9 @@ def _masked_swap_lowering_rule(
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape
)
other = None if mask is None else value
other = None
if value is not None and mask is not None:
other = tc.broadcast_to(value, mask.shape)
old_value = tc.load(ptr, mask=mask, other=other)
tc.store(
ptr,
Expand Down Expand Up @@ -1005,12 +1010,15 @@ def _dot_general_lowering(
if acc_dtype not in (tc.int32, tc.float16):
acc_dtype = tc.float32

return tc.dot(
a,
b,
allow_tf32=allow_tf32,
out_dtype=acc_dtype,
).to(out_dtype)
return tc.semantic.cast(
tc.dot(
a,
b,
allow_tf32=allow_tf32,
out_dtype=acc_dtype,
),
out_dtype,
)


triton_lowering_rules[lax.dot_general_p] = _dot_general_lowering
Expand Down Expand Up @@ -1071,7 +1079,7 @@ def _reduce_lowering(body, ctx: TritonLoweringRuleContext, a, *, axes):
# reduces, which seems necessary for correctness.
# TODO(bjp): Get rid of the double negation.
# https://github.com/openai/triton/issues/1776
a = -(-a)
a = tc.semantic.minus(tc.semantic.minus(a))
ctx = ctx.replace(avals_in=dst_avals)
axes = tuple(ax for ax in axes if ax != axis)
return _reduction_lowering(body, ctx, a, axes=axes)[0]
Expand Down Expand Up @@ -1100,9 +1108,9 @@ def _argreduce_lowering(
index = tc.arange(0, n)
if len(a.shape) > 1:
# Broadcast index across the non-reduced axes
expand_dims_index = [tc.constexpr(None)] * len(a.shape)
expand_dims_index[axis] = slice(None)
index = index[expand_dims_index]
for i in range(len(a.shape)):
if i != axis:
index = tc.expand_dims(index, i)
index = tc.broadcast_to(index, a.shape)
ctx = ctx.replace(
avals_in=[
Expand Down Expand Up @@ -1311,7 +1319,7 @@ def _scan_lowering_rule(
if has_loop_index:
lb, *args = args
lower_bound = lb.handle
ub = lb + tc._to_tensor(length, lb.dtype)
ub = tc.semantic.add(lb, tc._to_tensor(length, lb.dtype))
upper_bound = ub.handle
bound_type = ub.type
else:
Expand Down Expand Up @@ -1502,7 +1510,7 @@ def to_type(out_aval):
out_types = [to_type(out) for out in ctx.avals_out]
out_ir_types = [t.to_ir(ctx.builder) for t in out_types]

use_branch0 = index == 0
use_branch0 = tc.semantic.equal(index, tc._to_tensor(0, index.dtype))
# TODO(bjp): Switch to scf.index_switch once exposed in triton.cc
if_op = ctx.builder.create_if_op(out_ir_types, use_branch0.handle, with_else=True)
with ir.InsertionPoint.at_block_begin(if_op.then_block):
Expand All @@ -1516,7 +1524,11 @@ def to_type(out_aval):
# TODO(bjp): Instead of linear nest of 'if's, partition into halves.
if len(branches) > 2:
outs1 = _cond_lowering_rule(
ctx, index - 1, *args, branches=branches[1:], linear=linear
ctx,
tc.semantic.sub(index, tc._to_tensor(1, index.dtype)),
*args,
branches=branches[1:],
linear=linear,
)
else:
outs1 = lower_jaxpr_to_triton_ir(
Expand Down

0 comments on commit 28ed607

Please sign in to comment.