Skip to content

Commit

Permalink
[MHLO] Migrate GPU select_and_gather_add translation rule to MHLO.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 447472634
  • Loading branch information
hawkinsp authored and jax authors committed May 9, 2022
1 parent 4fa4872 commit e94221b
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 62 deletions.
137 changes: 77 additions & 60 deletions jax/_src/lax/windowed_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np

import jax._src.lib
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
Expand All @@ -34,17 +35,11 @@
import jax._src.lax.slicing as slicing
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
import jax._src.util as util

map = util.safe_map
zip = util.safe_zip

xb = xla_bridge
xc = xla_client
xops = xla_client.ops

Array = Any


Expand Down Expand Up @@ -627,49 +622,66 @@ def _select_and_gather_add_shape_rule(
operand, window_dimensions, window_strides, padding, base_dilation,
window_dilation)

def _select_and_gather_add_translation(
ctx, avals_in, avals_out, tangents, operand, *, select_prim,
def _select_and_gather_add_lowering(
ctx, tangents, operand, *, select_prim,
window_dimensions, window_strides, padding, base_dilation, window_dilation,
max_bits=64):
c = ctx.builder
tangents_aval, operand_aval, = avals_in
_, operand_aval, = ctx.avals_in
out_aval, = ctx.avals_out
dtype = operand_aval.dtype
etype = xla.dtype_to_primitive_type(dtype)
etype = mlir.dtype_to_ir_type(dtype)
nbits = dtypes.finfo(dtype).bits

assert nbits <= max_bits
double_word_reduction = nbits * 2 <= max_bits

const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype))
const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype),
canonicalize_types=False)

if jax._src.lib.mlir_api_version >= 9:
def _broadcast(x, dims):
return mhlo.BroadcastOp(x, mlir.dense_int_elements(dims))
else:
def _broadcast(x, dims):
etype = ir.RankedTensorType(x.type).element_type
return mhlo.BroadcastOp(ir.RankedTensorType(dims, etype), x,
mlir.dense_int_elements(dims))

if double_word_reduction:
# TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
# we implement a pair-wise ReduceWindow by packing two k-bit values into
# 2k-bit unsigned integer using bit tricks.
word_dtype = lax._UINT_DTYPES[nbits]
double_word_dtype = lax._UINT_DTYPES[nbits * 2]
word_type = xla.dtype_to_primitive_type(word_dtype)
double_word_type = xla.dtype_to_primitive_type(double_word_dtype)
word_type = mlir.dtype_to_ir_type(word_dtype)
double_word_type = mlir.dtype_to_ir_type(double_word_dtype)

# Packs two values into a tuple.
def pack(a, b):
a = xops.BitcastConvertType(a, word_type)
b = xops.BitcastConvertType(b, word_type)
a = xops.ConvertElementType(a, double_word_type)
b = xops.ConvertElementType(b, double_word_type)
a = xops.ShiftLeft(a, const(c, double_word_dtype, nbits))
return xops.Or(a, b)
a_dims = ir.RankedTensorType(a.type).shape
b_dims = ir.RankedTensorType(b.type).shape
a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
a = mhlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a)
b = mhlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b)
a = mhlo.ShiftLeftOp(a,
_broadcast(const(double_word_dtype, nbits), a_dims))
return mhlo.OrOp(a, b)

# Unpacks the first element of a tuple.
def fst(c, t):
st = xops.ShiftRightLogical(t, const(c, double_word_dtype, nbits))
return xops.BitcastConvertType(xops.ConvertElementType(st, word_type),
etype)
def fst(t):
dims = ir.RankedTensorType(t.type).shape
st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
return mhlo.BitcastConvertOp(
ir.RankedTensorType.get(dims, etype),
mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result

# Unpacks the second element of a tuple.
def snd(t):
return xops.BitcastConvertType(xops.ConvertElementType(t, word_type),
etype)
dims = ir.RankedTensorType(t.type).shape
return mhlo.BitcastConvertOp(
ir.RankedTensorType.get(dims, etype),
mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result

else:
# The double-word trick above only works if we have a sufficiently large
Expand All @@ -686,46 +698,55 @@ def snd(t):
nmant = r_nbits - nexp - 1

double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits]
word_type = xla.dtype_to_primitive_type(word_dtype)
double_word_type = word_type = mlir.dtype_to_ir_type(word_dtype)

# Packs two values into a tuple.
def pack(a, b):
a = xops.ReducePrecision(a, exponent_bits=nexp, mantissa_bits=nmant)
b = xops.ReducePrecision(b, exponent_bits=nexp, mantissa_bits=nmant)
a = xops.BitcastConvertType(a, word_type)
b = xops.BitcastConvertType(b, word_type)
b = xops.ShiftRightLogical(b, const(c, word_dtype, r_nbits))
return xops.Or(a, b)
a_dims = ir.RankedTensorType(a.type).shape
b_dims = ir.RankedTensorType(b.type).shape
a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
b = mhlo.ShiftRightLogicalOp(
b, _broadcast(const(word_dtype, r_nbits), b_dims))
return mhlo.OrOp(a, b)

# Unpacks the first element of a tuple.
def fst(c, t):
st = xops.And(t, const(c, word_dtype, ((1 << r_nbits) - 1) << r_nbits))
return xops.BitcastConvertType(st, etype)
def fst(t):
st = mhlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
return mhlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
st).result

# Unpacks the second element of a tuple.
def snd(t):
return xops.BitcastConvertType(
xops.ShiftLeft(t, const(c, word_dtype, r_nbits)), etype)

def reducer():
c = xc.XlaBuilder("select_and_gather_pair_reducer")
x = xla.parameter(c, 0,
xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
y = xla.parameter(c, 1,
xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
assert select_prim is lax.ge_p or select_prim is lax.le_p
which = xops.Ge if select_prim is lax.ge_p else xops.Le
xops.Select(which(fst(c, x), fst(c, y)), x, y)
return c.build()

dims = ir.RankedTensorType(t.type).shape
return mhlo.BitcastConvertOp(
ir.RankedTensorType.get(dims, etype),
mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims))
).result

assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
init = -np.inf if select_prim is lax.ge_p else np.inf
out = xops.ReduceWindowWithGeneralPadding(
pack(operand, tangents), pack(const(c, dtype, init), const(c, dtype, 0)),
reducer(), window_dimensions, window_strides, base_dilation,
window_dilation, padding)
return [snd(out)]
rw = mhlo.ReduceWindowOp(
[ir.RankedTensorType.get(out_aval.shape, double_word_type)],
pack(operand, tangents), pack(const(dtype, init), const(dtype, 0)),
mlir.dense_int_elements(window_dimensions),
mlir.dense_int_elements(window_strides),
mlir.dense_int_elements(base_dilation),
mlir.dense_int_elements(window_dilation),
ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
scalar_type = ir.RankedTensorType.get([], double_word_type)
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):
x, y = reducer.arguments
assert select_prim is lax.ge_p or select_prim is lax.le_p
which = "GE" if select_prim is lax.ge_p else "LE"
out = mhlo.SelectOp(mlir.compare_mhlo(fst(x), fst(y), which), x, y)
mhlo.ReturnOp(out)
return [snd(rw.result)]

# TODO(phawkins): use this translation rule on all platforms.
def _select_and_gather_add_using_variadic_reducewindow(
Expand Down Expand Up @@ -822,11 +843,7 @@ def _select_and_gather_add_batching_rule(
multiple_results=False))

# TODO(b/183233858): use variadic reducewindow on GPU, when implemented.
xla.register_translation(
select_and_gather_add_p,
_select_and_gather_add_translation,
platform='gpu')
mlir.register_lowering(
select_and_gather_add_p,
mlir.xla_fallback_lowering(select_and_gather_add_p),
_select_and_gather_add_lowering,
platform="gpu")
12 changes: 10 additions & 2 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,10 +1047,18 @@ def add_jaxvals_lowering(ctx, x, y):
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])


def compare_mhlo(x, y, direction, type):
def compare_mhlo(x, y, direction: str, comparison_type: Optional[str] = None):
"""Creates mhlo.CompareOp."""
if comparison_type is None:
elem_type = ir.RankedTensorType(x.type).element_type
if ir.IntegerType.isinstance(elem_type):
comparison_type = ("UNSIGNED" if ir.IntegerType.is_unsigned(elem_type)
else "SIGNED")
else:
comparison_type = "FLOAT"

return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
mhlo.ComparisonTypeAttr.get(comparison_type))

def _minmax_mhlo(op, cmp, x, y):
"""Min/max that compares complex values lexicographically as pairs."""
Expand Down

0 comments on commit e94221b

Please sign in to comment.