You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I require taking gradients with respect to the square root of a matrix. So, I have turned to jax.scipy.linalg.sqrtm. Based on previous discussions and issues on this repo, I understand that it is only implemented on CPU. I can accept this for now -- I can always do a callback when necessary and just eat the computational overhead.
But, I am getting an error when trying to calculate gradients w.r.t. this operation. For example, see the following minimal example:
importjax.numpyasjnpfromjax.scipy.linalgimportsqrtmfromjaximportgradarr=jnp.ones((2, 2))
sqrt_arr=sqrtm(arr) # This worksgrad_sqrt_arr=grad(sqrtm)(arr) # This does not work
This yields the following error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 935, in sqrtm
return _sqrtm(A)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 915, in _sqrtm
T, Z = schur(A, output='complex')
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 206, in schur
return _schur(a, output)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 199, in _schur
return lax_linalg.schur(a)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: _schur_jvp_rule() got an unexpected keyword argument 'select_callable'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/api.py", line 647, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/api.py", line 723, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/api.py", line 2208, in _vjp
out_primal, out_vjp = ad.vjp(
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 139, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 128, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 935, in sqrtm
return _sqrtm(A)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 315, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 1465, in _pjit_jvp
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 699, in jvp_jaxpr
return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 709, in _jvp_jaxpr
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 229, in jaxpr_as_fun
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 448, in eval_jaxpr
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 315, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 1465, in _pjit_jvp
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 699, in jvp_jaxpr
return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 709, in _jvp_jaxpr
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 229, in jaxpr_as_fun
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 448, in eval_jaxpr
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 315, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: _schur_jvp_rule() got an unexpected keyword argument 'select_callable'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/ryan/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 935, in sqrtm
return _sqrtm(A)
TypeError: _schur_jvp_rule() got an unexpected keyword argument 'select_callable'
What jax/jaxlib version are you using?
jax v0.4.13, jaxlib v0.4.13
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.9.7, Ubuntu 18.04.6
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered:
Description
I require taking gradients with respect to the square root of a matrix. So, I have turned to
jax.scipy.linalg.sqrtm
. Based on previous discussions and issues on this repo, I understand that it is only implemented on CPU. I can accept this for now -- I can always do a callback when necessary and just eat the computational overhead.But, I am getting an error when trying to calculate gradients w.r.t. this operation. For example, see the following minimal example:
This yields the following error:
What jax/jaxlib version are you using?
jax v0.4.13, jaxlib v0.4.13
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.9.7, Ubuntu 18.04.6
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: