Skip to content

Commit

Permalink
[pallas] dot fixes.
Browse files Browse the repository at this point in the history
- Check that operands are 2D.
- Set `preferred_element_type`.
- Fix dot output type on GPU.

PiperOrigin-RevId: 573895904
  • Loading branch information
chr1sj0nes authored and jax authors committed Oct 16, 2023
1 parent 675cb15 commit dcc92e3
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 27 deletions.
9 changes: 7 additions & 2 deletions jax/_src/pallas/primitives.py
Expand Up @@ -432,13 +432,18 @@ def store(x_ref, idx, val, *, mask=None, eviction_policy="") -> None:

def dot(a, b, trans_a: bool = False, trans_b: bool = False,
allow_tf32: bool | None = None, precision=None):
if (a.ndim != 2) or (b.ndim != 2):
raise ValueError("`a` and `b` must be 2D arrays.")
lhs_contract_dim = 0 if trans_a else 1
rhs_contract_dim = 0 if not trans_b else 1
if allow_tf32 is not None:
if precision is not None:
raise ValueError("Only one of allow_tf32 and precision can be specified")
precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST
return jax.lax.dot_general(
a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
a,
b,
dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
precision=precision,
preferred_element_type=None).astype(jnp.float32)
preferred_element_type=jnp.float32,
)
57 changes: 34 additions & 23 deletions jax/_src/pallas/triton/lowering.py
Expand Up @@ -144,6 +144,22 @@ def _eval_index_map(
)


def _convert_dtype(dtype: jnp.dtype) -> tl.dtype:
if dtype == jnp.float32:
return tl.float32
elif dtype == jnp.float64:
return tl.float64
elif dtype == jnp.float16:
return tl.float16
elif dtype == jnp.bfloat16:
return tl.bfloat16
elif dtype == jnp.int32:
return tl.int32
elif dtype == jnp.int64:
return tl.int64
raise ValueError(f"Unhandled dtype: {dtype}")


triton_lowering_rules = {}


Expand Down Expand Up @@ -473,21 +489,7 @@ def _convert_element_type_lowering_rule(
):
if new_dtype == ctx.avals_in[0].dtype:
return a
if new_dtype == jnp.float32:
new_dtype = tl.float32
elif new_dtype == jnp.float64:
new_dtype = tl.float64
elif new_dtype == jnp.float16:
new_dtype = tl.float16
elif new_dtype == jnp.bfloat16:
new_dtype = tl.bfloat16
elif new_dtype == jnp.int32:
new_dtype = tl.int32
elif new_dtype == jnp.int64:
new_dtype = tl.int64
else:
raise ValueError(f"Unhandled dtype: {new_dtype}")
return tl.semantic.cast(a, new_dtype, ctx.builder)
return tl.semantic.cast(a, _convert_dtype(new_dtype), ctx.builder)


triton_lowering_rules[lax.convert_element_type_p] = (
Expand Down Expand Up @@ -868,23 +870,32 @@ def _dot_general_lowering(
precision,
preferred_element_type
):
contract_dims, batch_dims = dimension_numbers
del preferred_element_type # Unused.
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
assert batch_dims == ((), ())
(a_contract_dim,) = contract_dims[0]
(b_contract_dim,) = contract_dims[1]
trans_a = a_contract_dim == 0
trans_b = b_contract_dim == 1
if trans_a:

if a_contract_dim == 0:
a = tl.trans(a, _builder=ctx.builder)
if trans_b:
if b_contract_dim == 1:
b = tl.trans(b, _builder=ctx.builder)

if precision is None:
allow_tf32 = True
else:
prec_a, prec_b = precision
allow_tf32 = prec_a in _TF32_PRECISIONS or prec_b in _TF32_PRECISIONS
return tl.dot(a, b, _builder=ctx.builder, allow_tf32=allow_tf32)

out_dtype = acc_dtype = _convert_dtype(ctx.avals_out[0].dtype)
if acc_dtype not in (tl.int32, tl.float16):
acc_dtype = tl.float32

return tl.dot(
a,
b,
allow_tf32=allow_tf32,
out_dtype=acc_dtype,
_builder=ctx.builder,
).to(out_dtype, _builder=ctx.builder)


triton_lowering_rules[lax.dot_general_p] = _dot_general_lowering
Expand Down
4 changes: 2 additions & 2 deletions tests/pallas/pallas_test.py
Expand Up @@ -86,7 +86,7 @@ def body(i, acc_ref):
jax.lax.broadcast_in_dim(idx_k, (bk, bn), (0,)),
jax.lax.broadcast_in_dim(idx_n, (bk, bn), (1,)))
x_block, y_block = x_ref[x_idx], y_ref[y_idx]
out = jnp.dot(x_block, y_block)
out = pl.dot(x_block, y_block)
acc_ref[:, :] += out
acc = for_loop(k // bk, body, acc).astype(o_ref.dtype)
o_idx = (
Expand Down Expand Up @@ -115,7 +115,7 @@ def matmul_kernel(x_ref, y_ref, o_ref):
def body(i, acc_ref):
x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk)))
y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None)))
acc_ref[:, :] += jnp.dot(x_block, y_block)
acc_ref[:, :] += pl.dot(x_block, y_block)
acc = for_loop(k // bk, body, acc).astype(o_ref.dtype)
o_ref[:, :] = acc
return matmul_kernel(x, y)
Expand Down

0 comments on commit dcc92e3

Please sign in to comment.