Difference between VJP and JVP functions in JAX #10271
-
Hello JAX Community; I am confused about VJP and JVP functions introduced by JAX. While |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 10 replies
-
Because |
Beta Was this translation helpful? Give feedback.
-
Have you already checked out the Autodiff cookbook? I think there are two parts to this question:
For the first part: that's just the mathematical definition of what a VJP computes, along with how the chain rule works (i.e. the rule for how differentiation interacts with function composition). If you have a function f = g ∘ h, where ∘ means composition, i.e. f(x) = g(h(x)), then ∂ f(x) = ∂g(h(x)) ∘ ∂h(x), where the ∘ is composing linear maps or, if you like, multiplying matrices. So if we want to compute a vector-Jacobian product v ↦ v' ∂ f(x) then we can decompose it in terms of the constituent functions as ((v' ∂g(h(x))) ∂h(x)), which looks like evaluating a sequence of two simpler VJPs in the reverse order from how we would evaluate f(x) for some x. As for the asymmetry between the In a bit more detail, the "dual" to jvp : (a -> b) -> (a, T a) -> (b, T b)
linearize: (a -> b) -> a -> (b, T a -o T b)
vjp : (a -> b) -> a -> (b, CT b -o CT a) Even if type signatures aren't your thing, you can see that the The relationship between y, y_dot = jax.jvp(f, (x,), (x_dot,))
y2, f_lin = jax.linearize(f, x)
y_dot2 = f_lin(x_dot)
assert jnp.all(y == y2) # or at least all close, up to floating point differences
assert jnp.all(y_dot == y_dot2) In other words, we've just computed the same thing in two ways. But the The relationship with y3, f_vjp = jax.vjp(f, x)
x_bar, = f_vjp(y_bar)
assert jnp.all(y == y3)
assert jnp.vdot(x_bar, x_dot) == jnp.vdot(y_bar, y_dot) In math notation, the last line is saying we've computed a quantity like u' ∂f(x) v in two ways, once like (u' ∂f(x)) v and the other time like u' (∂f(x) v). Maybe some last clues are to look at how each of these would look for a primitive operation like def jvp(sin)(x, xdot):
y = sin(x)
ydot = cos(x) * xdot
return y, ydot
def lin(sin)(x):
y = sin(x)
cos_x = cos(x)
sin_lin = lambda xdot: cos_x * xdot
return y, sin_lin
def vjp(sin)(x):
y = sin(x)
cos_x = cos(x)
sin_vjp = lambda ybar: ybar * cos_x
return y, sin_vjp From this you might be able to see And for generic function composition: def jvp(g ∘ h)(x, xdot):
y, ydot = jvp(h)(x, xdot)
z, zdot = jvp(g)(y, ydot)
return z, zdot
def lin(g ∘ h)(x):
y, h_lin = lin(h)(x)
z, g_lin = lin(g)(y)
return z, lambda xdot: g_lin(h_lin(xdot))
def vjp(g ∘ h)(x):
y, h_vjp = vjp(h)(x)
z, g_vjp = vjp(g)(y)
return z, lambda ybar: h_vjp(g_vjp(ybar)) Hope these clues help! |
Beta Was this translation helpful? Give feedback.
Have you already checked out the Autodiff cookbook?
I think there are two parts to this question:
jax.vjp
return a function that performs the multiplication in reverse order?jax.jvp
andjax.vjp
, where in particular the latter returns a function and the former does not?For the first part: that's just the mathematical definition of what a VJP computes, along with how the chain rule works (i.e. the rule for how differentiation interacts with function composition). If you have a function f = g ∘ h, where ∘ means composition, i.e. f(x) = g(h(x)), then ∂ f(x) = ∂g(h(x)) ∘ ∂h(x), where the ∘ is composing linear maps or, if you like, multip…