Skip to content

Commit

Permalink
Refactor QDWH to be more efficient when run batched under vmap.
Browse files Browse the repository at this point in the history
In particular, avoid using lax.cond to switch to CholeskyQR for later iterations, as under vmap this can result in both branches being executed.

PiperOrigin-RevId: 628144162
  • Loading branch information
jlottes authored and jax authors committed Apr 25, 2024
1 parent beb49af commit 9fd5f7c
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 78 deletions.
1 change: 0 additions & 1 deletion jax/_src/internal_test_util/test_harnesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,7 +1819,6 @@ def _fft_rng_factory(dtype):
jax_unimplemented=[
Limitation(
"unimplemented",
devices=("cpu", "gpu"),
dtypes=[np.float16, dtypes.bfloat16],
),
],
Expand Down
165 changes: 104 additions & 61 deletions jax/_src/lax/qdwh.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,17 @@ def _use_qr(u, m, n, params):
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
params: the QDWH parameters.
"""
a, b, c = params
a_minus_e_by_sqrt_c, sqrt_c, e = params
M, N = u.shape

y = _dynamic_concat(jnp.sqrt(c) * u, jnp.eye(N, dtype=jnp.dtype(u)), m)
y = _dynamic_concat(sqrt_c * u, jnp.eye(N, dtype=jnp.dtype(u)), m)
q, _ = lax_linalg.qr(y, full_matrices=False)
# q1 = q[:m, :]
q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n))
# q2 = (q[m:, :]).T.conj()
q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0)
q2 = _mask(q2, (n, n)).T.conj()
e = b / c
u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2))
return u
return e * u + a_minus_e_by_sqrt_c * (q1 @ q2)


def _use_cholesky(u, m, n, params):
Expand All @@ -94,7 +92,7 @@ def _use_cholesky(u, m, n, params):
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
params: the QDWH parameters.
"""
a, b, c = params
a_minus_e, c, e = params
_, N = u.shape
x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u))
# Pads the lower-right corner with the identity matrix to prevent the Cholesky
Expand All @@ -111,9 +109,7 @@ def _use_cholesky(u, m, n, params):
z = lax_linalg.triangular_solve(y, z, left_side=True, lower=True,
transpose_a=True, conjugate_a=True).T.conj()

e = b / c
u = e * u + (a - e) * z
return u
return e * u + a_minus_e * z

def _qdwh(x, m, n, is_hermitian, max_iterations, eps):
"""QR-based dynamically weighted Halley iteration for polar decomposition."""
Expand All @@ -123,81 +119,125 @@ def _qdwh(x, m, n, is_hermitian, max_iterations, eps):
# the smallest singular value of x.
if eps is None:
eps = float(jnp.finfo(x.dtype).eps)
alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) *
jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf))).astype(x.dtype)
alpha_inverse = (lax.rsqrt(jnp.linalg.norm(x, ord=1)) *
lax.rsqrt(jnp.linalg.norm(x, ord=jnp.inf))).astype(x.dtype)
l = eps

u = x / alpha
u = x * alpha_inverse

# Iteration tolerances.
tol_l = 10.0 * eps / 2.0
tol_norm = jnp.cbrt(tol_l)

def cond_fun(state):
_, _, _, is_unconverged, is_not_max_iteration = state
return jnp.logical_and(is_unconverged, is_not_max_iteration)

def body_fun(state):
u, l, iter_idx, _, _ = state
def get_qr_params(a, b, c):
e = b / c
a_minus_e = a - e
sqrt_c = c ** (1 / 2)
return (a_minus_e / sqrt_c, sqrt_c, e)

def get_chol_params(a, b, c):
e = b / c
a_minus_e = a - e
return (a_minus_e, c, e)

CHOLESKY_CUTOFF = 100

qr_coefs = []
chol_coefs = []
k = 0
while l + tol_l < 1 and k < max_iterations:
k += 1
l2 = l * l
dd = (4 * (1 / l2 - 1) / l2) ** (1 / 3)
sqd = (1.0 + dd) ** (1 / 2)
a = sqd + (2 - dd + 2 * (2 - l2) / (l2 * sqd)) ** (1 / 2)
b = (a - 1) ** 2 / 4
c = a + b - 1
l = l * (a + b * l2) / (1 + c * l2)
if c > CHOLESKY_CUTOFF:
qr_coefs.append(get_qr_params(a, b, c))
else:
chol_coefs.append(get_chol_params(a, b, c))

def iteration(k, state, update_fn, coefs, test_convergence):
u, _ = state

if coefs is None:
# As l → 1, the coefficients a, b, c → 3, 1, 3, which is Halley's method.
params = get_chol_params(3, 1, 3)
else:
params = lax.dynamic_index_in_dim(coefs, k, keepdims=False)

u_prev = u

# Computes parameters.
l2 = l**2
dd = jnp.cbrt(4.0 * (1.0 / l2 - 1.0) / l2)
sqd = jnp.sqrt(1.0 + dd)
a = (sqd + jnp.sqrt(8.0 - 4.0 * dd + 8.0 * (2.0 - l2) / (l2 * sqd)) / 2)
a = jnp.real(a)
b = (a - 1.0)**2 / 4.0
c = a + b - 1.0

# Updates l.
l = l * (a + b * l2) / (1.0 + c * l2)

# Uses QR or Cholesky decomposition.
def true_fn(u):
return _use_qr(u, m, n, params=(a, b, c))

def false_fn(u):
return _use_cholesky(u, m, n, params=(a, b, c))

u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u))

u = update_fn(u, m, n, params)
if is_hermitian:
u = (u + u.T.conj()) / 2.0

# Checks convergence.
iterating_l = jnp.abs(1.0 - l) > tol_l
iterating_u = jnp.linalg.norm(u-u_prev) > tol_norm
is_unconverged = jnp.logical_or(iterating_l, iterating_u)

is_not_max_iteration = iter_idx < max_iterations

return u, l, iter_idx + 1, is_unconverged, is_not_max_iteration
is_not_converged = True
if test_convergence:
is_not_converged = jnp.linalg.norm(u - u_prev) > tol_norm
return u, is_not_converged

def iterate(u, coefs, **kwargs):
if not coefs:
return u, True
coefs = jnp.array(coefs).astype(x.dtype)
body = functools.partial(iteration, coefs=coefs, **kwargs)
return lax.fori_loop(0, len(coefs), body, (u, True))

u, _ = iterate(
u, coefs=qr_coefs, update_fn=_use_qr, test_convergence=False
)
u, is_not_converged = iterate(
u, coefs=chol_coefs, update_fn=_use_cholesky, test_convergence=True
)

# If l has converged but u still has not, continue with Halley's method
# (coef = None) until convergence.
def cond_fun(state):
k, _, is_not_converged = state
return jnp.logical_and(is_not_converged, k < max_iterations)

iter_idx = 1
is_unconverged = True
is_not_max_iteration = True
u, _, num_iters, is_unconverged, _ = jax.lax.while_loop(
cond_fun=cond_fun, body_fun=body_fun,
init_val=(u, l, iter_idx, is_unconverged, is_not_max_iteration))
def body_fun(state):
k, u, is_not_converged = state
u, is_not_converged = iteration(
k,
(u, is_not_converged),
coefs=None,
update_fn=_use_cholesky,
test_convergence=True,
)
return k + 1, u, is_not_converged

k = len(qr_coefs) + len(chol_coefs)
num_iters, u, is_not_converged = lax.while_loop(
cond_fun, body_fun, (k, u, is_not_converged)
)

# Applies Newton-Schulz refinement for better accuracy.
u = 1.5 * u - 0.5 * u @ (u.T.conj() @ u)

h = u.T.conj() @ x
h = (h + h.T.conj()) / 2.0
h = (h + h.T.conj()) / 2

# Converged within the maximum number of iterations.
is_converged = jnp.logical_not(is_unconverged)
is_converged = jnp.logical_not(is_not_converged)

return u, h, num_iters - 1, is_converged
return u, h, num_iters, is_converged


# TODO: Add pivoting.
@functools.partial(jax.jit, static_argnames=('is_hermitian',))
def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None,
dynamic_shape: tuple[int, int] | None = None):
@functools.partial(
jax.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps')
)
def qdwh(
x,
*,
is_hermitian: bool = False,
max_iterations: int | None = None,
eps: float | None = None,
dynamic_shape: tuple[int, int] | None = None,
):
"""QR-based dynamically weighted Halley iteration for polar decomposition.
Args:
Expand All @@ -222,6 +262,10 @@ def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None,

if max_iterations is None:
max_iterations = 10
else:
max_iterations = core.concrete_or_error(
int, max_iterations, 'The `max_iterations` argument must be statically '
'specified to use `qdwh` within JAX transformations.')

M, N = x.shape
if M < N:
Expand All @@ -236,5 +280,4 @@ def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None,
u, h, num_iters, is_converged = _qdwh(x, m, n, is_hermitian, max_iterations,
eps)


return u, h, num_iters, is_converged
2 changes: 1 addition & 1 deletion tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def testEighTinyNorm(self):
w = w.astype(v.dtype)
with jax.numpy_rank_promotion("allow"):
self.assertLessEqual(
np.linalg.norm(np.matmul(a, v) - w * v), 20 * eps * np.linalg.norm(a)
np.linalg.norm(np.matmul(a, v) - w * v), 80 * eps * np.linalg.norm(a)
)

@jtu.sample_product(
Expand Down
34 changes: 19 additions & 15 deletions tests/qdwh_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,27 +166,22 @@ def lsp_linalg_fn(a):
rtol=rtol, atol=1E-3)

@jtu.sample_product(
[dict(m=m, n=n) for m, n in [(10, 10), (8, 8)]],
[dict(m=m, n=n, r=r) for m, n, r in [(10, 10, 8), (8, 8, 7), (12, 8, 5)]],
log_cond=np.linspace(1, 4, 4),
)
def testQdwhWithOnRankDeficientInput(self, m, n, log_cond):
"""Tests qdwh with rank-deficient input."""
def testQdwhOnRankDeficientInput(self, m, n, r, log_cond):
"""Tests qdwh on rank-deficient input."""
a = np.triu(np.ones((m, n))).astype(_QDWH_TEST_DTYPE)

# Generates a rank-deficient input.
u, s, v = np.linalg.svd(a, full_matrices=False)
cond = 10**log_cond
s = jnp.linspace(cond, 1, min(m, n))
s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1))
a = (u * s) @ v
u, _, vh = np.linalg.svd(a, full_matrices=False)
s = 10**jnp.linspace(log_cond, 0, min(m, n))
s = jnp.expand_dims(s.at[r:].set(0), range(u.ndim - 1))
a = (u * s) @ vh

is_hermitian = _check_symmetry(a)
max_iterations = 15
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian,
max_iterations=max_iterations)
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=_check_symmetry(a))
_, expected_h = osp_linalg.polar(a)

# For rank-deficient matrix, `u` is not unique.
with self.subTest('Test h.'):
relative_diff_h = _compute_relative_diff(actual_h, expected_h)
np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5)
Expand All @@ -196,9 +191,18 @@ def testQdwhWithOnRankDeficientInput(self, m, n, log_cond):
relative_diff_a = _compute_relative_diff(a_round_trip, a)
np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5)

# QDWH gives U_p = U Σₖ V* for input A with SVD A = U Σ V*. For full rank
# input, we expect convergence Σₖ → I, giving the correct polar factor
# U_p = U V*. Zero singular values stay at 0 in exact arithmetic, but can
# end up anywhere in [0, 1] as a result of rounding errors---in particular,
# we do not generally expect convergence to 1. As a result, we can only
# expect (U_p V_r) to be orthogonal, where V_r are the columns of V
# corresponding to nonzero singular values.
with self.subTest('Test orthogonality.'):
actual_results = _dot(actual_u.T.conj(), actual_u)
expected_results = np.eye(n)
vr = vh.conj().T[:, :r]
uvr = _dot(actual_u, vr)
actual_results = _dot(uvr.T.conj(), uvr)
expected_results = np.eye(r)
self.assertAllClose(
actual_results, expected_results, rtol=_QDWH_TEST_EPS, atol=1e-6
)
Expand Down

0 comments on commit 9fd5f7c

Please sign in to comment.