# JVP and VJP

In [2]:
import jax
import jax.numpy as jnp


def fn(x):
    u = jnp.array(
        [
            x[0] ** 6 * x[1] ** 4 * x[2] ** 9 * x[3] ** 2,
            x[0] ** 2 * x[1] ** 3 * x[2] ** 5 * x[3] ** 3,
            x[0] ** 5 * x[1] ** 7 * x[2] ** 7 * x[3] ** 6,
        ]
    )
    return u


x = jnp.array([1.0, 0.5, 1.5, 2.0])
print(f"fn(x) = {x.shape} \n{fn(x)}\n")

##############################################################
print("## JVP ##")

# jacovian: df/dx
full_jacobian = jax.jacfwd(fn)(x)
print(f"full_jacobian = {full_jacobian.shape} \n{full_jacobian}\n")

# JVP: df/dx @ v
## 1.
v = jnp.array([0.2, 0.3, 0.4, 0.8])
print(f"v = {v.shape} \n{v}\n")
jvp_ = full_jacobian @ v
print(f"jvp_ = {jvp_.shape} \n{jvp_}\n")

## 2.
f_evaluated, jvp_evaluated = jax.jvp(fn, (x,), (v,))
print(f"f_evaluated = {f_evaluated.shape} \n{f_evaluated}\n")
print(f"jvp_evaluated = {jvp_evaluated.shape} \n{jvp_evaluated}\n")

##############################################################
print("## VJP ##")

# jacovian: df/dx
full_jacobian = jax.jacrev(fn)(x)
print(f"full_jacobian = {full_jacobian.shape} \n{full_jacobian}\n")

# VJP: df/dx @ v
## 1.
v = jnp.array([0.5, 0.8, 1.0])
print(f"v = {v.shape} \n{v}\n")
vjp_ = v.T @ full_jacobian
print(f"vjp_ = {vjp_.shape} \n{vjp_}\n")

## 2.
f_evaluated, vjp_fn = jax.vjp(fn, x)
print(f"f_evaluated = {f_evaluated.shape} \n{f_evaluated}\n")
vjp = vjp_fn(v)[0]
print(f"vjp = {vjp.shape} \n{vjp}\n")

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


fn(x) = (4,) 
[9.61084  7.59375  8.542969]

## JVP ##
full_jacobian = (3, 4) 
[[ 57.66504   76.88672   57.66504    9.61084 ]
 [ 15.1875    45.5625    25.3125    11.390625]
 [ 42.714844 119.60156   39.867188  25.628906]]

v = (4,) 
[0.2 0.3 0.4 0.8]

jvp_ = (3,) 
[65.353714 35.943752 80.87344 ]

f_evaluated = (3,) 
[9.61084  7.59375  8.542969]

jvp_evaluated = (3,) 
[65.353714 35.943752 80.87344 ]

## VJP ##
full_jacobian = (3, 4) 
[[ 57.66504   76.88672   57.66504    9.61084 ]
 [ 15.1875    45.5625    25.3125    11.390625]
 [ 42.714844 119.60156   39.867188  25.628906]]

v = (3,) 
[0.5 0.8 1. ]

vjp_ = (4,) 
[ 83.697365 194.49492   88.94971   39.546825]

f_evaluated = (3,) 
[9.61084  7.59375  8.542969]

vjp = (4,) 
[ 83.697365 194.49492   88.94971   39.546825]



In [3]:
import torch


def fn(x):
    return torch.stack(
        [
            x[0] ** 6 * x[1] ** 4 * x[2] ** 9 * x[3] ** 2,
            x[0] ** 2 * x[1] ** 3 * x[2] ** 5 * x[3] ** 3,
            x[0] ** 5 * x[1] ** 7 * x[2] ** 7 * x[3] ** 6,
        ]
    )


x = torch.tensor([1.0, 0.5, 1.5, 2.0], requires_grad=True)
##############################################################
print("## JVP ##")
v = torch.tensor([0.2, 0.3, 0.4, 0.8])
func_output, jvp = torch.autograd.functional.jvp(fn, x, v)
print(f"func_output = {func_output.shape} \n{func_output}\n")
print(f"jvp = {jvp.shape} \n{jvp}\n")


##############################################################
print("## JVP with torch.func.jvp ##")
v = torch.tensor([0.2, 0.3, 0.4, 0.8])
func_output, jvp = torch.func.jvp(fn, (x,), (v,))
print(f"func_output = {func_output.shape} \n{func_output}\n")
print(f"jvp = {jvp.shape} \n{jvp}\n")


##############################################################
print("## VJP ##")
v = torch.tensor([0.5, 0.8, 1.0])
func_output, vjp = torch.autograd.functional.vjp(fn, x, v)
print(f"func_output = {func_output.shape} \n{func_output}\n")
print(f"vjp = {vjp.shape} \n{vjp}\n")

##############################################################
print("## VJP with torch.func.vjp ##")
v = torch.tensor([0.5, 0.8, 1.0])
func_output, vjp_fn = torch.func.vjp(fn, x)
vjp = vjp_fn(v)[0]
print(f"func_output = {func_output.shape} \n{func_output}\n")
print(f"vjp = {vjp.shape} \n{vjp}\n")

##############################################################
print("## VJP with torch.autograd.grad ##")
v = torch.tensor([0.5, 0.8, 1.0])
y = fn(x)
print(f"y = {y.shape} \n{y}\n")
grad = torch.autograd.grad(y, x, grad_outputs=v)[0]  # == vjp
print(f"grad = {grad.shape} \n{grad}\n")


print(f"func_output = {func_output.shape} \n{func_output}\n")
print(f"vjp = {vjp.shape} \n{vjp}\n")


# ##############################################################
print("## J ##")

v = torch.tensor([0.2, 0.3, 0.4, 0.8])
j_rev_fn = torch.func.jacrev(fn)
j = j_rev_fn(x)
print(f"j = {j.shape} \n{j}\n")
jvp = j @ v
print(f"jvp = {jvp.shape} \n{jvp}\n")

j_fwd_fn = torch.func.jacfwd(fn)
j = j_fwd_fn(x)
print(f"j = {j.shape} \n{j}\n")
v = torch.tensor([0.5, 0.8, 1.0])
print(f"v = {v.shape} \n{v}\n")
vjp = v.T @ j
print(f"vjp = {vjp.shape} \n{vjp}\n")

## JVP ##
func_output = torch.Size([3]) 
tensor([9.6108, 7.5938, 8.5430])

jvp = torch.Size([3]) 
tensor([65.3537, 35.9438, 80.8734])

## JVP with torch.func.jvp ##
func_output = torch.Size([3]) 
tensor([9.6108, 7.5938, 8.5430], grad_fn=<AliasBackward0>)

jvp = torch.Size([3]) 
tensor([65.3537, 35.9438, 80.8734], grad_fn=<StackBackward0>)

## VJP ##
func_output = torch.Size([3]) 
tensor([9.6108, 7.5938, 8.5430])

vjp = torch.Size([4]) 
tensor([ 83.6974, 194.4949,  88.9497,  39.5468])

## VJP with torch.func.vjp ##
func_output = torch.Size([3]) 
tensor([9.6108, 7.5938, 8.5430], grad_fn=<StackBackward0>)

vjp = torch.Size([4]) 
tensor([ 83.6974, 194.4949,  88.9497,  39.5468], grad_fn=<AddBackward0>)

## VJP with torch.autograd.grad ##
y = torch.Size([3]) 
tensor([9.6108, 7.5938, 8.5430], grad_fn=<StackBackward0>)

grad = torch.Size([4]) 
tensor([ 83.6974, 194.4949,  88.9497,  39.5468])

func_output = torch.Size([3]) 
tensor([9.6108, 7.5938, 8.5430], grad_fn=<StackBackward0>)

vjp = torch

  vjp = v.T @ j


In [4]:
import tensorflow as tf


def fn(x):
    return tf.stack(
        [
            x[0] ** 6 * x[1] ** 4 * x[2] ** 9 * x[3] ** 2,
            x[0] ** 2 * x[1] ** 3 * x[2] ** 5 * x[3] ** 3,
            x[0] ** 5 * x[1] ** 7 * x[2] ** 7 * x[3] ** 6,
        ]
    )


x = tf.Variable([1.0, 0.5, 1.5, 2.0], dtype=tf.float32)

##############################################################
print("## VJP ##")
v = tf.constant([0.5, 0.8, 1.0], dtype=tf.float32)
with tf.GradientTape() as tape:
    y = fn(x)
vjp = tape.gradient(y, x, output_gradients=v)
print(f"y = {y.shape} \n{y}\n")
print(f"vjp = {vjp.shape} \n{vjp}\n")

##############################################################
print("## JVP ##")
v = tf.constant([0.2, 0.3, 0.4, 0.8], dtype=tf.float32)
with tf.autodiff.ForwardAccumulator(primals=x, tangents=v) as acc:
    y = fn(x)
jvp = acc.jvp(y)
print(f"y = {y.shape} \n{y}\n")
print(f"jvp = {jvp.shape} \n{jvp}\n")



## VJP ##
y = (3,) 
[9.61084  7.59375  8.542969]

vjp = (4,) 
[ 83.697365 194.49492   88.94971   39.546825]

## JVP ##
y = (3,) 
[9.61084  7.59375  8.542969]

jvp = (3,) 
[65.353714 35.943752 80.87344 ]



In [5]:
# with model
x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
targets = tf.constant([[1.0], [-1.0]])
dense = tf.keras.layers.Dense(1)
dense.build([None, 2])
with tf.autodiff.ForwardAccumulator(primals=dense.kernel, tangents=tf.constant([[1.0], [0.0]])) as acc:
    loss = tf.reduce_sum((dense(x) - targets) ** 2.0)
print(acc.jvp(loss))

tf.Tensor(-13.69874, shape=(), dtype=float32)


# JVP and VJP with model input

In [6]:
x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
targets = tf.constant([[1.0], [-1.0]])
dense = tf.keras.layers.Dense(1)
dense.build([None, 2])
with tf.autodiff.ForwardAccumulator(primals=dense.kernel, tangents=tf.constant([[1.0], [0.0]])) as acc:
    loss = tf.reduce_sum((dense(x) - targets) ** 2.0)
print(acc.jvp(loss))


print("dense.kernel")
print(dense.kernel)
print()
print("tf.constant([[1.], [0.]])")
print(tf.constant([[1.0], [0.0]]))
print()
print(x)
tangents = tf.constant([[0.0, 0], [0.0, 1.0]])
with tf.autodiff.ForwardAccumulator(primals=x, tangents=tangents) as acc:
    loss = tf.reduce_sum((dense(x) - targets) ** 2.0)
print(acc.jvp(loss))

tf.Tensor(-5.216034, shape=(), dtype=float32)
dense.kernel
<tf.Variable 'kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[ 1.2166749],
       [-0.7691392]], dtype=float32)>

tf.constant([[1.], [0.]])
tf.Tensor(
[[1.]
 [0.]], shape=(2, 1), dtype=float32)

tf.Tensor(
[[2. 3.]
 [1. 4.]], shape=(2, 2), dtype=float32)
tf.Tensor(1.3227375, shape=(), dtype=float32)


In [8]:
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)


def model(feature_vec):
    # Very simple linear model with activation
    return feature_vec.dot(weights).relu()


examples = torch.randn(batch_size, feature_size)
result = torch.vmap(model)(examples)
result

tensor([0.2995, 0.2864, 0.5609], grad_fn=<ReluBackward0>)

In [None]:
@tf.function
def run_JVP(imgs, labs):
    L = tf.shape(imgs)[0] # T
    grad_importance_array = tf.TensorArray(
        tf.float32, size=0, dynamic_size=True, infer_shape=False, element_shape=[None]
    )
    grad_importance_array, _ = tf.while_loop(
        cond=lambda grad_TA, k: tf.cast(k, dtype=tf.int32)
        < tf.cast(tf.math.ceil(tf.cast(L, dtype=tf.float32) / tf.cast(bs, dtype=tf.float32)), dtype=tf.int32),
        # ! cond = k < ceil(T/_PARALLEL_BATCH)
        body=lambda grad_TA, k: (
            one_step_JVP(
                grad_TA,
                imgs[batching(L, bs, k)[0] : batching(L, bs, k)[1]],  # ! [_PARALLEL_BATCH, *img_size]
                labs[batching(L, bs, k)[0] : batching(L, bs, k)[1]],  # ! [_PARALLEL_BATCH]
                k,
            ),
            k + 1,
        ),
        loop_vars=(grad_importance_array, tf.constant(0)),
        back_prop=False,
        parallel_iterations=1,
    )
    return grad_importance_array.concat()  # ! [T, ??]

In [17]:
# tf.TensorArray Test
import tensorflow as tf

def fn(aa, k):
    tmp =tf.ones(12, dtype=tf.float32) * tf.cast(k, dtype=tf.float32)
    tmp = tf.reshape(tmp, [2, 3, 2])
    tmp = [i for i in tmp]
    
    print(tmp)
    aa = aa.write(k, tmp)
    return aa

aa = tf.TensorArray(tf.float32, size=0, dynamic_size=True, infer_shape=False, element_shape=[None])
aa, _ = tf.while_loop(
    cond=lambda grad_TA, k: tf.cast(k, dtype=tf.int32) < tf.cast(5, dtype=tf.int32),
    body=lambda grad_TA, k: (
        fn(grad_TA, k),
        k + 1,
    ),
    loop_vars=(aa, tf.constant(0)),
    back_prop=False,
    parallel_iterations=1,
)

aa.concat()

[<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)>, <tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)>]


ValueError: Incompatible shape for value ((2, 3, 2)), expected ((None,))

<tf.Tensor: shape=(50,), dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 3., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
      dtype=float32)>