Skip to content

Commit

Permalink
Add jax.scipy.linalg.funm
Browse files Browse the repository at this point in the history
  • Loading branch information
ajcr committed May 2, 2022
1 parent 44006c7 commit 372371c
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/jax.scipy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jax.scipy.linalg
eigh_tridiagonal
expm
expm_frechet
funm
inv
lu
lu_factor
Expand Down
77 changes: 77 additions & 0 deletions jax/_src/third_party/scipy/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import scipy.linalg

from jax import jit, lax
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.linalg import norm
from jax._src.numpy.util import _wraps
from jax._src.scipy.linalg import rsf2csf, schur

@jit
def _algorithm_11_1_1(F, T):
# Algorithm 11.1.1 from Golub and Van Loan "Matrix Computations"
N = T.shape[0]
minden = jnp.abs(T[0, 0])

def _outer_loop(p, F_minden):
_, F, minden = lax.fori_loop(1, N-p+1, _inner_loop, (p, *F_minden))
return F, minden

def _inner_loop(i, p_F_minden):
p, F, minden = p_F_minden
j = i+p
s = T[i-1, j-1] * (F[j-1, j-1] - F[i-1, i-1])
T_row, T_col = T[i-1], T[:, j-1]
F_row, F_col = F[i-1], F[:, j-1]
ind = (jnp.arange(N) >= i) & (jnp.arange(N) < j-1)
val = (jnp.where(ind, T_row, 0) @ jnp.where(ind, F_col, 0) -
jnp.where(ind, F_row, 0) @ jnp.where(ind, T_col, 0))
s = s + val
den = T[j-1, j-1] - T[i-1, i-1]
s = jnp.where(den != 0, s / den, s)
F = F.at[i-1, j-1].set(s)
minden = jnp.minimum(minden, jnp.abs(den))
return p, F, minden

return lax.fori_loop(1, N, _outer_loop, (F, minden))

_FUNM_LAX_DESCRIPTION = """\
The array returned by :py:func:`jax.scipy.linalg.funm` may differ in dtype
from the array returned by py:func:`scipy.linalg.funm`. Specifically, in cases
where all imaginary parts of the array values are close to zero, the SciPy
function may return a real-valued array, whereas the JAX implementation will
return a complex-valued array.
Additionally, unlike the SciPy implementation, when ``disp=True`` no warning
will be printed if the error in the array output is estimated to be large.
"""

@_wraps(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION)
def funm(A, func, disp=True):
A = jnp.asarray(A)
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError('expected square array_like input')

T, Z = schur(A)
T, Z = rsf2csf(T, Z)

F = jnp.diag(func(jnp.diag(T)))
F = F.astype(T.dtype.char)

F, minden = _algorithm_11_1_1(F, T)
F = Z @ F @ Z.conj().T

if disp:
return F

if F.dtype.char.lower() == 'e':
tol = jnp.finfo(jnp.float16).eps
if F.dtype.char.lower() == 'f':
tol = jnp.finfo(jnp.float32).eps
else:
tol = jnp.finfo(jnp.float64).eps

minden = jnp.where(minden == 0.0, tol, minden)
err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum(
tol, (tol / minden) * norm(jnp.triu(T, 1), 1))))

return F, err
4 changes: 4 additions & 0 deletions jax/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@
from jax._src.lax.polar import (
polar_unitary as polar_unitary,
)

from jax._src.third_party.scipy.linalg import (
funm as funm,
)
22 changes: 22 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,28 @@ def testRsf2csf(self, shape, dtype):
args_maker, tol=tol)
self._CompileAndCheck(jsp.linalg.rsf2csf, args_maker)

@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
"_shape={}_disp={}".format(jtu.format_shape_dtype_string(shape, dtype), disp),
"shape": shape, "dtype": dtype, "disp": disp
} for shape in [(1, 1), (5, 5), (20, 20), (50, 50)]
for dtype in float_types + complex_types
for disp in [True, False]))
# funm uses jax.scipy.linalg.schur which is implemented for a CPU
# backend only, so tests on GPU and TPU backends are skipped here
@jtu.skip_on_devices("gpu", "tpu")
def testFunm(self, shape, dtype, disp):
def func(x):
return x**-2.718
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = lambda arr: jsp.linalg.funm(arr, func, disp=disp)
scp_fun = lambda arr: osp.linalg.funm(arr, func, disp=disp)
self._CheckAgainstNumpy(jnp_fun, scp_fun, args_maker, check_dtypes=False,
tol={np.complex64: 1e-5, np.complex128: 1e-6})
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
Expand Down

0 comments on commit 372371c

Please sign in to comment.