Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient of sqrtm #16579

Closed
rkruegs123 opened this issue Jun 28, 2023 · 1 comment · Fixed by #16600
Closed

Gradient of sqrtm #16579

rkruegs123 opened this issue Jun 28, 2023 · 1 comment · Fixed by #16600
Assignees
Labels
bug Something isn't working

Comments

@rkruegs123
Copy link

Description

I require taking gradients with respect to the square root of a matrix. So, I have turned to jax.scipy.linalg.sqrtm. Based on previous discussions and issues on this repo, I understand that it is only implemented on CPU. I can accept this for now -- I can always do a callback when necessary and just eat the computational overhead.

But, I am getting an error when trying to calculate gradients w.r.t. this operation. For example, see the following minimal example:

import jax.numpy as jnp
from jax.scipy.linalg import sqrtm
from jax import grad

arr = jnp.ones((2, 2))
sqrt_arr = sqrtm(arr) # This works
grad_sqrt_arr = grad(sqrtm)(arr) # This does not work

This yields the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 935, in sqrtm
    return _sqrtm(A)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 915, in _sqrtm
    T, Z = schur(A, output='complex')
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 206, in schur
    return _schur(a, output)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 199, in _schur
    return lax_linalg.schur(a)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: _schur_jvp_rule() got an unexpected keyword argument 'select_callable'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/api.py", line 647, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/api.py", line 723, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/api.py", line 2208, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 139, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 128, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 935, in sqrtm
    return _sqrtm(A)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 315, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 1465, in _pjit_jvp
    jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 699, in jvp_jaxpr
    return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 709, in _jvp_jaxpr
    jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 229, in jaxpr_as_fun
    return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 448, in eval_jaxpr
    ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 315, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 1465, in _pjit_jvp
    jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 699, in jvp_jaxpr
    return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 709, in _jvp_jaxpr
    jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 229, in jaxpr_as_fun
    return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 448, in eval_jaxpr
    ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 380, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 315, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: _schur_jvp_rule() got an unexpected keyword argument 'select_callable'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 935, in sqrtm
    return _sqrtm(A)
TypeError: _schur_jvp_rule() got an unexpected keyword argument 'select_callable'

What jax/jaxlib version are you using?

jax v0.4.13, jaxlib v0.4.13

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.9.7, Ubuntu 18.04.6

NVIDIA GPU info

No response

@rkruegs123 rkruegs123 added the bug Something isn't working label Jun 28, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 28, 2023

Thanks for the report. There is indeed a missing argument in the schur JVP implementation, but unfortunately all it's preventing is a NotImplementedError: https://github.com/google/jax/blob/f463437c7ee81f915d2404d302b6bc5b32ecffbe/jax/_src/lax/linalg.py#L2113-L2116

I don't think this is trival to implement unfortunately; see some related discussion at #669.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants