Skip to content

Commit

Permalink
remove _const from public jax.lax module
Browse files Browse the repository at this point in the history
Modify all internal call sites to use `jax._src.lax.lax._const`.
  • Loading branch information
froystig committed Mar 7, 2022
1 parent 03a50c0 commit f7731bf
Show file tree
Hide file tree
Showing 23 changed files with 171 additions and 138 deletions.
10 changes: 6 additions & 4 deletions jax/_src/lax/linalg.py
Expand Up @@ -323,8 +323,9 @@ def cholesky_jvp_rule(primals, tangents):
# Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
def phi(X):
l = jnp.tril(X)
return l / lax.expand_dims(lax._const(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype),
range(l.ndim - 2))
return l / lax.expand_dims(
lax_internal._const(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype),
range(l.ndim - 2))

tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True,
conjugate_a=True, lower=True)
Expand Down Expand Up @@ -991,15 +992,16 @@ def _lu_jvp_rule(primals, tangents):
ndims = len(a_shape)
l_padding = [(0, 0, 0)] * ndims
l_padding[-1] = (0, m - k, 0)
zero = lax._const(lu, 0)
zero = lax_internal._const(lu, 0)
l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding)
l = l + lax.expand_dims(jnp.eye(m, m, dtype=dtype), range(l.ndim - 2))

u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero,
((k, 0, 0), (k, 0, 0)))
u_padding = [(0, 0, 0)] * ndims
u_padding[-2] = (0, n - k, 0)
u = lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + lax.expand_dims(u_eye, range(lu.ndim - 2))
u = (lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) +
lax.expand_dims(u_eye, range(lu.ndim - 2)))

la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True,
unit_diagonal=True)
Expand Down

0 comments on commit f7731bf

Please sign in to comment.