# Tutorial 10: Automatic Differentiation

In this tutorial, we'll explore automatic differentiation (autodiff) in BrainState, which is essential for training neural networks and optimizing models.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Compute gradients using `grad`
- Calculate Jacobian and Hessian matrices
- Use `vector_grad` for efficient vector gradients
- Work with the `GradientTransform` class
- Apply gradients in training loops
- Understand gradient flow through stateful modules

## What is Automatic Differentiation?

Automatic differentiation is a technique for computing derivatives of functions automatically. Unlike:
- **Numerical differentiation**: Uses finite differences (inaccurate and slow)
- **Symbolic differentiation**: Manipulates mathematical expressions (can be complex)

Automatic differentiation uses the chain rule to compute exact derivatives efficiently.

In [None]:
import brainstate as bst
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
bst.random.seed(42)

## 1. Basic Gradient Computation with `grad`

The `grad` function computes the gradient of a scalar function with respect to its inputs.

In [None]:
# Simple function: f(x) = x^2
def square(x):
    return x ** 2

# Compute gradient: df/dx = 2x
grad_square = bst.transform.grad(square)

x = 3.0
print(f"f({x}) = {square(x)}")
print(f"df/dx at x={x}: {grad_square(x)}")
print(f"Expected (2*x): {2*x}")

### Gradients of More Complex Functions

In [None]:
# f(x) = sin(x^2)
def complex_fn(x):
    return jnp.sin(x ** 2)

# df/dx = 2x * cos(x^2)
grad_complex = bst.transform.grad(complex_fn)

# Visualize function and gradient
x_vals = jnp.linspace(-3, 3, 100)
y_vals = jnp.array([complex_fn(x) for x in x_vals])
grad_vals = jnp.array([grad_complex(x) for x in x_vals])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(x_vals, y_vals, 'b-', linewidth=2)
ax1.set_xlabel('x')
ax1.set_ylabel('f(x) = sin(x²)')
ax1.set_title('Function')
ax1.grid(True, alpha=0.3)

ax2.plot(x_vals, grad_vals, 'r-', linewidth=2)
ax2.set_xlabel('x')
ax2.set_ylabel("f'(x) = 2x·cos(x²)")
ax2.set_title('Gradient')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 2. Gradients with Multiple Arguments

By default, `grad` computes the gradient with respect to the first argument. Use `argnums` to specify which argument(s) to differentiate.

In [None]:
# f(x, y) = x^2 * y + y^3
def multi_arg_fn(x, y):
    return x**2 * y + y**3

# Gradient with respect to first argument (x)
grad_x = bst.transform.grad(multi_arg_fn, argnums=0)

# Gradient with respect to second argument (y)
grad_y = bst.transform.grad(multi_arg_fn, argnums=1)

# Gradient with respect to both arguments
grad_both = bst.transform.grad(multi_arg_fn, argnums=(0, 1))

x, y = 2.0, 3.0
print(f"f({x}, {y}) = {multi_arg_fn(x, y)}")
print(f"∂f/∂x = {grad_x(x, y)} (expected: 2xy = {2*x*y})")
print(f"∂f/∂y = {grad_y(x, y)} (expected: x² + 3y² = {x**2 + 3*y**2})")
print(f"Both gradients: {grad_both(x, y)}")

## 3. Gradients of Vector-Valued Functions

For vector inputs, `grad` computes the gradient of a scalar output with respect to the vector input.

In [None]:
# L2 norm squared: f(x) = ||x||^2 = sum(x_i^2)
def norm_squared(x):
    return jnp.sum(x ** 2)

# Gradient: df/dx_i = 2*x_i
grad_norm = bst.transform.grad(norm_squared)

x = jnp.array([1.0, 2.0, 3.0])
print(f"x = {x}")
print(f"f(x) = {norm_squared(x)}")
print(f"∇f = {grad_norm(x)}")
print(f"Expected (2x): {2*x}")

### Example: Gradient Descent

Let's use gradients to minimize a quadratic function.

In [None]:
# Quadratic bowl: f(x, y) = (x-1)^2 + (y-2)^2
# Minimum at (1, 2)
def quadratic_bowl(params):
    x, y = params
    return (x - 1)**2 + (y - 2)**2

grad_fn = bst.transform.grad(quadratic_bowl)

# Gradient descent
params = jnp.array([5.0, -3.0])  # Start far from minimum
learning_rate = 0.1
history = [params]

for i in range(50):
    grads = grad_fn(params)
    params = params - learning_rate * grads
    history.append(params)

history = jnp.array(history)

# Visualize optimization path
fig, ax = plt.subplots(figsize=(8, 6))

# Create contour plot
x_range = jnp.linspace(-1, 6, 100)
y_range = jnp.linspace(-4, 4, 100)
X, Y = jnp.meshgrid(x_range, y_range)
Z = (X - 1)**2 + (Y - 2)**2

contour = ax.contour(X, Y, Z, levels=20, cmap='viridis', alpha=0.6)
ax.clabel(contour, inline=True, fontsize=8)

# Plot optimization path
ax.plot(history[:, 0], history[:, 1], 'ro-', linewidth=2, markersize=4, label='Optimization path')
ax.plot(history[0, 0], history[0, 1], 'go', markersize=10, label='Start')
ax.plot(history[-1, 0], history[-1, 1], 'r*', markersize=15, label='End')
ax.plot(1, 2, 'b*', markersize=15, label='True minimum')

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('Gradient Descent Optimization')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final params: {history[-1]}")
print(f"Final loss: {quadratic_bowl(history[-1])}")

## 4. Jacobian Matrices

The Jacobian is the matrix of all first-order partial derivatives of a vector-valued function.

In [None]:
# Vector-valued function: R^2 -> R^3
# f([x, y]) = [x^2 + y, x*y, y^2]
def vector_fn(xy):
    x, y = xy
    return jnp.array([x**2 + y, x*y, y**2])

# Compute Jacobian
jacobian_fn = bst.transform.jacobian(vector_fn)

xy = jnp.array([2.0, 3.0])
jac = jacobian_fn(xy)

print("Input:", xy)
print("Output:", vector_fn(xy))
print("\nJacobian matrix:")
print(jac)
print("\nExpected Jacobian:")
print("[[2x,  1],   [[4, 1],")
print(" [ y,  x],  = [3, 2],")
print(" [ 0, 2y]]    [0, 6]]")

### Jacobian for Neural Network Layer

In [None]:
# Simple linear transformation: y = Wx + b
def linear_layer(x, W, b):
    return W @ x + b

# Create sample data
W = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])  # 3x2 matrix
b = jnp.array([0.1, 0.2, 0.3])
x = jnp.array([1.0, 2.0])

# Jacobian with respect to input x
jacobian_x = bst.transform.jacobian(lambda x: linear_layer(x, W, b))

jac_x = jacobian_x(x)
print("Jacobian w.r.t. input x:")
print(jac_x)
print("\nThis should equal W:")
print(W)
print("\nMatch:", jnp.allclose(jac_x, W))

## 5. Hessian Matrices

The Hessian is the matrix of all second-order partial derivatives.

In [None]:
# f(x, y) = x^3 + y^3 + 3xy
def hessian_example(xy):
    x, y = xy
    return x**3 + y**3 + 3*x*y

# Compute Hessian
hessian_fn = bst.transform.hessian(hessian_example)

xy = jnp.array([1.0, 2.0])
hess = hessian_fn(xy)

print("Hessian matrix:")
print(hess)
print("\nExpected Hessian:")
print("[[6x, 3 ],   [[6, 3],")
print(" [3,  6y]]  = [3, 12]]")

### Using Hessian for Second-Order Optimization

Newton's method uses the Hessian for faster convergence.

In [None]:
# Rosenbrock function: f(x,y) = (1-x)^2 + 100(y-x^2)^2
# Minimum at (1, 1)
def rosenbrock(xy):
    x, y = xy
    return (1 - x)**2 + 100 * (y - x**2)**2

grad_rb = bst.transform.grad(rosenbrock)
hess_rb = bst.transform.hessian(rosenbrock)

# Newton's method
params = jnp.array([0.0, 0.0])
history_newton = [params]

for i in range(10):
    g = grad_rb(params)
    H = hess_rb(params)
    # Newton update: x_new = x - H^{-1} * g
    params = params - jnp.linalg.solve(H, g)
    history_newton.append(params)

history_newton = jnp.array(history_newton)

# Compare with gradient descent
params_gd = jnp.array([0.0, 0.0])
history_gd = [params_gd]
lr = 0.001

for i in range(1000):
    g = grad_rb(params_gd)
    params_gd = params_gd - lr * g
    if i % 100 == 0:
        history_gd.append(params_gd)

history_gd = jnp.array(history_gd)

print(f"Newton's method final: {history_newton[-1]} (in {len(history_newton)} steps)")
print(f"Gradient descent final: {history_gd[-1]} (in {len(history_gd)*100} steps)")
print(f"True minimum: [1, 1]")

## 6. Vector Gradients with `vector_grad`

`vector_grad` is more efficient for computing gradients of vector-valued loss functions.

In [None]:
# Example: gradient of multiple samples
def batch_loss(params, batch):
    # params: model parameters
    # batch: multiple inputs
    # Returns: vector of losses (one per sample)
    return jnp.sum((params - batch) ** 2, axis=1)

params = jnp.array([1.0, 2.0])
batch = jnp.array([[1.0, 1.0], 
                   [2.0, 2.0],
                   [3.0, 3.0]])

# vector_grad computes gradient for each sample
vgrad_fn = bst.transform.vector_grad(batch_loss)

grads = vgrad_fn(params, batch)
print("Per-sample gradients:")
print(grads)
print("\nShape:", grads.shape)  # (3, 2) - one gradient per sample

## 7. GradientTransform Class

The `GradientTransform` class provides a high-level interface for gradient computation with stateful modules.

In [None]:
# Create a simple neural network module
class SimpleNet(bst.graph.Node):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.w1 = bst.ParamState(bst.random.randn(input_dim, hidden_dim) * 0.1)
        self.b1 = bst.ParamState(jnp.zeros(hidden_dim))
        self.w2 = bst.ParamState(bst.random.randn(hidden_dim, output_dim) * 0.1)
        self.b2 = bst.ParamState(jnp.zeros(output_dim))
    
    def __call__(self, x):
        h = jnp.tanh(x @ self.w1.value + self.b1.value)
        return h @ self.w2.value + self.b2.value

# Create network
net = SimpleNet(input_dim=2, hidden_dim=4, output_dim=1)

# Define loss function
def loss_fn(predictions, targets):
    return jnp.mean((predictions - targets) ** 2)

# Create gradient transform
grad_transform = bst.transform.GradientTransform(net, loss_fn)

# Generate sample data
x_train = bst.random.randn(32, 2)
y_train = bst.random.randn(32, 1)

# Compute loss and gradients
predictions = net(x_train)
loss, grads = grad_transform(predictions, y_train)

print(f"Loss: {loss}")
print(f"\nGradients keys: {list(grads.keys())}")
print(f"\nGradient w.r.t. w1 shape: {grads['w1'].shape}")
print(f"Gradient w.r.t. b1 shape: {grads['b1'].shape}")

## 8. Training with Gradients

Let's put it all together and train a network using gradient descent.

In [None]:
# Create a regression dataset
def generate_data(n_samples=100):
    x = bst.random.randn(n_samples, 2)
    # True function: y = 3*x1 - 2*x2 + 1 + noise
    y = 3*x[:, 0:1] - 2*x[:, 1:2] + 1 + 0.1*bst.random.randn(n_samples, 1)
    return x, y

x_train, y_train = generate_data(100)
x_test, y_test = generate_data(20)

# Create network
class RegressionNet(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.linear1 = bst.nn.Linear(2, 8, w_init=bst.init.KaimingNormal())
        self.linear2 = bst.nn.Linear(8, 1, w_init=bst.init.KaimingNormal())
    
    def __call__(self, x):
        x = self.linear1(x)
        x = jnp.tanh(x)
        x = self.linear2(x)
        return x

model = RegressionNet()

# Training function
def train_step(x, y, learning_rate=0.01):
    # Forward pass
    def loss_fn():
        predictions = model(x)
        return jnp.mean((predictions - y) ** 2)
    
    # Compute loss and gradients
    loss_val, grads = bst.augment.grad(loss_fn, model.states(bst.ParamState), return_value=True)()
    
    # Update parameters
    for key, grad in grads.items():
        model.states()[key].value -= learning_rate * grad
    
    return loss_val

# Training loop
train_losses = []
test_losses = []

for epoch in range(200):
    # Train
    train_loss = train_step(x_train, y_train, learning_rate=0.01)
    train_losses.append(train_loss)
    
    # Evaluate
    with bst.environ.context(fit=False):
        test_pred = model(x_test)
        test_loss = jnp.mean((test_pred - y_test) ** 2)
        test_losses.append(test_loss)
    
    if epoch % 50 == 0:
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}")

# Plot training curves
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', alpha=0.7)
plt.plot(test_losses, label='Test Loss', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot predictions vs targets
plt.subplot(1, 2, 2)
with bst.environ.context(fit=False):
    final_pred = model(x_test)
plt.scatter(y_test, final_pred, alpha=0.6)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', label='Perfect prediction')
plt.xlabel('True values')
plt.ylabel('Predictions')
plt.title('Test Set Predictions')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Advanced: Custom Gradient Rules

Sometimes you want to define custom gradients for specific operations.

In [None]:
import jax

# Define a function with custom gradient
@jax.custom_vjp
def clip_gradient(x, threshold=1.0):
    # Forward pass: identity function
    return x

def clip_gradient_fwd(x, threshold=1.0):
    return x, threshold

def clip_gradient_bwd(threshold, g):
    # Backward pass: clip gradients
    return (jnp.clip(g, -threshold, threshold), None)

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

# Test it
def loss_with_clipping(x):
    x = clip_gradient(x, threshold=0.5)
    return jnp.sum(x ** 3)

x = jnp.array([2.0, -3.0, 1.0])
grad_fn = bst.transform.grad(loss_with_clipping)
gradients = grad_fn(x)

print(f"Input: {x}")
print(f"Unclipped gradients would be: {3 * x**2}")
print(f"Clipped gradients: {gradients}")
print("Notice gradients are clipped to [-0.5, 0.5]")

## 10. Gradient Checkpointing for Memory Efficiency

For very deep networks, you can trade computation for memory using gradient checkpointing.

In [None]:
# Deep network that might run out of memory
class DeepNet(bst.graph.Node):
    def __init__(self, n_layers=50):
        super().__init__()
        self.layers = [bst.nn.Linear(10, 10) for _ in range(n_layers)]
    
    def __call__(self, x):
        for layer in self.layers:
            x = jnp.tanh(layer(x))
        return x

# With checkpointing (saves memory during backprop)
class DeepNetCheckpointed(bst.graph.Node):
    def __init__(self, n_layers=50, checkpoint_every=10):
        super().__init__()
        self.checkpoint_every = checkpoint_every
        self.layers = [bst.nn.Linear(10, 10) for _ in range(n_layers)]
    
    def __call__(self, x):
        for i, layer in enumerate(self.layers):
            x = jnp.tanh(layer(x))
            # Checkpoint at intervals
            if (i + 1) % self.checkpoint_every == 0:
                x = jax.checkpoint(lambda y: y)(x)
        return x

# Note: Actual memory savings would be visible in real training
# This is just a demonstration of the pattern
model_checkpointed = DeepNetCheckpointed(n_layers=20, checkpoint_every=5)
x = bst.random.randn(4, 10)
output = model_checkpointed(x)
print(f"Output shape: {output.shape}")
print("Checkpointing allows training deeper networks with limited memory")

## Summary

In this tutorial, we covered:

1. **Basic Gradients**: Using `grad()` for scalar functions
2. **Multiple Arguments**: Computing gradients with respect to different arguments
3. **Vector Inputs**: Gradients of functions with vector inputs
4. **Jacobians**: Computing full Jacobian matrices for vector-valued functions
5. **Hessians**: Second-order derivatives for optimization
6. **Vector Gradients**: Efficient per-sample gradients with `vector_grad`
7. **GradientTransform**: High-level gradient computation for modules
8. **Training Loops**: Practical gradient-based optimization
9. **Custom Gradients**: Defining custom backward passes
10. **Memory Optimization**: Gradient checkpointing for deep networks

## Key Takeaways

- BrainState's automatic differentiation builds on JAX's autodiff
- `grad()` is your primary tool for gradient computation
- Use `jacobian()` and `hessian()` for higher-order derivatives
- `GradientTransform` simplifies gradient computation for stateful modules
- Custom gradients allow fine-grained control over backpropagation
- Gradient checkpointing trades computation for memory

## Next Steps

In the next tutorial, we'll explore:
- **Vectorization with vmap and pmap**
- Batching operations automatically
- Parallel computation across devices
- StatefulMapping for stateful transformations