Skip to content

Commit

Permalink
Add support for padded arrays in QDWH algorithm.
Browse files Browse the repository at this point in the history
This change is in preparation for adding a jit-table QDWH-eig implementation.

PiperOrigin-RevId: 448571523
  • Loading branch information
hawkinsp authored and jax authors committed May 13, 2022
1 parent 7c582ab commit db73670
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 29 deletions.
98 changes: 78 additions & 20 deletions jax/_src/lax/qdwh.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,81 @@
"""

import functools
from typing import Optional, Tuple

import jax
from jax import core
from jax import lax
import jax.numpy as jnp
from jax._src.lax import linalg as lax_linalg


def _use_qr(u, params):
"""Uses QR decomposition."""
# Helpers for working with padded shapes
def _mask(x, dims, alternative=0):
"""Masks `x` up to the dynamic shape `dims`.
Replaces values outside those dimensions with `alternative`. `alternative` is
broadcast with `x`.
"""
assert jnp.ndim(x) == len(dims)
mask = None
for i, d in enumerate(dims):
if d is not None:
mask_dim_i = lax.broadcasted_iota(jnp.int32, x.shape, i) < d
mask = mask_dim_i if mask is None else (mask & mask_dim_i)
return x if mask is None else jnp.where(mask, x, alternative)

def _pad_in_dim(x, low=0, high=0, interior=0, fill_value=0, axis=0):
pads = [(0, 0, 0)] * x.ndim
pads[axis] = (low, high, interior)
return lax.pad(x, jnp.array(fill_value, x.dtype), pads)

def _dynamic_concat(a, b, m, axis=0):
"Concatenates padded arrays `a` and `b` where the true size of `a` is `m`."
if m is None:
return jnp.concatenate([a, b], axis=axis)
return lax.dynamic_update_slice_in_dim(
_pad_in_dim(a, high=b.shape[axis], axis=axis), b, m, axis)


def _use_qr(u, m, n, params):
"""QDWH iteration using QR decomposition.
Args:
u: a matrix, with static (padded) shape M x N.
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
params: the QDWH parameters.
"""
a, b, c = params
m, n = u.shape
y = jnp.concatenate([jnp.sqrt(c) * u, jnp.eye(n, dtype=jnp.dtype(u))])
M, N = u.shape

y = _dynamic_concat(jnp.sqrt(c) * u, jnp.eye(N, dtype=jnp.dtype(u)), m)
q, _ = lax_linalg.qr(y, full_matrices=False)
q1 = q[:m, :]
q2 = (q[m:, :]).T.conj()
# q1 = q[:m, :]
q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n))
# q2 = (q[m:, :]).T.conj()
q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0)
q2 = _mask(q2, (n, n)).T.conj()
e = b / c
u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2))
return u


def _use_cholesky(u, params):
"""Uses Cholesky decomposition."""
def _use_cholesky(u, m, n, params):
"""QDWH iteration using Cholesky decomposition.
Args:
u: a matrix, with static (padded) shape M x N
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
params: the QDWH parameters.
"""
a, b, c = params
_, n = u.shape
x = c * (u.T.conj() @ u) + jnp.eye(n, dtype=jnp.dtype(u))
_, N = u.shape
x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u))
# Pads the lower-right corner with the identity matrix to prevent the Cholesky
# decomposition from failing due to the matrix not being PSD if padded with
# zeros.
x = _mask(x, (n, n), jnp.eye(N, dtype=x.dtype))

# `y` is lower triangular.
y = lax_linalg.cholesky(x, symmetrize_input=False)
Expand All @@ -64,8 +114,7 @@ def _use_cholesky(u, params):
u = e * u + (a - e) * z
return u


def _qdwh(x, is_hermitian, max_iterations, eps):
def _qdwh(x, m, n, is_hermitian, max_iterations, eps):
"""QR-based dynamically weighted Halley iteration for polar decomposition."""

# Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of
Expand Down Expand Up @@ -106,10 +155,10 @@ def body_fun(state):

# Uses QR or Cholesky decomposition.
def true_fn(u):
return _use_qr(u, params=(a, b, c))
return _use_qr(u, m, n, params=(a, b, c))

def false_fn(u):
return _use_cholesky(u, params=(a, b, c))
return _use_cholesky(u, m, n, params=(a, b, c))

u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u))

Expand Down Expand Up @@ -146,16 +195,19 @@ def false_fn(u):

# TODO: Add pivoting.
@functools.partial(jax.jit, static_argnames=('is_hermitian',))
def qdwh(x, is_hermitian=False, max_iterations=None, eps=None):
def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None,
dynamic_shape: Optional[Tuple[int, int]] = None):
"""QR-based dynamically weighted Halley iteration for polar decomposition.
Args:
x: A full-rank matrix of shape `m x n`.
x: A full-rank matrix, with shape `M x N`. The matrix may be
padded up to that size from a smaller true shape (``dynamic_shape``).
is_hermitian: True if `x` is Hermitian. Default to `False`.
eps: The final result will satisfy
``|x_k - x_k-1| < |x_k| * (4*eps)**(1/3)`` where `x_k` is the iterate.
max_iterations: Iterations will terminate after this many steps even if the
above is unsatisfied.
dynamic_shape: the unpadded shape as an ``(m, n)`` tuple; optional.
Returns:
A four-tuple of (u, h, num_iters, is_converged) containing the
Expand All @@ -170,12 +222,18 @@ def qdwh(x, is_hermitian=False, max_iterations=None, eps=None):
if max_iterations is None:
max_iterations = 10

m, n = x.shape
if m < n:
raise ValueError('The input matrix of shape m x n must have m >= n.')
M, N = x.shape
if M < N:
raise ValueError('The input matrix of shape M x N must have M >= N.')
if dynamic_shape is not None:
m, n = dynamic_shape
x = _mask(x, (m, n))
else:
m, n = M, N

with jax.default_matmul_precision('float32'):
u, h, num_iters, is_converged = _qdwh(x, is_hermitian, max_iterations, eps)
u, h, num_iters, is_converged = _qdwh(x, m, n, is_hermitian, max_iterations,
eps)


return u, h, num_iters, is_converged
3 changes: 2 additions & 1 deletion jax/_src/lax/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def _svd_tall_and_square_input(
`a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned.
"""

u, h, _, _ = lax.linalg.qdwh(a, hermitian, max_iterations)
u, h, _, _ = lax.linalg.qdwh(a, is_hermitian=hermitian,
max_iterations=max_iterations)

# TODO: Uses `eigvals_only=True` if `compute_uv=False`.
v, s = lax.linalg.eigh(h)
Expand Down
26 changes: 18 additions & 8 deletions tests/qdwh_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def testQdwhUnconvergedAfterMaxNumberIterations(
max_iterations = 2

_, _, actual_num_iterations, is_converged = qdwh.qdwh(
a, is_hermitian, max_iterations)
a, is_hermitian=is_hermitian, max_iterations=max_iterations)

with self.subTest('Number of iterations.'):
self.assertEqual(max_iterations, actual_num_iterations)
Expand All @@ -105,7 +105,8 @@ def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond):
is_hermitian = _check_symmetry(a)
max_iterations = 10

actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian, max_iterations)
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian,
max_iterations=max_iterations)
expected_u, expected_h = osp_linalg.polar(a)

# Sets the test tolerance.
Expand All @@ -132,12 +133,13 @@ def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond):

@parameterized.named_parameters(jtu.cases_from_list(
{ # pylint:disable=g-complex-comprehension
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(
m, n, log_cond),
'm': m, 'n': n, 'log_cond': log_cond}
'testcase_name': '_m={}_by_n={}_log_cond={}_padding={}'.format(
m, n, log_cond, padding),
'm': m, 'n': n, 'log_cond': log_cond, 'padding': padding}
for m, n in zip([6, 8], [6, 4])
for padding in (None, (3, 2))
for log_cond in np.linspace(1, 4, 4)))
def testQdwhWithRandomMatrix(self, m, n, log_cond):
def testQdwhWithRandomMatrix(self, m, n, log_cond, padding):
"""Tests qdwh with random input."""
rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9)
a = rng((m, n), _QDWH_TEST_DTYPE)
Expand All @@ -149,8 +151,15 @@ def testQdwhWithRandomMatrix(self, m, n, log_cond):
max_iterations = 10

def lsp_linalg_fn(a):
if padding is not None:
pm, pn = padding
a = jnp.pad(a, [(0, pm), (0, pn)], constant_values=jnp.nan)
u, h, _, _ = qdwh.qdwh(
a, is_hermitian=is_hermitian, max_iterations=max_iterations)
a, is_hermitian=is_hermitian, max_iterations=max_iterations,
dynamic_shape=(m, n) if padding else None)
if padding is not None:
u = u[:m, :n]
h = h[:n, :n]
return u, h

args_maker = lambda: [a]
Expand Down Expand Up @@ -187,7 +196,8 @@ def testQdwhWithOnRankDeficientInput(self, m, n, log_cond):

is_hermitian = _check_symmetry(a)
max_iterations = 15
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian, max_iterations)
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian,
max_iterations=max_iterations)
_, expected_h = osp_linalg.polar(a)

# Sets the test tolerance.
Expand Down

0 comments on commit db73670

Please sign in to comment.