Skip to content

Commit

Permalink
Merge pull request #436 from hawkinsp/master
Browse files Browse the repository at this point in the history
Move masking of lower triangle of cholesky into Python code.
  • Loading branch information
hawkinsp committed Feb 23, 2019
2 parents e2681ab + d6b6514 commit f08c3b7
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 50 deletions.
4 changes: 2 additions & 2 deletions jax/lax_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
def cholesky(x, symmetrize_input=True):
if symmetrize_input:
x = symmetrize(x)
return cholesky_p.bind(x)
return np.tril(cholesky_p.bind(x))

def eigh(x, lower=True, symmetrize_input=True):
if symmetrize_input:
Expand Down Expand Up @@ -80,7 +80,7 @@ def symmetrize(x): return (x + _H(x)) / 2
def cholesky_jvp_rule(primals, tangents):
x, = primals
sigma_dot, = tangents
L = cholesky_p.bind(x)
L = np.tril(cholesky_p.bind(x))

# Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
phi = lambda X: np.tril(X) / (1 + np.eye(X.shape[-1], dtype=X.dtype))
Expand Down
48 changes: 0 additions & 48 deletions jaxlib/lapack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -333,18 +333,6 @@ cdef void lapack_spotrf(void* out_tuple, void** data) nogil:

spotrf(&uplo, &n, a_out, &n, info)

# spotrf leaves junk in the part of the triangle that is not written; zero it.
cdef int i
cdef int j
if lower:
for i in range(n):
for j in range(i):
a_out[i * n + j] = 0
else:
for i in range(n):
for j in range(i, n):
a_out[i * n + j] = 0

register_cpu_custom_call_target(b"lapack_spotrf", <void*>(lapack_spotrf))


Expand All @@ -362,18 +350,6 @@ cdef void lapack_dpotrf(void* out_tuple, void** data) nogil:

dpotrf(&uplo, &n, a_out, &n, info)

# dpotrf leaves junk in the part of the triangle that is not written; zero it.
cdef int i
cdef int j
if lower:
for i in range(n):
for j in range(i):
a_out[i * n + j] = 0
else:
for i in range(n):
for j in range(i, n):
a_out[i * n + j] = 0

register_cpu_custom_call_target(b"lapack_dpotrf", <void*>(lapack_dpotrf))


Expand All @@ -391,18 +367,6 @@ cdef void lapack_cpotrf(void* out_tuple, void** data) nogil:

cpotrf(&uplo, &n, a_out, &n, info)

# cpotrf leaves junk in the part of the triangle that is not written; zero it.
cdef int i
cdef int j
if lower:
for i in range(n):
for j in range(i):
a_out[i * n + j] = 0
else:
for i in range(n):
for j in range(i, n):
a_out[i * n + j] = 0

register_cpu_custom_call_target(b"lapack_cpotrf", <void*>(lapack_cpotrf))

cdef void lapack_zpotrf(void* out_tuple, void** data) nogil:
Expand All @@ -419,18 +383,6 @@ cdef void lapack_zpotrf(void* out_tuple, void** data) nogil:

zpotrf(&uplo, &n, a_out, &n, info)

# zpotrf leaves junk in the part of the triangle that is not written; zero it.
cdef int i
cdef int j
if lower:
for i in range(n):
for j in range(i):
a_out[i * n + j] = 0
else:
for i in range(n):
for j in range(i, n):
a_out[i * n + j] = 0

register_cpu_custom_call_target(b"lapack_zpotrf", <void*>(lapack_zpotrf))

def jax_potrf(c, a, lower=False):
Expand Down

0 comments on commit f08c3b7

Please sign in to comment.