# Automatic Differentiation in BrainState

BrainState provides a comprehensive automatic differentiation system built on top of JAX, designed specifically for stateful computations. This tutorial focuses on `brainstate.transform.grad` and related gradient transformations, demonstrating how to compute gradients with respect to function arguments and `State` objects.

## Key Concepts

BrainState's gradient system revolves around two key concepts:

1. **`argnums`**: Select which function arguments to differentiate with respect to (inherited from JAX)
2. **`grad_states`**: Select which `State` objects should receive gradients (BrainState's extension)

Additionally, BrainState uses **`ParamState`** to mark trainable parameters in neural networks and provides utilities to discover and manage states in arbitrary functions.

In [4]:
import jax
import jax.numpy as jnp

import brainstate
from brainstate.transform import grad, StateFinder

## 1. Understanding `argnums`: Gradients w.r.t. Function Arguments

The `argnums` parameter works just like in JAX's `jax.grad`. It specifies which positional arguments to differentiate with respect to.

In [5]:
def loss_fn(x, y, scale):
    """Simple loss function with multiple arguments."""
    return scale * jnp.sum((x - y) ** 2)

x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([0.5, 1.5, 2.5])
scale = 2.0

# Gradient w.r.t. the first argument (x)
grad_fn_x = grad(loss_fn, argnums=0)
grad_x = grad_fn_x(x, y, scale)
print("Gradient w.r.t. x:", grad_x)

# Gradient w.r.t. multiple arguments
grad_fn_xy = grad(loss_fn, argnums=[0, 1])
grad_x, grad_y = grad_fn_xy(x, y, scale)
print("Gradient w.r.t. x:", grad_x)
print("Gradient w.r.t. y:", grad_y)

Gradient w.r.t. x: [2. 2. 2.]
Gradient w.r.t. x: [2. 2. 2.]
Gradient w.r.t. y: [-2. -2. -2.]


## 2. Understanding `grad_states`: Gradients w.r.t. State Objects

### 2.1 ParamState for Trainable Parameters

In BrainState, **`ParamState`** is used to mark parameters that should receive gradients during training. This is the standard way to define trainable parameters in neural network modules.

In [7]:
class LinearRegressor(brainstate.nn.Module):
    """Simple linear regression model."""
    
    def __init__(self, in_features: int, out_features: int = 1):
        super().__init__()
        # ParamState marks these as trainable parameters
        self.weight = brainstate.ParamState(jnp.zeros((in_features, out_features)))
        self.bias = brainstate.ParamState(jnp.zeros((out_features,)))

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.weight.value + self.bias.value


# Create model and training data
model = LinearRegressor(1)
xs = jnp.linspace(-1.0, 1.0, 5).reshape(-1, 1)
y_true = 3.0 * xs + 1.0


def mse_loss(x: jax.Array, target: jax.Array) -> jax.Array:
    """Mean squared error loss."""
    pred = model(x)
    return jnp.mean((pred - target) ** 2)


# Compute gradients w.r.t. model parameters
loss_grad = grad(
    mse_loss,
    grad_states=model.states(brainstate.ParamState),  # Get all ParamState instances
    return_value=True,
)

param_grads, loss_value = loss_grad(xs, y_true)
print(f"Loss: {float(loss_value):.4f}")
print("\nParameter gradients:")
for path, g in param_grads.items():
    print(f"  {path}: {g}")

Loss: 5.5000

Parameter gradients:
  ('bias',): [-2.]
  ('weight',): [[-3.]]


### 2.2 Retrieving States from Modules

BrainState provides two main ways to retrieve states from modules:

1. **`module.states(*filter)`**: Get states directly from a `Module` instance
2. **`brainstate.graph.treefy_states(node, *filter)`**: Get states from any object (more general)

In [9]:
# Method 1: Using module.states()
params_method1 = model.states(brainstate.ParamState)
print("Using model.states():")
for path, state in params_method1.items():
    print(f"  {path}: shape={state.value.shape}")

# Method 2: Using brainstate.graph.treefy_states()
params_method2 = brainstate.graph.treefy_states(model, brainstate.ParamState)
print("\nUsing brainstate.graph.treefy_states():")
for path, state in params_method2.to_flat().items():
    print(f"  {path}: shape={state.value.shape}")

# Both methods return the same states
assert set(params_method1.keys()) == set(params_method2.to_flat().keys())

Using model.states():
  ('bias',): shape=(1,)
  ('weight',): shape=(1, 1)

Using brainstate.graph.treefy_states():
  ('bias',): shape=(1,)
  ('weight',): shape=(1, 1)


### 2.3 Using StateFinder for Arbitrary Functions

Not every function is a `nn.Module`. For arbitrary functions, you can use **`StateFinder`** to discover which states are used inside the function.

In [11]:
# Create some standalone states
scale = brainstate.ParamState(jnp.array(1.5), name="scale")
offset = brainstate.ParamState(jnp.array(-0.2), name="offset")
cache = brainstate.State(jnp.array(0.0), name="cache")  # Not a ParamState


def energy(x: jax.Array) -> jax.Array:
    """Energy function using external states."""
    shifted = x * scale.value + offset.value
    # Update a state to track it as a write operation
    scale.value = scale.value + 0.0  # Dummy update to mark as written
    cache.value = jnp.sum(shifted)  # Write to cache
    return jnp.sum(jnp.square(shifted))


# Use StateFinder to discover states used in the function
finder = StateFinder(
    energy,
    filter=brainstate.ParamState,  # Only find ParamState instances
    usage='all',  # Find both read and write states
    return_type='dict',  # Return as a dictionary
)

all_param_states = finder(jnp.ones((2,)))
print("States found by StateFinder:")
for name, state in all_param_states.items():
    print(f"  {name}: {state.name}")

# Now compute gradients w.r.t. these discovered states
energy_grad = grad(
    energy,
    grad_states=all_param_states,
    return_value=True,
)

state_grads, energy_value = energy_grad(jnp.array([1.0, 3.0]))
print(f"\nEnergy: {float(energy_value):.4f}")
print("Gradients:")
for idx, key in enumerate(state_grads):
    st = all_param_states[key]
    print(f"  {key}: {state_grads[key]}")

States found by StateFinder:
  0: scale
  1: offset

Energy: 20.1800
Gradients:
  0: 28.400001525878906
  1: 11.200000762939453


### 2.4 Important Note: Gradients are Not Limited to ParamState

While `ParamState` is the standard way to mark trainable parameters, **gradient computation works with any `State` instance**. You can compute gradients w.r.t. any `State` object.

In [12]:
# Create a regular State (not ParamState)
regular_state = brainstate.State(jnp.array(2.0), name="regular_state")


def compute_with_state(x):
    return jnp.sum((x * regular_state.value) ** 2)


# Compute gradient w.r.t. regular State
grad_fn = grad(compute_with_state, grad_states=[regular_state])
gradient = grad_fn(jnp.array([1.0, 2.0, 3.0]))
print(f"Gradient w.r.t. regular State: {gradient[0]}")

Gradient w.r.t. regular State: 56.0


## 3. Combining `argnums` and `grad_states`

You can compute gradients with respect to both function arguments and states simultaneously.

In [14]:
reg_model = LinearRegressor(1)


def penalized_loss(l2_coeff: float, inputs: jax.Array, target: jax.Array) -> jax.Array:
    """Loss with L2 regularization."""
    pred = reg_model(inputs)
    mse = jnp.mean((pred - target) ** 2)
    # L2 penalty on parameters
    l2 = jnp.sum(reg_model.weight.value ** 2) + jnp.sum(reg_model.bias.value ** 2)
    return mse + l2_coeff * l2


# Compute gradients w.r.t. both states and the first argument
grad_penalized = grad(
    penalized_loss,
    grad_states=reg_model.states(brainstate.ParamState),
    argnums=0,  # Also differentiate w.r.t. l2_coeff
    return_value=True,
)

(state_grads, coeff_grad), loss_val = grad_penalized(0.5, xs, y_true)
print(f"Loss: {float(loss_val):.4f}")
print(f"Gradient w.r.t. l2_coeff: {float(coeff_grad):.4f}")
print("\nState gradients:")
for path, g in state_grads.items():
    print(f"  {path}: {g}")

Loss: 5.5000
Gradient w.r.t. l2_coeff: 0.0000

State gradients:
  ('bias',): [-2.]
  ('weight',): [[-3.]]


## 4. Return Value Structures

All gradient transformations in BrainState share a common signature pattern. The return structure depends on the combination of `grad_states`, `argnums`, `has_aux`, and `return_value`.

### 4.1 Basic Return Structures

When `grad_states` is None:

- `has_aux=False` + `return_value=False` → `arg_grads`
- `has_aux=True` + `return_value=False` → `(arg_grads, aux_data)`
- `has_aux=False` + `return_value=True` → `(arg_grads, loss_value)`
- `has_aux=True` + `return_value=True` → `(arg_grads, loss_value, aux_data)`

When `grad_states` is not None and `argnums` is None:

- `has_aux=False` + `return_value=False` → `var_grads`
- `has_aux=True` + `return_value=False` → `(var_grads, aux_data)`
- `has_aux=False` + `return_value=True` → `(var_grads, loss_value)`
- `has_aux=True` + `return_value=True` → `(var_grads, loss_value, aux_data)`

When both `grad_states` and `argnums` are not None:

- `has_aux=False` + `return_value=False` → `(var_grads, arg_grads)`
- `has_aux=True` + `return_value=False` → `((var_grads, arg_grads), aux_data)`
- `has_aux=False` + `return_value=True` → `((var_grads, arg_grads), loss_value)`
- `has_aux=True` + `return_value=True` → `((var_grads, arg_grads), loss_value, aux_data)`


List them as a table for clarity:

| grad_states | argnums | has_aux | return_value | result |
|-------------|---------|---------|--------------|--------|
| `None` | any | `False` | `False` | `arg_grads` |
| `None` | any | `True` | `False` | `(arg_grads, aux)` |
| `None` | any | `False` | `True` | `(arg_grads, loss)` |
| `None` | any | `True` | `True` | `(arg_grads, loss, aux)` |
| not `None` | `None` | `False` | `False` | `var_grads` |
| not `None` | `None` | `True` | `False` | `(var_grads, aux)` |
| not `None` | `None` | `False` | `True` | `(var_grads, loss)` |
| not `None` | `None` | `True` | `True` | `(var_grads, loss, aux)` |
| not `None` | not `None` | `False` | `False` | `(var_grads, arg_grads)` |
| not `None` | not `None` | `True` | `False` | `((var_grads, arg_grads), aux)` |
| not `None` | not `None` | `False` | `True` | `((var_grads, arg_grads), loss)` |
| not `None` | not `None` | `True` | `True` | `((var_grads, arg_grads), loss, aux)` |



### 4.2 Complete Example: All Return Options

In [15]:
example_model = LinearRegressor(1)


def loss_with_metrics(l2_coeff: float, x: jax.Array, target: jax.Array):
    """Loss function that returns auxiliary metrics."""
    pred = example_model(x)
    mse = jnp.mean((pred - target) ** 2)
    l2 = jnp.sum(example_model.weight.value ** 2)
    loss = mse + l2_coeff * l2
    
    # Return loss and auxiliary metrics
    metrics = {
        "mae": jnp.mean(jnp.abs(pred - target)),
        "mse": mse,
        "l2": l2,
    }
    return loss, metrics


# Example: grad_states + argnums + has_aux + return_value
grad_complete = grad(
    loss_with_metrics,
    grad_states=example_model.states(brainstate.ParamState),
    argnums=0,
    has_aux=True,
    return_value=True,
)

((state_grads, coeff_grad), loss_val, aux_metrics) = grad_complete(0.3, xs, y_true)

print(f"Loss: {float(loss_val):.4f}")
print(f"\nGradient w.r.t. l2_coeff: {float(coeff_grad):.4f}")
print("\nState gradients:")
for path, g in state_grads.items():
    print(f"  {path}: {g}")
print("\nAuxiliary metrics:")
for key, val in aux_metrics.items():
    print(f"  {key}: {float(val):.4f}")

Loss: 5.5000

Gradient w.r.t. l2_coeff: 0.0000

State gradients:
  ('bias',): [-2.]
  ('weight',): [[-3.]]

Auxiliary metrics:
  l2: 0.0000
  mae: 2.0000
  mse: 5.5000


## 5. Other Gradient Transformations

BrainState provides several other gradient transformations, all sharing the same signature pattern as `grad`.

### 5.1 Vector Gradient

`vector_grad` is used for vector-valued functions. It computes the sum of gradients across all output dimensions.

In [16]:
from brainstate.transform import vector_grad


def vector_fun(x):
    """Vector-valued function."""
    return jnp.array([x[0] * x[1], jnp.sin(x[0]), x[0]**2 + x[1]**2])


x0 = jnp.array([1.0, 2.0])

# Vector gradient sums gradients across all outputs
vgrad = vector_grad(vector_fun)
result = vgrad(x0)
print("Vector gradient:", result)

Vector gradient: [4.5403023 5.       ]


### 5.2 Jacobians: `jacrev` and `jacfwd`

- **`jacrev`**: Jacobian using reverse-mode autodiff (efficient for many inputs, few outputs)
- **`jacfwd`**: Jacobian using forward-mode autodiff (efficient for few inputs, many outputs)
- **`jacobian`**: Alias for `jacrev`

In [17]:
from brainstate.transform import jacrev, jacfwd, jacobian


def multi_output(x):
    """Function with multiple outputs."""
    return jnp.array([x[0] * x[1], jnp.sin(x[0]), jnp.exp(x[1])])


x0 = jnp.array([1.0, 2.0])

# Reverse-mode Jacobian
jac_rev = jacrev(multi_output)
result_rev = jac_rev(x0)
print("Jacobian (reverse-mode):")
print(result_rev)

# Forward-mode Jacobian
jac_fwd = jacfwd(multi_output)
result_fwd = jac_fwd(x0)
print("\nJacobian (forward-mode):")
print(result_fwd)

# They should be the same
assert jnp.allclose(result_rev, result_fwd)

# jacobian is an alias for jacrev
jac_alias = jacobian(multi_output)
result_alias = jac_alias(x0)
assert jnp.allclose(result_rev, result_alias)

Jacobian (reverse-mode):
[[2.        1.       ]
 [0.5403023 0.       ]
 [0.        7.389056 ]]

Jacobian (forward-mode):
[[2.        1.       ]
 [0.5403023 0.       ]
 [0.        7.389056 ]]


### 5.3 Hessian

`hessian` computes second-order derivatives.

In [18]:
from brainstate.transform import hessian


def quadratic(x):
    """Quadratic function."""
    return jnp.dot(x, x)


x0 = jnp.array([1.0, 2.0])

hess_fn = hessian(quadratic)
result = hess_fn(x0)
print("Hessian:")
print(result)

# For a quadratic form x^T x, the Hessian is 2*I
expected = 2 * jnp.eye(2)
assert jnp.allclose(result, expected)

Hessian:
[[2. 0.]
 [0. 2.]]


### 5.4 Using Gradient Transformations with States

In [20]:
# Example: Jacobian with states
jac_model = LinearRegressor(2)


def model_output(x):
    """Multiple outputs from a model."""
    return jac_model(x)


# Compute Jacobian w.r.t. model parameters
jac_states = jacrev(
    model_output,
    grad_states=jac_model.states(brainstate.ParamState)
)

x_input = jnp.array([1.0, 2.0])
param_jacobian = jac_states(x_input)

print("Jacobian w.r.t. parameters:")
for path, jac in param_jacobian.items():
    print(f"  {path}: shape={jac.shape}")

Jacobian w.r.t. parameters:
  ('bias',): shape=(1, 1)
  ('weight',): shape=(1, 2, 1)


## 6. Custom Gradient Transformations with `GradientTransform`

You can create custom gradient transformations by using the `GradientTransform` class. This allows you to wrap any JAX gradient function while maintaining BrainState's state-aware behavior.

### 6.1 Basic Custom Transform

In [22]:
from brainstate.transform import GradientTransform


def scaled_grad_transform(fun, *, argnums, has_aux, scale):
    """Custom gradient transform that scales gradients."""
    # Use JAX's grad as the base transformation
    base = jax.grad(fun, argnums=argnums, has_aux=True)

    def wrapped(*args, **kwargs):
        grads, aux = base(*args, **kwargs)
        # Scale all gradients
        grads = jax.tree.map(lambda g: scale * g, grads)
        return grads, aux

    return wrapped


def scaled_grad(
    fun,
    *,
    scale=1.0,
    grad_states=None,
    argnums=None,
    has_aux=False,
    return_value=False,
):
    """Create a gradient function with scaled gradients."""
    return GradientTransform(
        fun,
        transform=scaled_grad_transform,
        grad_states=grad_states,
        argnums=argnums,
        has_aux=has_aux,
        return_value=return_value,
        transform_params={"scale": scale},  # Pass custom parameters
    )


# Example usage
custom_model = LinearRegressor(1)


def custom_loss(x, target):
    pred = custom_model(x)
    return jnp.mean((pred - target) ** 2)


# Use custom scaled gradient
scaled_grad_fn = scaled_grad(
    custom_loss,
    scale=0.5,  # Scale gradients by 0.5
    grad_states=custom_model.states(brainstate.ParamState),
)

scaled_grads = scaled_grad_fn(xs, y_true)
print("Scaled gradients:")
for path, g in scaled_grads.items():
    print(f"  {path}: {g}")

# Compare with unscaled gradients
normal_grad_fn = grad(custom_loss, grad_states=custom_model.states(brainstate.ParamState))
normal_grads = normal_grad_fn(xs, y_true)
print("\nNormal gradients:")
for path, g in normal_grads.items():
    print(f"  {path}: {g}")

Scaled gradients:
  ('bias',): [-1.]
  ('weight',): [[-1.5]]

Normal gradients:
  ('bias',): [-2.]
  ('weight',): [[-3.]]


### 6.2 Advanced: Gradient Clipping Transform

In [23]:
def clipped_grad_transform(fun, *, argnums, has_aux, max_norm):
    """Custom gradient transform with gradient clipping."""
    base = jax.grad(fun, argnums=argnums, has_aux=True)

    def wrapped(*args, **kwargs):
        grads, aux = base(*args, **kwargs)
        
        # Compute global norm
        global_norm = jnp.sqrt(
            sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads))
        )
        
        # Clip gradients
        scale = jnp.minimum(1.0, max_norm / (global_norm + 1e-6))
        grads = jax.tree.map(lambda g: scale * g, grads)
        
        return grads, aux

    return wrapped


def clipped_grad(
    fun,
    *,
    max_norm=1.0,
    grad_states=None,
    argnums=None,
    has_aux=False,
    return_value=False,
):
    """Create a gradient function with gradient clipping."""
    return GradientTransform(
        fun,
        transform=clipped_grad_transform,
        grad_states=grad_states,
        argnums=argnums,
        has_aux=has_aux,
        return_value=return_value,
        transform_params={"max_norm": max_norm},
    )


# Example: gradient clipping
clip_model = LinearRegressor(1)


def clip_loss(x, target):
    pred = clip_model(x)
    return jnp.mean((pred - target) ** 2)


clipped_grad_fn = clipped_grad(
    clip_loss,
    max_norm=0.1,  # Clip gradients to max norm of 0.1
    grad_states=clip_model.states(brainstate.ParamState),
)

clipped_grads = clipped_grad_fn(xs, y_true)
print("Clipped gradients:")
for path, g in clipped_grads.items():
    print(f"  {path}: {g}")
    print(f"    norm: {jnp.linalg.norm(g):.4f}")

Clipped gradients:
  ('bias',): [-0.05547]
    norm: 0.0555
  ('weight',): [[-0.08320501]]
    norm: 0.0832


## 7. Practical Example: Training Loop

Let's put everything together in a complete training example.

In [28]:
# Create a fresh model
training_model = LinearRegressor(1)

# Generate training data
true_weight = 3.0
true_bias = 1.0
x_train = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
y_train = true_weight * x_train + true_bias + 0.1 * brainstate.random.normal(size=x_train.shape)


@brainstate.transform.jit
def training_loss(x, y):
    """MSE loss with L2 regularization."""
    pred = training_model(x)
    mse = jnp.mean((pred - y) ** 2)
    l2 = 0.01 * (jnp.sum(training_model.weight.value ** 2) + jnp.sum(training_model.bias.value ** 2))
    return mse + l2, {"mse": mse, "l2": l2}


# Create gradient function
loss_grad_fn = grad(
    training_loss,
    grad_states=training_model.states(brainstate.ParamState),
    has_aux=True,
    return_value=True,
)

# Training loop
learning_rate = 0.1
num_epochs = 50

print("Training started...")
print(f"Initial weight: {training_model.weight.value}")
print(f"Initial bias: {training_model.bias.value}")

for epoch in range(num_epochs):
    # Compute gradients
    grads, loss_val, aux = loss_grad_fn(x_train, y_train)
    
    # Update parameters (simple SGD)
    for path, state in training_model.states(brainstate.ParamState).items():
        grad = grads[path]
        state.value = state.value - learning_rate * grad
    
    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"\nEpoch {epoch + 1}:")
        print(f"  Loss: {float(loss_val):.4f}")
        print(f"  MSE: {float(aux['mse']):.4f}")
        print(f"  L2: {float(aux['l2']):.6f}")
        print(f"  Weight: {training_model.weight.value}")
        print(f"  Bias: {training_model.bias.value}")

print("\nTraining completed!")
print(f"Final weight: {training_model.weight.value} (true: {true_weight})")
print(f"Final bias: {training_model.bias.value} (true: {true_bias})")

Training started...
Initial weight: [[0.]]
Initial bias: [0.]

Epoch 10:
  Loss: 0.8954
  MSE: 0.8662
  L2: 0.029213
  Weight: [[1.5811116]]
  Bias: [0.89346814]

Epoch 20:
  Loss: 0.2654
  MSE: 0.2050
  L2: 0.060385
  Weight: [[2.300836]]
  Bias: [0.98703194]

Epoch 30:
  Loss: 0.1382
  MSE: 0.0604
  L2: 0.077845
  Weight: [[2.628456]]
  Bias: [0.9968299]

Epoch 40:
  Loss: 0.1119
  MSE: 0.0254
  L2: 0.086541
  Weight: [[2.7775886]]
  Bias: [0.997856]

Epoch 50:
  Loss: 0.1065
  MSE: 0.0158
  L2: 0.090662
  Weight: [[2.8454742]]
  Bias: [0.9979635]

Training completed!
Final weight: [[2.8454742]] (true: 3.0)
Final bias: [0.9979635] (true: 1.0)


## Summary

In this tutorial, we covered:

1. **`argnums`**: Specify which function arguments to differentiate (inherited from JAX)
2. **`grad_states`**: Specify which `State` objects should receive gradients (BrainState extension)
3. **`ParamState`**: Standard way to mark trainable parameters in modules
4. **Retrieving states**: Use `module.states()` or `brainstate.graph.treefy_states()`
5. **`StateFinder`**: Discover states used in arbitrary functions
6. **Return structures**: How `has_aux` and `return_value` affect the output
7. **Other transforms**: `vector_grad`, `jacrev`, `jacfwd`, `jacobian`, `hessian`
8. **Custom transforms**: Build your own using `GradientTransform`

### Key Takeaways

- All gradient transformations share the same signature and return structure patterns
- `ParamState` is the standard for trainable parameters, but gradients work with any `State`
- `StateFinder` helps discover states in arbitrary functions
- `GradientTransform` enables custom gradient transformations while maintaining state-awareness
- The system seamlessly integrates JAX's autodiff with BrainState's stateful computation model