<a href="https://colab.research.google.com/github/mridul-sahu/baking-with-jax-autodiff/blob/main/Baking_with_JAX_Autodiff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Baking with JAX Autodiff 🍰

##Let's bake a cake!
Well, not literally, but we'll use the process of perfecting a cake recipe as our story to understand JAX's automatic differentiation (autodiff) capabilities, from the fundamentals to more advanced techniques.

Imagine we are bakers, and JAX is our incredibly smart assistant. Our goal is to create the perfect cake. This involves adjusting ingredients (inputs) and observing the cake's properties (outputs) like sweetness, fluffiness, or overall deliciousness score. To do this efficiently, we need to know how tweaking each ingredient affects the final cake. This "how things change" is the essence of differentiation, and JAX's autodiff is our tool to calculate it automatically.

### Our Baking Course

1.  **The Basics**
    Getting started with gradients (`jax.grad`), handling different input/output structures, and verifying results.
2.  **Advanced Baking**
    Dealing with multiple cake properties at once (Jacobians), understanding the rate at which improvements change (Hessians), efficiently calculating directional changes (JVPs, VJPs, HVPs), and controlling gradient flow (`stop_gradient`, `vmap`).
3.  **Exotic Flavors & Vibrations**
    Exploring how JAX differentiates functions involving complex numbers, understanding the difference between 'smooth' (holomorphic) and 'tricky' (non-holomorphic) cases, and using the right tools (`grad` with `holomorphic=True`, Jacobians) accordingly.
4.  **Secret Family Recipes**
    Teaching JAX custom differentiation rules using `jax.custom_jvp` and `jax.custom_vjp` to overcome limitations, such as fixing numerical instability, enforcing specific baking rules (like gradient clipping), or handling complex iterative processes (like dough maturation) that standard autodiff struggles with.

##Core Concepts: The Building Blocks of Change

Before diving into JAX functions, let's understand the ideas:

1. **Derivative**: If we change *one* ingredient (e.g., sugar) slightly, how much does *one* property (e.g., sweetness) change? The derivative measures this instantaneous rate of change.

2. **Gradient**: If we have *multiple* ingredients affecting *one* key outcome (e.g., an "overall deliciousness score"), the gradient is a list (vector) of derivatives. Each element tells us how the score changes with respect to one specific ingredient. The gradient points in the direction of ingredient adjustments that *most rapidly* increase the score.

3. **Jacobian**: What if changing *multiple* ingredients (sugar, flour, eggs) affects *multiple* properties (sweetness, fluffiness, cost)? The Jacobian is a table (matrix) containing *all* the partial derivatives – how *each* ingredient influences *each* property. It's the complete map of local sensitivities.

4. **Hessian**: This describes the *curvature* of our outcome. If we're adjusting ingredients to improve the deliciousness score, are we approaching the peak rapidly (high curvature) or is the landscape flat (low curvature)? It's the matrix of *second* derivatives, telling us how the *gradient* itself changes.

5. **Forward vs. Reverse Mode Autodiff**: Two main computational strategies:

  - **Forward Mode (JVP)**: Tracks changes *forward*. Efficient when you have *fewer* inputs than outputs. Think: "Let's nudge the sugar amount and see how it affects sweetness, fluffiness, and cost."

  - **Reverse Mode (VJP)**: Works *backward* from the output. Efficient when you have *more* inputs than outputs (like many ingredients affecting one score). Think: "To make the cake slightly sweeter, how should all the ingredients (sugar, flour, eggs...) have changed?"

#Part 1: The Basics - Simple Adjustments & Verification

Let's start with the fundamental tool for finding gradients.


###Taking Gradients with "jax.grad"

JAX's primary function for differentiation is `jax.grad()`. It takes a Python function (that uses JAX-compatible operations) which returns a *single scalar value*, and `jax.grad()` returns a *new function* that computes the gradient of the original function.

*Example:* A Simple Function (`tanh`)

The `tanh` function is often used in neural networks. Let's find its derivative.

In [None]:
import jax
import jax.numpy as jnp

# The function we want to differentiate
f_tanh = jnp.tanh

# Use jax.grad to get a function that computes the gradient (derivative)
grad_tanh = jax.grad(f_tanh)

# Evaluate the gradient function at a specific point (e.g., x=2.0)
gradient_value = grad_tanh(2.0)
print(f"The gradient of tanh at x=2.0 is: {gradient_value}")

# For tanh, the derivative is 1 - tanh(x)^2 or sech(x)^2
# Let's check: 1 - jnp.tanh(2.0)**2 = 0.07065...
assert jnp.allclose(gradient_value, 1 - jnp.tanh(2.0)**2)

The gradient of tanh at x=2.0 is: 0.07065081596374512


#### *Higher-Order Derivatives*

Since `jax.grad(f)` returns a function, we can apply `jax.grad` again to get second, third, or higher-order derivatives.

*Example: A Polynomial Recipe*

Imagine a simplified "cake quality" score based on one ingredient `x`: $f(x)=x^3 + 2x^2 − 3x+1$.

In [None]:
def cake_quality_simple(x):
  """A simple score based on one ingredient amount x."""
  return x**3 + 2*x**2 - 3*x + 1

# First derivative function (how score changes with x)
d_quality_dx = jax.grad(cake_quality_simple)

# Second derivative function (how the *rate of change* changes)
d2_quality_dx2 = jax.grad(d_quality_dx) # or grad(grad(cake_quality_simple))

# Third derivative function
d3_quality_dx3 = jax.grad(d2_quality_dx2)

# Fourth derivative function
d4_quality_dx4 = jax.grad(d3_quality_dx3)

# Let's evaluate these at x = 1.0
x_value = 1.0
print(f"\n--- Higher-Order Derivatives at x={x_value} ---")
print(f"f'(x)   = {d_quality_dx(x_value)}")   # Expected: 3*(1)^2 + 4*(1) - 3 = 4
print(f"f''(x)  = {d2_quality_dx2(x_value)}") # Expected: 6*(1) + 4 = 10
print(f"f'''(x) = {d3_quality_dx3(x_value)}") # Expected: 6
print(f"f''''(x)= {d4_quality_dx4(x_value)}") # Expected: 0


--- Higher-Order Derivatives at x=1.0 ---
f'(x)   = 4.0
f''(x)  = 10.0
f'''(x) = 6.0
f''''(x)= 0.0


###Computing Gradients for a Recipe (Logistic Regression Example)

Let's adapt the logistic regression example. Imagine we have features derived from our ingredients (`inputs`) and we want to predict if a customer will like the cake (`True`/`False`). We have parameters `W` (weights) and `b` (bias) for our prediction model. Our goal is to adjust `W` and `b` to minimize a `loss` function (like prediction error). We need the gradient of the `loss` with respect to `W` and `b`.

In [None]:
from jax import random

key = random.key(42)

def sigmoid(x):
  """Sigmoid activation function."""
  return 0.5 * (jnp.tanh(x / 2) + 1)

def predict_like(W, b, ingredient_features):
  """Predicts probability of liking the cake."""
  # Linear model followed by sigmoid
  logit = jnp.dot(ingredient_features, W) + b
  return sigmoid(logit)

# Toy data: [sweetness, fluffiness, cost_factor] features for 4 cakes
ingredient_features = jnp.array([[0.8, 0.9, -0.5],
                                 [0.7, 0.8, -0.4],
                                 [0.3, 0.4, -0.8], # Less sweet/fluffy, low cost
                                 [0.9, 0.7, -0.3]])
# Did customers like these cakes? (True/False)
customer_likes = jnp.array([True, True, False, True])

def calculate_loss(W, b, ingredient_features, customer_likes):
  """Calculates the negative log-likelihood loss."""
  predictions = predict_like(W, b, ingredient_features)
  # Formula for binary cross-entropy / negative log-likelihood
  # Avoid log(0) with a small epsilon
  epsilon = 1e-7
  label_probabilities = predictions * customer_likes + (1 - predictions) * (1 - customer_likes)
  return -jnp.sum(jnp.log(label_probabilities + epsilon))

# Initialize random parameters for our prediction model
key, W_key, b_key = random.split(key, 3)
# W has shape (num_features,) = (3,)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ()) # bias is a scalar

# --- Calculate Gradients ---
# Use argnums to specify which positional arguments to differentiate w.r.t.

# Gradient w.r.t. W (the 0-th argument)
grad_loss_W_fn = jax.grad(calculate_loss, argnums=0)
W_gradient = grad_loss_W_fn(W, b, ingredient_features, customer_likes)
print(f"\n--- Gradients for Cake Likelihood Model ---")
print(f"Gradient w.r.t. W (argnums=0):\n{W_gradient}")

# Gradient w.r.t. b (the 1st argument)
grad_loss_b_fn = jax.grad(calculate_loss, argnums=1)
b_gradient = grad_loss_b_fn(W, b, ingredient_features, customer_likes)
print(f"\nGradient w.r.t. b (argnums=1):\n{b_gradient}")

# Gradient w.r.t. both W and b (arguments 0 and 1)
# Returns a tuple of gradients in the same order as argnums
grad_loss_Wb_fn = jax.grad(calculate_loss, argnums=(0, 1))
W_gradient_tuple, b_gradient_tuple = grad_loss_Wb_fn(W, b, ingredient_features, customer_likes)
print(f"\nGradient w.r.t. W (from tuple):\n{W_gradient_tuple}")
print(f"Gradient w.r.t. b (from tuple):\n{b_gradient_tuple}")

# Default is argnums=0
assert jnp.allclose(jax.grad(calculate_loss)(W,b, ingredient_features, customer_likes), W_gradient)


--- Gradients for Cake Likelihood Model ---
Gradient w.r.t. W (argnums=0):
[-0.05690404  0.03078762 -0.52113414]

Gradient w.r.t. b (argnums=1):
0.45482203364372253

Gradient w.r.t. W (from tuple):
[-0.05690404  0.03078762 -0.52113414]
Gradient w.r.t. b (from tuple):
0.45482203364372253


These gradients tell us how to adjust W and b to reduce the prediction error (the loss).



### Differentiating with Respect to a Model (PyTrees)

Model parameters might are often grouped into dictionaries or lists (e.g., `params = {'W': W, 'b': b}`). JAX handles differentiation through these standard Python containers (called PyTrees) seamlessly.

In [None]:
# Combine parameters into a dictionary
params = {'W': W, 'b': b}

def calculate_loss_dict(parameters, ingredient_features, customer_likes):
  """Loss function taking parameters as a dictionary."""
  # Access parameters by key
  w = parameters['W']
  b = parameters['b']
  return calculate_loss(w, b, ingredient_features, customer_likes) # Reuse the previous loss logic

# Get gradient function for the dictionary input
grad_loss_dict_fn = jax.grad(calculate_loss_dict)

# Calculate gradient - result is a dictionary with the same structure
param_gradients = grad_loss_dict_fn(params, ingredient_features, customer_likes)

print(f"\n--- Gradient for Dictionary Parameters ---")
print(f"Gradient dictionary:\n{param_gradients}")

# Check the values match the previous individual gradients
assert jnp.allclose(param_gradients['W'], W_gradient)
assert jnp.allclose(param_gradients['b'], b_gradient)


--- Gradient for Dictionary Parameters ---
Gradient dictionary:
{'W': Array([-0.05690404,  0.03078762, -0.52113414], dtype=float32), 'b': Array(0.45482203, dtype=float32)}


###Getting the Cake and the Adjustment Plan Together (`jax.value_and_grad`)

Often, when optimizing, we need both the current value of our objective (e.g., the `loss` or `deliciousness_score`) *and* the gradient (how to improve it). Calculating them separately can be inefficient as they share computation. `jax.value_and_grad` does both in one go.

In [None]:
# Get a function that returns both the loss value and the gradients w.r.t. (W, b)
value_grad_loss_fn = jax.value_and_grad(calculate_loss, argnums=(0, 1))

# Call the function
loss_val, (W_grad_vg, b_grad_vg) = value_grad_loss_fn(W, b, ingredient_features, customer_likes)

print(f"\n--- Using value_and_grad ---")
print(f"Calculated Loss: {loss_val}")
# Verify loss value
assert jnp.allclose(loss_val, calculate_loss(W, b, ingredient_features, customer_likes))
print(f"Gradient w.r.t. W: {W_grad_vg}")
print(f"Gradient w.r.t. b: {b_grad_vg}")
# Verify gradients
assert jnp.allclose(W_grad_vg, W_gradient)
assert jnp.allclose(b_grad_vg, b_gradient)


--- Using value_and_grad ---
Calculated Loss: 2.247727394104004
Gradient w.r.t. W: [-0.05690404  0.03078762 -0.52113414]
Gradient w.r.t. b: 0.45482203364372253


#Part 2: Advanced Baking - Juggling Multiple Properties & Complex Recipes

Now that we've mastered the basics of getting gradients for a single score, let's tackle more complex baking challenges.

### Multiple Cake Properties (Jacobians: `jacfwd` & `jacrev`)

Our cake has multiple important properties we care about *simultaneously*, maybe `sweetness`, `fluffiness`, and `cost`. We need the **Jacobian** matrix to see how *each* ingredient affects *each* of these properties.

In [None]:
# Let's define ingredients as a dictionary for clarity
initial_ingredients = {'sugar': 1.5, 'flour': 2.0, 'butter': 1.0, 'eggs': 3.0}
ingredient_vec = jnp.array([initial_ingredients[k] for k in sorted(initial_ingredients.keys())])

def bake_cake_properties(ingredients_vec):
    """Calculates multiple cake properties from ingredients vector."""
    # Fictional relationships - order matters for Jacobian
    sugar, flour, butter, eggs = ingredients_vec[3], ingredients_vec[2], ingredients_vec[0], ingredients_vec[1]

    sweetness = 10 * jnp.log1p(sugar) + 0.1 * butter - 0.2*flour
    fluffiness = 5 * (eggs / 3.0) - 0.5 * (flour - 2.0)**2 + 0.2*butter
    # Cost might depend non-linearly (e.g., bulk discounts implicit)
    cost = (sugar**1.1 + flour**0.9 + butter**1.2 + eggs**1.0) * 0.5

    return jnp.array([sweetness, fluffiness, cost]) # Return a vector of properties

# Calculate the properties for our initial ingredients
properties = bake_cake_properties(ingredient_vec)
print(f"\n--- Jacobians: Multiple Properties ---")
print(f"Initial Cake Properties (Sweetness, Fluffiness, Cost): {properties}")

# --- Calculate Jacobian using jacfwd (Forward Mode) ---
# Efficient if num_ingredients < num_properties ("tall")
# jacfwd pushes tangents forward for each input dimension
jacobian_fwd_fn = jax.jacfwd(bake_cake_properties)
J_fwd = jacobian_fwd_fn(ingredient_vec)
print("\nJacobian (Forward Mode - jacfwd):")
print(f"Shape (num_properties, num_ingredients): {J_fwd.shape}") # (3, 4)
# Row i = sensitivities of property i; Col j = impact of ingredient j
# Ingredient order: butter, eggs, flour, sugar
print(J_fwd)

# --- Calculate Jacobian using jacrev (Reverse Mode) ---
# Efficient if num_ingredients > num_properties ("wide")
# jacrev pulls cotangents back for each output dimension
jacobian_rev_fn = jax.jacrev(bake_cake_properties)
J_rev = jacobian_rev_fn(ingredient_vec)
print("\nJacobian (Reverse Mode - jacrev):")
print(f"Shape (num_properties, num_ingredients): {J_rev.shape}") # (3, 4)
# Row i = sensitivities of property i; Col j = impact of ingredient j
# Ingredient order: butter, eggs, flour, sugar
print(J_rev)

# They compute the same value
assert jnp.allclose(J_fwd, J_rev)
print("\nForward and Reverse Jacobians match.")


--- Jacobians: Multiple Properties ---
Initial Cake Properties (Sweetness, Fluffiness, Cost): [8.862908 5.2      3.714068]

Jacobian (Forward Mode - jacfwd):
Shape (num_properties, num_ingredients): (3, 4)
[[ 0.1         0.         -0.2         4.        ]
 [ 0.2         1.6666667   0.          0.        ]
 [ 0.6         0.5         0.41986483  0.5727589 ]]

Jacobian (Reverse Mode - jacrev):
Shape (num_properties, num_ingredients): (3, 4)
[[ 0.1         0.         -0.2         4.        ]
 [ 0.2         1.6666666   0.          0.        ]
 [ 0.6         0.5         0.41986483  0.5727589 ]]

Forward and Reverse Jacobians match.


The Jacobian `J[i, j]` tells us how property `i` (e.g., fluffiness) changes per unit change in ingredient `j` (e.g., eggs).

### The Engine Room: JVPs and VJPs (`jax.jvp`, `jax.vjp`)

`jacfwd` and `jacrev` are built on fundamental operations: Jacobian-Vector Products (JVP) for forward mode, and Vector-Jacobian Products (VJP) for reverse mode.

* **JVP** (`jax.jvp`): Answers: "If I change the ingredients by this specific vector (the `tangents`), how will the output properties change?" It computes `J @ v.

In [None]:
# Define a change in ingredients (tangent vector)
# Increase sugar (+0.1), decrease flour (-0.05), keep others same
# Order: butter, eggs, flour, sugar
ingredient_changes = jnp.array([0.0, 0.0, -0.05, 0.1])

# Compute JVP: how do properties change given ingredient_changes?
# jvp returns (primal_output, tangent_output)
primals_out, tangents_out = jax.jvp(bake_cake_properties, # function
                                 (ingredient_vec,),    # where to evaluate (primals)
                                 (ingredient_changes,)) # how inputs change (tangents)

print(f"\n--- JVP Example ---")
print(f"Original properties: {primals_out}") # Same as 'properties' calculated before
print(f"Change in properties (J @ v): {tangents_out}") # [dSweetness, dFluffiness, dCost]

# Verify: Manually compute J @ v using the previously computed Jacobian
manual_jvp = J_fwd @ ingredient_changes
print(f"Manual J @ v check: {manual_jvp}")
assert jnp.allclose(tangents_out, manual_jvp)


--- JVP Example ---
Original properties: [8.862908 5.2      3.714068]
Change in properties (J @ v): [0.40999997 0.         0.03628265]
Manual J @ v check: [0.41       0.         0.03628265]


* **VJP** (`jax.vjp`): Answers: "If I want the properties to change by a specific vector (the `cotangent` vector `v`), what change in ingredients would cause this?" It computes `v^T @ J`. `vjp` returns the primal output and a *function* (`vjp_fun`) to compute the backward pass.
> **VJPs are the core mechanism behind** `jax.grad` **and** `jacrev`.




In [None]:
# Define desired changes in properties (cotangent vector 'v')
# We want: +0.2 sweetness, -0.1 fluffiness, +0 cost
desired_property_changes = jnp.array([0.2, -0.1, 0.0])

# Compute VJP:
# 1. Forward pass to get output and vjp_fun
primals_out_vjp, vjp_fun = jax.vjp(bake_cake_properties, ingredient_vec)

# 2. Backward pass using vjp_fun
# vjp_fun takes the cotangent vector (desired output changes)
# and returns a tuple of gradient-like values for each primal input
ingredient_sensitivities = vjp_fun(desired_property_changes) # Computes v^T @ J

print(f"\n--- VJP Example ---")
print(f"Original properties: {primals_out_vjp}")
# The result has the same structure as the input (a single vector here)
print(f"Ingredient sensitivities for desired change (v^T @ J): {ingredient_sensitivities[0]}")


# Verify: Manually compute v^T @ J
manual_vjp = desired_property_changes @ J_fwd # Note: v is a row vector here
print(f"Manual v^T @ J check: {manual_vjp}")
assert jnp.allclose(ingredient_sensitivities[0], manual_vjp)


--- VJP Example ---
Original properties: [8.862908 5.2      3.714068]
Ingredient sensitivities for desired change (v^T @ J): [ 0.         -0.16666667 -0.04        0.8       ]
Manual v^T @ J check: [ 8.1956386e-10 -1.6666667e-01 -4.0000003e-02  8.0000001e-01]


### Understanding Curvature (Hessians)

Let's define an `overall_deliciousness_score` (scalar output) and find its Hessian to understand the curvature of our deliciousness landscape.

In [None]:
def deliciousness_score(ingredient_vec):
    """Combines properties into a single score."""
    sweetness, fluffiness, cost = bake_cake_properties(ingredient_vec)
    # Example: reward sweetness and fluffiness, penalize cost
    # Penalize deviation from ideal sweetness (e.g., 10) and fluffiness (e.g., 5)
    score = -(sweetness - 10.0)**2 - (fluffiness - 5.0)**2 - 0.5 * cost
    return score

# jax.hessian(f) uses jacfwd(jacrev(f)) because forward-over-reverse is typically the most efficient.
# It can also be written as jacrev(jacfwd(f)) or any other composition of these two.
hessian_score_fn = jax.hessian(deliciousness_score)
H_score = hessian_score_fn(ingredient_vec)

print(f"\n--- Hessian of Deliciousness Score ---")
print(f"Shape (num_ingredients, num_ingredients): {H_score.shape}") # (4, 4)
print("Hessian Matrix:")
print(H_score)
# The Hessian tells us how the gradient of the score changes.
# Negative diagonal elements might suggest we are near a local maximum.
# Off-diagonal elements show how changing one ingredient affects the gradient w.r.t. another.


--- Hessian of Deliciousness Score ---
Shape (num_ingredients, num_ingredients): (4, 4)
Hessian Matrix:
[[ -0.16000001  -0.6666667    0.04        -0.8       ]
 [ -0.6666667   -5.555556     0.           0.        ]
 [  0.04         0.           0.33049622   1.6       ]
 [ -0.8          0.           1.6        -35.657787  ]]


### Efficient Curvature Check (Hessian-Vector Products - HVP)

For recipes with many ingredients, the full Hessian is too big. We can check curvature in one direction `v` using HVP: `H @ v`.

In [None]:
def hvp(f, primals, tangents):
  return jax.jvp(jax.jacrev(f), primals, tangents)[1]

# Direction of ingredient change we want to analyze
ingredient_direction = jnp.array([0.1, -0.2, 0.05, 0.1]) # Change butter, eggs, flour, sugar

# Calculate HVP for the deliciousness score
hvp_score_result = hvp(deliciousness_score, (ingredient_vec,), (ingredient_direction,))

print(f"\n--- HVP of Deliciousness Score ---")
print(f"HVP result in direction {ingredient_direction}:")
print(hvp_score_result)

# Verify manually
manual_hvp_score = H_score @ ingredient_direction
print(f"Manual HVP check: {manual_hvp_score}")
assert jnp.allclose(hvp_score_result, manual_hvp_score)


--- HVP of Deliciousness Score ---
HVP result in direction [ 0.1  -0.2   0.05  0.1 ]:
[ 0.03933334  1.0444444   0.18052481 -3.5657785 ]
Manual HVP check: [ 0.03933333  1.0444446   0.18052481 -3.5657787 ]


###  Controlling Gradient Flow (`jax.lax.stop_gradient`)

In recipe refinement, we might adjust `flour` to meet an `ideal_target_fluffiness`, which itself depends on other ingredients like `butter`, `eggs`, and even `flour`. Algorithmically, when calculating the gradient for `flour`, we might want to treat the `ideal_target_fluffiness` as a *fixed goal* for that step. `jax.lax.stop_gradient` achieves this by using the target's value while blocking gradient flow back through its calculation, thus focusing the update.

In [None]:
def calculate_fluffiness(ingredients_vec):
    """Calculates cake fluffiness."""
    butter, eggs, flour, sugar = ingredients_vec
    return 5 * (eggs / 3.0) - 0.5 * (flour - 2.0)**2 + 0.2 * butter

def predict_ideal_fluffiness(butter, eggs, flour):
    """Calculates target fluffiness, including a flour dependency."""
    # Target depends on wet ingredients and slightly on flour
    return 2.0 + butter * 0.5 + eggs * 1.0 - 0.1 * flour

# Loss comparing actual fluffiness to the calculated ideal target
def fluffiness_loss(ingredients_vec):
    """Calculates loss with full gradient dependencies."""
    butter, eggs, flour, sugar = ingredients_vec
    actual_fluff = calculate_fluffiness(ingredients_vec)
    ideal_target = predict_ideal_fluffiness(butter, eggs, flour)
    return (actual_fluff - ideal_target)**2

# Loss where the ideal target is treated as fixed for gradient calculation
def fluffiness_loss_stopgrad(ingredients_vec):
    """Calculates loss, stopping gradient flow through ideal_target."""
    butter, eggs, flour, sugar = ingredients_vec
    actual_fluff = calculate_fluffiness(ingredients_vec)
    ideal_target = predict_ideal_fluffiness(butter, eggs, flour)
    # Stop gradient flow back through ideal_target calculation
    loss = (actual_fluff - jax.lax.stop_gradient(ideal_target))**2
    return loss

# --- Calculate and Compare Gradients ---
print(f"\n--- stop_gradient Example: Focused Adjustment ---")

# Calculate gradient without stop_gradient
grad_normal_fn = jax.grad(fluffiness_loss)
grad_normal = grad_normal_fn(ingredient_vec)
print(f"Gradient WITHOUT stop_gradient:\n{grad_normal}")
print(f"  Gradient for Flour (idx 2) w/o stop_gradient: {grad_normal[2]:.4f}")

# Calculate gradient with stop_gradient
grad_stopgrad_fn = jax.grad(fluffiness_loss_stopgrad)
grad_stop = grad_stopgrad_fn(ingredient_vec)
print(f"\nGradient WITH stop_gradient on target:\n{grad_stop}")
print(f"  Gradient for Flour (idx 2) w/  stop_gradient: {grad_stop[2]:.4f}")

# The gradients differ because the stop_gradient version ignores dT/dFlour (-0.1 here).


--- stop_gradient Example: Focused Adjustment ---
Gradient WITHOUT stop_gradient:
[ 0.06000023 -0.13333383 -0.02000008  0.        ]
  Gradient for Flour (idx 2) w/o stop_gradient: -0.0200

Gradient WITH stop_gradient on target:
[-0.04000015 -0.3333346   0.          0.        ]
  Gradient for Flour (idx 2) w/  stop_gradient: 0.0000


### Baking for Many Customers (`vmap` + `grad`)

We need to bake slightly different cakes for a batch of customer orders, each with a target `sweetness`. We want the gradient of `(actual_sweetness - target_sweetness)**2` for each customer, using *their* specific target.

In [None]:
def sweetness_error(ingredients_vec, target_sweet):
    """Squared error for sweetness."""
    # bake_cake_properties returns [sweetness, fluffiness, cost]
    actual_sweetness = bake_cake_properties(ingredients_vec)[0]
    return (actual_sweetness - target_sweet)**2

# Batch of target sweetness levels for different customers
target_sweetness_batch = jnp.array([9.5, 10.0, 10.5, 9.8])

# Gradient function for one target
grad_sweet_error_fn = jax.grad(sweetness_error, argnums=0) # Grad w.r.t ingredients

# Map the gradient function over the batch of targets
# Keep ingredients fixed (None), map over target_sweet (axis 0)
per_customer_grad_fn = jax.vmap(grad_sweet_error_fn, in_axes=(None, 0))

# Calculate all gradients
per_customer_ingredient_grads = per_customer_grad_fn(ingredient_vec, target_sweetness_batch)

print(f"\n--- Per-Example Gradients (vmap + grad) ---")
print(f"Shape of gradients: {per_customer_ingredient_grads.shape}") # (batch_size, num_ingredients) -> (4, 4)
print("Gradients per customer order:")
print(per_customer_ingredient_grads)
# Each row tells how to adjust 'ingredient_vec' to reduce sweetness error for that specific customer.


--- Per-Example Gradients (vmap + grad) ---
Shape of gradients: (4, 4)
Gradients per customer order:
[[ -0.12741832   0.           0.25483665  -5.096733  ]
 [ -0.22741833   0.           0.45483667  -9.096733  ]
 [ -0.32741833   0.           0.65483665 -13.096733  ]
 [ -0.18741837   0.           0.37483674  -7.4967346 ]]


# Part 3: Exotic Flavors Using Complex Numbers

Here, we'll explore how JAX handles even more advanced situations using mystical complex numbers in our analysis

### Complex Flavors & Vibrations: Differentiation with Complex Numbers

Sometimes, simple numbers aren't enough. Maybe we're inventing truly magical cakes where flavors interact in ways best described by *complex* numbers (numbers with a 'real' part and an 'imaginary' part, like `3 + 4j`). Or perhaps we're analyzing the complex vibrations (`amplitude` + `phase`) in our mixer batter, knowing they affect the final texture. Can JAX handle derivatives involving these "magic numbers"? Yes!

### The Nuance: Smooth Magic vs. Tricky Magic

Think of complex functions like magic spells:

1. **Holomorphic Spells (Smooth Magic)**: These behave nicely, like stretching and rotating things smoothly. Functions like $f(z)=z^2$ or $f(z)=sin(z)$ are like this. Their derivative is just one complex number, similar to regular calculus.
2. **Non-Holomorphic Spells (Tricky Magic)**: These spells might involve sharp changes or depend on direction. Examples are taking just the real part $(f(z)=Re(z))$ or the complex conjugate $(f(z)=\bar{z}=x−iy)$. Their "derivative" is more complicated; it needs to describe how both the real and imaginary output parts change when either the real or imaginary input part changes (like a 2x2 chart, or Jacobian).

### JAX's Approach: The Universal Magic Wand (JVPs & VJPs)

JAX's fundamental tools, JVP and VJP, are built to handle *all* types of complex magic correctly. They work by tracking the real and imaginary parts separately.

*Think like this*: A complex number $z=x+iy$ is like a point $(x,y)$ on a 2D map. A function $f(z)=u(x,y)+iv(x,y)$ takes a point on one map and gives a point on another map. JAX's JVP and VJP figure out the derivatives for this underlying $R^2→R^2$ map.

* `jvp(f, (z,), (dz,))`: If you wiggle the input `z` by a complex amount `dz`, this tells you the complex wiggle `df` in the output `f(z)`. It works for any complex function `f`.
* `vjp(f, z)`: Gives you the output `f(z)` and a "pullback" function. This pullback function takes a desired complex output wiggle `g` and tells you the complex input wiggle `dz` needed to cause it. It also works for any complex function `f`.

### Using `jax.grad` (The Easy Spellbook)

The `jax.grad` function is simpler but has specific rules for complex numbers:

1. **Magic Input -> Real Score (**`C -> R`**)**: If our magical vibration `z` affects the real `deliciousness_score`, `jax.grad` works directly! It gives a *complex* gradient. This complex number tells you the direction on the complex z plane to move z to increase the score the fastest.

In [None]:
def score_from_vibration(z):
  """Real score based on complex vibration z = x + iy."""
  # Score increases with real part, decreases with distance from origin
  return jnp.real(z) * 3.0 - jnp.real(z * jnp.conjugate(z)) # 3x - (x^2 + y^2)

z_vibration = 2.0 + 1.0j
current_score = score_from_vibration(z_vibration)
print(f"\n--- Complex Input -> Real Output ---")
print(f"Vibration z = {z_vibration}, Score = {current_score:.3f}")

# Get the complex gradient
grad_score_complex_fn = jax.grad(score_from_vibration)
complex_grad = grad_score_complex_fn(z_vibration)
print(f"Complex Gradient ∇_z score: {complex_grad}")
# dScore/dx = 3 - 2x = 3 - 4.0 = -1.0. dScore/dy = 2y = 2.0. Grad = -1.0 + 2.0j


--- Complex Input -> Real Output ---
Vibration z = (2+1j), Score = 1.000
Complex Gradient ∇_z score: (-1+2j)


2. **Magic Input -> Magic Output (**`C -> C`)**:**

  **Holomorphic ("Smooth Magic")**: You must tell `grad` it's smooth magic with `holomorphic=True`. It then calculates the correct single complex derivative.

In [None]:
def smooth_magic_process(z):
  """A holomorphic function."""
  return z**2 + z

# Must use holomorphic=True for C->C
grad_smooth_magic_fn = jax.grad(smooth_magic_process, holomorphic=True)
complex_deriv = grad_smooth_magic_fn(z_vibration)
print(f"\n--- Holomorphic C -> C Example (z^2 + z) ---")
print(f"Input z = {z_vibration}")
print(f"Output f(z) = {smooth_magic_process(z_vibration)}")
print(f"Complex Derivative df/dz: {complex_deriv}")
# Check: derivative is 2z + 1 = 2*(2.0+1.0j) + 1 = 4 + 2.0j + 1 = 5.0 + 2.0j
assert jnp.allclose(complex_deriv, 2 * z_vibration + 1)


--- Holomorphic C -> C Example (z^2 + z) ---
Input z = (2+1j)
Output f(z) = (5+5j)
Complex Derivative df/dz: (5+2j)


* **Non-Holomorphic ("Tricky Magic")**: Using `grad(..., holomorphic=True)` doesn't give you the whole magic picture for tricky spells. We can still write `holomorphic=True` when the function isn't holomorphic (this stops JAX from giving an error just because the output is complex), but the answer we get out won't represent the *full Jacobian* [the complete derivative information].
  
  Instead, it'll usually be the gradient of the function where we just discard the imaginary part of the output: (i.e., the gradient of $Re(f(z))$). This complex number result can be misleading if you need to know how all parts of the magic change. JAX will usually raise an error if you try `jax.grad` on a `C->C` spell without `holomorphic=True`. For the full picture (how real/imaginary inputs affect real/imaginary outputs), use `jax.jacfwd` or `jax.jacrev` on the function viewed as mapping $R^2→R^2$.

In [None]:
def tricky_magic_process(z):
  """Non-holomorphic: complex conjugate."""
  return jnp.conjugate(z) # Example: conj(x+iy) = x-iy

# Using grad with holomorphic=True gives a specific, potentially misleading, result
grad_tricky_magic_fn = jax.grad(tricky_magic_process, holomorphic=True)
misleading_result = grad_tricky_magic_fn(z_vibration)
print(f"\n--- Non-Holomorphic C -> C Example (conj(z)) ---")
print(f"Using grad(..., holomorphic=True) on conj(z): {misleading_result}")
# For conj(z)=x-iy, u=x, v=-y. grad(..., holo=True) gives du/dx + i*du/dy = 1 + 0j = 1.0

# Correct way: Use Jacobian on R^2 -> R^2 view
def tricky_magic_real_view(real_vec):
     z = real_vec[0] + 1j * real_vec[1]
     out_complex = tricky_magic_process(z)
     # Return [real_output, imaginary_output]
     return jnp.array([jnp.real(out_complex), jnp.imag(out_complex)])

jacobian_fn = jax.jacrev(tricky_magic_real_view)
# Use the real/imag parts of z_vibration = 1.2 + 0.5j
real_jacobian_info = jacobian_fn(jnp.array([1.2, 0.5]))
print(f"\nFull Jacobian via R^2 view for conj(z):")
print(real_jacobian_info)
# Expected Jacobian for f(x,y) = [x, -y] w.r.t. [x, y] is [[1, 0], [0, -1]]


--- Non-Holomorphic C -> C Example (conj(z)) ---
Using grad(..., holomorphic=True) on conj(z): (1-0j)

Full Jacobian via R^2 view for conj(z):
[[ 1.  0.]
 [ 0. -1.]]


3. **Real Recipe, Magic Inside (**`R -> R`**)**: If your normal recipe (real inputs/outputs) uses complex math internally (like special mixing algorithms using FFTs), `jax.grad` works just fine and gives the correct real gradients. JAX handles the complex steps internally via JVPs/VJPs.

In [None]:
def calculate_flavor_stability(ingredient_potency_x):
  """
  Calculates a real 'stability' score based on potency 'x'.
  Uses complex math internally. R -> C -> R function.
  """
  # Create a complex number based on the real input potency 'x'
  z_activation = 1.0 + ingredient_potency_x * 1j  # e.g., z = 1 + ix

  # Perform some calculation involving complex numbers
  # Example: compute squared magnitude |z|^2 = (1+ix)(1-ix) = 1 + x^2
  stability_score = z_activation * jnp.conjugate(z_activation)
  return jnp.real(stability_score)

ingredient_potency = 3.0

print(f"Input Ingredient Potency x = {ingredient_potency}")
# Calculate the score directly
stability = calculate_flavor_stability(ingredient_potency)
# Expected: 1 + 3.0^2 = 10.0
print(f"Output Flavor Stability Score = {stability:.4f}")

# Even though we created complex numbers inside, the function maps R -> R.
# jax.grad works seamlessly.
grad_stability_fn = jax.grad(calculate_flavor_stability)
stability_gradient = grad_stability_fn(ingredient_potency)

print(f"\nGradient d(Score)/dx at x={ingredient_potency}: {stability_gradient:.4f}")
# Check: Derivative of 1 + x^2 is 2x. For x=3.0, derivative is 6.0
assert jnp.allclose(stability_gradient, 2 * ingredient_potency)

Input Ingredient Potency x = 3.0
Output Flavor Stability Score = 10.0000

Gradient d(Score)/dx at x=3.0: 6.0000


# Part 4: Secret Family Recipes & Advanced Techniques!

We've seen how our JAX assistant can handle tricky calculations, multiple outputs, and even complex numbers. But sometimes, the standard way JAX figures out derivatives isn't quite right, or we have a special technique we want to use. Maybe JAX's calculation causes a kitchen disaster (`NaN!`), or we have a secret family rule for how much an ingredient should affect the taste.

Fear not! JAX lets us *teach* it our own **Custom Differentiation Rules**. Think of it like passing down a secret family recipe or technique that JAX wouldn't know otherwise. We'll focus on the main tools for this: `jax.custom_jvp` and `jax.custom_vjp`.

### Why Use Custom Rules? (Special Baking Scenarios)

Why would we override JAX's smarts?

* **Avoiding Kitchen Disasters (Numerical Stability)**: Sometimes, a standard calculation (like for yeast activity at high temperatures) might involve numbers so big or small that JAX's automatic derivative gives `NaN` or `inf`. We can provide a mathematically equivalent, but more stable, formula just for the derivative.

* **Enforcing Kitchen Rules (Gradient Modification)**: We might have a strict rule: "Adding more butter should never increase the 'lightness cost' gradient by more than X units". We can enforce this by directly modifying the gradient during the calculation.

* **Handling Secret Processes (Implicit Functions/Loops)**: Our unique sourdough starter might mature in a complex way described by a loop (`while_loop`) that stops when ready. Differentiating through thousands of loop steps is often impossible or inefficient for grad. We can use math shortcuts (like the Implicit Function Theorem) to figure out the derivative of the final state and teach that shortcut to JAX.


### Teaching JAX: `jax.custom_jvp` (Forward Changes)

Use `custom_jvp` when you want to define how the output changes (`output_tangent`) given a specific change in the inputs (`tangents`). JAX can often figure out the `grad` rule from this too!

Imagine yeast activity explodes at high temperatures, causing `exp(temp)` to give `inf` in gradients.

In [None]:
@jax.custom_jvp
def yeast_activity(temp):
  # Activity increases sharply, potentially unstable derivative
  return jnp.log(1. + jnp.exp(temp))

# Custom JVP rule
@yeast_activity.defjvp
def yeast_activity_jvp(primals, tangents):
  temp, = primals
  temp_dot, = tangents # How temperature is changing
  activity = yeast_activity(temp) # Original function value

  # Standard derivative: (1 / (1 + jnp.exp(temp))) * jnp.exp(temp) * temp_dot
  # for high temp effectively turns into 0. * jnp.inf * temp_dot
  # Stable alternative: (1 - 1/(1 + jnp.exp(temp))) * temp_dot
  activity_dot = (1 - 1/(1 + jnp.exp(temp))) * temp_dot # More stable calculation

  return activity, activity_dot

# --- Test ---
high_temp = 500.0 # A high temperature where exp might be huge

print(f"\n--- Custom JVP (Yeast Stability) ---")
print(f"Yeast activity at {high_temp}: {yeast_activity(high_temp)}")
# grad will use the custom rule
print(f"Gradient of activity at {high_temp}: {jax.grad(yeast_activity)(high_temp)}")


--- Custom JVP (Yeast Stability) ---
Yeast activity at 500.0: inf
Gradient of activity at 500.0: 1.0


### Teaching JAX: `jax.custom_vjp` (Backward Sensitivity)

Use `custom_vjp` when you need direct control over the backward pass used by `grad`. This is needed for gradient clipping or handling loops/implicit functions. *Note*: This only defines the rule for grad/reverse-mode; forward-mode (`jvp`) won't work.

Let's enforce a rule: the gradient component for `butter` (index 0) cannot have a magnitude larger than `butter_range`.

1. Decorate `f` with `@custom_vjp`.
2. Define `f_fwd(...)`: Takes same inputs as `f`, returns (`original_output`, `residuals`). `residuals` are your "notes" to pass to the backward step.
3. Define `f_bwd(residuals, output_grad)`: Takes the notes and the incoming gradient (`output_grad` - how later steps want the output to change), returns a tuple of gradients for each input of `f`.
Link them: `f.defvjp(f_fwd, f_bwd)`

In [None]:
@jax.custom_vjp
def pass_through_with_butter_limit(ingredient_vec, butter_range):
  """Passes ingredients through, but limits butter grad."""
  return ingredient_vec

# Forward: Output is input, no notes needed
def butter_limit_fwd(ingredient_vec, butter_range):
  return ingredient_vec, butter_range # output, residuals

# Backward: Apply clipping to butter component (index 0) of incoming grad
def butter_limit_bwd(residuals, grad_in):
  butter_grad = grad_in[0]
  clipped_butter_grad = jnp.clip(butter_grad, -residuals, residuals)
  # Create the output gradient tuple, replacing the butter component
  # use None to indicate zero cotangents for x
  return (grad_in.at[0].set(clipped_butter_grad), None)

pass_through_with_butter_limit.defvjp(butter_limit_fwd, butter_limit_bwd)

# --- Test ---
def score_with_butter_limit(ingredient_vec):
   # Apply the custom rule function before the score
   adjusted_for_grad = pass_through_with_butter_limit(ingredient_vec, 0.1)
   return deliciousness_score(adjusted_for_grad)

# Calculate gradients
grad_original_fn = jax.grad(deliciousness_score)
grad_clipped_fn = jax.grad(score_with_butter_limit)
original_grad = grad_original_fn(ingredient_vec)
clipped_grad = grad_clipped_fn(ingredient_vec)

print(f"\n--- Custom VJP (Butter Gradient Clipping) ---")
print(f"Original gradient:       {original_grad}")
print(f"Gradient w/ butter clip: {clipped_grad}")
print(f"(Note butter grad [index 0] is clipped to +/- 0.1)")


--- Custom VJP (Butter Gradient Clipping) ---
Original gradient:       [-0.1525816  -0.91666603 -0.66476905  8.810353  ]
Gradient w/ butter clip: [-0.1        -0.91666603 -0.66476905  8.810353  ]
(Note butter grad [index 0] is clipped to +/- 0.1)


### Handling a Secret Process (Dough Maturation)

Imagine our sourdough's final "maturity" level depends on a base "boost" (`b` from sugar/yeast) and how much maturity is retained (`a` < 1, from gluten strength) each hour in a loop: `maturity = a * maturity + b`. Finding the final stable maturity `maturity_star` involves this loop. We use `custom_vjp` to find `d(maturity_star)/da` and `d(maturity_star)/db` using a math shortcut, as `grad` might struggle with the loop.

In [None]:
@jax.custom_vjp
def find_final_dough_maturity(params, initial_maturity, tol=1e-6):
  """Finds equilibrium maturity x* where x* = a*x* + b via iteration."""
  a_retention, b_boost = params
  update = lambda x: a_retention * x + b_boost
  # Simple fixed-point loop
  state = (initial_maturity, update(initial_maturity))
  maturity_star = jax.lax.while_loop(lambda s: jnp.abs(s[0]-s[1]) > tol,
                             lambda s: (s[1], update(s[1])),
                             state)[1]
  return maturity_star

# Forward pass: Run iteration, save needed values.
def maturity_fwd(params, initial_maturity, tol=1e-6):
  maturity_star = find_final_dough_maturity(params, initial_maturity, tol)
  return maturity_star, (params, maturity_star) # Save params(a,b) and result x*

# Backward pass: Apply analytical shortcut for x = ax + b derivative.
def maturity_bwd(residuals, maturity_star_bar):
  params, maturity_star = residuals
  a_retention, b_boost = params
  # Analytical shortcut based on implicit function theorem:
  # w = (dL/dx) / (1 - a); dL/da = w*x; dL/db = w
  w = maturity_star_bar / (1.0 - a_retention + 1e-8) # Add epsilon for stability near a=1
  a_grad = w * maturity_star
  b_grad = w
  # Return gradients for inputs: ((a_grad, b_grad), guess_grad)
  return ((a_grad, b_grad), None, None) # No grad for initial guess and tol

# Link the forward and backward rules
find_final_dough_maturity.defvjp(maturity_fwd, maturity_bwd)

current_params = (0.75, 10.0) # Equilibrium = 10 / (1-0.75) = 40
initial_guess = 0.0

print(f"\n--- Custom VJP (Minimal Loop Example) ---")
# Calculate gradient of final maturity w.r.t. params (a, b)
# JAX uses our custom VJP rules, not the loop.
param_sensitivity = jax.grad(find_final_dough_maturity)(current_params, initial_guess)

# Expected analytical: dM*/da=b/(1-a)^2=10/(0.25^2)=160; dM*/db=1/(1-a)=1/0.25=4
print(f"Gradient d(Maturity)/d(params={current_params}): ({param_sensitivity[0]:.2f}, {param_sensitivity[1]:.2f})")



--- Custom VJP (Minimal Loop Example) ---
Gradient d(Maturity)/d(params=(0.75, 10.0)): (160.00, 4.00)


# Conclusion: Bake Like a Pro with JAX! 🍰

Right then, Baker! You've officially graduated from JAX Autodiff Culinary Academy 🎓. You didn't just bake a cake, you *differentiated* the heck out of it! 🤯

You started with simple `jax.grad` (the "more sugar = more happy?" calculation 🤔), but quickly graduated to the wild stuff: Jacobians mapping sugar-to-sadness-and-fluffiness, Hessians checking if your deliciousness peak is pointy or flat, and HVPs for a quick curvature peek without melting your machine 🔥.

You learned the kitchen commands: `jax.lax.stop_gradient` (telling JAX 'just aim for *this* fluffiness target, don't ask where it came from!'), `vmap` (getting picky gradients for *all* your demanding customers at once), and even wrangled complex numbers - because apparently, imaginary flavor dimensions are a thing now, and JAX is cool with it.

Best of all? When JAX just didn't get your ancient sourdough loop-de-loop or that ingredient prone to numerical kitchen fires, you showed it who's boss 😎 with `jax.custom_jvp` and `jax.custom_vjp` - your Secret Family Recipe scrolls 📜.

You're now dangerously equipped to optimize practically anything. Go forth and `grad`! (Maybe bake a real cake now 🎂, you've earned it).