Skip to content

Commit

Permalink
Added a GPU-specific approximate tanh to Pallas
Browse files Browse the repository at this point in the history
  • Loading branch information
superbobry authored and jax authors committed Apr 26, 2024
1 parent 94766b8 commit 268b39d
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 2 deletions.
19 changes: 19 additions & 0 deletions jax/_src/pallas/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,29 @@ py_library_providing_imports_info(
deps = [
":lowering",
":pallas_call_registration",
":primitives",
"//jax/_src/lib",
],
)

pytype_strict_library(
name = "primitives",
srcs = ["primitives.py"],
deps = [
":lowering",
"//jax",
"//jax:ad_util",
"//jax:api_util",
"//jax:core",
"//jax:mlir",
"//jax:partial_eval",
"//jax:source_info_util",
"//jax:util",
"//jax/_src/lib",
"//jax/_src/pallas",
] + py_deps("numpy"),
)

pytype_strict_library(
name = "lowering",
srcs = ["lowering.py"],
Expand Down
1 change: 1 addition & 0 deletions jax/_src/pallas/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Contains Triton-specific pallas modules."""

from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.pallas.triton.primitives import approx_tanh


try:
Expand Down
12 changes: 10 additions & 2 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import functools
import math
import operator
from typing import Any, Callable
from typing import Any, Callable, TypeVar

import jax
from jax import lax
Expand Down Expand Up @@ -56,11 +56,12 @@
import jax.numpy as jnp
import numpy as np


# TODO(sharadmv): Enable type checking.
# mypy: ignore-errors
# pytype: skip-file

_T = TypeVar("_T")

map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

Expand Down Expand Up @@ -167,6 +168,13 @@ def _bcast(
triton_lowering_rules = {}


def register_lowering(primitive: jax_core.Primitive) -> Callable[[_T], _T]:
def wrapper(fn):
triton_lowering_rules[primitive] = fn
return fn
return wrapper


def _process_grid_to_3d_grid(grid_mapping: GridMapping):
launch_grid = []
launch_grid_to_pallas_grid = []
Expand Down
68 changes: 68 additions & 0 deletions jax/_src/pallas/triton/primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for GPU-specific JAX primitives."""

from __future__ import annotations

import jax
from jax import core as jax_core
from jax._src.lib.triton import dialect as tt_dialect
from jax._src.pallas.triton import lowering
import jax.numpy as jnp


def approx_tanh(x: jax.typing.ArrayLike) -> jax.Array:
r"""Elementwise approximate hyperbolic tangent: :math:`\mathrm{tanh}(x)`.
See
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-tanh.
"""
return approx_tanh_p.bind(x)


approx_tanh_p = jax_core.Primitive("approx_tanh_p")


@approx_tanh_p.def_abstract_eval
def _approx_tanh_abstract_eval(
x_aval: jax_core.ShapedArray,
) -> jax_core.ShapedArray:
if jnp.dtype(x_aval.dtype) not in (jnp.float16, jnp.bfloat16, jnp.float32):
raise TypeError(f"approx_tanh does not accept {x_aval.dtype} arrays")
return x_aval


@lowering.register_lowering(approx_tanh_p)
def _approx_tanh_lowering(ctx: lowering.LoweringContext, x):
[x_aval] = ctx.avals_in
if x_aval.dtype == jnp.float16:
asm = "tanh.approx.f16 $0, $1;"
constraint = "h"
elif x_aval.dtype == jnp.bfloat16:
asm = "tanh.approx.bf16 $0, $1;"
constraint = "h"
elif x_aval.dtype == jnp.float32:
asm = "tanh.approx.f32 $0, $1;"
constraint = "f"
else:
raise NotImplementedError(f"Unsupported dtype: {x_aval.dtype}")
return tt_dialect.elementwise_inline_asm(
[x.type],
asm,
constraints=f"={constraint},{constraint}",
pure=True,
packed_element=1,
args=[x],
)
1 change: 1 addition & 0 deletions jax/experimental/pallas/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

"""Contains Triton specific Pallas functions."""
from jax._src.pallas import triton
from jax._src.pallas.triton import approx_tanh
get_compute_capability = triton.get_compute_capability
del triton
23 changes: 23 additions & 0 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,29 @@ def kernel(o_ref):

np.testing.assert_allclose(f(), kernel())

@parameterized.parameters("float16", "bfloat16", "float32")
def test_approx_tanh(self, dtype):
if self.INTERPRET:
self.skipTest("approx_tanh is not supported in interpreter mode")
if dtype == "bfloat16" and not self.check_gpu_capability_at_least(90):
self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90")

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1
)
def kernel(x_ref, o_ref):
o_ref[...] = plgpu.approx_tanh(x_ref[...])

x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype)
# We upcast to float32 because NumPy <2.0 does not handle custom dtypes
# properly. See https://github.com/google/jax/issues/11014.
np.testing.assert_allclose(
kernel(x).astype(jnp.float32),
jnp.tanh(x).astype(jnp.float32),
atol=5e-3,
rtol=5e-3,
)


class PallasOpsInterpretTest(PallasOpsTest):
INTERPRET = True
Expand Down

0 comments on commit 268b39d

Please sign in to comment.