Skip to content

Commit

Permalink
jax.scipy.linalg.schur: error on 16-bit floats
Browse files Browse the repository at this point in the history
Fixes #10530

PiperOrigin-RevId: 446279906
  • Loading branch information
Jake VanderPlas authored and jax authors committed May 3, 2022
1 parent 37ea024 commit c6343dd
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 22 deletions.
25 changes: 13 additions & 12 deletions jax/_src/lax/linalg.py
Expand Up @@ -1607,18 +1607,19 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
operand_aval, = ctx.avals_in
batch_dims = operand_aval.shape[:-2]

if sort_eig_vals:
T, vs, _sdim, info = lapack.gees_mhlo(
operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
else:
T, vs, info = lapack.gees_mhlo(
operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
# TODO(jakevdp): remove this try/except when minimum jaxlib >= 0.3.8
try:
gees_result = lapack.gees_mhlo(operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
except TypeError: # API for jaxlib <= 0.3.7
gees_result = lapack.gees_mhlo(operand, # pytype: disable=missing-parameter
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
# Number of return values depends on value of sort_eig_vals.
T, vs, *_, info = gees_result

ok = mlir.compare_mhlo(
info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))),
Expand Down
26 changes: 16 additions & 10 deletions jaxlib/lapack.py
Expand Up @@ -672,7 +672,7 @@ def geev_mhlo(dtype, a, jobvl=True, jobvr=True):

# # gees : Schur factorization

def gees_mhlo(a, jobvs=True, sort=False, select=None):
def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
a_type = ir.RankedTensorType(a.type)
etype = a_type.element_type
dims = a_type.shape
Expand All @@ -695,22 +695,28 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None):
jobvs = ord('V' if jobvs else 'N')
sort = ord('S' if sort else 'N')

if not ir.ComplexType.isinstance(etype):
fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees"
schurvecs_type = etype
workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)]
if dtype == np.float32:
fn = "lapack_sgees"
elif dtype == np.float64:
fn = "lapack_dgees"
elif dtype == np.complex64:
fn = "lapack_cgees"
elif dtype == np.complex128:
fn = "lapack_zgees"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

if not np.issubdtype(dtype, np.complexfloating):
workspaces = [ir.RankedTensorType.get(dims, etype)]
workspace_layouts = [layout]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2
eigvals_layouts = [
ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
type=ir.IndexType.get())
] * 2
else:
fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get())
else "lapack_zgees")
schurvecs_type = etype
workspaces = [
ir.RankedTensorType.get(dims, schurvecs_type),
ir.RankedTensorType.get(dims, etype),
ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type),
]
workspace_layouts = [
Expand All @@ -729,7 +735,7 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None):
type=ir.IndexType.get())
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple(workspaces + eigvals + [
ir.RankedTensorType.get(dims, schurvecs_type),
ir.RankedTensorType.get(dims, etype),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get(batch_dims, i32_type),
])],
Expand Down
15 changes: 15 additions & 0 deletions tests/linalg_test.py
Expand Up @@ -41,6 +41,8 @@
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex

jaxlib_version = tuple(map(int, jax.lib.__version__.split('.')))


class NumpyLinalgTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -719,6 +721,19 @@ def compare_orthogonal(q1, q2):
qr = partial(jnp.linalg.qr, mode=mode)
jtu.check_jvp(qr, partial(jvp, qr), (a,), atol=3e-3)

@unittest.skipIf(jaxlib_version < (0, 3, 8), "test requires jaxlib>=0.3.8")
@jtu.skip_on_devices("tpu")
def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16):
# Regression test for https://github.com/google/jax/issues/10530
rng = jtu.rand_default(self.rng())
arr = rng(shape, dtype)
if jtu.device_under_test() == 'cpu':
err, msg = NotImplementedError, "Unsupported dtype float16"
else:
err, msg = ValueError, r"Unsupported dtype dtype\('float16'\)"
with self.assertRaisesRegex(err, msg):
jnp.linalg.qr(arr)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
Expand Down

0 comments on commit c6343dd

Please sign in to comment.