# Autodiff

JAX is like numpy, but with autodiff. This means that you can compute gradients of functions with respect to their inputs. 

#### Standart numpy

In [11]:
import numpy as np

def f(x,y,z):
    return (x+y)*z

f(-2,5,-4)

-12

### Same with JAX

In [12]:
### Same in jax
import jax.numpy as jnp
from jax import grad

def g(x,y,z):
    return (x+y)*z

g(-2,5,-4)
# Gradient of g with respect to x
grad(g)(-2.,5.,-4.)

Array(-4., dtype=float32, weak_type=True)

In [20]:
import jax
import jax.numpy as jnp

# Benutzerdefinierte Funktion
def custom_func(x):
    return x**3

# Vorwärtsfunktion: Gibt den Ausgangswert und (optionale) Zwischenwerte zurück
def custom_func_fwd(x):
    y = x**33
    return y, x  # x wird an die Rückwärtsfunktion weitergegeben

# Rückwärtsfunktion: Berechnet den Gradienten
def custom_func_bwd(x, grad_y):
    grad_x = 33 * x**32 * grad_y
    return (grad_x,)

# Registrierung der Ableitung
custom_func = jax.custom_vjp(custom_func)
custom_func.defvjp(custom_func_fwd, custom_func_bwd)

# Test der Gradientenberechnung
print(jax.grad(custom_func)(2.0))  # Ausgabe: 12.0

#### Same w/o custom_vjp
def func2(x):
    return x**33

print(jax.grad(func2)(2.0))  # Ausgabe: 12.0



141733920000.0
141733920000.0


### Custom gradients

The next examples show how to compute gradients with JAX. Relu at 0 is not differentiable and the gradient is not defined. In the standard implementation, the gradient is set to 0.5.

In [24]:
# Define the ReLU function
def relu(x):
    return jnp.maximum(0, x)

# Compute the gradient of ReLU
x_values = jnp.array([-1.0, 0.0, 1.0])  # Test inputs
grads = jax.grad(lambda x: jnp.sum(relu(x)))(x_values)

# Print results
print("X values:", (x_values))
print("ReLU values:", relu(x_values))
print("Gradients of ReLU at test points:", grads)

X values: [-1.  0.  1.]
ReLU values: [0. 0. 1.]
Gradients of ReLU at test points: [0.  0.5 1. ]


### Let's define another value for the gradient at 0

The following code shows how to define a custom gradient for the relu function with gradient at 0 abitrarely defined as 0.42.


In [None]:
import jax

# Define the custom ReLU function
def custom_relu(x):
    return jnp.maximum(0, x)

# Forward function: computes the output and any intermediate values
def custom_relu_fwd(x):
    return custom_relu(x), x  # Pass x to the backward function for gradient handling

# Backward function: provides a custom gradient rule
def custom_relu_bwd(x, grad_y):
    # Gradient is:
    # - 0 for x < 0
    # - 1 for x > 0
    # - 0.42 at x = 0
    grad_x = jnp.where(x > 0, grad_y, 0)  # Grad = 1 for x > 0
    grad_x = jnp.where(x == 0, grad_y * 0.42, grad_x)  # Grad = 0.42 for x == 0
    return (grad_x,)

# Register the custom VJP
custom_relu = jax.custom_vjp(custom_relu)
custom_relu.defvjp(custom_relu_fwd, custom_relu_bwd)

# Test the function and its gradient
x_values = jnp.array([-1.0, 0.0, 1.0])  # Test points
print("Custom ReLU values:", custom_relu(x_values))

# Compute gradients
grads = jax.grad(lambda x: jnp.sum(custom_relu(x)))(x_values)
print("Gradients of Custom ReLU at test points:", grads)

Custom ReLU values: [0. 0. 1.]
Gradients of Custom ReLU at test points: [0.   0.42 1.  ]


### Complicated functions

You can also autodiff any function, containing loops, conditionals, etc.

In [27]:
import jax
import jax.numpy as jnp

# Function with conditions and loops
def complicated_func(x):
    # Conditional operation: square if greater than 3, else multiply by 2
    if x > 3.:
        y = x**2
    else:
        y = 2 * x
    
    # Use a loop to compute a cumulative sum
    result = 0.0
    for i in range(1, 6):  # Sum of i * y for i in range(1, 6)
        result += i * y
    return result.mean()

# Vectorize the function for multiple inputs
v_complicated_func = jax.vmap(complicated_func)

# Compute the gradient of the function
x_values = jnp.array([2.0, 4.0, 5.0])  # Test inputs
grads = jax.grad(lambda x: jnp.sum(v_complicated_func(x)))(x_values)

# Print results
print("Function values:", v_complicated_func(x_values))
print("Gradients at test points:", grads)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
This BatchTracer with object id 5872364544 was created on line:
  /var/folders/bk/0vv7sh9n43n3dm4fth1qw93r0000gq/T/ipykernel_84478/168064051.py:7 (complicated_func)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError