From 28ed6078208beea0744c7e9e4ed0465cbc632053 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 1 Feb 2024 04:41:07 -0800 Subject: [PATCH] Inlined some Triton-specific abstractions * tensor is now just a container for dtype/shape with no extra methods; * constexpr is not used; * all APIs assume that the arguments are tensors and do no ->tensor conversion. PiperOrigin-RevId: 603332629 --- jax/_src/pallas/primitives.py | 8 +- jax/_src/pallas/triton/lowering.py | 88 +++--- jaxlib/triton/compat.py | 486 ++++++++++++++--------------- tests/pallas/pallas_test.py | 9 + 4 files changed, 303 insertions(+), 288 deletions(-) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 603370d2329f..4acde8462ecb 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -417,8 +417,8 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule) -def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="", - eviction_policy="", volatile=False) -> jax.Array: +def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None, + eviction_policy=None, volatile=False) -> jax.Array: x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load") args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other)) return load_p.bind( @@ -429,7 +429,7 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="", is_volatile=volatile, ) -def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy="", +def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, _function_name="swap") -> Any: x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, _function_name) args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask)) @@ -437,7 +437,7 @@ def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy="", *args_flat, args_tree=args_tree, eviction_policy=eviction_policy ) -def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy="") -> None: +def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None) -> None: _ = swap(x_ref_or_view, idx, val, mask=mask, eviction_policy=eviction_policy, _function_name="store") diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 39a24c502c8a..4c92ee3f03e7 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -148,7 +148,7 @@ def _eval_index_map( return tuple( i if b is pallas_core.mapped - else i * tc._to_tensor(b, i.dtype) + else tc.semantic.mul(i, tc._to_tensor(b, i.dtype)) for i, b in zip(block_indices, block_mapping.block_shape) ) @@ -228,7 +228,9 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): grid0 = tc.program_id(0) for i, s in enumerate(collapse_dims): out_idx = launch_grid_to_pallas_grid[i] - grid0, out_indices[out_idx] = grid0 // s, grid0 % s + s = tc._to_tensor(s, grid0.dtype) + out_indices[out_idx] = tc.semantic.mod(grid0, s) + grid0 = tc.semantic.floordiv(grid0, s) for i in range(len(prog_id_dims)): out_idx = launch_grid_to_pallas_grid[num_collapse + i] @@ -567,12 +569,9 @@ def _pow_lowering_rule(ctx: TritonLoweringRuleContext, x: tc.tensor, y: tc.tenso } for prim, fn in _JAX_TO_TRITON_OTHER.items(): - - def rule(ctx: TritonLoweringRuleContext, *args, fn=fn, **kwargs): - kwargs = tree_util.tree_map(tc.constexpr, kwargs) - return fn(*args, **kwargs) - - triton_lowering_rules[prim] = rule + triton_lowering_rules[prim] = lambda ctx, *args, fn=fn, **kwargs: fn( + *args, **kwargs + ) def _integer_pow(a, *, y): @@ -670,7 +669,6 @@ def select_n_lowering_rule(ctx: TritonLoweringRuleContext, pred, a, b): def _broadcast_in_dim_lowering_rule( ctx: TritonLoweringRuleContext, a, *, broadcast_dimensions, shape ): - shape = map(tc.constexpr, shape) if not a.type.is_block(): return tc.broadcast_to(a, shape) expand_dims = [i for i in range(len(shape)) if i not in broadcast_dimensions] @@ -699,25 +697,25 @@ def _reshape_lowering_rule( if dimensions is not None: return ValueError("`dimensions` is not supported.") - dst_shape = map(tc.constexpr, ctx.avals_out[0].shape) + dst_shape = ctx.avals_out[0].shape if not a.type.is_block(): - assert all(dim_size.value == 1 for dim_size in dst_shape) + assert all(dim_size == 1 for dim_size in dst_shape) return tc.broadcast_to(a, dst_shape) # Expand-dims or reduce-sum to handle singleton dims as `tl.reshape` is not # currently implemented. i = 0 while a.shape != dst_shape: - dim_size = a.shape[i].value if i < len(a.shape) else None - dst_dim_size = dst_shape[i].value if i < len(dst_shape) else None + dim_size = a.shape[i] if i < len(a.shape) else None + dst_dim_size = dst_shape[i] if i < len(dst_shape) else None if dim_size == dst_dim_size: i += 1 elif dst_dim_size == 1: a = tc.expand_dims(a, axis=i) i += 1 elif dim_size == 1: - in_shape = tuple(d.value for d in a.shape) - out_shape = tuple(d.value for di, d in enumerate(a.shape) if di != i) + in_shape = a.shape + out_shape = tuple(d for di, d in enumerate(a.shape) if di != i) reduce_ctx = ctx.replace( avals_in=[ctx.avals_in[0].update(shape=in_shape)], avals_out=[ctx.avals_in[0].update(shape=out_shape)], @@ -773,8 +771,9 @@ def _compute_pointers_from_indices( if isinstance(index.start, int): ptr_dim_offset = tc.arange(index.start, index.start + index.size) else: - ptr_dim_offset = tc.broadcast_to(index.start, [index.size]) - ptr_dim_offset += tc.arange(0, index.size) + ptr_dim_offset = tc.semantic.add( + tc.broadcast_to(index.start, [index.size]), tc.arange(0, index.size) + ) # We need to add broadcastable dimensions for the advanced int indexing # and for previous slices num_left_expand_dims = len(int_indexer_shape) + other_shape_idx @@ -808,15 +807,15 @@ def _compute_pointers_from_indices( ndim = len(ptr_dim_offset.shape) ptr_dim_offset = tc.expand_dims(ptr_dim_offset, ndim) if start_offset is not None: - ptr_dim_offset += tc.broadcast_to( - tc.semantic.cast(start_offset, ptr_dim_offset.dtype), - ptr_dim_offset.shape, + start_offset = tc.semantic.cast(start_offset, ptr_dim_offset.dtype) + ptr_dim_offset = tc.semantic.add( + ptr_dim_offset, tc.broadcast_to(start_offset, ptr_dim_offset.shape) ) stride_size = tc.broadcast_to( tc._to_tensor(dim_stride, ptr_dim_offset.dtype), ptr_dim_offset.shape ) - bcast_indices.append(ptr_dim_offset * stride_size) + bcast_indices.append(tc.semantic.mul(ptr_dim_offset, stride_size)) block_shapes = [ () if not index.type.is_block() else tuple(index.type.get_block_shapes()) for index in bcast_indices @@ -827,7 +826,9 @@ def _compute_pointers_from_indices( else index for index, block_shape in zip(bcast_indices, block_shapes) ] - return sum(bcast_indices, tc.broadcast_to(root_ptr, indexer_shape)) + return functools.reduce( + tc.semantic.add, bcast_indices, tc.broadcast_to(root_ptr, indexer_shape) + ) def _pack_indices(non_slice_idx, indexed_dims): @@ -880,16 +881,18 @@ def _masked_load_lowering_rule( ptr = _compute_pointers_from_indices( ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape ) + if other is not None and mask is not None: + other = tc.broadcast_to(other, mask.shape) val = tc.load( ptr, mask=mask, other=other, cache_modifier=cache_modifier, - volatile=is_volatile, + is_volatile=is_volatile, eviction_policy=eviction_policy, ) # `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type. - return val.to(ptr.dtype.element_ty) + return tc.semantic.cast(val, ptr.dtype.element_ty) triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule @@ -931,7 +934,9 @@ def _masked_swap_lowering_rule( ptr = _compute_pointers_from_indices( ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape ) - other = None if mask is None else value + other = None + if value is not None and mask is not None: + other = tc.broadcast_to(value, mask.shape) old_value = tc.load(ptr, mask=mask, other=other) tc.store( ptr, @@ -1005,12 +1010,15 @@ def _dot_general_lowering( if acc_dtype not in (tc.int32, tc.float16): acc_dtype = tc.float32 - return tc.dot( - a, - b, - allow_tf32=allow_tf32, - out_dtype=acc_dtype, - ).to(out_dtype) + return tc.semantic.cast( + tc.dot( + a, + b, + allow_tf32=allow_tf32, + out_dtype=acc_dtype, + ), + out_dtype, + ) triton_lowering_rules[lax.dot_general_p] = _dot_general_lowering @@ -1071,7 +1079,7 @@ def _reduce_lowering(body, ctx: TritonLoweringRuleContext, a, *, axes): # reduces, which seems necessary for correctness. # TODO(bjp): Get rid of the double negation. # https://github.com/openai/triton/issues/1776 - a = -(-a) + a = tc.semantic.minus(tc.semantic.minus(a)) ctx = ctx.replace(avals_in=dst_avals) axes = tuple(ax for ax in axes if ax != axis) return _reduction_lowering(body, ctx, a, axes=axes)[0] @@ -1100,9 +1108,9 @@ def _argreduce_lowering( index = tc.arange(0, n) if len(a.shape) > 1: # Broadcast index across the non-reduced axes - expand_dims_index = [tc.constexpr(None)] * len(a.shape) - expand_dims_index[axis] = slice(None) - index = index[expand_dims_index] + for i in range(len(a.shape)): + if i != axis: + index = tc.expand_dims(index, i) index = tc.broadcast_to(index, a.shape) ctx = ctx.replace( avals_in=[ @@ -1311,7 +1319,7 @@ def _scan_lowering_rule( if has_loop_index: lb, *args = args lower_bound = lb.handle - ub = lb + tc._to_tensor(length, lb.dtype) + ub = tc.semantic.add(lb, tc._to_tensor(length, lb.dtype)) upper_bound = ub.handle bound_type = ub.type else: @@ -1502,7 +1510,7 @@ def to_type(out_aval): out_types = [to_type(out) for out in ctx.avals_out] out_ir_types = [t.to_ir(ctx.builder) for t in out_types] - use_branch0 = index == 0 + use_branch0 = tc.semantic.equal(index, tc._to_tensor(0, index.dtype)) # TODO(bjp): Switch to scf.index_switch once exposed in triton.cc if_op = ctx.builder.create_if_op(out_ir_types, use_branch0.handle, with_else=True) with ir.InsertionPoint.at_block_begin(if_op.then_block): @@ -1516,7 +1524,11 @@ def to_type(out_aval): # TODO(bjp): Instead of linear nest of 'if's, partition into halves. if len(branches) > 2: outs1 = _cond_lowering_rule( - ctx, index - 1, *args, branches=branches[1:], linear=linear + ctx, + tc.semantic.sub(index, tc._to_tensor(1, index.dtype)), + *args, + branches=branches[1:], + linear=linear, ) else: outs1 = lower_jaxpr_to_triton_ir( diff --git a/jaxlib/triton/compat.py b/jaxlib/triton/compat.py index e55ae8d61f99..d3a496a9e1ea 100644 --- a/jaxlib/triton/compat.py +++ b/jaxlib/triton/compat.py @@ -20,8 +20,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from functools import partial, wraps +from functools import partial import threading +from typing import Any from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith as arith_dialect @@ -453,120 +454,6 @@ def create_xor(self, lhs: ir.Value, rhs: ir.Value) -> ir.Value: def create_or(self, lhs: ir.Value, rhs: ir.Value) -> ir.Value: return arith_dialect.ori(lhs, rhs) - def create_load( - self, - ptr: ir.Value, - cache_modifier: tt_dialect.CacheModifier, - eviction_policy: tt_dialect.EvictionPolicy, - is_volatile: bool, - ) -> ir.Value: - if ir.RankedTensorType.isinstance(ptr.type): - ptr_type = ir.RankedTensorType(ptr.type) - element_type = tt_dialect.PointerType(ptr_type.element_type) - result_type = ir.RankedTensorType.get( - ptr_type.shape, - element_type.pointee_type, - ptr_type.encoding, - ) - else: - ptr_type = tt_dialect.PointerType(ptr.type) - result_type = ptr_type.pointee_type - return tt_dialect.load( - result_type, ptr, cache_modifier, eviction_policy, is_volatile - ) - - def create_store( - self, - ptr: ir.Value, - value: ir.Value, - cache_modifier: tt_dialect.CacheModifier, - eviction_policy: tt_dialect.EvictionPolicy, - ) -> ir.Value: - return tt_dialect.store( - ptr, value, cache=cache_modifier, evict=eviction_policy - ) - - def create_tensor_pointer_load( - self, - ptr: ir.Value, - boundary_check: Sequence[int], - padding_option: Sequence[tt_dialect.PaddingOption], - cache_modifier: tt_dialect.CacheModifier, - eviction_policy: tt_dialect.EvictionPolicy, - is_volatile: bool, - ) -> ir.Value: - return tt_dialect.load( - ptr.type, - ptr, - cache_modifier, - eviction_policy, - is_volatile, - boundary_check=boundary_check, - padding=padding_option, - ) - - def create_tensor_pointer_store( - self, - ptr: ir.Value, - value: ir.Value, - boundary_check: Sequence[int], - cache_modifier: tt_dialect.CacheModifier, - eviction_policy: tt_dialect.EvictionPolicy, - ) -> ir.Value: - return tt_dialect.store( - ptr, - value, - boundary_check=boundary_check, - cache=cache_modifier, - evict=eviction_policy, - ) - - def create_masked_load( - self, - ptr: ir.Value, - mask: ir.Value, - other: ir.Value | None, - cache_modifier: tt_dialect.CacheModifier, - eviction_policy: tt_dialect.EvictionPolicy, - is_volatile: bool, - ) -> ir.Value: - if ir.RankedTensorType.isinstance(ptr.type): - ptr_type = ir.RankedTensorType(ptr.type) - element_type = tt_dialect.PointerType(ptr_type.element_type) - result_type = ir.RankedTensorType.get( - ptr_type.shape, - element_type.pointee_type, - ptr_type.encoding, - ) - else: - ptr_type = tt_dialect.PointerType(ptr.type) - result_type = ptr_type.pointee_type - return tt_dialect.load( - result_type, - ptr, - cache_modifier, - eviction_policy, - is_volatile, - mask=mask, - other=other, - ) - - def create_masked_store( - self, - ptr: ir.Value, - value: ir.Value, - mask: ir.Value, - cache_modifier: tt_dialect.CacheModifier, - eviction_policy: tt_dialect.EvictionPolicy, - ) -> ir.Value: - return tt_dialect.store( - ptr, - value, - mask=mask, - cache=cache_modifier, - evict=eviction_policy, - ) - def create_cat(self, lhs: ir.Value, rhs: ir.Value) -> ir.Value: assert ir.RankedTensorType.isinstance(lhs.type) assert ir.RankedTensorType.isinstance(rhs.type) @@ -735,42 +622,32 @@ def _infer_reduce_op_return_types( dtype = tl.core.dtype -block_type = tl.core.block_type -function_type = tl.core.function_type + +class block_type(tl.core.block_type): + def __init__(self, element_ty: dtype, shape: list[Any]) -> Any: + super().__init__(element_ty, shape) + self.shape = tuple(self.shape) + + pointer_type = tl.core.pointer_type +void = tl.core.void bfloat16 = tl.core.bfloat16 float16 = tl.core.float16 float32 = tl.core.float32 float64 = tl.core.float64 int1 = tl.core.int1 +int8 = tl.core.int8 int32 = tl.core.int32 int64 = tl.core.int64 uint32 = tl.core.uint32 uint64 = tl.core.uint64 -def _bool_block_like(v: tensor) -> block_type: +def _bool_block_like(v: tensor) -> dtype: if not v.type.is_block(): return int1 - return block_type(int1, v.type.shape) - - -def wrap_with_builder(fn): - @wraps(fn) - def inner(*args, **kwargs): - if tl.core.is_builtin(fn): - v = fn(*args, **kwargs, _builder=builder.current) - else: - v = fn(*args, **kwargs, builder=builder.current) - if isinstance(v, tl.core.tensor): - return _to_tensor(v) - return v - - return inner - - -constexpr = tl.core.constexpr + return block_type(int1, v.shape) def _to_tensor(v, dtype: dtype | None = None) -> "tensor": @@ -785,81 +662,175 @@ def _to_tensor(v, dtype: dtype | None = None) -> "tensor": return tensor(t.handle, t.type) -class tensor(tl.core.tensor): - - def __add__(self, other): - return semantic.add(self, _to_tensor(other, self.dtype)) - - def __radd__(self, other): - return self + other - - def __sub__(self, other): - return semantic.sub(self, _to_tensor(other, self.dtype)) +class tensor: - def __rsub__(self, other): - return semantic.sub(_to_tensor(other, self.dtype), self) + def __init__(self, handle: ir.Value, type: dtype): + self.handle = handle + self.shape = tuple(type.shape) if type.is_block() else () + self.type = type + self.dtype = type.scalar - def __mul__(self, other): - return semantic.mul(self, _to_tensor(other, self.dtype)) + def __str__(self) -> str: + return f"{self.dtype}[{', '.join(map(str, self.shape))}]" - def __rmul__(self, other): - return self * other - def __truediv__(self, other): - return semantic.truediv(self, _to_tensor(other, self.dtype)) - - def __rtruediv__(self, other): - return semantic.truediv(_to_tensor(other, self.dtype), self) - - def __floordiv__(self, other): - return semantic.floordiv(self, _to_tensor(other, self.dtype)) - - def __rfloordiv__(self, other): - return semantic.floordiv(_to_tensor(other, self.dtype), self) - - def __mod__(self, other): - return semantic.mod(self, _to_tensor(other, self.dtype)) - - def __rmod__(self, other): - return semantic.mod(_to_tensor(other, self.dtype), self) +def program_id(axis: int) -> tensor: + if axis not in range(3): + raise ValueError(f"axis must be in [0, 3), but got: {axis}") + return tensor(tt_dialect.get_program_id(axis), int32) - def __neg__(self): - return semantic.minus(self) - def __invert__(self): - return semantic.invert(self) +_STR_TO_EVICTION_POLICY = {str(e): e for e in tt_dialect.EvictionPolicy} +_STR_TO_CACHE_MODIFIER = {str(c): c for c in tt_dialect.CacheModifier} - # TODO(slebedev): Override other comparison methods. - def __eq__(self, other): - return semantic.equal(self, _to_tensor(other, self.dtype)) - def __getitem__(self, slices) -> tensor: - if isinstance(slices, (slice, constexpr)): - slices = [slices] - t = self - for axis, s in enumerate(slices): - if s is None or isinstance(s, constexpr) and s.value is None: - t = expand_dims(t, axis) - elif ( - isinstance(s, slice) - and s.start is s.stop is s.step is None - ): - pass - else: - raise IndexError(f"unsupported tensor index: {s}") - return t +def _infer_load_return_type(ptr: ir.Value) -> ir.Type: + if ir.RankedTensorType.isinstance(ptr.type): + ptr_type = ir.RankedTensorType(ptr.type) + element_type = tt_dialect.PointerType(ptr_type.element_type) + return ir.RankedTensorType.get( + ptr_type.shape, + element_type.pointee_type, + ptr_type.encoding, + ) + else: + ptr_type = tt_dialect.PointerType(ptr.type) + return ptr_type.pointee_type - to = wrap_with_builder(tl.tensor.to) +def load( + ptr: tensor, + mask: tensor | None = None, + other: tensor | None = None, + *, + cache_modifier: str | None = None, + eviction_policy: str | None = None, + is_volatile: bool = False, +) -> tensor: + if cache_modifier is None: + cache_modifier = tt_dialect.CacheModifier.NONE + elif cache_modifier == ".ca" or cache_modifier == ".cg": + cache_modifier = _STR_TO_CACHE_MODIFIER[cache_modifier] + else: + raise ValueError(f"unsupported cache modifier: {cache_modifier}") + if eviction_policy is None: + eviction_policy = tt_dialect.EvictionPolicy.NORMAL + else: + try: + eviction_policy = _STR_TO_EVICTION_POLICY[eviction_policy] + except KeyError: + raise ValueError( + f"unsupported eviction policy: {eviction_policy}" + ) from None + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # TODO(slebedev): Support load from a block pointer. + raise NotImplementedError("loading from a block pointer is not supported") + if not ptr.dtype.is_ptr(): + raise ValueError(f"unsupported pointer dtype: {ptr.dtype}") + if other is not None: + if mask is None: + raise ValueError("other requires mask to be provided") + assert mask.shape == other.shape == ptr.shape, ( + mask.shape, + other.shape, + ptr.shape, + ) + elif mask is not None: + assert mask.shape == ptr.shape + if not ptr.type.is_block(): + if other is not None and other.type.is_block(): + raise ValueError("other cannot be a block if pointer is not a block") + if mask is not None and mask.type.is_block(): + raise ValueError("mask cannot be a block if pointer is not a block") + + ptr_type = ptr.dtype + element_type = ptr_type.element_ty + + if element_type == int1: + # TODO(slebedev): Cast the result back to int1 before returning. + element_type = int8 + ptr_type = pointer_type(element_type, ptr_type.address_space) + ptr = semantic.cast(ptr, ptr_type) + + if other is not None: + other = semantic.cast(other, element_type) + + result_handle = tt_dialect.load( + _infer_load_return_type(ptr.handle), + ptr.handle, + mask=mask.handle if mask is not None else None, + other=other.handle if other is not None else None, + cache=cache_modifier, + evict=eviction_policy, + is_volatile=is_volatile, + ) + if ptr.type.is_block(): + return tensor(result_handle, block_type(element_type, ptr.type.shape)) + else: + return tensor(result_handle, element_type) -def program_id(axis: int) -> tensor: - if axis not in range(3): - raise ValueError(f"axis must be in [0, 3), but got: {axis}") - return tensor(tt_dialect.get_program_id(axis), int32) +def store( + ptr: tensor, + value: tensor, + mask: tensor | None = None, + *, + cache_modifier: str | None = None, + eviction_policy: str | None = None, +) -> tensor: + if cache_modifier is None: + cache_modifier = tt_dialect.CacheModifier.NONE + elif cache_modifier != ".ca": + cache_modifier = _STR_TO_CACHE_MODIFIER[cache_modifier] + else: + raise ValueError(f"unsupported cache modifier: {cache_modifier}") + if eviction_policy is None: + eviction_policy = tt_dialect.EvictionPolicy.NORMAL + else: + try: + eviction_policy = _STR_TO_EVICTION_POLICY[eviction_policy] + except KeyError: + raise ValueError( + f"unsupported eviction policy: {eviction_policy}" + ) from None + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # TODO(slebedev): Support load from a block pointer. + raise NotImplementedError("storing to a block pointer is not supported") + + if not ptr.dtype.is_ptr(): + raise ValueError(f"unsupported pointer dtype: {ptr.dtype}") + assert value.shape == ptr.shape + if mask is not None: + assert mask.shape == ptr.shape + if not ptr.type.is_block(): + if value.type.is_block(): + raise ValueError("other cannot be a block if pointer is not a block") + if mask is not None and mask.type.is_block(): + raise ValueError("mask cannot be a block if pointer is not a block") + + ptr_type = ptr.dtype + element_type = ptr_type.element_ty + + if element_type == int1: + # TODO(slebedev): Cast the result back to int1 before returning. + element_type = int8 + ptr_type = pointer_type(element_type, ptr_type.address_space) + ptr = semantic.cast(ptr, ptr_type) + + value = semantic.cast(value, element_type) -load = wrap_with_builder(tl.core.load) -store = wrap_with_builder(tl.core.store) + return tensor( + tt_dialect.store( + ptr.handle, + value.handle, + mask=mask.handle if mask is not None else None, + cache=cache_modifier, + evict=eviction_policy, + ), + void, + ) def arange(start: int, end: int) -> tensor: @@ -876,13 +847,11 @@ def arange(start: int, end: int) -> tensor: return tensor(tt_dialect.make_range(ir_ty, start, end), ty) -def broadcast_to(x: object, shape: Sequence[int | constexpr]) -> tensor: - x = _to_tensor(x) +def broadcast_to(x: tensor, shape: Sequence[int]) -> tensor: if not x.type.is_block(): return splat(x, shape) elif x.shape == shape: return x - shape = [dim.__index__() for dim in shape] x_ir_type = ir.RankedTensorType(x.handle.type) result_ir_type = ir.RankedTensorType.get( shape, x_ir_type.element_type, x_ir_type.encoding @@ -893,22 +862,19 @@ def broadcast_to(x: object, shape: Sequence[int | constexpr]) -> tensor: ) -def splat(x: object, shape: Sequence[int | constexpr]) -> tensor: - x = _to_tensor(x) +def splat(x: tensor, shape: Sequence[int]) -> tensor: if x.type.is_block(): raise ValueError("cannot splat a block tensor") if len(shape) == 0: return x - shape = [dim.__index__() for dim in shape] result_ir_type = ir.RankedTensorType.get(shape, x.handle.type) return tensor( tt_dialect.splat(result_ir_type, x.handle), block_type(x.dtype, shape) ) -def expand_dims(x: object, axis: int) -> tensor: - x = _to_tensor(x) - dst_shape = [dim.__index__() for dim in x.shape] +def expand_dims(x: tensor, axis: int) -> tensor: + dst_shape = list(x.shape) dst_shape.insert(axis, 1) if not x.type.is_block(): return splat(input, dst_shape) @@ -929,7 +895,68 @@ def reshape(x: tensor, dst_shape: Sequence[int]) -> tensor: ) -dot = wrap_with_builder(tl.core.dot) +def _check_dot_operands(x_dtype: dtype, y_dtype: dtype, options: Any): + # TODO(slebedev): Ensure that the dtypes are supported by CUDA. + return + + +def dot( + x: tensor, + y: tensor, + acc: tensor | None = None, + allow_tf32: bool = True, + max_num_imprecise_acc: int | None = None, + out_dtype: dtype = float32, +) -> tensor: + if min(*x.shape, *y.shape) < 16: + raise ValueError("all dimensions of x and y must be >= 16 ") + if out_dtype.is_bf16(): + raise ValueError(f"out_dtype={out_dtype} is unsupported") + b: builder = builder.current + _check_dot_operands(x.dtype, y.dtype, b.options) + if x.dtype.is_int(): + if x.dtype != int8: + raise TypeError(f"unsupported dtype: {x.dtype}") + zero = tensor(b.get_int32(0), int32) + element_type = int32 + elif x.dtype.is_fp32() or x.dtype.is_bf16(): + zero = tensor(b.get_fp32(0), float32) + element_type = float32 + else: + if out_dtype.is_fp16(): + zero = tensor(b.get_fp16(0), float16) + else: + zero = tensor(b.get_fp32(0), float32) + element_type = out_dtype + + if element_type != out_dtype: + raise TypeError(f"out_dtype={out_dtype} does not match element type {element_type}") + + m, _ = x.shape + _, n = y.shape + result_type = block_type(element_type, [m, n]) + + if acc is None: + acc = splat(zero, [m, n]) + else: + assert acc.type == result_type + + if max_num_imprecise_acc is None: + if x.dtype.is_fp8() and y.dtype.is_fp8(): + max_num_imprecise_acc = b.options.max_num_imprecise_acc_default + else: + max_num_imprecise_acc = 0 + + return tensor( + tt_dialect.dot( + x.handle, + y.handle, + acc.handle if acc is not None else None, + allow_tf32, + max_num_imprecise_acc, + ), + result_type, + ) def atomic_cas( @@ -1008,8 +1035,7 @@ def atomic_add( return _atomic_rmw(op, ptr, val, mask, semantic, sync_scope) -def abs(x: object) -> tensor: - x = _to_tensor(x) +def abs(x: tensor) -> tensor: dtype = x.dtype if dtype.is_floating(): return tensor(math_dialect.absf(x.handle), x.type) @@ -1021,41 +1047,6 @@ def abs(x: object) -> tensor: raise ValueError(f"unsupported dtype: {dtype}") -def exp(x: object) -> tensor: - x = _to_tensor(x) - if x.dtype != float32 and x.dtype != float64: - raise ValueError(f"unsupported dtype: {x.dtype}") - return tensor(math_dialect.exp(x.handle), x.type) - - -def log(x: object) -> tensor: - x = _to_tensor(x) - if x.dtype != float32 and x.dtype != float64: - raise ValueError(f"unsupported dtype: {x.dtype}") - return tensor(math_dialect.log(x.handle), x.type) - - -def sqrt(x: object) -> tensor: - x = _to_tensor(x) - if x.dtype != float32 and x.dtype != float64: - raise ValueError(f"unsupported dtype: {x.dtype}") - return tensor(math_dialect.sqrt(x.handle), x.type) - - -def sin(x: object) -> tensor: - x = _to_tensor(x) - if x.dtype != float32 and x.dtype != float64: - raise ValueError(f"unsupported dtype: {x.dtype}") - return tensor(math_dialect.sin(x.handle), x.type) - - -def cos(x: object) -> tensor: - x = _to_tensor(x) - if x.dtype != float32 and x.dtype != float64: - raise ValueError(f"unsupported dtype: {x.dtype}") - return tensor(math_dialect.cos(x.handle), x.type) - - def multiple_of(x: tensor, values: Sequence[int]) -> tensor: assert max(1, len(x.shape)) == len(values) set_attr( @@ -1276,7 +1267,10 @@ def min(x: tensor, y: tensor) -> tensor: class semantic: - cast = wrap_with_builder(tl.semantic.cast) + + @staticmethod + def cast(x: tensor, dst_ty: dtype) -> tensor: + return _to_tensor(tl.semantic.cast(x, dst_ty, builder.current)) @staticmethod def where(cond: tensor, x: tensor, y: tensor) -> tensor: diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 2695ff8f2d16..af7b633a3ee2 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -193,6 +193,15 @@ def add_one(x_ref, o_ref): x = jnp.arange(64).reshape((8, 8)) np.testing.assert_allclose(add_one(x), x + 1) + def test_bool_array(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.bool_)) + def logical_and(x_ref, o_ref): + o_ref[()] = jnp.logical_and(x_ref[()], True) + + x = jnp.array(True) + self.assertTrue(jnp.all(logical_and(x))) + def test_vector_indexing(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),