lax.cond
sometimes inserts a nonlinear lax.stop_gradient
into its JVP rule.
#22011
Labels
bug
Something isn't working
Description
I finally have a MWE for an intermittent issue I've been seeing for months!
First of all the root cause: the
lax.stop_gradient
on this line:jax/jax/_src/lax/control_flow/conditionals.py
Line 388 in 84d748f
is being applied unconditionally to all operands. However, if we've been linearised first, then some of those operands may be tangents -- for which only linear operations are valid! Note that JAX treats
lax.stop_gradient
as nonlinear, and does not offer a transpose rule for it.Thus the following MWE:
crashes at trace-time with:
Huge thanks to @dkweiss31 over in patrick-kidger/diffrax#387 for having enough of a MWE that I was able to isolate it down to this.
Tagging @mattjj as my guess as being the likely person for this kind of composition-of-transforms stuff.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: