Skip to content

Commit

Permalink
Change non-array arguments to jax.lax.linalg functions to be keyword-…
Browse files Browse the repository at this point in the history
…only arguments.

PiperOrigin-RevId: 448066207
  • Loading branch information
hawkinsp authored and jax authors committed May 11, 2022
1 parent d092d63 commit 705e241
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 40 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Changes
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
that allows users to opt out of eigenvalue sorting on TPU.
* Non-array arguments to functions in {mod}`jax.lax.linalg` are now marked
keyword-only. As a backward-compatibility step passing keyword-only
arguments positionally yields a warning, but in a future JAX release passing
keyword-only arguments positionally will fail.
However, most users should prefer to use {mod}`jax.numpy.linalg` instead.

## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
Expand Down
125 changes: 91 additions & 34 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import functools
from functools import partial
import warnings

import numpy as np

Expand Down Expand Up @@ -58,7 +60,52 @@

# traceables

def cholesky(x, symmetrize_input: bool = True):
# TODO(phawkins): remove backward compatibility shim after 2022/08/11.
def _warn_on_positional_kwargs(f):
"""Decorator used for backward compatibility of keyword-only arguments.
Some functions were changed to mark their keyword arguments as keyword-only.
This decorator allows existing code to keep working temporarily, while issuing
a warning if a now keyword-only parameter is passed positionally."""
sig = inspect.signature(f)
pos_names = [name for name, p in sig.parameters.items()
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD]
kwarg_names = [name for name, p in sig.parameters.items()
if p.kind == inspect.Parameter.KEYWORD_ONLY]

# This decorator assumes that all arguments to `f` are either
# positional-or-keyword or keyword-only.
assert len(pos_names) + len(kwarg_names) == len(sig.parameters)

@functools.wraps(f)
def wrapped(*args, **kwargs):
if len(args) < len(pos_names):
a = pos_names[len(args)]
raise TypeError(f"{f.__name__} missing required positional argument: {a}")

pos_args = args[:len(pos_names)]
extra_kwargs = args[len(pos_names):]

if len(extra_kwargs) > len(kwarg_names):
raise TypeError(f"{f.__name__} takes at most {len(sig.parameters)} "
f" arguments but {len(args)} were given.")

for name, value in zip(kwarg_names, extra_kwargs):
if name in kwargs:
raise TypeError(f"{f.__name__} got multiple values for argument: "
f"{name}")

warnings.warn(f"Argument {name} to {f.__name__} is now a keyword-only "
"argument. Support for passing it positionally will be "
"removed in an upcoming JAX release.",
DeprecationWarning)
kwargs[name] = value
return f(*pos_args, **kwargs)

return wrapped

@_warn_on_positional_kwargs
def cholesky(x, *, symmetrize_input: bool = True):
"""Cholesky decomposition.
Computes the Cholesky decomposition
Expand Down Expand Up @@ -87,15 +134,17 @@ def cholesky(x, symmetrize_input: bool = True):
x = symmetrize(x)
return jnp.tril(cholesky_p.bind(x))

def eig(x, compute_left_eigenvectors=True, compute_right_eigenvectors=True):
@_warn_on_positional_kwargs
def eig(x, *, compute_left_eigenvectors=True, compute_right_eigenvectors=True):
"""Eigendecomposition of a general matrix.
Nonsymmetric eigendecomposition is at present only implemented on CPU.
"""
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)

def eigh(x, lower: bool = True, symmetrize_input: bool = True,
@_warn_on_positional_kwargs
def eigh(x, *, lower: bool = True, symmetrize_input: bool = True,
sort_eigenvalues: bool = True, ):
r"""Eigendecomposition of a Hermitian matrix.
Expand Down Expand Up @@ -182,7 +231,8 @@ def lu(x):
lu, pivots, permutation = lu_p.bind(x)
return lu, pivots, permutation

def qr(x, full_matrices: bool = True):
@_warn_on_positional_kwargs
def qr(x, *, full_matrices: bool = True):
"""QR decomposition.
Computes the QR decomposition
Expand Down Expand Up @@ -213,7 +263,8 @@ def qr(x, full_matrices: bool = True):
return q, r

# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
def svd(x, full_matrices=True, compute_uv=True):
@_warn_on_positional_kwargs
def svd(x, *, full_matrices=True, compute_uv=True):
"""Singular value decomposition.
Returns the singular values if compute_uv is False, otherwise returns a triple
Expand All @@ -228,7 +279,8 @@ def svd(x, full_matrices=True, compute_uv=True):
s, = result
return s

def triangular_solve(a, b, left_side: bool = False, lower: bool = False,
@_warn_on_positional_kwargs
def triangular_solve(a, b, *, left_side: bool = False, lower: bool = False,
transpose_a: bool = False, conjugate_a: bool = False,
unit_diagonal: bool = False):
r"""Triangular solve.
Expand Down Expand Up @@ -330,7 +382,7 @@ def g(c, *args, **kwargs):

# Cholesky decomposition

def cholesky_jvp_rule(primals, tangents):
def _cholesky_jvp_rule(primals, tangents):
x, = primals
sigma_dot, = tangents
L = jnp.tril(cholesky_p.bind(x))
Expand All @@ -349,15 +401,15 @@ def phi(X):
precision=lax.Precision.HIGHEST)
return L, L_dot

def cholesky_batching_rule(batched_args, batch_dims):
def _cholesky_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return cholesky(x), 0

cholesky_p = standard_unop(_float | _complex, 'cholesky')
ad.primitive_jvps[cholesky_p] = cholesky_jvp_rule
batching.primitive_batchers[cholesky_p] = cholesky_batching_rule
ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule

def _cholesky_lowering(ctx, x):
aval, = ctx.avals_out
Expand Down Expand Up @@ -635,7 +687,7 @@ def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues):
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
'triangular_solve')

def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
def triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
if a.ndim < 2:
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim))
Expand All @@ -657,7 +709,8 @@ def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
return b.shape

def triangular_solve_jvp_rule_a(
g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal):
m, n = b.shape[-2:]
k = 1 if unit_diagonal else 0
g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k)
Expand All @@ -668,8 +721,9 @@ def triangular_solve_jvp_rule_a(
precision=lax.Precision.HIGHEST)

def a_inverse(rhs):
return triangular_solve(a, rhs, left_side, lower, transpose_a, conjugate_a,
unit_diagonal)
return triangular_solve(a, rhs, left_side=left_side, lower=lower,
transpose_a=transpose_a, conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal)

# triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
# for matrix/vector inputs). Order these operations in whichever order is
Expand All @@ -688,20 +742,22 @@ def a_inverse(rhs):
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})

def triangular_solve_transpose_rule(
cotangent, a, b, left_side, lower, transpose_a, conjugate_a,
cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal):
# Triangular solve is nonlinear in its first argument and linear in its second
# argument, analogous to `div` but swapped.
assert not ad.is_undefined_primal(a) and ad.is_undefined_primal(b)
if type(cotangent) is ad_util.Zero:
cotangent_b = ad_util.Zero(b.aval)
else:
cotangent_b = triangular_solve(a, cotangent, left_side, lower,
not transpose_a, conjugate_a, unit_diagonal)
cotangent_b = triangular_solve(a, cotangent, left_side=left_side,
lower=lower, transpose_a=not transpose_a,
conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal)
return [None, cotangent_b]


def triangular_solve_batching_rule(batched_args, batch_dims, left_side,
def triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
lower, transpose_a, conjugate_a,
unit_diagonal):
x, y = batched_args
Expand Down Expand Up @@ -1206,7 +1262,7 @@ def lu_solve(lu, permutation, b, trans=0):

# QR decomposition

def qr_impl(operand, full_matrices):
def _qr_impl(operand, *, full_matrices):
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
return q, r

Expand All @@ -1219,7 +1275,7 @@ def _qr_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices):
_zeros_like_xla(ctx.builder, avals_out[1])]
return xops.QR(operand, full_matrices)

def qr_abstract_eval(operand, full_matrices):
def _qr_abstract_eval(operand, *, full_matrices):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
Expand All @@ -1232,7 +1288,7 @@ def qr_abstract_eval(operand, full_matrices):
r = operand
return q, r

def qr_jvp_rule(primals, tangents, full_matrices):
def qr_jvp_rule(primals, tangents, *, full_matrices):
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
x, = primals
dx, = tangents
Expand All @@ -1252,7 +1308,7 @@ def qr_jvp_rule(primals, tangents, full_matrices):
dr = jnp.matmul(qt_dx_rinv - do, r)
return (q, r), (dq, dr)

def qr_batching_rule(batched_args, batch_dims, full_matrices):
def _qr_batching_rule(batched_args, batch_dims, *, full_matrices):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
Expand Down Expand Up @@ -1329,11 +1385,11 @@ def _qr_cpu_gpu_lowering(geqrf_impl, orgqr_impl, ctx, operand, *,

qr_p = Primitive('qr')
qr_p.multiple_results = True
qr_p.def_impl(qr_impl)
qr_p.def_abstract_eval(qr_abstract_eval)
qr_p.def_impl(_qr_impl)
qr_p.def_abstract_eval(_qr_abstract_eval)
xla.register_translation(qr_p, _qr_translation_rule)
ad.primitive_jvps[qr_p] = qr_jvp_rule
batching.primitive_batchers[qr_p] = qr_batching_rule
batching.primitive_batchers[qr_p] = _qr_batching_rule

mlir.register_lowering(
qr_p, partial(_qr_cpu_gpu_lowering, lapack.geqrf_mhlo, lapack.orgqr_mhlo),
Expand All @@ -1360,7 +1416,7 @@ def _qr_cpu_gpu_lowering(geqrf_impl, orgqr_impl, ctx, operand, *,

# Singular value decomposition

def svd_impl(operand, full_matrices, compute_uv):
def _svd_impl(operand, *, full_matrices, compute_uv):
return xla.apply_primitive(svd_p, operand, full_matrices=full_matrices,
compute_uv=compute_uv)

Expand All @@ -1376,7 +1432,7 @@ def _eye_like_xla(c, aval):
xops.Iota(c, iota_shape, len(aval.shape) - 2))
return xops.ConvertElementType(x, xla.dtype_to_primitive_type(aval.dtype))

def svd_abstract_eval(operand, full_matrices, compute_uv):
def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to singular value decomposition must have ndims >= 2")
Expand All @@ -1395,7 +1451,7 @@ def svd_abstract_eval(operand, full_matrices, compute_uv):
else:
raise NotImplementedError

def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
A, = primals
dA, = tangents
s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)
Expand Down Expand Up @@ -1520,7 +1576,7 @@ def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv):
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)

def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
Expand All @@ -1533,10 +1589,10 @@ def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):

svd_p = Primitive('svd')
svd_p.multiple_results = True
svd_p.def_impl(svd_impl)
svd_p.def_abstract_eval(svd_abstract_eval)
ad.primitive_jvps[svd_p] = svd_jvp_rule
batching.primitive_batchers[svd_p] = svd_batching_rule
svd_p.def_impl(_svd_impl)
svd_p.def_abstract_eval(_svd_abstract_eval)
ad.primitive_jvps[svd_p] = _svd_jvp_rule
batching.primitive_batchers[svd_p] = _svd_batching_rule

mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
Expand Down Expand Up @@ -1665,7 +1721,8 @@ def tridiagonal_solve(dl, d, du, b):
# Schur Decomposition


def schur(x,
@_warn_on_positional_kwargs
def schur(x, *,
compute_schur_vectors=True,
sort_eig_vals=False,
select_callable=None):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True,
else:
return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1])

return lax_linalg.svd(a, full_matrices, compute_uv)
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)


@_wraps(np.linalg.matrix_power)
Expand Down Expand Up @@ -484,7 +484,7 @@ def qr(a, mode="reduced"):
else:
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
a, = _promote_dtypes_inexact(jnp.asarray(a))
q, r = lax_linalg.qr(a, full_matrices)
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
if mode == "r":
return r
return q, r
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
def _svd(a, *, full_matrices, compute_uv):
a, = _promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices, compute_uv)
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)

@_wraps(scipy.linalg.svd,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lapack_driver'))
Expand Down Expand Up @@ -189,7 +189,7 @@ def _qr(a, mode, pivoting):
else:
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
a, = _promote_dtypes_inexact(jnp.asarray(a))
q, r = lax_linalg.qr(a, full_matrices)
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
if mode == "r":
return (r,)
return q, r
Expand Down
5 changes: 3 additions & 2 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,9 @@ def check_left_eigenvectors(a, w, vl):
check_right_eigenvectors(aH, wC, vl)

a, = args_maker()
results = lax.linalg.eig(a, compute_left_eigenvectors,
compute_right_eigenvectors)
results = lax.linalg.eig(
a, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
w = results[0]

if compute_left_eigenvectors:
Expand Down

0 comments on commit 705e241

Please sign in to comment.