From dcc92e3c5da7e6ed955bedbe45fc307b88fc57d6 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 16 Oct 2023 12:35:06 -0700 Subject: [PATCH] [pallas] `dot` fixes. - Check that operands are 2D. - Set `preferred_element_type`. - Fix dot output type on GPU. PiperOrigin-RevId: 573895904 --- jax/_src/pallas/primitives.py | 9 +++-- jax/_src/pallas/triton/lowering.py | 57 ++++++++++++++++++------------ tests/pallas/pallas_test.py | 4 +-- 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 25e4f9f80ca8..8ceee708f406 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -432,6 +432,8 @@ def store(x_ref, idx, val, *, mask=None, eviction_policy="") -> None: def dot(a, b, trans_a: bool = False, trans_b: bool = False, allow_tf32: bool | None = None, precision=None): + if (a.ndim != 2) or (b.ndim != 2): + raise ValueError("`a` and `b` must be 2D arrays.") lhs_contract_dim = 0 if trans_a else 1 rhs_contract_dim = 0 if not trans_b else 1 if allow_tf32 is not None: @@ -439,6 +441,9 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False, raise ValueError("Only one of allow_tf32 and precision can be specified") precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST return jax.lax.dot_general( - a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())), + a, + b, + dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())), precision=precision, - preferred_element_type=None).astype(jnp.float32) + preferred_element_type=jnp.float32, + ) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 91e018b8fc78..2ebbe9e19d8b 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -144,6 +144,22 @@ def _eval_index_map( ) +def _convert_dtype(dtype: jnp.dtype) -> tl.dtype: + if dtype == jnp.float32: + return tl.float32 + elif dtype == jnp.float64: + return tl.float64 + elif dtype == jnp.float16: + return tl.float16 + elif dtype == jnp.bfloat16: + return tl.bfloat16 + elif dtype == jnp.int32: + return tl.int32 + elif dtype == jnp.int64: + return tl.int64 + raise ValueError(f"Unhandled dtype: {dtype}") + + triton_lowering_rules = {} @@ -473,21 +489,7 @@ def _convert_element_type_lowering_rule( ): if new_dtype == ctx.avals_in[0].dtype: return a - if new_dtype == jnp.float32: - new_dtype = tl.float32 - elif new_dtype == jnp.float64: - new_dtype = tl.float64 - elif new_dtype == jnp.float16: - new_dtype = tl.float16 - elif new_dtype == jnp.bfloat16: - new_dtype = tl.bfloat16 - elif new_dtype == jnp.int32: - new_dtype = tl.int32 - elif new_dtype == jnp.int64: - new_dtype = tl.int64 - else: - raise ValueError(f"Unhandled dtype: {new_dtype}") - return tl.semantic.cast(a, new_dtype, ctx.builder) + return tl.semantic.cast(a, _convert_dtype(new_dtype), ctx.builder) triton_lowering_rules[lax.convert_element_type_p] = ( @@ -868,15 +870,13 @@ def _dot_general_lowering( precision, preferred_element_type ): - contract_dims, batch_dims = dimension_numbers + del preferred_element_type # Unused. + ((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers assert batch_dims == ((), ()) - (a_contract_dim,) = contract_dims[0] - (b_contract_dim,) = contract_dims[1] - trans_a = a_contract_dim == 0 - trans_b = b_contract_dim == 1 - if trans_a: + + if a_contract_dim == 0: a = tl.trans(a, _builder=ctx.builder) - if trans_b: + if b_contract_dim == 1: b = tl.trans(b, _builder=ctx.builder) if precision is None: @@ -884,7 +884,18 @@ def _dot_general_lowering( else: prec_a, prec_b = precision allow_tf32 = prec_a in _TF32_PRECISIONS or prec_b in _TF32_PRECISIONS - return tl.dot(a, b, _builder=ctx.builder, allow_tf32=allow_tf32) + + out_dtype = acc_dtype = _convert_dtype(ctx.avals_out[0].dtype) + if acc_dtype not in (tl.int32, tl.float16): + acc_dtype = tl.float32 + + return tl.dot( + a, + b, + allow_tf32=allow_tf32, + out_dtype=acc_dtype, + _builder=ctx.builder, + ).to(out_dtype, _builder=ctx.builder) triton_lowering_rules[lax.dot_general_p] = _dot_general_lowering diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index b11c3d92b516..7ead7449a576 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -86,7 +86,7 @@ def body(i, acc_ref): jax.lax.broadcast_in_dim(idx_k, (bk, bn), (0,)), jax.lax.broadcast_in_dim(idx_n, (bk, bn), (1,))) x_block, y_block = x_ref[x_idx], y_ref[y_idx] - out = jnp.dot(x_block, y_block) + out = pl.dot(x_block, y_block) acc_ref[:, :] += out acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) o_idx = ( @@ -115,7 +115,7 @@ def matmul_kernel(x_ref, y_ref, o_ref): def body(i, acc_ref): x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk))) y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None))) - acc_ref[:, :] += jnp.dot(x_block, y_block) + acc_ref[:, :] += pl.dot(x_block, y_block) acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) o_ref[:, :] = acc return matmul_kernel(x, y)