Skip to content

Commit

Permalink
Inlined some Triton-specific abstractions
Browse files Browse the repository at this point in the history
* tensor is now just a container for dtype/shape with no extra methods;
* constexpr is not used;
* all APIs assume that the arguments are tensors and do no ->tensor conversion.

PiperOrigin-RevId: 603894852
  • Loading branch information
superbobry authored and jax authors committed Feb 3, 2024
1 parent a44c0e0 commit 06d3280
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 151 deletions.
63 changes: 34 additions & 29 deletions jax/_src/pallas/triton/lowering.py
Expand Up @@ -148,7 +148,7 @@ def _eval_index_map(
return tuple(
i
if b is pallas_core.mapped
else i * tc._to_tensor(b, i.dtype)
else tc.semantic.mul(i, tc._to_tensor(b, i.dtype))
for i, b in zip(block_indices, block_mapping.block_shape)
)

Expand Down Expand Up @@ -228,7 +228,9 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping):
grid0 = tc.program_id(0)
for i, s in enumerate(collapse_dims):
out_idx = launch_grid_to_pallas_grid[i]
grid0, out_indices[out_idx] = grid0 // s, grid0 % s
s = tc._to_tensor(s, grid0.dtype)
out_indices[out_idx] = tc.semantic.mod(grid0, s)
grid0 = tc.semantic.floordiv(grid0, s)

for i in range(len(prog_id_dims)):
out_idx = launch_grid_to_pallas_grid[num_collapse + i]
Expand Down Expand Up @@ -567,12 +569,9 @@ def _pow_lowering_rule(ctx: TritonLoweringRuleContext, x: tc.tensor, y: tc.tenso
}

for prim, fn in _JAX_TO_TRITON_OTHER.items():

def rule(ctx: TritonLoweringRuleContext, *args, fn=fn, **kwargs):
kwargs = tree_util.tree_map(tc.constexpr, kwargs)
return fn(*args, **kwargs)

triton_lowering_rules[prim] = rule
triton_lowering_rules[prim] = lambda ctx, *args, fn=fn, **kwargs: fn(
*args, **kwargs
)


def _integer_pow(a, *, y):
Expand Down Expand Up @@ -670,7 +669,6 @@ def select_n_lowering_rule(ctx: TritonLoweringRuleContext, pred, a, b):
def _broadcast_in_dim_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, broadcast_dimensions, shape
):
shape = map(tc.constexpr, shape)
if not a.type.is_block():
return tc.broadcast_to(a, shape)
expand_dims = [i for i in range(len(shape)) if i not in broadcast_dimensions]
Expand Down Expand Up @@ -699,25 +697,25 @@ def _reshape_lowering_rule(
if dimensions is not None:
return ValueError("`dimensions` is not supported.")

dst_shape = map(tc.constexpr, ctx.avals_out[0].shape)
dst_shape = ctx.avals_out[0].shape
if not a.type.is_block():
assert all(dim_size.value == 1 for dim_size in dst_shape)
assert all(dim_size == 1 for dim_size in dst_shape)
return tc.broadcast_to(a, dst_shape)

# Expand-dims or reduce-sum to handle singleton dims as `tl.reshape` is not
# currently implemented.
i = 0
while a.shape != dst_shape:
dim_size = a.shape[i].value if i < len(a.shape) else None
dst_dim_size = dst_shape[i].value if i < len(dst_shape) else None
dim_size = a.shape[i] if i < len(a.shape) else None
dst_dim_size = dst_shape[i] if i < len(dst_shape) else None
if dim_size == dst_dim_size:
i += 1
elif dst_dim_size == 1:
a = tc.expand_dims(a, axis=i)
i += 1
elif dim_size == 1:
in_shape = tuple(d.value for d in a.shape)
out_shape = tuple(d.value for di, d in enumerate(a.shape) if di != i)
in_shape = a.shape
out_shape = tuple(d for di, d in enumerate(a.shape) if di != i)
reduce_ctx = ctx.replace(
avals_in=[ctx.avals_in[0].update(shape=in_shape)],
avals_out=[ctx.avals_in[0].update(shape=out_shape)],
Expand Down Expand Up @@ -773,8 +771,9 @@ def _compute_pointers_from_indices(
if isinstance(index.start, int):
ptr_dim_offset = tc.arange(index.start, index.start + index.size)
else:
ptr_dim_offset = tc.broadcast_to(index.start, [index.size])
ptr_dim_offset += tc.arange(0, index.size)
ptr_dim_offset = tc.semantic.add(
tc.broadcast_to(index.start, [index.size]), tc.arange(0, index.size)
)
# We need to add broadcastable dimensions for the advanced int indexing
# and for previous slices
num_left_expand_dims = len(int_indexer_shape) + other_shape_idx
Expand Down Expand Up @@ -808,15 +807,15 @@ def _compute_pointers_from_indices(
ndim = len(ptr_dim_offset.shape)
ptr_dim_offset = tc.expand_dims(ptr_dim_offset, ndim)
if start_offset is not None:
ptr_dim_offset += tc.broadcast_to(
tc.semantic.cast(start_offset, ptr_dim_offset.dtype),
ptr_dim_offset.shape,
start_offset = tc.semantic.cast(start_offset, ptr_dim_offset.dtype)
ptr_dim_offset = tc.semantic.add(
ptr_dim_offset, tc.broadcast_to(start_offset, ptr_dim_offset.shape)
)

stride_size = tc.broadcast_to(
tc._to_tensor(dim_stride, ptr_dim_offset.dtype), ptr_dim_offset.shape
)
bcast_indices.append(ptr_dim_offset * stride_size)
bcast_indices.append(tc.semantic.mul(ptr_dim_offset, stride_size))
block_shapes = [
() if not index.type.is_block() else tuple(index.type.get_block_shapes())
for index in bcast_indices
Expand All @@ -827,7 +826,9 @@ def _compute_pointers_from_indices(
else index
for index, block_shape in zip(bcast_indices, block_shapes)
]
return sum(bcast_indices, tc.broadcast_to(root_ptr, indexer_shape))
return functools.reduce(
tc.semantic.add, bcast_indices, tc.broadcast_to(root_ptr, indexer_shape)
)


def _pack_indices(non_slice_idx, indexed_dims):
Expand Down Expand Up @@ -1078,7 +1079,7 @@ def _reduce_lowering(body, ctx: TritonLoweringRuleContext, a, *, axes):
# reduces, which seems necessary for correctness.
# TODO(bjp): Get rid of the double negation.
# https://github.com/openai/triton/issues/1776
a = -(-a)
a = tc.semantic.minus(tc.semantic.minus(a))
ctx = ctx.replace(avals_in=dst_avals)
axes = tuple(ax for ax in axes if ax != axis)
return _reduction_lowering(body, ctx, a, axes=axes)[0]
Expand Down Expand Up @@ -1107,9 +1108,9 @@ def _argreduce_lowering(
index = tc.arange(0, n)
if len(a.shape) > 1:
# Broadcast index across the non-reduced axes
expand_dims_index = [tc.constexpr(None)] * len(a.shape)
expand_dims_index[axis] = slice(None)
index = index[expand_dims_index]
for i in range(len(a.shape)):
if i != axis:
index = tc.expand_dims(index, i)
index = tc.broadcast_to(index, a.shape)
ctx = ctx.replace(
avals_in=[
Expand Down Expand Up @@ -1318,7 +1319,7 @@ def _scan_lowering_rule(
if has_loop_index:
lb, *args = args
lower_bound = lb.handle
ub = lb + tc._to_tensor(length, lb.dtype)
ub = tc.semantic.add(lb, tc._to_tensor(length, lb.dtype))
upper_bound = ub.handle
bound_type = ub.type
else:
Expand Down Expand Up @@ -1509,7 +1510,7 @@ def to_type(out_aval):
out_types = [to_type(out) for out in ctx.avals_out]
out_ir_types = [t.to_ir(ctx.builder) for t in out_types]

use_branch0 = index == 0
use_branch0 = tc.semantic.equal(index, tc._to_tensor(0, index.dtype))
# TODO(bjp): Switch to scf.index_switch once exposed in triton.cc
if_op = ctx.builder.create_if_op(out_ir_types, use_branch0.handle, with_else=True)
with ir.InsertionPoint.at_block_begin(if_op.then_block):
Expand All @@ -1523,7 +1524,11 @@ def to_type(out_aval):
# TODO(bjp): Instead of linear nest of 'if's, partition into halves.
if len(branches) > 2:
outs1 = _cond_lowering_rule(
ctx, index - 1, *args, branches=branches[1:], linear=linear
ctx,
tc.semantic.sub(index, tc._to_tensor(1, index.dtype)),
*args,
branches=branches[1:],
linear=linear,
)
else:
outs1 = lower_jaxpr_to_triton_ir(
Expand Down
147 changes: 25 additions & 122 deletions jaxlib/triton/compat.py
Expand Up @@ -622,8 +622,13 @@ def _infer_reduce_op_return_types(

dtype = tl.core.dtype

block_type = tl.core.block_type
function_type = tl.core.function_type

class block_type(tl.core.block_type):
def __init__(self, element_ty: dtype, shape: list[Any]) -> Any:
super().__init__(element_ty, shape)
self.shape = tuple(self.shape)


pointer_type = tl.core.pointer_type

void = tl.core.void
Expand All @@ -639,13 +644,10 @@ def _infer_reduce_op_return_types(
uint64 = tl.core.uint64


def _bool_block_like(v: tensor) -> block_type:
def _bool_block_like(v: tensor) -> dtype:
if not v.type.is_block():
return int1
return block_type(int1, v.type.shape)


constexpr = tl.core.constexpr
return block_type(int1, v.shape)


def _to_tensor(v, dtype: dtype | None = None) -> "tensor":
Expand All @@ -660,72 +662,16 @@ def _to_tensor(v, dtype: dtype | None = None) -> "tensor":
return tensor(t.handle, t.type)


class tensor(tl.core.tensor):

def __add__(self, other):
return semantic.add(self, _to_tensor(other, self.dtype))

def __radd__(self, other):
return self + other

def __sub__(self, other):
return semantic.sub(self, _to_tensor(other, self.dtype))

def __rsub__(self, other):
return semantic.sub(_to_tensor(other, self.dtype), self)

def __mul__(self, other):
return semantic.mul(self, _to_tensor(other, self.dtype))

def __rmul__(self, other):
return self * other

def __truediv__(self, other):
return semantic.truediv(self, _to_tensor(other, self.dtype))

def __rtruediv__(self, other):
return semantic.truediv(_to_tensor(other, self.dtype), self)

def __floordiv__(self, other):
return semantic.floordiv(self, _to_tensor(other, self.dtype))

def __rfloordiv__(self, other):
return semantic.floordiv(_to_tensor(other, self.dtype), self)

def __mod__(self, other):
return semantic.mod(self, _to_tensor(other, self.dtype))
class tensor:

def __rmod__(self, other):
return semantic.mod(_to_tensor(other, self.dtype), self)
def __init__(self, handle: ir.Value, type: dtype):
self.handle = handle
self.shape = tuple(type.shape) if type.is_block() else ()
self.type = type
self.dtype = type.scalar

def __neg__(self):
return semantic.minus(self)

def __invert__(self):
return semantic.invert(self)

# TODO(slebedev): Override other comparison methods.
def __eq__(self, other):
return semantic.equal(self, _to_tensor(other, self.dtype))

def __getitem__(self, slices) -> tensor:
if isinstance(slices, (slice, constexpr)):
slices = [slices]
t = self
for axis, s in enumerate(slices):
if s is None or isinstance(s, constexpr) and s.value is None:
t = expand_dims(t, axis)
elif (
isinstance(s, slice)
and s.start is s.stop is s.step is None
):
pass
else:
raise IndexError(f"unsupported tensor index: {s}")
return t

def to(self, *args, **kwargs) -> tensor:
raise NotImplementedError
def __str__(self) -> str:
return f"{self.dtype}[{', '.join(map(str, self.shape))}]"


def program_id(axis: int) -> tensor:
Expand Down Expand Up @@ -901,13 +847,11 @@ def arange(start: int, end: int) -> tensor:
return tensor(tt_dialect.make_range(ir_ty, start, end), ty)


def broadcast_to(x: object, shape: Sequence[int | constexpr]) -> tensor:
x = _to_tensor(x)
def broadcast_to(x: tensor, shape: Sequence[int]) -> tensor:
if not x.type.is_block():
return splat(x, shape)
elif x.shape == shape:
return x
shape = [dim.__index__() for dim in shape]
x_ir_type = ir.RankedTensorType(x.handle.type)
result_ir_type = ir.RankedTensorType.get(
shape, x_ir_type.element_type, x_ir_type.encoding
Expand All @@ -918,22 +862,19 @@ def broadcast_to(x: object, shape: Sequence[int | constexpr]) -> tensor:
)


def splat(x: object, shape: Sequence[int | constexpr]) -> tensor:
x = _to_tensor(x)
def splat(x: tensor, shape: Sequence[int]) -> tensor:
if x.type.is_block():
raise ValueError("cannot splat a block tensor")
if len(shape) == 0:
return x
shape = [dim.__index__() for dim in shape]
result_ir_type = ir.RankedTensorType.get(shape, x.handle.type)
return tensor(
tt_dialect.splat(result_ir_type, x.handle), block_type(x.dtype, shape)
)


def expand_dims(x: object, axis: int) -> tensor:
x = _to_tensor(x)
dst_shape = [dim.__index__() for dim in x.shape]
def expand_dims(x: tensor, axis: int) -> tensor:
dst_shape = list(x.shape)
dst_shape.insert(axis, 1)
if not x.type.is_block():
return splat(input, dst_shape)
Expand Down Expand Up @@ -967,9 +908,7 @@ def dot(
max_num_imprecise_acc: int | None = None,
out_dtype: dtype = float32,
) -> tensor:
x_dims = [dim.__index__() for dim in x.shape]
y_dims = [dim.__index__() for dim in y.shape]
if min(*x_dims, *y_dims) < 16:
if min(*x.shape, *y.shape) < 16:
raise ValueError("all dimensions of x and y must be >= 16 ")
if out_dtype.is_bf16():
raise ValueError(f"out_dtype={out_dtype} is unsupported")
Expand All @@ -993,8 +932,8 @@ def dot(
if element_type != out_dtype:
raise TypeError(f"out_dtype={out_dtype} does not match element type {element_type}")

m, _ = x_dims
_, n = y_dims
m, _ = x.shape
_, n = y.shape
result_type = block_type(element_type, [m, n])

if acc is None:
Expand Down Expand Up @@ -1096,8 +1035,7 @@ def atomic_add(
return _atomic_rmw(op, ptr, val, mask, semantic, sync_scope)


def abs(x: object) -> tensor:
x = _to_tensor(x)
def abs(x: tensor) -> tensor:
dtype = x.dtype
if dtype.is_floating():
return tensor(math_dialect.absf(x.handle), x.type)
Expand All @@ -1109,41 +1047,6 @@ def abs(x: object) -> tensor:
raise ValueError(f"unsupported dtype: {dtype}")


def exp(x: object) -> tensor:
x = _to_tensor(x)
if x.dtype != float32 and x.dtype != float64:
raise ValueError(f"unsupported dtype: {x.dtype}")
return tensor(math_dialect.exp(x.handle), x.type)


def log(x: object) -> tensor:
x = _to_tensor(x)
if x.dtype != float32 and x.dtype != float64:
raise ValueError(f"unsupported dtype: {x.dtype}")
return tensor(math_dialect.log(x.handle), x.type)


def sqrt(x: object) -> tensor:
x = _to_tensor(x)
if x.dtype != float32 and x.dtype != float64:
raise ValueError(f"unsupported dtype: {x.dtype}")
return tensor(math_dialect.sqrt(x.handle), x.type)


def sin(x: object) -> tensor:
x = _to_tensor(x)
if x.dtype != float32 and x.dtype != float64:
raise ValueError(f"unsupported dtype: {x.dtype}")
return tensor(math_dialect.sin(x.handle), x.type)


def cos(x: object) -> tensor:
x = _to_tensor(x)
if x.dtype != float32 and x.dtype != float64:
raise ValueError(f"unsupported dtype: {x.dtype}")
return tensor(math_dialect.cos(x.handle), x.type)


def multiple_of(x: tensor, values: Sequence[int]) -> tensor:
assert max(1, len(x.shape)) == len(values)
set_attr(
Expand Down

0 comments on commit 06d3280

Please sign in to comment.