Skip to content

Commit

Permalink
adds jax.scipy.schur
Browse files Browse the repository at this point in the history
  • Loading branch information
SaturdayGenfo committed Feb 16, 2022
1 parent 2ae10ea commit 514d888
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
13 changes: 12 additions & 1 deletion jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,18 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
del overwrite_a, overwrite_b, turbo, check_finite
return _eigh(a, b, lower, eigvals_only, eigvals, type)


@partial(jit, static_argnames=('output',))
def _schur(a, output):
if output == "complex":
a = a.astype(jnp.result_type(a.dtype, 0j))
return lax_linalg.schur(a)

@_wraps(scipy.linalg.schur)
def schur(a, output='real'):
if output not in ('real', 'complex'):
raise ValueError(
"Expected 'output' to be either 'real' or 'complex', got output={}.".format(output))
return _schur(a, output)

@_wraps(scipy.linalg.inv)
def inv(a, overwrite_a=False, check_finite=True):
Expand Down
1 change: 1 addition & 0 deletions jax/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
cho_solve as cho_solve,
det as det,
eigh as eigh,
schur as schur,
eigh_tridiagonal as eigh_tridiagonal,
expm as expm,
expm_frechet as expm_frechet,
Expand Down
14 changes: 14 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,20 @@ def expm(x):
return jsp.linalg.expm(x, upper_triangular=False, max_squarings=16)
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
rtol=tol)
@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype
} for shape in [(4, 4), (15, 15), (50, 50), (100, 100)]
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu", "tpu")
def testSchur(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]

self._CheckAgainstNumpy(osp.linalg.schur, jsp.linalg.schur, args_maker)
self._CompileAndCheck(jsp.linalg.schur, args_maker)

@jtu.with_config(jax_numpy_rank_promotion='raise')
class LaxLinalgTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 514d888

Please sign in to comment.