Skip to content

Difference between VJP and JVP functions in JAX #10271

Answered by mattjj
GoktugGuvercin asked this question in Q&A
Discussion options

You must be logged in to vote

Have you already checked out the Autodiff cookbook?

I think there are two parts to this question:

  1. why does jax.vjp return a function that performs the multiplication in reverse order?
  2. why the asymmetry between the signatures of jax.jvp and jax.vjp, where in particular the latter returns a function and the former does not?

For the first part: that's just the mathematical definition of what a VJP computes, along with how the chain rule works (i.e. the rule for how differentiation interacts with function composition). If you have a function f = g ∘ h, where ∘ means composition, i.e. f(x) = g(h(x)), then ∂ f(x) = ∂g(h(x)) ∘ ∂h(x), where the ∘ is composing linear maps or, if you like, multip…

Replies: 2 comments 10 replies

Comment options

You must be logged in to vote
3 replies
@GoktugGuvercin
Comment options

@GoktugGuvercin
Comment options

@YouJiacheng
Comment options

Comment options

You must be logged in to vote
7 replies
@mattjj
Comment options

@mattjj
Comment options

@AntixK
Comment options

@frhack
Comment options

@mattjj
Comment options

Answer selected by GoktugGuvercin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
useful read PR or issue that contains useful design discussion
5 participants