# Documentation on JAX Automatic Differentiation
Here describes the automatic differetiation notation used throughout jax-gcm and how to set up your own jax gradient code.  

### Notation 

To take gradients of functions in jax-gcm we will be using the jax function jax.vjp().  While jax.grad() can also take derivatives, it is a special case of jax.vjp() and requires the output of the function being differentiated to be a scaler.  jax.vjp() is more generalized and can be used in the cases we require (a.k.a. functions that have outputs other than scalers). 

jax.vjp() takes in the function being differetiated and the function parameters and outputs a set of primals (which we won't worry about too much, but are the outputs from the forward model) and the function that will output the gradients. Here is an example: 

In [42]:
import jax
import jax.numpy as jnp

# Example of the funtion to be differentiated 
def my_func(x, y):  # Input two parameters (can be vectors of the same shape) 
    return x**2 + y**3  # Outputs one vector (or scaler )

# Define a function to calculate the gradient
def grad_my_func(f, x, y): # Parameters should be the same as my_func
    primals, f_vjp = jax.vjp(f, x, y)  # Creates primals and "derivative function"
    input = (jnp.ones_like(primals)) # Input into f_vjp.  Must be the shape of the output of my_func 
                                # I think it should be initialized with all ones to get the gradient value with
                                # respect to the parameters of my_func 
    df_dx, df_dy = f_vjp(input) #Takes derivate with respect to each parameter
    return df_dx, df_dy

# Test 
xx = 2*jnp.ones(3) # x input 
yy = jnp.arange(0.0, 3.0, 1) # y input
df_dx, df_dy = grad_my_func(my_func, xx, yy) # Call gradient function
print('The derivate of my function with respect to x is:', df_dx)
print('The derivate of my function with respect to y is:',df_dy)

The derivate of my function with respect to x is: [4. 4. 4.]
The derivate of my function with respect to y is: [ 0.  3. 12.]


##### Best practices when creating gradient functions
- To define the gradient of a function, put 'grad_' at the front of the function name.
- When calling jax.vjp(), always call the output 'primals, f_vjp'.
- Initialize input values to be inserted back into f_vjp(). (You can use the shape of the primals (which are the outputs of the function for the given input) but you need to make sure that you create the same object type with all ones so that the gradient is accurate)
- When calculating the partial derivatives, call them d(function output)_d(variable the derivative is with respect to).




### Other important tips
- JAX cannot take the gradient of integer values, so make sure to initialize the function parameter values with floats (or an array of floats). 
- If there are multiple output objects, jax.vjp() sums over the gradients. 

### Taking gradients using jax.jvp (jacobian vector product)

- Sometime it might be better to use jax.jvp instead of vjp
- JVP allows for computing the gradient for individual parameters (one at a time)
- JVP also computes the primals (the output of the forward function)
- The third argument of jax.jvp is called the tangent.  It has the same shape as the parameters of the function 
and if you set a parameter value = 1.0, it takes the gradient with respect to that parameter
- Below are some examples, along with explanations for notation

In [37]:
# Use JVP to take gradient with respect to x
def jac_my_func_x(f, x, y): 
    primals, df_dx = jax.jvp(f, [x, y], 
                        [jnp.ones_like(xx), # set ones for parameter you want to take the gradient with respect to
                         jnp.zeros_like(yy)]) # set all other parameters to zeros
    return df_dx

# Use JVP to take gradient with respect to y
def jac_my_func_y(f, x, y): 
    primals, df_dy = jax.jvp(f, [x, y], 
                        [jnp.zeros_like(xx), 
                         jnp.ones_like(yy)])
    return df_dy

# If you set all the tangents with ones: 
def jac_my_func_all(f, x, y):
    primals, df = jax.jvp(f, [x, y], 
                        [jnp.ones_like(xx), 
                         jnp.ones_like(yy)])
    return df

df_dx = jac_my_func_x(my_func, xx, yy)
df_dy = jac_my_func_y(my_func, xx, yy)
df = jac_my_func_all(my_func, xx, yy)
print('The derivate of my function with respect to x is:', df_dx)
print('The derivate of my function with respect to y is:', df_dy)
print('The derivate of my function summed over all partials is:', df)

The derivate of my function with respect to x is: [4. 4. 4.]
The derivate of my function with respect to y is: [ 0.  3. 12.]
The derivate of my function summed over all partials is: [ 4.  7. 16.]


### JVP vs VJP for functions with multiple output components
- JVP: calculates each function output gradient with respect to the parameter of interest
- VJP: sums over the function output gradients for all the parameters

In [38]:
# JVP and VJP have different behavior with multiple function outputs
def my_multi_func(x, y): 
    return x**2 + y**3, x**2 + y**3 

# VJP
def grad_my_multi_func(f, x, y): 
    primals, f_vjp = jax.vjp(f, x, y) 
    input = (jnp.ones_like(primals[0]), jnp.ones_like(primals[1]))
    df_dx, df_dy = f_vjp(input) 
    return df_dx, df_dy

# JVP with respect to x
def jac_my_multi_func_x(f, x, y):
    primals, df_dx = jax.jvp(f, [x, y], 
                    [jnp.ones_like(xx), 
                        jnp.zeros_like(yy)])
    return df_dx

# JVP with respect to y
def jac_my_multi_func_y(f, x, y):
    primals, df_dy = jax.jvp(f, [x, y], 
                    [jnp.zeros_like(xx), 
                        jnp.ones_like(yy)])
    return df_dy

df_dx_vjp, df_dy_vjp = grad_my_multi_func(my_multi_func, xx, yy) 
df_dx_jvp = jac_my_multi_func_x(my_multi_func, xx, yy)
df_dy_jvp = jac_my_multi_func_y(my_multi_func, xx, yy)
print('The derivate of my function with respect to x from VJP is:', df_dx_vjp)
print('The derivate of my function with respect to y from VJP is:',df_dy_vjp)
print('The derivate of my function with respect to x from JVP is:', df_dx_jvp)
print('The derivate of my function with respect to y from JVP is:', df_dy_jvp)
    

The derivate of my function with respect to x from VJP is: [8. 8. 8.]
The derivate of my function with respect to y from VJP is: [ 0.  6. 24.]
The derivate of my function with respect to x from JVP is: (Array([4., 4., 4.], dtype=float32), Array([4., 4., 4.], dtype=float32))
The derivate of my function with respect to y from JVP is: (Array([ 0.,  3., 12.], dtype=float32), Array([ 0.,  3., 12.], dtype=float32))


In [44]:
# What if there was only one parameter vector inserted into the function? 
def vectorized_func(theta): # x must be of length 6
    return theta[:3]**2 + theta[3:]**3

# VJP
def grad_v_func(f, theta): 
    primals, f_vjp = jax.vjp(f, theta) 
    input = (jnp.ones_like(primals))
    df_dtheta = f_vjp(input) 
    return df_dtheta

# JVP with respect to theta (i.e. all the parameters)
def jac_v_func_theta(f, theta):
    primals, df_dtheta = jax.jvp(f, [theta], [jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0])])
    return df_dtheta

theta = jnp.array([2.0, 2.0, 2.0, 0.0, 1.0, 2.0])
df_dtheta_vjp = grad_v_func(vectorized_func, theta) 
df_dtheta_jvp = jac_v_func_theta(vectorized_func, theta)
print('The derivate of my vectorized function with respect to theta from VJP is:', df_dtheta_vjp)
print('The derivate of my vectorized function with respect to theta from JVP is:', df_dtheta_jvp)

The derivate of my vectorized function with respect to theta from VJP is: (Array([ 4.,  4.,  4.,  0.,  3., 12.], dtype=float32),)
The derivate of my vectorized function with respect to theta from JVP is: [4. 0. 0.]


# Additional thoughts:

### VJP: 
- Sums over each function partial and returns a vector of of the function gradient with respect to each parameter, but summed over each function output partial (i.e. returns a vector of the same shape/size as the function parameters)

### JVP: 
- Sums over each parameter partial and returns a vector of summed partial derivatives for each element of the function output (i.e. returns a vector the same shape/size as the function output)