Skip to content

Commit

Permalink
[shape_poly] linalg.eig: shape polymorphism with native serialization…
Browse files Browse the repository at this point in the history
… on CPU

The backwards compatibility tests to be added separately.

PiperOrigin-RevId: 541122069
  • Loading branch information
gnecula authored and jax authors committed Jun 17, 2023
1 parent 68a38c6 commit 3adfe32
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 56 deletions.
15 changes: 14 additions & 1 deletion jax/_src/interpreters/mlir.py
Expand Up @@ -553,7 +553,6 @@ def sharded_aval(aval: core.AbstractValue,

def eval_dynamic_shape(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
# assert not core.is_constant_shape(shape)
if config.jax_dynamic_shapes:
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
else:
Expand All @@ -565,6 +564,20 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
multiple_results=True)(ctx, *ctx.dim_var_values)
return util.flatten(res) # type: ignore

def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Value, ...]:
"""Evaluates the dynamic shapes as int32 values."""
def convert_dim(d: Union[int, Value]):
if type(d) is int:
return ir_constant(np.array(d, dtype=np.int32))
else:
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
if d.type != i32_type: # type: ignore
return hlo.ConvertOp(i32_type, d).result
else:
return d
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))


class LoweringResult(NamedTuple):
module: ir.Module
Expand Down
20 changes: 14 additions & 6 deletions jax/_src/lax/linalg.py
Expand Up @@ -487,15 +487,23 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors,

def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
compute_right_eigenvectors):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (eig); b/261671778")
operand_aval, = ctx.avals_in
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]

w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand,
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
if jaxlib_version < (0, 4, 13):
if any(not is_constant_shape(a.shape) for a in ctx.avals_in):
raise NotImplementedError(
"Shape polymorphism for eig is not implemented. "
"Try upgrading jaxlib")
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand, # type: ignore
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
else:
op_shape_vals = mlir.eval_dynamic_shape_as_vals(ctx, operand_aval.shape)
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand,
input_shape_vals=op_shape_vals,
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)

ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/jax_export.py
Expand Up @@ -694,6 +694,8 @@ def _check_lowering(lowering) -> None:
"cusolver_syevj", "cusolver_syevd",
# eigh on TPU
"Eigh",
# eig on CPU
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev",
# qr on CPU
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
Expand Down
3 changes: 3 additions & 0 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Expand Up @@ -412,6 +412,9 @@ def test_custom_call_coverage(self):
self.assertIsInstance(data, CompatTestData)
covered_targets = covered_targets.union(data.custom_call_targets)

# TODO(necula): add tests for eig on CPU
covered_targets = covered_targets.union({
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev"})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered)

Expand Down
23 changes: 21 additions & 2 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Expand Up @@ -2021,7 +2021,21 @@ def f_jax(operand, start_indices, x):
# x:shape: (b, 4)
lambda x, idx: lax.dynamic_update_slice(x, x, idx),
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, _", None]).both_enable_and_disable_xla(),
[
PolyHarness("eig", f"shape={jtu.format_shape_dtype_string((3, 5, 5), dtype)}_poly={poly}_{left=}_{right=}",
lambda x, left, right: lax.linalg.eig(x, compute_left_eigenvectors=left, compute_right_eigenvectors=right),
arg_descriptors=[RandArg((3, 5, 5), dtype),
StaticArg(left), StaticArg(right)],
polymorphic_shapes=[poly],
# In non-native serialization, we cannot check exact match,
# we ought to check the invariants of the result.
check_result=config.jax2tf_default_native_serialization)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
for poly in ["b, ...", "b, w, w"]
for left in ([True, False] if dtype == np.float32 else [True])
for right in ([True, False] if dtype == np.float32 else [False])
],
PolyHarness("einsum", "0",
lambda x: jnp.einsum("...i->...", x),
arg_descriptors=[RandArg((3, 4), _f32)],
Expand Down Expand Up @@ -2728,7 +2742,6 @@ def test_harness(self, harness: PolyHarness):
# Set of harness.group_name:platform that are implemented with custom call
custom_call_harnesses = {
"vmap_cholesky:cpu", "vmap_cholesky:gpu",
"vmap_eig:cpu",
"vmap_fft:cpu", "fft:cpu",
"householder_product:cpu", "householder_product:gpu",
"vmap_geqrf:cpu", "vmap_geqrf:gpu",
Expand Down Expand Up @@ -2795,11 +2808,17 @@ def test_harness(self, harness: PolyHarness):
# For non-native serialization the overflow behavior is different.
harness.check_result = False

if harness.group_name == "eig" and "left=True_right=True" in harness.fullname:
raise unittest.SkipTest("jax2tf graph serialization does not support both left and right.")

# FOR BOTH NATIVE AND GRAPH SERIALIZATION
if harness.group_name == "vmap_conv_general_dilated":
# https://github.com/openxla/stablehlo/issues/1268
raise unittest.SkipTest("Need more dynamism for DynamicConvOp")

if harness.group_name == "eig" and jtu.device_under_test() != "cpu":
raise unittest.SkipTest("JAX implements eig only on CPU.")

prev_jax_config_flags = {
fname: getattr(jax.config, fname)
for fname, fvalue in harness.override_jax_config_flags.items()
Expand Down
63 changes: 59 additions & 4 deletions jaxlib/hlo_helpers.py
Expand Up @@ -12,13 +12,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Helpers for building MLIR operators
from typing import Dict, Optional, Sequence, Union
"""A small libary of helpers for use in jaxlib to build MLIR operations."""
from functools import partial
from typing import Callable, Dict, Optional, Sequence, Union

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np


_dtype_to_ir_type_factory : Dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
np.dtype(np.float16): ir.F16Type.get,
np.dtype(np.float32): ir.F32Type.get,
np.dtype(np.float64): ir.F64Type.get,
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}
def dtype_to_ir_type(dtype) -> ir.Type:
return _dtype_to_ir_type_factory[np.dtype(dtype)]()

def ir_constant(x: np.ndarray) -> ir.Value:
assert isinstance(x, np.ndarray)
return hlo.ConstantOp(
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype))).result

def ir_constant_u8(x: int): return ir_constant(np.array(x, dtype=np.uint8))
def ir_constant_i32(x: int): return ir_constant(np.array(x, dtype=np.int32))

def shape_dtype_to_ir_type(shape: Sequence[int], dtype) -> ir.Type:
return ir.RankedTensorType.get(shape, dtype_to_ir_type(dtype))


# TODO(necula): share this with mlir.shape_tensor
def shape_tensor(sizes: Sequence[Union[int, ir.Value]]) -> ir.Value:
int1d = shape_dtype_to_ir_type((1,), np.int32)
i32_type = shape_dtype_to_ir_type((), np.int32)
def dim_to_i32x1(d):
if type(d) is int:
return ir_constant(np.array([d], dtype=np.int32))
else:
if d.type != i32_type:
d = hlo.ConvertOp(i32_type, d).result
return hlo.ReshapeOp(int1d, d).result
ds = [dim_to_i32x1(sz) for sz in sizes]
if not ds:
return ir_constant(np.array([], np.int32))
elif len(ds) == 1:
return ds[0]
else:
return hlo.ConcatenateOp(
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0)).result


# TODO(necula): share this with mlir.custom_call
def custom_call(
call_target_name: Union[str, bytes],
out_types: Sequence[ir.Type],
Expand All @@ -42,12 +97,12 @@ def custom_call(
match the number of the results. They are appended to the list
of operands.
"""
i32_type = ir.IntegerType.get_signless(32)
attributes = dict(
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
backend_config=ir.StringAttr.get(backend_config),
api_version=ir.IntegerAttr.get(i32_type, api_version),
api_version=ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), api_version),
called_computations=ir.ArrayAttr.get([]),
output_operand_aliases=ir.ArrayAttr.get([
hlo.OutputOperandAlias.get(
Expand Down
131 changes: 88 additions & 43 deletions jaxlib/lapack.py
Expand Up @@ -19,9 +19,13 @@
import jaxlib.mlir.dialects.stablehlo as hlo

import numpy as np
from typing import Tuple
from jaxlib import xla_client

from .hlo_helpers import custom_call
from .hlo_helpers import (
custom_call, ir_constant_u8, ir_constant_i32,
shape_tensor
)
from .cpu import _lapack

for _name, _value in _lapack.registrations().items():
Expand Down Expand Up @@ -477,85 +481,126 @@ def mk_constant_shape_tensor(ranked_type: ir.RankedTensorType) -> ir.Value:
return out[:3]


# # geev: Nonsymmetric eigendecomposition
# # geev: Nonsymmetric eigendecomposition (eig)

def geev_hlo(dtype, a, jobvl=True, jobvr=True):
def geev_hlo(dtype, input, *,
input_shape_vals: Tuple[ir.Value, ...], # input.shape as ir.Values
jobvl=True, jobvr=True):
# input_shape_vals are used for when input has dynamic shapes.
_initialize()
dims = ir.RankedTensorType(a.type).shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
input_shape = ir.RankedTensorType(input.type).shape
assert len(input_shape) >= 2
n = input_shape[-1]
n_val: ir.Value = input_shape_vals[-1]
batch_dims = tuple(input_shape[:-2])
batch_dims_vals = input_shape_vals[:-2]
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d

layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

jobvl_c = ord('V' if jobvl else 'N')
jobvr_c = ord('V' if jobvr else 'N')

i32_type = ir.IntegerType.get_signless(32)
f32_type = ir.F32Type.get()
f64_type = ir.F64Type.get()
c64_type = ir.ComplexType.get(ir.F32Type.get())
c128_type = ir.ComplexType.get(ir.F64Type.get())

if n == ir.ShapedType.get_dynamic_size():
two_n = ir.ShapedType.get_dynamic_size()
else:
two_n = n + n
if dtype == np.float32:
fn = b"lapack_sgeev"
real = True
eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
workspaces = [ir.RankedTensorType.get([n, n], ir.F32Type.get()),
ir.RankedTensorType.get([n, n], ir.F32Type.get()),
ir.RankedTensorType.get([n, n], ir.F32Type.get())]
eigvecs_type = c64_type
workspace_types = [ir.RankedTensorType.get([n, n], f32_type)] * 3
workspace_result_shapes = [shape_tensor((n_val, n_val))] * 3
workspace_layouts = [[0, 1]] * 3
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), ir.F32Type.get()),
ir.RankedTensorType.get(batch_dims + (n,), ir.F32Type.get())]
eigval_types = [
ir.RankedTensorType.get(batch_dims + (n,), f32_type)] * 2
eigval_result_shapes = [
shape_tensor(batch_dims_vals + (n_val,))] * 2
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
elif dtype == np.float64:
fn = b"lapack_dgeev"
real = True
eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
workspaces = [ir.RankedTensorType.get([n, n], ir.F64Type.get()),
ir.RankedTensorType.get([n, n], ir.F64Type.get()),
ir.RankedTensorType.get([n, n], ir.F64Type.get())]
eigvecs_type = c128_type
workspace_types = [ir.RankedTensorType.get([n, n], f64_type)] * 3
workspace_result_shapes = [shape_tensor((n_val, n_val))] * 3
workspace_layouts = [[0, 1]] * 3
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), ir.F64Type.get()),
ir.RankedTensorType.get(batch_dims + (n,), ir.F64Type.get())]
eigval_types = [
ir.RankedTensorType.get(batch_dims + (n,), f64_type)] * 2
eigval_result_shapes = [
shape_tensor(batch_dims_vals + (n_val,))] * 2
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
elif dtype == np.complex64:
fn = b"lapack_cgeev"
real = False
eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
workspaces = [ir.RankedTensorType.get([n, n],
ir.ComplexType.get(ir.F32Type.get())),
ir.RankedTensorType.get([2 * n], ir.F32Type.get())]
eigvecs_type = c64_type
workspace_types = [
ir.RankedTensorType.get([n, n], c64_type),
ir.RankedTensorType.get([two_n], f32_type)]
workspace_result_shapes = [
shape_tensor((n_val, n_val)),
shape_tensor((hlo.AddOp(n_val, n_val).result,))]
workspace_layouts = [[0, 1], [0]]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,),
ir.ComplexType.get(ir.F32Type.get()))]
eigval_types = [
ir.RankedTensorType.get(batch_dims + (n,), c64_type)]
eigval_result_shapes = [shape_tensor(batch_dims_vals + (n_val,))]
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
elif dtype == np.complex128:
fn = b"lapack_zgeev"
real = False
eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
workspaces = [ir.RankedTensorType.get([n, n],
ir.ComplexType.get(ir.F64Type.get())),
ir.RankedTensorType.get([2 * n], ir.F64Type.get())]
eigvecs_type = c128_type
workspace_types = [
ir.RankedTensorType.get([n, n], c128_type),
ir.RankedTensorType.get([two_n], f64_type)]
workspace_result_shapes = [
shape_tensor((n_val, n_val)),
shape_tensor((hlo.AddOp(n_val, n_val).result,))]
workspace_layouts = [[0, 1], [0]]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,),
ir.ComplexType.get(ir.F64Type.get()))]
eigval_types = [
ir.RankedTensorType.get(batch_dims + (n,), c128_type)]
eigval_result_shapes = [
shape_tensor(batch_dims_vals + (n_val,))]
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

i32_type = ir.IntegerType.get_signless(32)
scalar_layout = []
info_layout = tuple(range(num_bd - 1, -1, -1))

batch_size_val = ir_constant_i32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, b_v).result

result_types = (
workspace_types + eigval_types + [
ir.RankedTensorType.get(input_shape, eigvecs_type),
ir.RankedTensorType.get(input_shape, eigvecs_type),
ir.RankedTensorType.get(batch_dims, i32_type),
])
if any(a == ir.ShapedType.get_dynamic_size() for a in input_shape):
result_shapes = workspace_result_shapes + eigval_result_shapes + [
shape_tensor(input_shape_vals),
shape_tensor(input_shape_vals),
shape_tensor(batch_dims_vals),
]
else:
result_shapes = None
out = custom_call(
fn,
workspaces + eigvals + [
ir.RankedTensorType.get(dims, eigvecs_type),
ir.RankedTensorType.get(dims, eigvecs_type),
ir.RankedTensorType.get(batch_dims, i32_type),
],
[_hlo_s32(b), _hlo_s32(n), _hlo_u8(jobvl_c), _hlo_u8(jobvr_c), a],
result_types,
[batch_size_val, n_val,
ir_constant_u8(jobvl_c),
ir_constant_u8(jobvr_c),
input],
operand_layouts=[scalar_layout] * 4 + [layout],
result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 +
[info_layout])
[info_layout]),
result_shapes=result_shapes,
)
if real:
return (hlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7])
Expand Down

0 comments on commit 3adfe32

Please sign in to comment.