In [2]:
import jax
import jax.numpy as jnp

# Advanced Differentiation
We introduced **jax.grad** and how it computes the gradients of a given function. Sometimes, we need to define custom gradients for functions that JAX can't differentiate automatically or where we want to improve numerical stability.

We are going to introduce **jax.custom_jvp** and **jax.custom_vjp**

- **custom_jvp** : To define custom [forward-mode ](https://https://en.wikipedia.org/wiki/Automatic_differentiation)differentiation rules

Let's take the sigmoid function for example


\begin{align}
\sigma(x) &= \frac{1}{1 + e^{-x}}
\end{align}


Its derivative:

\begin{align}
\frac{d\sigma(x)}{dx} &= \sigma(x) \cdot (1 - \sigma(x))
\end{align}


In [6]:

from jax import custom_jvp

@custom_jvp
def sigmoid(x):
  return  1 / (1 + jnp.exp(-x))

#Now we define its custom gradient

@sigmoid.defjvp
def sigmoid_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  y = sigmoid(x)
  return y, y* (1- y) * x_dot

x = 0.0
print(sigmoid(x)) # sigmoid(0.0) = 1.0 / 1 + 1 = 0.5

#manually the output of the gradient should
# (d(sigmoid)/dx)[0.0] = 0.5 * (1.0 - 0.5) = 0.25
print(jax.grad(sigmoid)(x))



0.5
0.25


- **custom_vjp**: we introduce this custom gradient when dealing with non-differentiable functions

let's suppose the following function: \begin{align}
f(x) = |x|
\end{align}

the derivative of this function is not defined at 0. We'll define custom_vjp that handles this situation

In [11]:
from jax import custom_vjp

@custom_vjp
def f(x):
  return jnp.abs(x)

In [12]:
# Forward pass
def f_forward(x):
  return f(x), x

for x > 0:

\begin{align}
\frac{df(x)}{dx} &= 1
\end{align}

for x < 0:

\begin{align}
\frac{df(x)}{dx} &= -1
\end{align}

for x = 0:

\begin{align}
\frac{df(x)}{dx} &= 0
\end{align}

In [24]:
#backward pass
# df/dx = df/dy * dy/dx : The chain-rule, have look at this https://en.wikipedia.org/wiki/Automatic_differentiation
# The code is a bit absurd, please take your time, LOL!!!

def f_backward(res, g):
  x = res #Residual from the forward pass
  return (jnp.where(x > 0, 1.0, jnp.where(x<0, -1.0, 0.0)) * g,)


In [22]:
# we attach the forward and the backward pass with customized vjp
f.defvjp(f_forward, f_backward)

In [23]:
#---- Test

#forward pass
x = jnp.array([-3.0, -0.8, 1.0, 0.0, 10.0, -22.0])
print(f(x)) #Expects to return the abs of each item in x

#backward pass
grad_f = jax.grad(lambda x : jnp.sum(f(x)))
print(grad_f(x))

[ 3.   0.8  1.   0.  10.  22. ]
[-1. -1.  1.  0.  1. -1.]
