diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 523019d07420..b7d7d54450eb 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -463,19 +463,19 @@ def _cumsum_lowering_rule( _TRITON_FN_MAPPING = { # Unary ops. lax.neg_p: tc.semantic.minus, - lax.abs_p: tc.abs, + lax.abs_p: tc.math.abs, lax.ceil_p: tc.math.ceil, lax.floor_p: tc.math.floor, - lax.exp_p: tc.exp, + lax.exp_p: tc.math.exp, lax.exp2_p: tc.math.exp2, lax.expm1_p: tc.math.expm1, - lax.log_p: tc.log, + lax.log_p: tc.math.log, lax.log1p_p: tc.math.log1p, - lax.sqrt_p: tc.sqrt, + lax.sqrt_p: tc.math.sqrt, lax.cbrt_p: tc.math.cbrt, lax.rsqrt_p: tc.math.rsqrt, - lax.sin_p: tc.sin, - lax.cos_p: tc.cos, + lax.sin_p: tc.math.sin, + lax.cos_p: tc.math.cos, lax.tan_p: tc.math.tan, lax.asin_p: tc.math.asin, lax.acos_p: tc.math.acos, diff --git a/jaxlib/triton/compat.py b/jaxlib/triton/compat.py index c45dd28f5d10..b2d7fbdf71e4 100644 --- a/jaxlib/triton/compat.py +++ b/jaxlib/triton/compat.py @@ -19,7 +19,7 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from functools import partial, wraps import threading @@ -1051,22 +1051,133 @@ def set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None: op.attributes[name] = attr +_LIBDEVICE_PATH = tl.math.libdevice_path() + + +def libdevice_extern_elementwise( + table: Mapping[tuple[dtype, ...], tuple[str, dtype]], + is_pure: bool = True, +): + def inner(arg: tensor): + try: + symbol, dtype = table[(arg.dtype,)] + except KeyError: + raise NotImplementedError(f"unsupported dtypes: {(arg.dtype,)}") from None + + return_type = dtype + if arg.type.is_block(): + return_type = block_type(dtype, arg.shape) + return tensor( + tt_dialect.extern_elementwise( + return_type.to_ir(builder.current), + [arg.handle], + libname="libdevice", + libpath=_LIBDEVICE_PATH, + symbol=symbol, + pure=is_pure, + ), + return_type, + ) + + return inner + + class math: - acos = wrap_with_builder(tl.math.acos) - acosh = wrap_with_builder(tl.math.acosh) - asin = wrap_with_builder(tl.math.asin) - asinh = wrap_with_builder(tl.math.asinh) - atan = wrap_with_builder(tl.math.atan) - atan2 = wrap_with_builder(tl.math.atan2) - atanh = wrap_with_builder(tl.math.atanh) - cbrt = wrap_with_builder(tl.math.cbrt) - ceil = wrap_with_builder(tl.math.ceil) - clz = wrap_with_builder(tl.math.clz) - cosh = wrap_with_builder(tl.math.cosh) - exp2 = wrap_with_builder(tl.math.exp2) - expm1 = wrap_with_builder(tl.math.expm1) - floor = wrap_with_builder(tl.math.floor) - log1p = wrap_with_builder(tl.math.log1p) + sin = libdevice_extern_elementwise({ + (float32,): ("__nv_sinf", float32), + (float64,): ("__nv_sin", float64), + }) + cos = libdevice_extern_elementwise({ + (float32,): ("__nv_cosf", float32), + (float64,): ("__nv_cos", float64), + }) + tan = libdevice_extern_elementwise({ + (float32,): ("__nv_tanf", float32), + (float64,): ("__nv_tan", float64), + }) + asin = libdevice_extern_elementwise({ + (float32,): ("__nv_asinf", float32), + (float64,): ("__nv_asin", float64), + }) + acos = libdevice_extern_elementwise({ + (float32,): ("__nv_acosf", float32), + (float64,): ("__nv_acos", float64), + }) + atan = libdevice_extern_elementwise({ + (float32,): ("__nv_atanf", float32), + (float64,): ("__nv_atan", float64), + }) + atan2 = libdevice_extern_elementwise({ + (float32,): ("__nv_atan2f", float32), + (float64,): ("__nv_atan2", float64), + }) + sinh = libdevice_extern_elementwise({ + (float32,): ("__nv_sinhf", float32), + (float64,): ("__nv_sinh", float64), + }) + cosh = libdevice_extern_elementwise({ + (float32,): ("__nv_coshf", float32), + (float64,): ("__nv_cosh", float64), + }) + tanh = libdevice_extern_elementwise({ + (float32,): ("__nv_tanhf", float32), + (float64,): ("__nv_tanh", float64), + }) + asinh = libdevice_extern_elementwise({ + (float32,): ("__nv_asinhf", float32), + (float64,): ("__nv_asinh", float64), + }) + acosh = libdevice_extern_elementwise({ + (float32,): ("__nv_acosf", float32), + (float64,): ("__nv_acosh", float64), + }) + atanh = libdevice_extern_elementwise({ + (float32,): ("__nv_atanhf", float32), + (float64,): ("__nv_atanh", float64), + }) + + cbrt = libdevice_extern_elementwise({ + (float32,): ("__nv_cbrtf", float32), + (float64,): ("__nv_cbrt", float64), + }) + clz = libdevice_extern_elementwise({ + (int32,): ("__nv_clz", int32), + (int64,): ("__nv_clzll", int64), + }) + exp = libdevice_extern_elementwise({ + (float32,): ("__nv_expf", float32), + (float64,): ("__nv_exp", float64), + }) + exp2 = libdevice_extern_elementwise({ + (float32,): ("__nv_exp2f", float32), + (float64,): ("__nv_exp2", float64), + }) + expm1 = libdevice_extern_elementwise({ + (float32,): ("__nv_expm1f", float32), + (float64,): ("__nv_expm1", float64), + }) + log = libdevice_extern_elementwise({ + (float32,): ("__nv_logf", float32), + (float64,): ("__nv_log", float64), + }) + log1p = libdevice_extern_elementwise({ + (float32,): ("__nv_log1pf", float32), + (float64,): ("__nv_log1p", float64), + }) + floor = libdevice_extern_elementwise({ + (float32,): ("__nv_floorf", float32), + (float64,): ("__nv_floor", float64), + }) + ceil = libdevice_extern_elementwise({ + (float32,): ("__nv_ceilf", float32), + (float64,): ("__nv_ceil", float64), + }) + abs = libdevice_extern_elementwise({ + (int32,): ("__nv_abs", int32), + (int64,): ("__nv_llabs", int64), + (float32,): ("__nv_fabsf", float32), + (float64,): ("__nv_fabs", float64), + }) max = partial( wrap_with_builder(tl.math.max), propagate_nan=tl.PropagateNan.NONE, @@ -1076,12 +1187,19 @@ class math: propagate_nan=tl.PropagateNan.NONE, ) nextafter = wrap_with_builder(tl.math.nextafter) - popc = wrap_with_builder(tl.math.popc) + popc = libdevice_extern_elementwise({ + (int32,): ("__nv_popc", int32), + (int64,): ("__nv_popcll", int64), + }) pow = wrap_with_builder(tl.math.pow) - rsqrt = wrap_with_builder(tl.math.rsqrt) - sinh = wrap_with_builder(tl.math.sinh) - tan = wrap_with_builder(tl.math.tan) - tanh = wrap_with_builder(tl.math.tanh) + sqrt = libdevice_extern_elementwise({ + (float32,): ("__nv_sqrtf", float32), + (float64,): ("__nv_sqrt", float64), + }) + rsqrt = libdevice_extern_elementwise({ + (float32,): ("__nv_rsqrtf", float32), + (float64,): ("__nv_rsqrt", float64), + }) class semantic: