Skip to content

Commit

Permalink
[pallas:gpu] Simplify broadcast_to, min, max lowering.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574204406
  • Loading branch information
chr1sj0nes authored and jax authors committed Oct 17, 2023
1 parent 2c9ea51 commit c16b893
Showing 1 changed file with 5 additions and 29 deletions.
34 changes: 5 additions & 29 deletions jax/_src/pallas/triton/lowering.py
Expand Up @@ -398,12 +398,15 @@ def _atomic_lowering_rule(
lax.ge_p: tl.semantic.greater_equal,
lax.lt_p: tl.semantic.less_than,
lax.le_p: tl.semantic.less_equal,
lax.max_p: tl.math.max,
lax.min_p: tl.math.min,
lax.shift_left_p: tl.semantic.shl,
lax.shift_right_arithmetic_p: tl.semantic.ashr,
lax.shift_right_logical_p: tl.semantic.lshr,
lax.nextafter_p: tl.math.nextafter,
ad_util.add_any_p: tl.semantic.add,
# Other ops.
indexing.broadcast_to_p: tl.broadcast_to,
primitives.atomic_cas_p: tl.atomic_cas,
primitives.max_contiguous_p: tl.max_contiguous,
primitives.multiple_of_p: tl.multiple_of,
Expand All @@ -424,7 +427,8 @@ def rule(ctx, *args, fn=fn, **kwargs):


def _clamp_lowering_rule(ctx: TritonLoweringRuleContext, min, operand, max):
return _min_lowering_rule(ctx, max_lowering_rule(ctx, min, operand), max)
operand = tl.math.max(operand, min, _builder=ctx.builder)
return tl.math.min(operand, max, _builder=ctx.builder)


triton_lowering_rules[lax.clamp_p] = _clamp_lowering_rule
Expand Down Expand Up @@ -476,14 +480,6 @@ def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y):
triton_lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule


def _min_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
pred = a.__lt__(b, _builder=ctx.builder)
return tl.semantic.where(pred, a, b, ctx.builder)


triton_lowering_rules[lax.min_p] = _min_lowering_rule


def _convert_element_type_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type
):
Expand All @@ -497,14 +493,6 @@ def _convert_element_type_lowering_rule(
)


def max_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
pred = a.__gt__(b, _builder=ctx.builder)
return tl.semantic.where(pred, a, b, ctx.builder)


triton_lowering_rules[lax.max_p] = max_lowering_rule


def select_n_lowering_rule(ctx: TritonLoweringRuleContext, pred, a, b):
return tl.semantic.where(pred, b, a, ctx.builder)

Expand All @@ -529,18 +517,6 @@ def _broadcast_in_dim_lowering_rule(
)


def _broadcast_to_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, shape
):
shape = map(tl.constexpr, shape)
return tl.broadcast_to(a, shape, _builder=ctx.builder)


triton_lowering_rules[indexing.broadcast_to_p] = (
_broadcast_to_lowering_rule
)


def _squeeze_lowering_rule(ctx: TritonLoweringRuleContext, a, *, dimensions):
del dimensions
return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None)
Expand Down

0 comments on commit c16b893

Please sign in to comment.