Skip to content

Commit

Permalink
Make lax.linalg.qr robust to zero-dimensional inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 29, 2022
1 parent b81f57b commit f6dca14
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 24 deletions.
26 changes: 21 additions & 5 deletions jax/_src/lax/linalg.py
Expand Up @@ -1150,18 +1150,22 @@ def qr_impl(operand, full_matrices):
return q, r

def _qr_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices):
operand_aval, = avals_in
shape = operand_aval.shape
m, n = shape[-2:]
if m == 0 or n == 0:
return [_eye_like_xla(ctx.builder, avals_out[0]),
_zeros_like_xla(ctx.builder, avals_out[1])]
return xops.QR(operand, full_matrices)

def qr_abstract_eval(operand, full_matrices):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
batch_dims = operand.shape[:-2]
m = operand.shape[-2]
n = operand.shape[-1]
*batch_dims, m, n = operand.shape
k = m if full_matrices else min(m, n)
q = operand.update(shape=batch_dims + (m, k))
r = operand.update(shape=batch_dims + (k, n))
q = operand.update(shape=(*batch_dims, m, k))
r = operand.update(shape=(*batch_dims, k, n))
else:
q = operand
r = operand
Expand Down Expand Up @@ -1193,13 +1197,25 @@ def qr_batching_rule(batched_args, batch_dims, full_matrices):
x = batching.moveaxis(x, bd, 0)
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)

def _empty_qr(a, *, full_matrices):
*batch_shape, m, n = a.shape
k = m if full_matrices else min(m, n)
q = jnp.broadcast_to(jnp.eye(m, k, dtype=a.dtype), (*batch_shape, m, k))
r = jnp.empty((*batch_shape, k, n), dtype=a.dtype)
return [q, r]

def _qr_cpu_gpu_lowering(geqrf_impl, orgqr_impl, ctx, operand, *,
full_matrices):
operand_aval, = ctx.avals_in
q_aval, r_aval = ctx.avals_out
dims = operand_aval.shape
m, n = dims[-2:]
batch_dims = dims[:-2]

if m == 0 or n == 0:
return mlir.lower_fun(_empty_qr, multiple_results=True)(
ctx, operand, full_matrices=full_matrices)

r, tau, info_geqrf = geqrf_impl(operand_aval.dtype, operand)
if m < n:
q = mhlo.SliceOp(r,
Expand Down
40 changes: 21 additions & 19 deletions tests/linalg_test.py
Expand Up @@ -654,22 +654,7 @@ def testJspSVDBasic(self):
{"testcase_name": "_shape={}_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode),
"shape": shape, "dtype": dtype, "mode": mode}
for shape in [(3, 4), (3, 3), (4, 3)]
for dtype in [np.float32]
for mode in ["full", "r", "economic"]))
def testScipyQrModes(self, shape, dtype, mode):
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.qr, mode=mode)
sp_func = partial(scipy.linalg.qr, mode=mode)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
self._CompileAndCheck(jsp_func, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode),
"shape": shape, "dtype": dtype, "mode": mode}
for shape in [(3, 4), (3, 3), (4, 3)]
for shape in [(0, 2), (2, 0), (3, 4), (3, 3), (4, 3)]
for dtype in [np.float32]
for mode in ["reduced", "r", "full", "complete"]))
def testNumpyQrModes(self, shape, dtype, mode):
Expand All @@ -686,7 +671,7 @@ def testNumpyQrModes(self, shape, dtype, mode):
{"testcase_name": "_shape={}_fullmatrices={}".format(
jtu.format_shape_dtype_string(shape, dtype), full_matrices),
"shape": shape, "dtype": dtype, "full_matrices": full_matrices}
for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]
for shape in [(0, 0), (2, 0), (0, 2), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]
for dtype in float_types + complex_types
for full_matrices in [False, True]))
def testQr(self, shape, dtype, full_matrices):
Expand All @@ -713,11 +698,12 @@ def testQr(self, shape, dtype, full_matrices):
# Norm, adjusted for dimension and type.
def norm(x):
n = np.linalg.norm(x, axis=(-2, -1))
return n / (max_rank * jnp.finfo(dtype).eps)
return n / (max(1, max_rank) * jnp.finfo(dtype).eps)

def compare_orthogonal(q1, q2):
# Q is unique up to sign, so normalize the sign first.
sum_of_ratios = np.sum(np.divide(q1, q2), axis=-2, keepdims=True)
ratio = np.divide(np.where(q2 == 0, 0, q1), np.where(q2 == 0, 1, q2))
sum_of_ratios = ratio.sum(axis=-2, keepdims=True)
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
q1 *= phases
self.assertTrue(np.all(norm(q1 - q2) < 30))
Expand Down Expand Up @@ -1334,6 +1320,22 @@ def testExpm(self, n, dtype):
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu)
self._CompileAndCheck(jsp_fun_triu, args_maker_triu)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode),
"shape": shape, "dtype": dtype, "mode": mode}
# Skip empty shapes because scipy fails: https://github.com/scipy/scipy/issues/1532
for shape in [(3, 4), (3, 3), (4, 3)]
for dtype in [np.float32]
for mode in ["full", "r", "economic"]))
def testScipyQrModes(self, shape, dtype, mode):
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.qr, mode=mode)
sp_func = partial(scipy.linalg.qr, mode=mode)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
self._CompileAndCheck(jsp_func, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
Expand Down

0 comments on commit f6dca14

Please sign in to comment.