Skip to content

Commit

Permalink
Fixed a typo in min/max Triton lowering rules
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604424404
  • Loading branch information
superbobry authored and jax authors committed Feb 5, 2024
1 parent b53f757 commit 9e94e6e
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions jaxlib/triton/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,27 +672,27 @@ def max(x: tensor, y: tensor) -> tensor:
assert x.shape == y.shape
if x.dtype.is_floating():
# TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
return tensor(arith_dialect.maxnumf(x.handle, y.handle), x.dtype)
return tensor(arith_dialect.maxnumf(x.handle, y.handle), x.type)
if not x.dtype.is_int():
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
elif x.dtype.is_int_signed():
return tensor(arith_dialect.maxsi(x.handle, y.handle), x.dtype)
return tensor(arith_dialect.maxsi(x.handle, y.handle), x.type)
else:
return tensor(arith_dialect.maxui(x.handle, y.handle), x.dtype)
return tensor(arith_dialect.maxui(x.handle, y.handle), x.type)

@staticmethod
def min(x: tensor, y: tensor) -> tensor:
# TODO(slebedev): Consider allowing customizing nan behavior.
assert x.shape == y.shape
if x.dtype.is_floating():
# TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
return tensor(arith_dialect.minnumf(x.handle, y.handle), x.dtype)
return tensor(arith_dialect.minnumf(x.handle, y.handle), x.type)
if not x.dtype.is_int():
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
elif x.dtype.is_int_signed():
return tensor(arith_dialect.minsi(x.handle, y.handle), x.dtype)
return tensor(arith_dialect.minsi(x.handle, y.handle), x.type)
else:
return tensor(arith_dialect.minui(x.handle, y.handle), x.dtype)
return tensor(arith_dialect.minui(x.handle, y.handle), x.type)

sin = libdevice_extern_elementwise({
(float32,): ("__nv_sinf", float32),
Expand Down

0 comments on commit 9e94e6e

Please sign in to comment.