From c16b8936002bb8efc02cba28d477591830097c8e Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 17 Oct 2023 11:00:13 -0700 Subject: [PATCH] [pallas:gpu] Simplify `broadcast_to`, `min`, `max` lowering. PiperOrigin-RevId: 574204406 --- jax/_src/pallas/triton/lowering.py | 34 +++++------------------------- 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 2ebbe9e19d8b..cf337e5efb99 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -398,12 +398,15 @@ def _atomic_lowering_rule( lax.ge_p: tl.semantic.greater_equal, lax.lt_p: tl.semantic.less_than, lax.le_p: tl.semantic.less_equal, + lax.max_p: tl.math.max, + lax.min_p: tl.math.min, lax.shift_left_p: tl.semantic.shl, lax.shift_right_arithmetic_p: tl.semantic.ashr, lax.shift_right_logical_p: tl.semantic.lshr, lax.nextafter_p: tl.math.nextafter, ad_util.add_any_p: tl.semantic.add, # Other ops. + indexing.broadcast_to_p: tl.broadcast_to, primitives.atomic_cas_p: tl.atomic_cas, primitives.max_contiguous_p: tl.max_contiguous, primitives.multiple_of_p: tl.multiple_of, @@ -424,7 +427,8 @@ def rule(ctx, *args, fn=fn, **kwargs): def _clamp_lowering_rule(ctx: TritonLoweringRuleContext, min, operand, max): - return _min_lowering_rule(ctx, max_lowering_rule(ctx, min, operand), max) + operand = tl.math.max(operand, min, _builder=ctx.builder) + return tl.math.min(operand, max, _builder=ctx.builder) triton_lowering_rules[lax.clamp_p] = _clamp_lowering_rule @@ -476,14 +480,6 @@ def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y): triton_lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule -def _min_lowering_rule(ctx: TritonLoweringRuleContext, a, b): - pred = a.__lt__(b, _builder=ctx.builder) - return tl.semantic.where(pred, a, b, ctx.builder) - - -triton_lowering_rules[lax.min_p] = _min_lowering_rule - - def _convert_element_type_lowering_rule( ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type ): @@ -497,14 +493,6 @@ def _convert_element_type_lowering_rule( ) -def max_lowering_rule(ctx: TritonLoweringRuleContext, a, b): - pred = a.__gt__(b, _builder=ctx.builder) - return tl.semantic.where(pred, a, b, ctx.builder) - - -triton_lowering_rules[lax.max_p] = max_lowering_rule - - def select_n_lowering_rule(ctx: TritonLoweringRuleContext, pred, a, b): return tl.semantic.where(pred, b, a, ctx.builder) @@ -529,18 +517,6 @@ def _broadcast_in_dim_lowering_rule( ) -def _broadcast_to_lowering_rule( - ctx: TritonLoweringRuleContext, a, *, shape -): - shape = map(tl.constexpr, shape) - return tl.broadcast_to(a, shape, _builder=ctx.builder) - - -triton_lowering_rules[indexing.broadcast_to_p] = ( - _broadcast_to_lowering_rule -) - - def _squeeze_lowering_rule(ctx: TritonLoweringRuleContext, a, *, dimensions): del dimensions return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None)