Skip to content

Commit

Permalink
add some simple iree tests
Browse files Browse the repository at this point in the history
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):

```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
  • Loading branch information
mattjj committed Apr 14, 2022
1 parent 86c8446 commit d21b958
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 21 deletions.
19 changes: 14 additions & 5 deletions jax/_src/dispatch.py
Expand Up @@ -228,7 +228,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)

if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and
backend.platform != 'iree'):
not _backend_supports_unbounded_dynamic_shapes(backend)):
jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)

# Computations that only produce constants and/or only rearrange their inputs,
Expand Down Expand Up @@ -281,6 +281,10 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)


def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool:
return backend.platform == 'iree'


def prefetch(x):
if isinstance(x, device_array.DeviceArray):
x.copy_to_host_async()
Expand Down Expand Up @@ -408,12 +412,14 @@ def aval_to_num_buffers(aval: core.AbstractValue) -> int:
num_buffers_handlers[core.ConcreteArray] = lambda _: 1


def _input_handler(which_explicit: Optional[Sequence[bool]],
def _input_handler(backend: Backend,
which_explicit: Optional[Sequence[bool]],
in_avals: Sequence[core.AbstractValue]
) -> Optional[Callable]:
# Extract implicit inputs, and pad bounded-size inputs to their max size.
needs_implicit = which_explicit and not all(which_explicit)
needs_padding = any(type(in_avals[d.val]) is core.AbstractBInt # type: ignore
needs_padding = any(backend.platform != 'iree' and
type(in_avals[d.val]) is core.AbstractBInt # type: ignore
for a in in_avals if type(a) is core.DShapedArray
for d in a.shape if type(d) is pe.DBIdx)

Expand Down Expand Up @@ -448,13 +454,16 @@ def elaborate_and_pad(explicit_args):
explicit_args_ = iter(explicit_args)
args = [next(explicit_args_) if ex else None for ex in which_explicit]
assert next(explicit_args_, None) is None
assert needs_implicit
for i, j, k in implicit_args_from_axes:
if args[i] is None:
args[i] = args[j].shape[k] # type: ignore
else:
if args[i] != args[j].shape[k]:
raise Exception("inconsistent argument axis sizes for type")
return tuple([pad(x) if pad else x for x, pad in zip(args, padders)])
if needs_padding:
args = tuple([pad(x) if pad else x for x, pad in zip(args, padders)])
return args
return elaborate_and_pad

def _pad_arg(shape, x):
Expand Down Expand Up @@ -686,7 +695,7 @@ def from_xla_computation(
out_avals: Sequence[core.AbstractValue],
kept_var_idx: Set[int]) -> XlaCompiledComputation:
sticky_device = device
input_handler = _input_handler(explicit_args, in_avals)
input_handler = _input_handler(backend, explicit_args, in_avals)
result_handlers = map(partial(aval_to_result_handler, sticky_device),
out_avals)
options = xb.get_compile_options(
Expand Down
21 changes: 17 additions & 4 deletions jax/_src/iree.py
Expand Up @@ -20,7 +20,9 @@

# pytype: skip-file

from typing import Any, List, Sequence
from __future__ import annotations

from typing import Any, List, Sequence, Optional

import iree.compiler
from iree import runtime as iree_runtime
Expand Down Expand Up @@ -48,7 +50,7 @@ def transfer_to_infeed(self, literal: Any):
def transfer_from_outfeed(self, shape: xla_client.Shape):
raise NotImplementedError("transfer_to_outfeed")

def live_buffers(self) -> List['IreeBuffer']:
def live_buffers(self) -> List[IreeBuffer]:
raise NotImplementedError("live_buffers")


Expand All @@ -57,6 +59,7 @@ class IreeBuffer(xla_client.DeviceArrayBase):
def __init__(self, client, device, npy_value):
self.client = client
self._device = device
assert device is not None
self._npy_value = np.asarray(npy_value)

def copy_to_device(self, device):
Expand All @@ -74,6 +77,9 @@ def platform(self):
def device(self):
return self._device

def block_until_ready(self) -> IreeBuffer:
return self # no async

class IreeExecutable:

def __init__(self, client, devices, module_object, function_name):
Expand Down Expand Up @@ -136,8 +142,12 @@ def get_default_device_assignment(

def compile(self, computation: str,
compile_options: xla_client.CompileOptions) -> IreeExecutable:
del compile_options # Ignored.
iree_binary = iree.compiler.compile_str(
computation, target_backends=["dylib"], input_type="mhlo")
computation, target_backends=["dylib"], input_type="mhlo",
# extra_args=["--print-ir-after-all"],
# extended_diagnostics=True,
)
# Load it into the runtime.
vm_module = iree_runtime.binding.VmModule.from_flatbuffer(iree_binary)
module_object = iree_runtime.load_vm_module(vm_module, self.iree_config)
Expand All @@ -146,13 +156,16 @@ def compile(self, computation: str,
def buffer_from_pyval(
self,
argument: Any,
device: IreeDevice,
device: Optional[IreeDevice],
force_copy: bool = True,
host_buffer_semantics: xla_client.HostBufferSemantics = xla_client
.HostBufferSemantics.ZERO_COPY
) -> IreeBuffer:
# TODO(phawkins): IREE's python API will accept a numpy array directly but
# may want to explicitly construct a lower level BufferView to avoid copies.
if device is None:
assert type(argument) is np.ndarray
device = self._devices[0]
return IreeBuffer(self, device, np.array(argument, copy=True))


Expand Down
23 changes: 17 additions & 6 deletions jax/_src/lax/lax.py
Expand Up @@ -2784,13 +2784,23 @@ def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape,
pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule

def _broadcast_in_dim_lower(ctx, x, *, shape, broadcast_dimensions):
del shape
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions):
aval_out, = ctx.avals_out
return mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(broadcast_dimensions)
).results
if dyn_shape:
dyn_shape = iter(dyn_shape)
shape = [next(dyn_shape) if d is None else d for d in shape]
assert next(dyn_shape, None) is None
return mhlo.DynamicBroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.shape_tensor(shape),
mlir.dense_int_elements(broadcast_dimensions),
None, None,
).results
else:
return mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(broadcast_dimensions)
).results
mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower)


Expand Down Expand Up @@ -4485,6 +4495,7 @@ def bint_abstract_eval(_, *, bd: int):
return core.AbstractBInt(bound=bd)

pe.padding_rules[bint_p] = lambda _, __, i, bd: [i]
mlir.register_lowering(bint_p, lambda ctx, x, bd: [x])


### util
Expand Down
9 changes: 8 additions & 1 deletion jax/_src/lax/slicing.py
Expand Up @@ -812,7 +812,6 @@ def _slice_masking_rule(
masking.masking_rules[slice_p] = _slice_masking_rule

def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
aval_out, = ctx.avals_out
strides = strides or [1] * len(start_indices)
return mhlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
Expand Down Expand Up @@ -2142,3 +2141,11 @@ def _getslice_padding_rule(in_avals, out_avals, x, lo, hi):
xx = lax.concatenate([x, x], 0)
return [dynamic_slice_in_dim(xx, lo, x.shape[0])]
pe.padding_rules[getslice_p] = _getslice_padding_rule

def _getslice_lower(ctx, x, lo, hi):
aval_out, = ctx.avals_out
return mhlo.RealDynamicSliceOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.shape_tensor([lo]), mlir.shape_tensor([hi]), mlir.shape_tensor([1])
).results
mlir.register_lowering(getslice_p, _getslice_lower)
15 changes: 12 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -41,7 +41,7 @@
from jax import core
from jax import errors
from jax import lax
from jax.core import ShapedArray, DShapedArray, ConcreteArray, canonicalize_shape
from jax.core import ShapedArray, DShapedArray, ConcreteArray
from jax.interpreters import pxla
from jax.tree_util import tree_leaves, tree_flatten, tree_map

Expand Down Expand Up @@ -81,6 +81,15 @@

newaxis = None

# Like core.canonicalize_shape, but also accept int-like (non-sequence)
# arguments for `shape`.
def canonicalize_shape(
shape: Union[core.Shape, int, core.Tracer], context: str="") -> core.Shape:
if isinstance(shape, core.Tracer) or ndim(shape) == 0:
return core.canonicalize_shape((shape,), context)
else:
return core.canonicalize_shape(shape, context) # type: ignore

# Common docstring additions:

_PRECISION_DOC = """\
Expand Down Expand Up @@ -1923,15 +1932,15 @@ def zeros(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax_internal._check_user_dtype_supported(dtype, "zeros")
shape = canonicalize_shape((shape,) if ndim(shape) == 0 else shape)
shape = canonicalize_shape(shape)
return lax.full(shape, 0, _jnp_dtype(dtype))

@_wraps(np.ones)
def ones(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
shape = canonicalize_shape(shape)
lax_internal._check_user_dtype_supported(dtype, "ones")
shape = canonicalize_shape((shape,) if ndim(shape) == 0 else shape)
return lax.full(shape, 1, _jnp_dtype(dtype))


Expand Down
11 changes: 9 additions & 2 deletions jax/interpreters/mlir.py
Expand Up @@ -83,6 +83,14 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)

def shape_tensor(sizes: Sequence[Union[int, ir.RankedTensorType]]
) -> ir.RankedTensorType:
sizes = [ir_constant(np.array(d, np.dtype('int32'))) if type(d) is int else d
for d in sizes]
int1d = aval_to_ir_type(core.ShapedArray((1,), np.dtype('int32')))
return mhlo.ConcatenateOp([mhlo.ReshapeOp(int1d, d) for d in sizes],
i64_attr(0)).results


# IR Types

Expand Down Expand Up @@ -666,8 +674,7 @@ def aval_to_types(aval):
if not use_sharding_annotations and ir_arg_shardings is not None:
flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings)

unflattened_args = util.unflatten(flat_args,
map(len, input_types))
unflattened_args = util.unflatten(flat_args, map(len, input_types))
args: List[List[ir.Value]] = []
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
if replace_units_with_dummy and aval is core.abstract_unit:
Expand Down
25 changes: 25 additions & 0 deletions tests/api_test.py
Expand Up @@ -8116,21 +8116,46 @@ def test_jit_abstracted_axes_return_polymorphic_shape2(self):
self.assertIsInstance(three_, int)
self.assertEqual(three_, 3)

@unittest.skipIf(jtu.device_under_test() != 'iree', "iree test")
def test_jit_basic_iree(self):
if not jtu.device_under_test() == 'iree':
raise unittest.SkipTest("test only works on IREE")

@jax.jit
def f(i):
return jnp.sum(jnp.ones(i, dtype='float32'))

self.assertAllClose(f(3), jnp.array(3., dtype='float32'), check_dtypes=True)

@unittest.skipIf(jtu.device_under_test() != 'iree', "iree test")
def test_jit_basic_iree_2(self):
count = 0

@partial(jax.jit, abstracted_axes=('n',))
def f(x):
nonlocal count
count += 1
return jnp.sum(x)

x = f(jnp.arange(3))
y = f(jnp.arange(4))
self.assertAllClose(x, 3., check_dtypes=False)
self.assertAllClose(y, 6., check_dtypes=False)
self.assertEqual(count, 1)

# TODO(mattjj,dougalm,phawkins): debug iree failure, "'arith.subi' op requires
# the same type for all operands and results"
# https://github.com/google/iree/issues/8881
@jtu.skip_on_devices('iree')
def test_slicing_basic(self):
f = jax.jit(lambda x, n: jnp.sum(x[:n]))
ans = f(jnp.arange(10), 3)
expected = jnp.sum(jnp.arange(10)[:3])
self.assertAllClose(ans, expected, check_dtypes=True)

# TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize
# operation 'mhlo.while' that was explicitly marked illegal"
@jtu.skip_on_devices('iree')
def test_scan_basic(self):
def cumsum(x):
def body(i, _):
Expand Down

0 comments on commit d21b958

Please sign in to comment.