# Automatic Differentiation with State-Aware Gradients

`brainstate.transform.grad` behaves like `jax.grad` but understands BrainState's
state system. This notebook shows how to differentiate functions that depend on
`State` objects, return auxiliary values, and mix parameter and argument
sensitivities.

In [1]:
from __future__ import annotations

import jax
import jax.numpy as jnp

import brainstate
from brainstate.transform import grad

## Example data

We'll fit a tiny linear regression problem so gradients are easy to interpret.

In [2]:
xs = jnp.linspace(-1.0, 1.0, 5).reshape(-1, 1)
y_true = 3.0 * xs + 1.0

## 1. Gradients with respect to `ParamState`

Parameters live in `ParamState` containers. `grad_states` tells the transform to
differentiate with respect to those states.

In [3]:
class LinearRegressor(brainstate.nn.Module):
    def __init__(self, in_features: int, out_features: int = 1):
        super().__init__()
        self.weight = brainstate.ParamState(jnp.zeros((in_features, out_features)))
        self.bias = brainstate.ParamState(jnp.zeros((out_features,)))

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.weight.value + self.bias.value


model = LinearRegressor(in_features=1)


def mse_loss(x: jax.Array, target: jax.Array) -> jax.Array:
    prediction = model(x)
    return jnp.mean((prediction - target) ** 2)


loss_and_grads = grad(mse_loss, grad_states=model.states(brainstate.ParamState), return_value=True)
param_grads, loss_value = loss_and_grads(xs, y_true)
print('loss:', float(loss_value))
print('gradients:')
for path, g in param_grads.items():
    print(' ', path, g)

loss: 5.5
gradients:
  ('bias',) [-2.]
  ('weight',) [[-3.]]


The gradient dictionary mirrors the parameter tree. Updating the parameters is
as simple as iterating through both structures.

In [4]:
learning_rate = 0.1
params = model.states(brainstate.ParamState)
for path, state in params.items():
    state.value = state.value - learning_rate * param_grads[path]

print('updated weight:', model.weight.value)
print('updated bias:', model.bias.value)

updated weight: [[0.3]]
updated bias: [0.2]


## 2. Returning auxiliary data

Set `has_aux=True` when your function returns extra values (e.g. metrics).
With `return_value=True` you receive `(grads, loss, aux)`.

In [5]:
model_aux = LinearRegressor(1)


def loss_with_metrics(x: jax.Array, target: jax.Array):
    pred = model_aux(x)
    mse = jnp.mean((pred - target) ** 2)
    metrics = {
        'mae': jnp.mean(jnp.abs(pred - target)),
        'mean_pred': jnp.mean(pred),
    }
    return mse, metrics


grad_with_aux = grad(
    loss_with_metrics,
    grad_states=model_aux.states(brainstate.ParamState),
    has_aux=True,
    return_value=True,
)

param_grads_aux, loss_val_aux, metrics = grad_with_aux(xs, y_true)
print('loss:', float(loss_val_aux))
print('grad(weight):', param_grads_aux[('weight',)])
print('metrics:', {k: float(v) for k, v in metrics.items()})

loss: 5.5
grad(weight): [[-3.]]
metrics: {'mae': 2.0, 'mean_pred': 0.0}


## 3. Gradients w.r.t. states *and* arguments

Provide `argnums` to differentiate with respect to positional arguments while
also differentiating states. The result bundles `(state_grads, arg_grads)`.

In [6]:
model_reg = LinearRegressor(1)


def penalised_loss(l2_coeff: float, x: jax.Array, target: jax.Array) -> jax.Array:
    pred = model_reg(x)
    mse = jnp.mean((pred - target) ** 2)
    reg = l2_coeff * (jnp.sum(model_reg.weight.value ** 2) + jnp.sum(model_reg.bias.value ** 2))
    return mse + reg


grad_penalised = grad(
    penalised_loss,
    grad_states=model_reg.states(brainstate.ParamState),
    argnums=0,
    return_value=True,
)

(grads_pair, loss_val_penalised) = grad_penalised(0.1, xs, y_true)
state_grads, coeff_grad = grads_pair
print('loss:', float(loss_val_penalised))
print('coeff gradient:', float(coeff_grad))
for path, g in state_grads.items():
    print('state grad', path, g)

loss: 5.5
coeff gradient: 0.0
state grad ('bias',) [-2.]
state grad ('weight',) [[-3.]]


## Summary

- Pass `grad_states` to target specific `State` objects.
- `has_aux` and `return_value` control whether loss values or metrics are
  returned alongside gradients.
- Combine `grad_states` with `argnums` to differentiate both states and regular
  arguments in one call.