Skip to content

Commit

Permalink
Lower a subset of math primitives directly to Triton IR
Browse files Browse the repository at this point in the history
Note that all primitives are now lowered to libdevice calls. Previously,
some of them were lowered to the MLIR arith dialect, and some to libdevice
calls, without any apparent reason for doing so.

PiperOrigin-RevId: 601259707
  • Loading branch information
superbobry authored and jax authors committed Jan 24, 2024
1 parent cfb6250 commit f15cad4
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 27 deletions.
12 changes: 6 additions & 6 deletions jax/_src/pallas/triton/lowering.py
Expand Up @@ -463,19 +463,19 @@ def _cumsum_lowering_rule(
_TRITON_FN_MAPPING = {
# Unary ops.
lax.neg_p: tc.semantic.minus,
lax.abs_p: tc.abs,
lax.abs_p: tc.math.abs,
lax.ceil_p: tc.math.ceil,
lax.floor_p: tc.math.floor,
lax.exp_p: tc.exp,
lax.exp_p: tc.math.exp,
lax.exp2_p: tc.math.exp2,
lax.expm1_p: tc.math.expm1,
lax.log_p: tc.log,
lax.log_p: tc.math.log,
lax.log1p_p: tc.math.log1p,
lax.sqrt_p: tc.sqrt,
lax.sqrt_p: tc.math.sqrt,
lax.cbrt_p: tc.math.cbrt,
lax.rsqrt_p: tc.math.rsqrt,
lax.sin_p: tc.sin,
lax.cos_p: tc.cos,
lax.sin_p: tc.math.sin,
lax.cos_p: tc.math.cos,
lax.tan_p: tc.math.tan,
lax.asin_p: tc.math.asin,
lax.acos_p: tc.math.acos,
Expand Down
160 changes: 139 additions & 21 deletions jaxlib/triton/compat.py
Expand Up @@ -19,7 +19,7 @@

from __future__ import annotations

from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from functools import partial, wraps
import threading

Expand Down Expand Up @@ -1051,22 +1051,133 @@ def set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None:
op.attributes[name] = attr


_LIBDEVICE_PATH = tl.math.libdevice_path()


def libdevice_extern_elementwise(
table: Mapping[tuple[dtype, ...], tuple[str, dtype]],
is_pure: bool = True,
):
def inner(arg: tensor):
try:
symbol, dtype = table[(arg.dtype,)]
except KeyError:
raise NotImplementedError(f"unsupported dtypes: {(arg.dtype,)}") from None

return_type = dtype
if arg.type.is_block():
return_type = block_type(dtype, arg.shape)
return tensor(
tt_dialect.extern_elementwise(
return_type.to_ir(builder.current),
[arg.handle],
libname="libdevice",
libpath=_LIBDEVICE_PATH,
symbol=symbol,
pure=is_pure,
),
return_type,
)

return inner


class math:
acos = wrap_with_builder(tl.math.acos)
acosh = wrap_with_builder(tl.math.acosh)
asin = wrap_with_builder(tl.math.asin)
asinh = wrap_with_builder(tl.math.asinh)
atan = wrap_with_builder(tl.math.atan)
atan2 = wrap_with_builder(tl.math.atan2)
atanh = wrap_with_builder(tl.math.atanh)
cbrt = wrap_with_builder(tl.math.cbrt)
ceil = wrap_with_builder(tl.math.ceil)
clz = wrap_with_builder(tl.math.clz)
cosh = wrap_with_builder(tl.math.cosh)
exp2 = wrap_with_builder(tl.math.exp2)
expm1 = wrap_with_builder(tl.math.expm1)
floor = wrap_with_builder(tl.math.floor)
log1p = wrap_with_builder(tl.math.log1p)
sin = libdevice_extern_elementwise({
(float32,): ("__nv_sinf", float32),
(float64,): ("__nv_sin", float64),
})
cos = libdevice_extern_elementwise({
(float32,): ("__nv_cosf", float32),
(float64,): ("__nv_cos", float64),
})
tan = libdevice_extern_elementwise({
(float32,): ("__nv_tanf", float32),
(float64,): ("__nv_tan", float64),
})
asin = libdevice_extern_elementwise({
(float32,): ("__nv_asinf", float32),
(float64,): ("__nv_asin", float64),
})
acos = libdevice_extern_elementwise({
(float32,): ("__nv_acosf", float32),
(float64,): ("__nv_acos", float64),
})
atan = libdevice_extern_elementwise({
(float32,): ("__nv_atanf", float32),
(float64,): ("__nv_atan", float64),
})
atan2 = libdevice_extern_elementwise({
(float32,): ("__nv_atan2f", float32),
(float64,): ("__nv_atan2", float64),
})
sinh = libdevice_extern_elementwise({
(float32,): ("__nv_sinhf", float32),
(float64,): ("__nv_sinh", float64),
})
cosh = libdevice_extern_elementwise({
(float32,): ("__nv_coshf", float32),
(float64,): ("__nv_cosh", float64),
})
tanh = libdevice_extern_elementwise({
(float32,): ("__nv_tanhf", float32),
(float64,): ("__nv_tanh", float64),
})
asinh = libdevice_extern_elementwise({
(float32,): ("__nv_asinhf", float32),
(float64,): ("__nv_asinh", float64),
})
acosh = libdevice_extern_elementwise({
(float32,): ("__nv_acosf", float32),
(float64,): ("__nv_acosh", float64),
})
atanh = libdevice_extern_elementwise({
(float32,): ("__nv_atanhf", float32),
(float64,): ("__nv_atanh", float64),
})

cbrt = libdevice_extern_elementwise({
(float32,): ("__nv_cbrtf", float32),
(float64,): ("__nv_cbrt", float64),
})
clz = libdevice_extern_elementwise({
(int32,): ("__nv_clz", int32),
(int64,): ("__nv_clzll", int64),
})
exp = libdevice_extern_elementwise({
(float32,): ("__nv_expf", float32),
(float64,): ("__nv_exp", float64),
})
exp2 = libdevice_extern_elementwise({
(float32,): ("__nv_exp2f", float32),
(float64,): ("__nv_exp2", float64),
})
expm1 = libdevice_extern_elementwise({
(float32,): ("__nv_expm1f", float32),
(float64,): ("__nv_expm1", float64),
})
log = libdevice_extern_elementwise({
(float32,): ("__nv_logf", float32),
(float64,): ("__nv_log", float64),
})
log1p = libdevice_extern_elementwise({
(float32,): ("__nv_log1pf", float32),
(float64,): ("__nv_log1p", float64),
})
floor = libdevice_extern_elementwise({
(float32,): ("__nv_floorf", float32),
(float64,): ("__nv_floor", float64),
})
ceil = libdevice_extern_elementwise({
(float32,): ("__nv_ceilf", float32),
(float64,): ("__nv_ceil", float64),
})
abs = libdevice_extern_elementwise({
(int32,): ("__nv_abs", int32),
(int64,): ("__nv_llabs", int64),
(float32,): ("__nv_fabsf", float32),
(float64,): ("__nv_fabs", float64),
})
max = partial(
wrap_with_builder(tl.math.max),
propagate_nan=tl.PropagateNan.NONE,
Expand All @@ -1076,12 +1187,19 @@ class math:
propagate_nan=tl.PropagateNan.NONE,
)
nextafter = wrap_with_builder(tl.math.nextafter)
popc = wrap_with_builder(tl.math.popc)
popc = libdevice_extern_elementwise({
(int32,): ("__nv_popc", int32),
(int64,): ("__nv_popcll", int64),
})
pow = wrap_with_builder(tl.math.pow)
rsqrt = wrap_with_builder(tl.math.rsqrt)
sinh = wrap_with_builder(tl.math.sinh)
tan = wrap_with_builder(tl.math.tan)
tanh = wrap_with_builder(tl.math.tanh)
sqrt = libdevice_extern_elementwise({
(float32,): ("__nv_sqrtf", float32),
(float64,): ("__nv_sqrt", float64),
})
rsqrt = libdevice_extern_elementwise({
(float32,): ("__nv_rsqrtf", float32),
(float64,): ("__nv_rsqrt", float64),
})


class semantic:
Expand Down

0 comments on commit f15cad4

Please sign in to comment.