Skip to content

Commit

Permalink
Decompose lax.linalg.qr into two subprimitives geqrf and orgqr.
Browse files Browse the repository at this point in the history
In essence, this lifts the implementation of QR decomposition out of the lowering rules and into the JAX level instead.

This is useful because it allows direct access to the raw form of the decomposition returned by geqrf; sometimes we actually want access to the Householder reflectors instead of their product. Currently neither geqrf nor orgqr are differentiable in isolation.

Change in preparation for adding an implementation of jnp.linalg.slogdet that uses QR decomposition instead of LU decomposition.

Fixes #2322

PiperOrigin-RevId: 449033350
  • Loading branch information
hawkinsp authored and jax authors committed May 16, 2022
1 parent 744f6b4 commit 909c032
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 89 deletions.
276 changes: 191 additions & 85 deletions jax/_src/lax/linalg.py
Expand Up @@ -38,6 +38,7 @@
_input_dtype)
from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
import jax._src.lib
from jax._src.lib import lapack

from jax._src.lib import gpu_linalg
Expand Down Expand Up @@ -1265,6 +1266,173 @@ def lu_solve(lu, permutation, b, trans=0):

# QR decomposition

# QR decomposition is implemented as a composition of two lower-level primitives
# geqrf and orgqr. The names, while cryptic Fortran alphabet soup, are LAPACK's
# names for the primitives, and we stick with them for consistency.

def geqrf(a):
"""Computes the QR decomposition of a matrix.
Args:
a: an ``[..., m, n]`` batch of matrices, with floating-point or complex type.
Returns:
An ``(a, taus)`` pair where ``r`` is in the upper triangle of ``a``,
``q`` is represented in the lower triangle of ``a`` and in ``taus`` as
elementary Householder reflectors.
"""
a_out, taus = geqrf_p.bind(a)
return a_out, taus

def _geqrf_abstract_eval(operand):
if not isinstance(operand, ShapedArray):
raise NotImplementedError("Unsupported aval in geqrf_abstract_eval: "
f"{operand.aval}")
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = operand.shape
taus = operand.update(shape=(*batch_dims, min(m, n)))
return operand, taus

def _geqrf_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
return geqrf(batching.moveaxis(x, bd, 0)), (0, 0)

def _geqrf_translation_rule(ctx, avals_in, avals_out, operand):
return xops.QrDecomposition(operand)

def _geqrf_cpu_gpu_lowering(geqrf_impl, ctx, a):
a_aval, taus_aval = ctx.avals_out
*batch_dims, m, n = a_aval.shape

if m == 0 or n == 0:
return mlir.full_like_aval(0, a_aval), mlir.full_like_aval(0, taus_aval)

a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED")
ok_a = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get((*batch_dims, 1, 1),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
a_out = _broadcasting_select_mhlo(ok_a, a_out, _nan_like_mhlo(a_aval))
ok_taus = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get((*batch_dims, 1,),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
taus = _broadcasting_select_mhlo(ok_taus, taus, _nan_like_mhlo(taus_aval))
return a_out, taus

geqrf_p = Primitive('geqrf')
geqrf_p.multiple_results = True
geqrf_p.def_impl(partial(xla.apply_primitive, geqrf_p))
geqrf_p.def_abstract_eval(_geqrf_abstract_eval)
batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule
xla.register_translation(geqrf_p, _geqrf_translation_rule)

mlir.register_lowering(
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf),
platform='cuda')
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf),
platform='rocm')

if solver_apis is not None:
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, solver_apis.geqrf_mhlo),
platform='gpu')


# orgqr: product of elementary Householder reflectors

def orgqr(a, taus):
"""Product of elementary Householder reflectors.
Args:
a: A matrix with shape ``[..., m, n]``, whose lower triangle contains
elementary Householder reflectors.
taus: A vector with shape ``[..., k]``, where ``k < min(m, n)``, containing
the scalar factors of the elementary Householder reflectors.
Returns:
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
containing the products of the elementary Householder reflectors.
"""
return orgqr_p.bind(a, taus)


def _orgqr_abstract_eval(a, taus):
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
raise NotImplementedError("Unsupported aval in orgqr_abstract_eval: "
f"{a.aval} {taus.aval}")
if a.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = a.shape
*taus_batch_dims, k = taus.shape
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n):
raise ValueError(f"Type mismatch for orgqr: a={a} taus={taus}")
return a

def _orgqr_batching_rule(batched_args, batch_dims):
a, taus = batched_args
b_a, b_taus, = batch_dims
return orgqr(batching.moveaxis(a, b_a, 0),
batching.moveaxis(taus, b_taus, 0)), (0,)

def _orgqr_translation_rule(ctx, avals_in, avals_out, a, taus):
return [xops.ProductOfElementaryHouseholderReflectors(a, taus)]

def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
a_aval, _ = ctx.avals_in
*batch_dims, m, n = a_aval.shape

if m == 0 or n == 0:
return [mlir.full_like_aval(0, a_aval)]

a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_mhlo(info_orgqr, zeros, "EQ", "SIGNED")
ok = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get((*batch_dims, 1, 1),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
a = _broadcasting_select_mhlo(ok, a, _nan_like_mhlo(a_aval))
return [a]


orgqr_p = Primitive('orgqr')
orgqr_p.def_impl(partial(xla.apply_primitive, orgqr_p))
orgqr_p.def_abstract_eval(_orgqr_abstract_eval)
batching.primitive_batchers[orgqr_p] = _orgqr_batching_rule
xla.register_translation(orgqr_p, _orgqr_translation_rule)

mlir.register_lowering(
orgqr_p, partial(_orgqr_cpu_gpu_lowering, lapack.orgqr_mhlo),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
platform='cuda')
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
platform='rocm')

if solver_apis is not None:
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, solver_apis.orgqr_mhlo),
platform='gpu')


def _qr_impl(operand, *, full_matrices):
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
return q, r
Expand Down Expand Up @@ -1317,105 +1485,43 @@ def _qr_batching_rule(batched_args, batch_dims, *, full_matrices):
x = batching.moveaxis(x, bd, 0)
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)

def _empty_qr(a, *, full_matrices):
*batch_shape, m, n = a.shape
k = m if full_matrices else min(m, n)
q = jnp.broadcast_to(jnp.eye(m, k, dtype=a.dtype), (*batch_shape, m, k))
r = jnp.empty((*batch_shape, k, n), dtype=a.dtype)
return [q, r]

def _qr_cpu_gpu_lowering(geqrf_impl, orgqr_impl, ctx, operand, *,
full_matrices):
operand_aval, = ctx.avals_in
q_aval, r_aval = ctx.avals_out
dims = operand_aval.shape
m, n = dims[-2:]
batch_dims = dims[:-2]

def _qr_lowering(a, *, full_matrices):
*batch_dims, m, n = a.shape
if m == 0 or n == 0:
return mlir.lower_fun(_empty_qr, multiple_results=True)(
ctx, operand, full_matrices=full_matrices)
k = m if full_matrices else min(m, n)
q = jnp.broadcast_to(jnp.eye(m, k, dtype=a.dtype), (*batch_dims, m, k))
r = jnp.empty((*batch_dims, k, n), dtype=a.dtype)
return q, r

r, tau, info_geqrf = geqrf_impl(operand_aval.dtype, operand)
r, taus = geqrf(a)
if m < n:
q = mhlo.SliceOp(r,
mlir.dense_int_elements([0] * len(dims)),
mlir.dense_int_elements(list(batch_dims) + [m, m]),
mlir.dense_int_elements([1] * len(dims))).result
q, info_orgqr = orgqr_impl(operand_aval.dtype, q, tau)
elif not full_matrices:
q, info_orgqr = orgqr_impl(operand_aval.dtype, r, tau)
r = mhlo.SliceOp(r,
mlir.dense_int_elements([0] * len(dims)),
mlir.dense_int_elements(list(batch_dims) + [n, n]),
mlir.dense_int_elements([1] * len(dims))).result
else:
if jax._src.lib.mlir_api_version < 15:
q = mhlo.PadOp(mlir.aval_to_ir_type(q_aval), r,
mlir.ir_constant(np.array(0, dtype=operand_aval.dtype)),
mlir.dense_int_elements([0] * len(dims)),
mlir.dense_int_elements([0] * (len(dims) - 1) + [m - n]),
mlir.dense_int_elements([0] * len(dims))).result
else:
q = mhlo.PadOp(r,
mlir.ir_constant(np.array(0, dtype=operand_aval.dtype)),
mlir.dense_int_elements([0] * len(dims)),
mlir.dense_int_elements([0] * (len(dims) - 1) + [m - n]),
mlir.dense_int_elements([0] * len(dims))).result
q, info_orgqr = orgqr_impl(operand_aval.dtype, q, tau)
if info_geqrf is not None:
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mhlo.AndOp(
mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED"),
mlir.compare_mhlo(info_orgqr, zeros, "EQ", "SIGNED"))
ok = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get(batch_dims + (1, 1),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
q = _broadcasting_select_mhlo(ok, q, _nan_like_mhlo(q_aval))
r = _broadcasting_select_mhlo(ok, r, _nan_like_mhlo(r_aval))
q = orgqr(r[..., :m, :m], taus)
elif full_matrices:
pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)]
q = lax.pad(r, lax_internal._zero(r), pads)
q = orgqr(q, taus)
else:
pass # rocsolver does not return info
q = orgqr(r, taus)
r = r[..., :n, :n]
r = jnp.triu(r)
return q, r

sub_ctx = mlir.LoweringRuleContext(module_context=ctx.module_context,
primitive=None,
avals_in=[r_aval],
avals_out=[r_aval],
tokens_in=ctx.tokens_in,
tokens_out=ctx.tokens_out)
r, = mlir.lower_fun(jnp.triu, multiple_results=False)(sub_ctx, r)
return [q, r]

qr_p = Primitive('qr')
qr_p.multiple_results = True
qr_p.def_impl(_qr_impl)
qr_p.def_abstract_eval(_qr_abstract_eval)
xla.register_translation(qr_p, _qr_translation_rule)
ad.primitive_jvps[qr_p] = qr_jvp_rule
batching.primitive_batchers[qr_p] = _qr_batching_rule

mlir.register_lowering(
qr_p, partial(_qr_cpu_gpu_lowering, lapack.geqrf_mhlo, lapack.orgqr_mhlo),
platform='cpu')
# Older jaxlibs didn't expose geqrf and orgqr as separate XLA operations.
# TODO(phawkins): remove after minimum jaxlib version is > 0.3.10.
if jax._src.lib.xla_extension_version < 69:
xla.register_translation(qr_p, _qr_translation_rule, platform="tpu")

if gpu_solver is not None:
mlir.register_lowering(
qr_p,
partial(_qr_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
gpu_solver.cuda_orgqr),
platform='cuda')
mlir.register_lowering(
qr_p,
partial(_qr_cpu_gpu_lowering, gpu_solver.rocm_geqrf,
gpu_solver.rocm_orgqr),
platform='rocm')
ad.primitive_jvps[qr_p] = qr_jvp_rule
batching.primitive_batchers[qr_p] = _qr_batching_rule

mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering));

if solver_apis is not None:
mlir.register_lowering(
qr_p,
partial(_qr_cpu_gpu_lowering, solver_apis.geqrf_mhlo, solver_apis.orgqr_mhlo),
platform='gpu')

# Singular value decomposition

Expand Down
5 changes: 4 additions & 1 deletion jax/_src/numpy/linalg.py
Expand Up @@ -477,13 +477,16 @@ def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None,
@_wraps(np.linalg.qr)
@partial(jit, static_argnames=('mode',))
def qr(a, mode="reduced"):
a, = _promote_dtypes_inexact(jnp.asarray(a))
if mode == "raw":
a, taus = lax_linalg.geqrf(a)
return _T(a), taus
if mode in ("reduced", "r", "full"):
full_matrices = False
elif mode == "complete":
full_matrices = True
else:
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
a, = _promote_dtypes_inexact(jnp.asarray(a))
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
if mode == "r":
return r
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1011,6 +1011,8 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"all_gather",
"lu_pivots_to_permutation",
"xla_pmap",
"geqrf",
"orgqr",
]

tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/lapack.py
Expand Up @@ -208,7 +208,7 @@ def orgqr_mhlo(dtype, a, tau):
b *= d

tau_dims = ir.RankedTensorType(tau.type).shape
assert tau_dims[:-1] == dims[:-2]
assert tau_dims[:-1] == dims[:-2], (tau.type, a.type)
k = tau_dims[-1]

if dtype == np.float32:
Expand Down
5 changes: 3 additions & 2 deletions tests/linalg_test.py
Expand Up @@ -657,15 +657,16 @@ def testJspSVDBasic(self):
"shape": shape, "dtype": dtype, "mode": mode}
for shape in [(0, 2), (2, 0), (3, 4), (3, 3), (4, 3)]
for dtype in [np.float32]
for mode in ["reduced", "r", "full", "complete"]))
for mode in ["reduced", "r", "full", "complete", "raw"]))
def testNumpyQrModes(self, shape, dtype, mode):
rng = jtu.rand_default(self.rng())
jnp_func = partial(jax.numpy.linalg.qr, mode=mode)
np_func = partial(np.linalg.qr, mode=mode)
if mode == "full":
np_func = jtu.ignore_warning(category=DeprecationWarning, message="The 'full' option.*")(np_func)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1E-5, atol=1E-5)
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1e-5, atol=1e-5,
check_dtypes=(mode != "raw"))
self._CompileAndCheck(jnp_func, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down

0 comments on commit 909c032

Please sign in to comment.