Consider using autograd.make_vjp()
instead of ceviche.jacobian()
#1
Labels
enhancement
New feature or request
autograd.make_vjp()
instead of ceviche.jacobian()
#1
Hey there, cool package! I wanted to suggest an alternative implementation that may be more efficient.
It looks like
javiche
is currently usingceviche
'sjacobian()
method, which (I think) may be less efficient if@javiche.jaxit
is ever applied to a function whose output is not scalar-valued. The reason is thatceviche
'sjacobian()
method loops over the output basis vectors to construct the Jacobian, and you typically don't need the explicit Jacobian when calculating VJPs. An example of where this would matter is if you just applied@jaxit
to a function that returns the field distribution, but performed the loss function calculation in terms of the field distribution in JAX, callingjax.grad
orjax.value_and_grad
on the combination.Below is a sketch of a more direct approach that maps autograd's
make_vjp()
function to JAX's VJP mechanism.Given an autograd function,
f_ag(*args) -> np.ndarray
we can wrap it into a function,f(*args) -> jnp.ndarray
. This is not ceviche-specific; it can be used for any autograd function with multiple inputs (*args
) and a single array output, though it could be generalized to support multiple array outputs as well.I was thinking about adding something like this to ceviche_challenges, but had not gotten around to it.
The text was updated successfully, but these errors were encountered: