value_and_grad
of pytrees
#18173
-
Typically, we take the Here's a simple example which complains about the def simple_losses(x):
return {
"a" : jnp.sum(x),
"b" : jnp.sum(x ** 2),
"c" : jnp.sum(x ** 3),
}
x = 2 * jnp.ones((5,))
val_and_grad = jax.value_and_grad(simple_loss)
print(val_and_grad(x)) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
You could do this directly with grad_only = jax.jacrev(simple_loss)
print(grad_only(x)) There is no equivalent function that computes the value and jacobian in a single pass, primarily because it wouldn't save much computation in practice (see the discussion in #762) |
Beta Was this translation helpful? Give feedback.
-
Here you could also do the following: def simple_loss_with_aux(x):
y = simple_loss(x)
return y, y
jac, value = jax.jacrev(simple_loss_with_aux, has_aux=True)(x) |
Beta Was this translation helpful? Give feedback.
You could do this directly with
jax.jacrev
:There is no equivalent function that computes the value and jacobian in a single pass, primarily because it wouldn't save much computation in practice (see the discussion in #762)