Skip to content

Commit

Permalink
[MHLO] Add MHLO lowering for PRNG kernels.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 439919104
  • Loading branch information
hawkinsp authored and jax authors committed Apr 6, 2022
1 parent b9bb613 commit 3bfa6af
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 5 deletions.
36 changes: 33 additions & 3 deletions jax/_src/prng.py
Expand Up @@ -26,11 +26,14 @@
from jax.config import config
from jax.dtypes import float0
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.api import jit, vmap
from jax._src.lax import lax as lax_internal
import jax._src.lib
from jax._src.lib import xla_client
from jax._src.lib import cuda_prng
from jax._src.lib.mlir.dialects import mhlo
from jax._src.numpy.lax_numpy import (
_canonicalize_tuple_index, _eliminate_deprecated_list_indexing,
_expand_bool_indices, _register_stackable)
Expand Down Expand Up @@ -405,6 +408,26 @@ def _broadcast(x, aval):
ctx.builder, (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))))

def _threefry2x32_gpu_lowering(ctx, k1, k2, x1, x2):
aval_out, _ = ctx.avals_out
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
rank = len(aval_out.shape)
if 0 in aval_out.shape:
zeros = mlir.full_like_aval(0, aval_out)
return [zeros, zeros]
def _broadcast(x, aval):
return mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(range(rank - len(aval.shape), rank))).result
if cuda_prng:
return cuda_prng.threefry2x32_lowering(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
else:
return hip_prng.threefry2x32_lowering(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))


threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
Expand All @@ -417,11 +440,18 @@ def _broadcast(x, aval):
xla.register_translation(threefry2x32_p, xla.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
multiple_results=True, new_style=True), platform='cpu')
if cuda_prng:
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True))
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
multiple_results=True), platform='cpu')

if cuda_prng or hip_prng:
xla.register_translation(threefry2x32_p, _threefry2x32_gpu_translation_rule,
platform='gpu')
if hip_prng:
xla.register_translation(threefry2x32_p, _threefry2x32_gpu_translation_rule,
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(threefry2x32_p, _threefry2x32_gpu_lowering,
platform='gpu')

@partial(jit, inline=True)
Expand Down
37 changes: 37 additions & 0 deletions jaxlib/cuda_prng.py
Expand Up @@ -17,6 +17,9 @@
import itertools
import operator

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo

import numpy as np

from jaxlib import xla_client
Expand Down Expand Up @@ -53,3 +56,37 @@ def threefry2x32(c, keys, data):
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)



def threefry2x32_lowering(keys, data):
"""ThreeFry2x32 kernel for GPU."""
assert len(keys) == 2, keys
assert len(data) == 2, data
assert (ir.RankedTensorType(keys[0].type).element_type ==
ir.IntegerType.get_unsigned(32)), keys[0].type
typ = keys[0].type
dims = ir.RankedTensorType(typ).shape

for x in itertools.chain(keys, data):
assert x.type == typ, (x.type, typ)
ndims = len(dims)

opaque = _cuda_prng.cuda_threefry2x32_descriptor(_prod(dims))
layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1),
type=ir.IndexType.get())
i32_type = ir.IntegerType.get_signless(32)
tup = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([typ, typ])],
[keys[0], keys[1], data[0], data[1]],
call_target_name = ir.StringAttr.get("cuda_threefry2x32"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([layout] * 4),
result_layouts=ir.ArrayAttr.get([layout] * 2)).result
return [
mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, i)).result
for i in range(2)
]
2 changes: 1 addition & 1 deletion jaxlib/cusparse.py
Expand Up @@ -17,7 +17,7 @@

import numpy as np

from jax._src.lib import xla_client
from jaxlib import xla_client

try:
from . import _cusparse
Expand Down
35 changes: 35 additions & 0 deletions jaxlib/hip_prng.py
Expand Up @@ -16,6 +16,9 @@
import itertools
import operator

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo

import numpy as np

from jaxlib import xla_client
Expand Down Expand Up @@ -54,3 +57,35 @@ def threefry2x32(c, keys, data):
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion.
API_VERSION_STATUS_RETURNING)

def threefry2x32_lowering(keys, data):
"""ThreeFry2x32 kernel for GPU."""
assert len(keys) == 2, keys
assert len(data) == 2, data
assert (ir.RankedTensorType(keys[0].type).element_type ==
ir.IntegerType.get_unsigned(32)), keys[0].type
typ = keys[0].type
dims = ir.RankedTensorType(typ).shape

for x in itertools.chain(keys, data):
assert x.type == typ, (x.type, typ)
ndims = len(dims)

opaque = _hip_prng.cuda_threefry2x32_descriptor(_prod(dims))
layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1),
type=ir.IndexType.get())
i32_type = ir.IntegerType.get_signless(32)
tup = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([typ, typ])],
[keys[0], keys[1], data[0], data[1]],
call_target_name = ir.StringAttr.get("hip_threefry2x32"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([layout] * 4),
result_layouts=ir.ArrayAttr.get([layout] * 2)).result
return [
mhlo.GetTupleElementOp(typ, tup, ir.IntegerAttr.get(i32_type, i)).result
for i in range(2)
]
2 changes: 1 addition & 1 deletion jaxlib/hipsparse.py
Expand Up @@ -17,7 +17,7 @@

import numpy as np

from jax._src.lib import xla_client
from jaxlib import xla_client

try:
from . import _hipsparse
Expand Down

0 comments on commit 3bfa6af

Please sign in to comment.