# Automatic differentiation: Reverse mode

In [1]:
#pip install -U "jax[cpu]"

In [2]:
import numpy as np
from IPython.display import display, Math
import time

Lets use our previously used function:

\begin{gather*}
f(x_1, x_2) = exp(x_1)-x_1 \cdot x_2+cos(x_2)
\end{gather*}

And do reverse and forward mode for AD

In the following $x_1$ and $x_2$ will always have the values:

In [3]:
x1 = 2.0
x2 = 4.0

Therefore our function result is:

In [4]:
def f(x1, x2):
    v1 = np.exp(x1)
    v2 = x1*x2
    v3 = np.cos(x2)
    v4 = v1 - v2
    v5 = v3 + v4
    y = v5
    return y

print(f(x1,x2))

-1.2645875219329614


## 1. With transverse jacobian
For each of the derivative steps it helps to look at an example computational tree with reverse mode

In [5]:
def exp_rev(x1,dx1,deriv_child):
    dx1 += deriv_child * np.exp(x1)
    return dx1

def mul_rev(x1,x2,dx1,dx2,deriv_child):
    dx2 += deriv_child * x1
    dx1 += deriv_child * x2
    return dx1,dx2

def dif_rev(x1,x2,dx1,dx2,deriv_child):
    dx1 += deriv_child * 1
    dx2 += deriv_child * (-1)
    return dx1, dx2

def cos_rev(x1,dx1,deriv_child):
    dx1 += deriv_child * (-np.sin(x1))
    return dx1

def sum_rev(x1,x2,dx1,dx2,deriv_child):
    dx1 += deriv_child * 1
    dx2 += deriv_child * 1
    return dx1,dx2


def function_rev(x1, x2, dy):
    
    #run forward
    v1 = np.exp(x1)
    v2 = x1 * x2
    v3 = np.cos(x2)
    v4 = v1 - v2
    v5 = v3 + v4
    y = v5


    #zero gradients 
    dx1 = 0
    dx2 = 0
    dv1 = 0
    dv2 = 0
    dv3 = 0
    dv4 = 0
    
    #run backward
    dv3,dv4  = sum_rev(v3, v4, dv3, dv4, dy)
    dv1,dv2  = dif_rev(v1, v2, dv1, dv2, dv4)
    dx2  = cos_rev(x2, dx2, dv3)
    dx1,dx2 = mul_rev(x1,x2, dx1, dx2, dv2)
    dx1  = exp_rev(x1, dx1, dv1)
    
    return y,[dx1,dx2]

In [21]:
function_rev(x1,x2,1.0)

(-1.2645875219329614, [3.3890560989306504, -1.2431975046920718])

## 2. With Jax

In [22]:
#conda install jax -c conda-forge

In [23]:
import jax
import jax.numpy as jnp

In [24]:
# define function uing jax numpy library
def f_jax(x1, x2):
    return jnp.exp(x1) - x1*x2 + jnp.cos(x2)

In [25]:
# backward

# define with argnum that gradient is derived for first input position
print('Gradient for first input position:', jax.grad(f_jax, argnums=0)(x1, x2)) 

# define with argnum that gradient is derived for second input position
print('Gradient for second input position:', jax.grad(f_jax, argnums=1)(x1, x2))

Gradient for first input position: 3.3890562
Gradient for second input position: -1.2431974


## 3. With PyTorch

In [26]:
#pip install torch torchvision torchaudio

In [27]:
import torch

In [28]:
# torch can only use torch tensors as function input
x1_t = torch.tensor([x1], requires_grad=True)
x2_t = torch.tensor([x2], requires_grad=True)

# define function
def f_pt(x1_t, x2_t):

    v1 = torch.exp(x1_t)
    v2 = torch.matmul(x1_t, x2_t)
    v3 = torch.cos(x2_t)
    v4 = torch.sub(v1, v2)
    v5 = torch.add(v3, v4)
    y = v5
    return y

In [29]:
# apply reverse mode, we don't need to specify for which input position; it assigns the object variable
torch.autograd.backward([f_pt(x1_t, x2_t)], inputs = [x1_t, x2_t])

In [30]:
# gradient is derived for first input position; assigns the variable
print(x1_t.grad)

# gradient is derived for second input position
print(x2_t.grad)

tensor([3.3891])
tensor([-1.2432])


## 4. With TensorFlow

In [31]:
# pip install tensorflow-cpu

In [32]:
import tensorflow as tf

In [33]:
# tensorflow functions only work with tensorflow tensors
x1_tf = tf.Variable(x1)
x2_tf = tf.Variable(x2)

# define function with tensorflow mathematical differentiation
def f_tf(x1_tf, x2_tf):
    v1 = tf.math.exp(x1_tf)
    v2 = tf.math.multiply(x1_tf, x2_tf)
    v3 = tf.math.cos(x2_tf)
    v4 = tf.math.subtract(v1, v2)
    v5 = tf.math.add(v3, v4)
    y = v5
    return y


def train_tf(x1_tf, x2_tf):
    with tf.GradientTape() as tape:
        # forward pass
        y = f_tf(x1_tf, x2_tf)
        
    # compute reverse AD
    gradients = tape.gradient(y, [x1_tf, x2_tf])
    return gradients

t1 = time.time()
gradients = train_tf(x1_tf, x2_tf)
%timeit train_tf(x1_tf, x2_tf)
t2 = time.time()
print(t2-t1, gradients[0].numpy(), gradients[1].numpy())

2.75 ms ± 374 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.300999879837036 3.3890562 -1.2431974


## Summary

print derivative of function for all different approaches and how long it took

In [35]:
print('1. Gradient via Jacobian-vector-product:')
f_rev, [dx1, dx2] = function_rev(x1,x2,1.0)
print(f'f({x1}, {x2})=', round(f_rev, 3), 'with dx1=', round(dx1,3), 'and dx2=', round(dx2,3))
%timeit function_rev(x1,x2,1.0)

print('\n')
print('2. Gradient via Jax')
f_jx = f_jax(x1, x2)
dx1, dx2 = jax.grad(f_jax, argnums=0)(x1, x2), jax.grad(f_jax, argnums=1)(x1, x2)
print(f'f({x1}, {x2})=', round(f_jx, 3), 'with dx1=', round(dx1,3), 'and dx2=', round(dx2,3))
%timeit jax.grad(f_jax, argnums=0)(x1, x2)

print('lets try with jit')
jit_f_jx = jax.jit(f_jax)
%timeit jax.grad(jit_f_jx, argnums=0)(x1, x2)
%timeit jax.grad(jit_f_jx, argnums=0)(x1, x2).block_until_ready()

print('\n')
print('3. Gradient via PyTorch')
x1_t, x2_t = torch.tensor([x1], requires_grad=True), torch.tensor([x2], requires_grad=True)
f_pyt = f_pt(x1_t, x2_t)
torch.autograd.backward([f_pt(x1_t, x2_t)], inputs = [x1_t, x2_t])
dx1, dx2 = x1_t.grad, x2_t.grad
print(f'f({x1}, {x2})=', f_pyt, 'with dx1=', dx1, 'and dx2=', dx2)
%timeit torch.autograd.backward([f_pt(x1_t, x2_t)], inputs = [x1_t, x2_t])

print('\n')
print('4. Gradient via TensorFlow')
x1_tf, x2_tf = tf.Variable(x1), tf.Variable(x2)
f_tfw = f_tf(x1_tf, x2_tf)
gradients = train_tf(x1_tf, x2_tf)
print(f'f({x1}, {x2})=', f_tfw.numpy(), 'with dx1=', gradients[0].numpy(), 'and dx2=', gradients[1].numpy())
%timeit train_tf(x1_tf, x2_tf)

1. Gradient via Jacobian-vector-product:
f(2.0, 4.0)= -1.265 with dx1= 3.389 and dx2= -1.243
10.5 µs ± 3.32 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


2. Gradient via Jax
f(2.0, 4.0)= -1.2650001 with dx1= 3.3890002 and dx2= -1.243
7 ms ± 349 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
lets try with jit
2.72 ms ± 65.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4 ms ± 960 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


3. Gradient via PyTorch
f(2.0, 4.0)= tensor([-1.2646], grad_fn=<AddBackward0>) with dx1= tensor([3.3891]) and dx2= tensor([-1.2432])
330 µs ± 39 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


4. Gradient via TensorFlow
f(2.0, 4.0)= -1.2645874 with dx1= 3.3890562 and dx2= -1.2431974
2.8 ms ± 501 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Check difference between tensorflow with and without XLA

In [36]:
# train normal tensorflow
x1_tf, x2_tf = tf.Variable(x1), tf.Variable(x2)

t1 = time.time()
gradients = train_tf(x1_tf, x2_tf)
t2 = time.time()

t_tf = t2-t1

In [37]:
# train with tensorflow and jit

t1 = time.time()
@tf.function(jit_compile=True)
def train_jit(x1_tf, x2_tf):
    with tf.GradientTape() as tape:
        # forward pass
        y = f_tf(x1_tf, x2_tf)
    gradients = tape.gradient(y, [x1_tf, x2_tf])
t2 = time.time()
t_tf_jit = t2-t1

In [38]:
print('Time needed for tensorflow:', t_tf)
print('Time needed for tensorflow with jit:', t_tf_jit) # using XLA 

Time needed for tensorflow: 0.004996061325073242
Time needed for tensorflow with jit: 0.007998228073120117


# Forward mode

## 1. With Jacobian
For a better understanding look at the computation steps of the forward mode using a computation graph

In [41]:
# forward
def mul_forward(x1,dx1,x2,dx2):
    y  = x1*x2
    dy = x1*dx2 + x2*dx1
    return y, dy

def sum_forward(x1,dx1,x2,dx2):
    return x1+x2, dx1 + dx2

def dif_forward(x1,dx1,x2,dx2):
    return x1-x2, dx1 - dx2

def exp_forward(x1,dx1):
    return np.exp(x1), dx1 * np.exp(x1)

def cos_forward(x2,dx2):
    return np.cos(x2), - dx2 * np.sin(x2)

def function_forward(x,dx):
    x1,x2 = x
    dx1,dx2 = dx
    v1, dv1 = exp_forward(x1,dx1)
    v2,dv2 = mul_forward(x1,dx1,x2,dx2)
    v3, dv3 = cos_forward(x2,dx2)
    v4, dv4 = dif_forward(v1,dv1,v2,dv2)
    v5, dv5 = sum_forward(v3,dv3,v4,dv4)
    y, dy = v5, dv5
    return y, dy

print(function_forward([x1,x2],[1,0]))
print(function_forward([x1,x2],[0,1]))

(-1.2645875219329614, 3.3890560989306504)
(-1.2645875219329614, -1.2431975046920718)


In [42]:
%timeit function_forward([x1,x2],[1,0])

9.75 µs ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## 2. With Jax

In [44]:
# forward

y, f_jvp = jax.jvp(f_jax, (x1, x2,), (1.0, 0.0,))
print('y=', y, ', dy=', f_jvp)

y, f_jvp = jax.jvp(f_jax, (x1, x2,), (0.0, 1.0,))
print('y=', y, ', dy=', f_jvp)

%timeit jax.jvp(f_jax, (x1, x2,), (1.0, 0.0,))

y= -1.2645874 , dy= 3.3890562
y= -1.2645874 , dy= -1.2431974
5.73 ms ± 866 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
