# Documentation on JAX Automatic Differentiation
Here describes the automatic differetiation notation used throughout jax-gcm and how to set up your own jax gradient code.  

### Notation 

To take gradients of functions in jax-gcm we will be using the jax function jax.vjp().  While jax.grad() can also take derivatives, it is a special case of jax.vjp() and requires the output of the function being differentiated to be a scaler.  jax.vjp() is more generalized and can be used in the cases we require (a.k.a. functions that have outputs other than scalers). 

jax.vjp() takes in the function being differetiated and the function parameters and outputs a set of primals (which we won't worry about too much) and the function that will output the gradients. Here is an example: 

In [4]:
import jax
import jax.numpy as jnp

# Example of the funtion to be differentiated 
def my_func(x, y):  # Input two parameters (can be vectors of the same shape) 
    return x**2 + y**3  # Outputs one vector (or scaler )

# Define a function to calculate the gradient
def grad_my_func(f, x, y): # Parameters should be the same as my_func
    primals, f_vjp = jax.vjp(f, x, y)  # Creates primals and "derivative function"
    input = (jnp.ones_like(primals)) # Input into f_vjp.  Must be the shape of the output of my_func 
                                # I think it should be initialized with all ones to get the gradient value with
                                # respect to the parameters of my_func 
    df_dx, df_dy = f_vjp(input) #Takes derivate with respect to each parameter
    return df_dx, df_dy

# Test 
xx = 2*jnp.ones(3) # x input 
yy = jnp.arange(0.0, 3.0, 1) # y input
df_dx, df_dy = grad_my_func(my_func, xx, yy) # Call gradient function
print('The derivate of my function with respect to x is:', df_dx)
print('The derivate of my function with respect to y is:',df_dy)

The derivate of my function with respect to x is: [4. 4. 4.]
The derivate of my function with respect to y is: [ 0.  3. 12.]


##### Best practices when creating gradient functions
- To define the gradient of a function, put 'grad_' at the front of the function name.
- When calling jax.vjp(), always call the output 'primals, f_vjp'.
- Initialize input values to be inserted back into f_vjp(). (You can use the primals (which are the outputs of the function for the given input) but you need to make sure that you create the same object type with all ones so that the gradient is accurate)
- When calculating the partial derivatives, call them d(function output)_d(variable the derivative is with respect to).




### Other important tips
- JAX cannot take the gradient of integer values, so make sure to initialize the function parameter values with floats (or an array of floats). 
- More to come...