Skip to content

Commit

Permalink
[JAX] Add an option subset_by_index that allows computing a contiguou…
Browse files Browse the repository at this point in the history
…s subset of singular components from svd.

PiperOrigin-RevId: 607493941
  • Loading branch information
jax authors committed Feb 16, 2024
1 parent 0203d15 commit 7e7094c
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 80 deletions.
25 changes: 17 additions & 8 deletions jax/_src/internal_test_util/test_harnesses.py
Expand Up @@ -1800,25 +1800,35 @@ def _fft_rng_factory(dtype):
for shape in [(2, 2), (2, 7), (29, 29), (2, 3, 53), (2, 3, 29, 7)]:
for full_matrices in [False, True]:
for compute_uv in [False, True]:
subset_by_index = None
define(
lax.linalg.svd_p,
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}_computeuv={compute_uv}",
lambda *args: lax.linalg.svd_p.bind(
args[0], full_matrices=args[1], compute_uv=args[2]), [
RandArg(shape, dtype),
StaticArg(full_matrices),
StaticArg(compute_uv)
],
args[0],
full_matrices=args[1],
compute_uv=args[2],
subset_by_index=args[3],
),
[
RandArg(shape, dtype),
StaticArg(full_matrices),
StaticArg(compute_uv),
StaticArg(subset_by_index),
],
jax_unimplemented=[
Limitation(
"unimplemented",
devices=("cpu", "gpu"),
dtypes=[np.float16, dtypes.bfloat16]),
dtypes=[np.float16, dtypes.bfloat16],
),
],
shape=shape,
dtype=dtype,
full_matrices=full_matrices,
compute_uv=compute_uv)
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)

for dtype in jtu.dtypes.all_inexact:
for shape in [(0, 0), (5, 5), (2, 6, 6)]:
Expand Down Expand Up @@ -2666,7 +2676,6 @@ def _make_reducer_harness(prim,
dtype=dtype)



def wrap_and_split():
key = jax.random.key(42)
result = jax.random.split(key, 2)
Expand Down
154 changes: 124 additions & 30 deletions jax/_src/lax/linalg.py
Expand Up @@ -298,32 +298,69 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
q, r = qr_p.bind(x, full_matrices=full_matrices)
return q, r


@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ...
def svd(
x: ArrayLike,
*,
full_matrices: bool = True,
compute_uv: Literal[True],
subset_by_index: tuple[int, int] | None = None,
) -> tuple[Array, Array, Array]:
...


@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[False]) -> Array: ...
def svd(
x: ArrayLike,
*,
full_matrices: bool = True,
compute_uv: Literal[False],
subset_by_index: tuple[int, int] | None = None,
) -> Array:
...


@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Array | tuple[Array, Array, Array]: ...
def svd(
x: ArrayLike,
*,
full_matrices: bool = True,
compute_uv: bool = True,
subset_by_index: tuple[int, int] | None = None,
) -> Array | tuple[Array, Array, Array]:
...


# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
@_warn_on_positional_kwargs
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Array | tuple[Array, Array, Array]:
def svd(
x: ArrayLike,
*,
full_matrices: bool = True,
compute_uv: bool = True,
subset_by_index: tuple[int, int] | None = None,
) -> Array | tuple[Array, Array, Array]:
"""Singular value decomposition.
Returns the singular values if compute_uv is False, otherwise returns a triple
containing the left singular vectors, the singular values and the adjoint of
the right singular vectors.
"""
result = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
result = svd_p.bind(
x,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)
if compute_uv:
s, u, v = result
return u, s, v
else:
s, = result
return s


@_warn_on_positional_kwargs
def triangular_solve(a: ArrayLike, b: ArrayLike, *,
left_side: bool = False, lower: bool = False,
Expand Down Expand Up @@ -1043,7 +1080,6 @@ def _triangular_solve_cpu_lower(
# Support operation for LU decomposition: Transformation of the pivots returned
# by LU decomposition into permutations.


# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
def _lu_pivots_body_fn(i, permutation_and_swaps):
permutation, swaps = permutation_and_swaps
Expand Down Expand Up @@ -1138,7 +1174,6 @@ def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
gpu_linalg.hip_lu_pivots_to_permutation),
platform='rocm')


# LU decomposition

# Computes a pivoted LU decomposition such that
Expand Down Expand Up @@ -1745,35 +1780,50 @@ def _qr_lowering(a, *, full_matrices):


# Singular value decomposition
def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None):
return dispatch.apply_primitive(
svd_p,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)

def _svd_impl(operand, *, full_matrices, compute_uv):
return dispatch.apply_primitive(svd_p, operand, full_matrices=full_matrices,
compute_uv=compute_uv)

def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to singular value decomposition must have ndims >= 2")

batch_dims = operand.shape[:-2]
m = operand.shape[-2]
n = operand.shape[-1]
s = operand.update(shape=batch_dims + (min(m, n),),
dtype=lax_internal._complex_basetype(operand.dtype))
rank = min(m, n)
if subset_by_index is not None:
if full_matrices and subset_by_index != (0, rank):
raise ValueError("full_matrices and subset_by_index cannot both be set")
rank = min(rank, subset_by_index[1] - subset_by_index[0])

s = operand.update(
shape=batch_dims + (rank,),
dtype=lax_internal._complex_basetype(operand.dtype),
)
if compute_uv:
u = operand.update(shape=batch_dims + (m, m if full_matrices else min(m, n)))
vt = operand.update(shape=batch_dims + (n if full_matrices else min(m, n), n))
u = operand.update(shape=batch_dims + (m, m if full_matrices else rank))
vt = operand.update(shape=batch_dims + (n if full_matrices else rank, n))
return s, u, vt
else:
return s,
else:
raise NotImplementedError


@jax.default_matmul_precision("float32")
def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
def _svd_jvp_rule(
primals, tangents, *, full_matrices, compute_uv, subset_by_index
):
A, = primals
dA, = tangents
s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)
s, U, Vt = svd_p.bind(
A, full_matrices=False, compute_uv=True, subset_by_index=subset_by_index
)

if compute_uv and full_matrices:
# TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
Expand Down Expand Up @@ -1812,6 +1862,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):

return (s, U, Vt), (ds, dU, _H(dV))


def _empty_svd(a, *, full_matrices, compute_uv):
batch_shape = a.shape[:-2]
m, n = a.shape[-2:]
Expand All @@ -1828,8 +1879,17 @@ def _empty_svd(a, *, full_matrices, compute_uv):
u, v = v, u
return s, u, v

def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
compute_uv, platform: str):

def _svd_cpu_gpu_lowering(
gesvd_impl,
ctx,
operand,
*,
full_matrices,
compute_uv,
subset_by_index,
platform: str,
):
operand_aval, = ctx.avals_in
s_aval = ctx.avals_out[0]
m, n = operand_aval.shape[-2:]
Expand All @@ -1841,9 +1901,16 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
f"implemented only for the batch dimensions: {operand_aval.shape}")
batch_dims = operand_aval.shape[:-2]

if not (subset_by_index is None or subset_by_index == (0, min(m, n))):
raise NotImplementedError("subset_by_index not implemented for CPU and GPU")

if m == 0 or n == 0:
return mlir.lower_fun(_empty_svd, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
ctx,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
)

if platform in ["cuda", "rocm"]:
if not is_constant_shape(operand_aval.shape):
Expand Down Expand Up @@ -1891,10 +1958,16 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,

return result

def _svd_tpu(a, *, full_matrices, compute_uv):

def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index):
batch_dims = a.shape[:-2]

fn = partial(lax_svd.svd, full_matrices=full_matrices, compute_uv=compute_uv)
fn = partial(
lax_svd.svd,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)
for _ in range(len(batch_dims)):
fn = api.vmap(fn)

Expand All @@ -1905,28 +1978,49 @@ def _svd_tpu(a, *, full_matrices, compute_uv):
s = fn(a)
return [s]

def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv):

def _svd_tpu_lowering_rule(
ctx, operand, *, full_matrices, compute_uv, subset_by_index
):
operand_aval, = ctx.avals_in
m, n = operand_aval.shape[-2:]

if m == 0 or n == 0:
return mlir.lower_fun(_empty_svd, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
ctx,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
)

return mlir.lower_fun(_svd_tpu, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
ctx,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)


def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv):
def _svd_batching_rule(
batched_args, batch_dims, *, full_matrices, compute_uv, subset_by_index
):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
outs = svd_p.bind(
x,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)

if compute_uv:
return outs, (0, 0, 0)
else:
return outs, (0,)


svd_p = Primitive('svd')
svd_p.multiple_results = True
svd_p.def_impl(_svd_impl)
Expand Down

0 comments on commit 7e7094c

Please sign in to comment.