Skip to content

Commit

Permalink
Remove use of xla.lower_fun in SVD translation rule.
Browse files Browse the repository at this point in the history
This is the only use of xla.lower_fun that is still needed (as a fallback) when the non-MHLO path is removed.

PiperOrigin-RevId: 441538472
  • Loading branch information
hawkinsp authored and jax authors committed Apr 13, 2022
1 parent eb43071 commit 21f95d5
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,15 +1477,29 @@ def svd_impl(operand, full_matrices, compute_uv):
return xla.apply_primitive(svd_p, operand, full_matrices=full_matrices,
compute_uv=compute_uv)


def _zeros_like_xla(c, aval):
zero = xops.Constant(c, np.array(0, aval.dtype))
return xops.Broadcast(zero, aval.shape)

def _eye_like_xla(c, aval):
iota_shape = xla_client.Shape.array_shape(
xla.dtype_to_primitive_type(np.dtype(np.int32)), aval.shape)
x = xops.Eq(xops.Iota(c, iota_shape, len(aval.shape) - 1),
xops.Iota(c, iota_shape, len(aval.shape) - 2))
return xops.ConvertElementType(x, xla.dtype_to_primitive_type(aval.dtype))

def _svd_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices,
compute_uv):
operand_aval, = avals_in
shape = operand_aval.shape
m, n = shape[-2:]
if m == 0 or n == 0:
return xla.lower_fun(_empty_svd, multiple_results=True, new_style=True)(
ctx, avals_in, avals_out, operand, full_matrices=full_matrices,
compute_uv=compute_uv)
out = [_zeros_like_xla(ctx.builder, avals_out[0])]
if compute_uv:
out.append(_eye_like_xla(ctx.builder, avals_out[1]))
out.append(_eye_like_xla(ctx.builder, avals_out[2]))
return out

u, s, v = xops.SVD(operand)
permutation = list(range(len(shape)))
Expand Down

0 comments on commit 21f95d5

Please sign in to comment.