Skip to content

value_and_grad of pytrees #18173

Answered by jakevdp
smartalecH asked this question in Q&A
Oct 18, 2023 · 2 comments · 2 replies
Discussion options

You must be logged in to vote

You could do this directly with jax.jacrev:

grad_only = jax.jacrev(simple_loss)
print(grad_only(x))

There is no equivalent function that computes the value and jacobian in a single pass, primarily because it wouldn't save much computation in practice (see the discussion in #762)

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@smartalecH
Comment options

Answer selected by smartalecH
Comment options

You must be logged in to vote
1 reply
@smartalecH
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants