# Module 1 (Part 2): Automatic Differentiation

**Exercise:** [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kks32-courses/sciml/blob/main/lectures/07-ad/07-ad-exercise.ipynb)
**Solution:** [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kks32-courses/sciml/blob/main/lectures/07-ad/07-ad.ipynb)

## Introduction: Why Do We Need Better Ways to Compute Derivatives?

In scientific computing, we constantly need derivatives. Whether we're solving optimization problems, training neural networks, or inverting geophysical data, derivatives are the mathematical engine that drives our algorithms forward.

Traditionally, we've had three approaches to compute derivatives, each with fundamental limitations:

![Modes of differentiation](figs/differentiation.png)

## The Four Modes of Differentiation

### 1. Manual Differentiation
The traditional calculus approach: derive formulas by hand using calculus rules.

**Pros**: 
- Exact derivatives
- Allows optimization during derivation

**Cons**: 
- Error-prone and time-consuming
- Not scalable for complex functions
- Impractical for functions with thousands of parameters

### 2. Symbolic Differentiation
Let computers apply calculus rules symbolically, like Mathematica or SymPy.

**Pros**:
- Provides exact, closed-form expressions

**Cons**:
- **Expression swell**: derivatives can become exponentially large
- Not feasible when symbolic form isn't available
- Computationally expensive for complex expressions

### 3. Numerical Differentiation
Approximate derivatives using finite differences:
$$ f'(x) \approx \frac{f(x + h) - f(x - h)}{2h} $$

**Pros**:
- Simple to implement
- Works for any function you can evaluate

**Cons**:
- **Accuracy issues**: prone to rounding errors
- **Step size dilemma**: too large → truncation error, too small → round-off error
- **Computational cost**: requires multiple function evaluations

### 4. Automatic Differentiation (AD)
The revolutionary approach: compute exact derivatives alongside function evaluation.

**Pros**:
- **Exact derivatives** (up to machine precision)
- **Efficient**: computational cost proportional to function evaluation
- **General**: works for any differentiable function expressible as code

**Cons**:
- Requires specialized libraries
- Can be memory-intensive for some applications

## The Core Insight: Functions Are Computational Graphs

Every computer program that evaluates a mathematical function can be viewed as a **computational graph**. Consider this simple function:

In [None]:
def f(x1, x2):
    y = x1**2 + x2
    return y

This creates a computational graph where each operation is a node:

![AD graph](figs/ad1.png)

We can decompose this into elementary operations and assign intermediate variables:

![AD graph variables](figs/ad2.png)

This decomposition is the key insight that makes automatic differentiation possible.

## Forward Mode Automatic Differentiation

Forward mode AD computes derivatives by propagating derivative information **forward** through the computational graph, following the same path as the function evaluation.

![AD forward evaluation](figs/ad3.png)

### Forward Mode: Computing $\frac{\partial y}{\partial x_1}$

Starting with our function $y = x_1^2 + x_2$, let's trace through the computation:

1. **Seed the input**: Set $\dot{x}_1 = 1$ and $\dot{x}_2 = 0$ (we're differentiating w.r.t. $x_1$)

2. **Forward propagation**:
   - $v_1 = x_1^2$, so $\dot{v}_1 = 2x_1 \cdot \dot{x}_1 = 2x_1 \cdot 1 = 2x_1$
   - $y = v_1 + x_2$, so $\dot{y} = \dot{v}_1 + \dot{x}_2 = 2x_1 + 0 = 2x_1$

3. **Result**: $\frac{\partial y}{\partial x_1} = 2x_1$

### Forward Mode: Computing $\frac{\partial y}{\partial x_2}$

To get the derivative w.r.t. $x_2$, we seed differently:

1. **Seed the input**: Set $\dot{x}_1 = 0$ and $\dot{x}_2 = 1$

2. **Forward propagation**:
   - $v_1 = x_1^2$, so $\dot{v}_1 = 2x_1 \cdot \dot{x}_1 = 2x_1 \cdot 0 = 0$
   - $y = v_1 + x_2$, so $\dot{y} = \dot{v}_1 + \dot{x}_2 = 0 + 1 = 1$

3. **Result**: $\frac{\partial y}{\partial x_2} = 1$

**Key insight**: Forward mode requires one pass per input variable to compute all partial derivatives.

## Reverse Mode Automatic Differentiation

Reverse mode AD (also called **backpropagation**) computes derivatives by propagating derivative information **backward** through the computational graph.

![Reverse mode AD](figs/ad4.png)

### The Backward Pass Algorithm

1. **Forward pass**: Compute function values and store intermediate results
2. **Seed the output**: Set $\bar{y} = 1$ (derivative of output w.r.t. itself)
3. **Backward pass**: Use the chain rule to propagate derivatives backward

![Chain rule AD](figs/ad5.png)

Let's trace through our example:

![Chain rule AD steps](figs/ad6.png)

![Final chain rule AD](figs/ad7.png)

### Computing All Partial Derivatives in One Pass

The beauty of reverse mode is that it computes **all** partial derivatives in a single backward pass:

1. **Forward pass**: $y = x_1^2 + x_2$ (store intermediate values)

2. **Backward pass with $\bar{y} = 1$**:
   - $\frac{\partial y}{\partial x_1} = \frac{\partial y}{\partial v_1} \cdot \frac{\partial v_1}{\partial x_1} = 1 \cdot 2x_1 = 2x_1$
   - $\frac{\partial y}{\partial x_2} = \frac{\partial y}{\partial x_2} = 1$

**Key insight**: Reverse mode computes gradients w.r.t. all inputs in a single backward pass!

## When to Use Forward vs Reverse Mode

The choice depends on the structure of your problem:

- **Forward Mode**: Efficient when **few inputs, many outputs** (e.g., $f: \mathbb{R}^n \to \mathbb{R}^m$ with $n \ll m$)
- **Reverse Mode**: Efficient when **many inputs, few outputs** (e.g., $f: \mathbb{R}^n \to \mathbb{R}^m$ with $n \gg m$)

In machine learning, we typically have millions of parameters (inputs) and a single loss function (output), making reverse mode the natural choice.

## Automatic Differentiation in Practice: PyTorch

Let's see how automatic differentiation works in PyTorch:

In [None]:
import torch

# Define variables that require gradients
x1 = torch.tensor(2.0, requires_grad=True)
x2 = torch.tensor(3.0, requires_grad=True)

# Define the function
y = x1**2 + x2

# Compute gradients using reverse mode AD
y.backward()

# Access the computed gradients
print(f"dy/dx1: {x1.grad.item()}")  # Should be 2*x1 = 4.0
print(f"dy/dx2: {x2.grad.item()}")  # Should be 1.0

### A More Complex Example: Neural Network

In [None]:
import torch
import torch.nn as nn

# Simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(2, 3)
        self.layer2 = nn.Linear(3, 1)
    
    def forward(self, x):
        x = torch.tanh(self.layer1(x))
        x = self.layer2(x)
        return x

# Create network and data
net = SimpleNet()
x = torch.tensor([[1.0, 2.0]], requires_grad=True)
target = torch.tensor([[0.5]])

# Forward pass
output = net(x)
loss = ((output - target)**2).mean()

# Backward pass - computes gradients for ALL parameters
loss.backward()

# Access gradients
for name, param in net.named_parameters():
    print(f"{name}: gradient shape {param.grad.shape}")

The power of automatic differentiation becomes clear: PyTorch automatically computes gradients for all parameters in the network, regardless of how complex the architecture becomes.

## JAX: Functional Automatic Differentiation

JAX takes a different, more functional approach to automatic differentiation. Instead of tracking gradients on data structures (like PyTorch Tensors), JAX provides **function transformations** that take a function and return a new, transformed function.

### The Core Idea: Function Transformations

JAX's power comes from a few key transformations:

- **`jit` (Just-In-Time Compilation)**: Speeds up your code by compiling it to highly optimized machine code using XLA (Accelerated Linear Algebra). This is especially effective for code with loops.
- **`grad` (Gradient)**: Takes a function and returns a new function that computes its gradient. This is JAX's implementation of reverse-mode AD.
- **`vmap` (Vectorization)**: Automatically vectorizes a function, allowing it to process batches of data without you needing to write explicit loops. It's like adding a batch dimension to your function for free.
- **`pmap` (Parallelization)**: Runs computations in parallel across multiple devices (like GPUs or TPUs).

### Computing Gradients and Hessians

Let's see how to use these transformations to get first and second-order derivatives.

In [None]:
import jax
import jax.numpy as jnp

#### Gradients of Multi-Argument Functions

The `grad` function is incredibly flexible. For a function with multiple inputs, you can specify which argument to differentiate with respect to using `argnums`.

In [None]:
# Define a function with two arguments
def my_func(x, y):
    return x**3 * y + 2*x*y**2

# Create a function that computes the gradient w.r.t. x (the 0-th argument)
grad_x = jax.grad(my_func, argnums=0)

# Create a function that computes the gradient w.r.t. y (the 1st argument)
grad_y = jax.grad(my_func, argnums=1)

# Evaluate at a point (x=2, y=3)
x_val, y_val = 2.0, 3.0
df_dx = grad_x(x_val, y_val)
df_dy = grad_y(x_val, y_val)

print(f"Original function f(2, 3) = {my_func(x_val, y_val)}")
print(f"∂f/∂x at (2, 3) = {df_dx}") # Analytical: 3*x^2*y + 2*y^2 = 3*4*3 + 2*9 = 36 + 18 = 54
print(f"∂f/∂y at (2, 3) = {df_dy}") # Analytical: x^3 + 4*x*y = 8 + 4*2*3 = 8 + 24 = 32

#### Higher-Order Derivatives: Hessians

Because JAX transformations are composable, computing higher-order derivatives is straightforward. The Hessian is the Jacobian of the gradient. We can compute it by composing `jax.jacfwd` (or `jax.jacrev`) with `jax.grad`.

In [None]:
# Let's use a simpler function for clarity
def scalar_func(x):
    return x[0]**2 * x[1] + x[1]**3

# The Hessian is the jacobian of the gradient
hessian_func = jax.jacfwd(jax.grad(scalar_func))

# Evaluate at a point
point = jnp.array([2.0, 3.0])
hessian_matrix = hessian_func(point)

print("Hessian matrix at (2, 3):")
print(hessian_matrix)

# --- For verification ---
# Gradient: ∇f = [2xy, x^2 + 3y^2]
# At (2,3): ∇f = [12, 4 + 27] = [12, 31]
#
# Hessian: H = [[∂²f/∂x², ∂²f/∂y∂x],
#               [∂²f/∂x∂y, ∂²f/∂y²]]
#
# H = [[2y, 2x],
#      [2x, 6y]]
#
# At (2,3): H = [[6, 4],
#               [4, 18]]

## Computational Considerations

### Memory vs Computation Trade-offs

**Forward Mode**:
- Memory: O(1) additional storage
- Computation: O(n) for n input variables

**Reverse Mode**:
- Memory: O(computation graph size)
- Computation: O(1) for any number of input variables

### Modern Optimizations

1. **Checkpointing**: Trade computation for memory by recomputing intermediate values
2. **JIT compilation**: Compile computational graphs for faster execution
3. **Parallelization**: Distribute gradient computation across multiple devices

## The Mathematical Foundation

Automatic differentiation works because of a fundamental theorem:

**Chain Rule**: For composite functions $f(g(x))$:
$$\frac{d}{dx}f(g(x)) = f'(g(x)) \cdot g'(x)$$

By systematically applying the chain rule to each operation in a computational graph, AD can compute exact derivatives for arbitrarily complex functions.

## Summary: The AD Revolution

Automatic differentiation has revolutionized scientific computing by making gradients:

1. **Ubiquitous**: Available for any differentiable computation
2. **Exact**: No approximation errors (up to machine precision)
3. **Efficient**: Computational cost proportional to function evaluation
4. **Automatic**: No manual derivation required

This has enabled new approaches to:
- Machine learning (deep learning wouldn't exist without AD)
- Scientific computing (differentiable simulators)
- Optimization (gradient-based methods for complex problems)
- Parameter estimation (automatic gradient computation for inverse problems)

The combination of automatic differentiation with modern hardware (GPUs, TPUs) has created unprecedented opportunities for scientific discovery through computation.

## Looking Forward

As we move into the era of scientific machine learning, automatic differentiation serves as the mathematical engine that powers:
- Physics-informed neural networks
- Neural differential equations  
- Differentiable simulators
- End-to-end optimization of scientific workflows

The ability to compute gradients automatically through arbitrary computational processes is not just a technical convenience—it's a fundamental capability that's reshaping how we approach scientific problems in the 21st century.