Skip to content

Commit

Permalink
[pallas:gpu] Lower more complex primitives using JAX functions in ter…
Browse files Browse the repository at this point in the history
…ms of more basic primitives.

PiperOrigin-RevId: 575883386
  • Loading branch information
chr1sj0nes authored and jax authors committed Oct 23, 2023
1 parent 8fa287e commit b61af5a
Showing 1 changed file with 28 additions and 27 deletions.
55 changes: 28 additions & 27 deletions jax/_src/pallas/triton/lowering.py
Expand Up @@ -18,7 +18,7 @@
import dataclasses
import functools
import operator
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Callable, Dict, Sequence, Tuple
import zlib

import jax
Expand Down Expand Up @@ -432,23 +432,39 @@ def rule(ctx, *args, fn=fn, **kwargs):
triton_lowering_rules[primitive] = rule


def _clamp_lowering_rule(ctx: TritonLoweringRuleContext, min, operand, max):
operand = tl.math.max(operand, min, _builder=ctx.builder)
return tl.math.min(operand, max, _builder=ctx.builder)
def _integer_pow(a, *, y):
if y == 2:
return a * a
if y == 3:
return a * a * a
if y == -2:
return 1.0 / (a * a)
return jax.lax.pow(a, y)


triton_lowering_rules[lax.clamp_p] = _clamp_lowering_rule
def lower_fun(
fun: Callable[..., Any], *, multiple_results: bool
) -> Callable[..., Any]:
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)

def f_lowered(ctx: TritonLoweringRuleContext, *args, **params):
wrapped_fun = lu.wrap_init(fn, params)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
return out if multiple_results else out[0]

def _logistic_lowering_rule(ctx: TritonLoweringRuleContext, a):
one_ = tl.core._to_tensor(1.0, ctx.builder)
x = tl.exp(a.__neg__(_builder=ctx.builder), _builder=ctx.builder)
x = x.__add__(one_, _builder=ctx.builder)
x = one_.__truediv__(x, _builder=ctx.builder)
return x
return f_lowered


triton_lowering_rules[lax.logistic_p] = _logistic_lowering_rule
_JAX_FN_MAPPING = {
lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max),
lax.integer_pow_p: _integer_pow,
lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)),
}

for primitive, fn in _JAX_FN_MAPPING.items():
triton_lowering_rules[primitive] = lower_fun(fn, multiple_results=False)


def _div_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
Expand All @@ -471,21 +487,6 @@ def _iota_lowering_rule(
triton_lowering_rules[lax.iota_p] = _iota_lowering_rule


def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y):
if y == 2:
return a.__mul__(a, _builder=ctx.builder)
if y == 3:
return a.__mul__(a.__mul__(a, _builder=ctx.builder), _builder=ctx.builder)
if y == -2:
one_ = tl.core._to_tensor(1.0, ctx.builder)
a_sq = a.__mul__(a, _builder=ctx.builder)
return one_.__truediv__(a_sq, _builder=ctx.builder)
return tl.math.pow(a, y, _builder=ctx.builder)


triton_lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule


def _convert_element_type_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type
):
Expand Down

0 comments on commit b61af5a

Please sign in to comment.