# JVP and VJP

In [2]:
import jax
import jax.numpy as jnp

def fn(x):
    u=jnp.array([
        x[0]**6 * x[1]**4 * x[2]**9 * x[3]**2,
        x[0]**2 * x[1]**3 * x[2]**5 * x[3]**3,
        x[0]**5 * x[1]**7 * x[2]**7 * x[3]**6,
    ])
    return u

x = jnp.array([1.0, 0.5, 1.5, 2.0])
print(f"fn(x) = {x.shape} \n{fn(x)}\n")

##############################################################
print("## JVP ##")

# jacovian: df/dx
full_jacobian = jax.jacfwd(fn)(x)
print(f"full_jacobian = {full_jacobian.shape} \n{full_jacobian}\n")

# JVP: df/dx @ v
## 1.
v = jnp.array([0.2, 0.3, 0.4, 0.8])
print(f"v = {v.shape} \n{v}\n")
jvp_ = full_jacobian @ v
print(f"jvp_ = {jvp_.shape} \n{jvp_}\n")

## 2.
f_evaluated, jvp_evaluated = jax.jvp(fn, (x,), (v,))
print(f"f_evaluated = {f_evaluated.shape} \n{f_evaluated}\n")
print(f"jvp_evaluated = {jvp_evaluated.shape} \n{jvp_evaluated}\n")

##############################################################
print("## VJP ##")

# jacovian: df/dx
full_jacobian = jax.jacrev(fn)(x)
print(f"full_jacobian = {full_jacobian.shape} \n{full_jacobian}\n")

# VJP: df/dx @ v
## 1.
v = jnp.array([0.5, 0.8, 1.0])
print(f"v = {v.shape} \n{v}\n")
vjp_ = v.T @ full_jacobian
print(f"vjp_ = {vjp_.shape} \n{vjp_}\n")

## 2.
f_evaluated, vjp_fn = jax.vjp(fn, x)
print(f"f_evaluated = {f_evaluated.shape} \n{f_evaluated}\n")
vjp = vjp_fn(v)[0]
print(f"vjp = {vjp.shape} \n{vjp}\n")


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


fn(x) = (4,) 
[9.61084  7.59375  8.542969]

## JVP ##
full_jacobian = (3, 4) 
[[ 57.66504   76.88672   57.66504    9.61084 ]
 [ 15.1875    45.5625    25.3125    11.390625]
 [ 42.714844 119.60156   39.867188  25.628906]]

v = (4,) 
[0.2 0.3 0.4 0.8]

jvp_ = (3,) 
[65.353714 35.943752 80.87344 ]

f_evaluated = (3,) 
[9.61084  7.59375  8.542969]

jvp_evaluated = (3,) 
[65.353714 35.943752 80.87344 ]

## VJP ##
full_jacobian = (3, 4) 
[[ 57.66504   76.88672   57.66504    9.61084 ]
 [ 15.1875    45.5625    25.3125    11.390625]
 [ 42.714844 119.60156   39.867188  25.628906]]

v = (3,) 
[0.5 0.8 1. ]

vjp_ = (4,) 
[ 83.697365 194.49492   88.94971   39.546825]

f_evaluated = (3,) 
[9.61084  7.59375  8.542969]

vjp = (4,) 
[ 83.697365 194.49492   88.94971   39.546825]



In [3]:
import torch

def fn(x):
    return torch.stack([
        x[0]**6 * x[1]**4 * x[2]**9 * x[3]**2,
        x[0]**2 * x[1]**3 * x[2]**5 * x[3]**3,
        x[0]**5 * x[1]**7 * x[2]**7 * x[3]**6,
    ])

x = torch.tensor([1.0, 0.5, 1.5, 2.0], requires_grad=True)
##############################################################
print("## JVP ##")
v = torch.tensor([0.2, 0.3, 0.4, 0.8])
func_output, jvp = torch.autograd.functional.jvp(fn, x, v)
print(f"func_output = {func_output.shape} \n{func_output}\n")
print(f"jvp = {jvp.shape} \n{jvp}\n")

##############################################################
print("## VJP ##")
v = torch.tensor([0.5, 0.8, 1.0])
func_output, vjp = torch.autograd.functional.vjp(fn, x, v)
print(f"func_output = {func_output.shape} \n{func_output}\n")
print(f"vjp = {vjp.shape} \n{vjp}\n")

## JVP ##
func_output = torch.Size([3]) 
tensor([9.6108, 7.5938, 8.5430])

jvp = torch.Size([3]) 
tensor([65.3537, 35.9438, 80.8734])

## VJP ##
func_output = torch.Size([3]) 
tensor([9.6108, 7.5938, 8.5430])

vjp = torch.Size([4]) 
tensor([ 83.6974, 194.4949,  88.9497,  39.5468])



In [4]:
import tensorflow as tf

def fn(x):
    return tf.stack([
        x[0]**6 * x[1]**4 * x[2]**9 * x[3]**2,
        x[0]**2 * x[1]**3 * x[2]**5 * x[3]**3,
        x[0]**5 * x[1]**7 * x[2]**7 * x[3]**6,
    ])

x = tf.Variable([1.0, 0.5, 1.5, 2.0], dtype=tf.float32)
##############################################################
print("## JVP ##")
v = tf.constant([0.2, 0.3, 0.4, 0.8], dtype=tf.float32)
with tf.autodiff.ForwardAccumulator(primals=x, tangents=v) as acc:
    y= fn(x)
jvp = acc.jvp(y)
print(f"y = {y.shape} \n{y}\n")
print(f"jvp = {jvp.shape} \n{jvp}\n")

##############################################################
print("## VJP ##")
v = tf.constant([0.5, 0.8, 1.0], dtype=tf.float32)
with tf.GradientTape() as tape:
    tape.watch(x)
    y = fn(x)
vjp = tape.gradient(y, x, output_gradients=v)
print(f"y = {y.shape} \n{y}\n")
print(f"vjp = {vjp.shape} \n{vjp}\n")




## JVP ##
y = (3,) 
[9.61084  7.59375  8.542969]

jvp = (3,) 
[65.353714 35.943752 80.87344 ]

## VJP ##
y = (3,) 
[9.61084  7.59375  8.542969]

vjp = (4,) 
[ 83.697365 194.49492   88.94971   39.546825]



In [5]:
x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
targets = tf.constant([[1.], [-1.]])
dense = tf.keras.layers.Dense(1)
dense.build([None, 2])
with tf.autodiff.ForwardAccumulator(
   primals=dense.kernel,
   tangents=tf.constant([[1.], [0.]])) as acc:
  loss = tf.reduce_sum((dense(x) - targets) ** 2.)
print(acc.jvp(loss))

tf.Tensor(-1.7002447, shape=(), dtype=float32)


# JVP and VJP with model input

In [6]:
x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
targets = tf.constant([[1.], [-1.]])
dense = tf.keras.layers.Dense(1)
dense.build([None, 2])
with tf.autodiff.ForwardAccumulator(
   primals=dense.kernel,
   tangents=tf.constant([[1.], [0.]])) as acc:
  loss = tf.reduce_sum((dense(x) - targets) ** 2.)
print(acc.jvp(loss))


print("dense.kernel")
print(dense.kernel)
print()
print("tf.constant([[1.], [0.]])")
print(tf.constant([[1.], [0.]]))
print()
print(x)
tangents = tf.constant([[0., 0], [0., 1.]])
with tf.autodiff.ForwardAccumulator(
   primals=x,
   tangents=tangents
  ) as acc:
  loss = tf.reduce_sum((dense(x) - targets) ** 2.)
print(acc.jvp(loss))



tf.Tensor(11.96414, shape=(), dtype=float32)
dense.kernel
<tf.Variable 'kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.3399825],
       [ 1.3681983]], dtype=float32)>

tf.constant([[1.], [0.]])
tf.Tensor(
[[1.]
 [0.]], shape=(2, 1), dtype=float32)

tf.Tensor(
[[2. 3.]
 [1. 4.]], shape=(2, 2), dtype=float32)
tf.Tensor(14.045405, shape=(), dtype=float32)
