# Migrating Concepts from PyTorch to BrainState

BrainState borrows many familiar ideas from PyTorch—tensor computations,
modules with parameters, automatic differentiation—while leaning on JAX for
JIT compilation and functional programming. This note contrasts the key
building blocks so you can translate existing PyTorch workflows into
BrainState idioms quickly.

## Concept map

| PyTorch | BrainState | Notes |
| --- | --- | --- |
| `torch.Tensor` | `jax.Array` (`jnp.ndarray`) | Manipulated with `jax.numpy` semantics. |
| `nn.Module` | `brainstate.nn.Module` | Define `State` attributes (e.g. `ParamState`, `HiddenState`). |
| `nn.Parameter` | `brainstate.ParamState` | Holds differentiable weights; retrieved via `.states`. |
| `autograd.grad` / `backward()` | `brainstate.transform.grad` | Explicitly select which states or arguments receive gradients. |
| `torch.optim` optimisers | `braintools.optim` (optional) | Works on `.states(brainstate.ParamState)`. |
| `torch.jit.script` / `torch.jit.trace` | `brainstate.transform.jit` | JIT compile pure or stateful functions; integrates with JAX. |
| `state_dict()` / `load_state_dict()` | `brainstate.graph.treefy_states` / `brainstate.graph.update_states` | Serialize/restore state trees. |
| Random number generators (`torch.manual_seed`) | `brainstate.random.seed` / `RandomState` | Keys are JAX PRNGs, automatically split in transforms. |

## PyTorch baseline

Consider a minimal linear regression in PyTorch:

```python
import torch
import torch.nn as nn
import torch.optim as optim

class TorchLinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

model = TorchLinear()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1)

for step in range(100):
    optimizer.zero_grad()
    preds = model(inputs)
    loss = criterion(preds, targets)
    loss.backward()
    optimizer.step()
```

BrainState follows the same logic but makes states and gradients explicit.

## BrainState equivalent

In [1]:
import braintools.file
import jax
import jax.numpy as jnp
import numpy as np

import brainstate
from brainstate.transform import grad, jit
import braintools.optim as optim

# Synthetic dataset
def make_dataset(n=64):
    rng = np.random.default_rng(0)
    x = rng.uniform(-1.0, 1.0, (n, 1)).astype(np.float32)
    y = 3.0 * x + 1.0 + rng.normal(0.0, 0.1, (n, 1)).astype(np.float32)
    return jnp.asarray(x), jnp.asarray(y)

x_train, y_train = make_dataset()

class LinearModel(brainstate.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        k1, k2 = jax.random.split(jax.random.PRNGKey(0))
        self.weight = brainstate.ParamState(jax.random.normal(k1, (in_features, out_features)))
        self.bias = brainstate.ParamState(jax.random.normal(k2, (out_features,)))

    def __call__(self, x):
        return x @ self.weight.value + self.bias.value

model = LinearModel(1, 1)
params = model.states(brainstate.ParamState)
optimizer = optim.SGD(lr=1e-1)
optimizer.register_trainable_weights(params)

@jit
def train_step(x, y):
    def loss_fn():
        preds = model(x)
        return jnp.mean((preds - y) ** 2)

    (grads, loss) = grad(loss_fn, grad_states=params, return_value=True)()
    optimizer.update(grads)
    return loss

for step in range(200):
    loss = train_step(x_train, y_train)
    if step % 40 == 0:
        print(f"step {step:3d}, loss = {float(loss):.4f}")

@jit
def predict(x):
    return model(x)

print('predictions for x=0:', predict(jnp.array([[0.0]])))

step   0, loss = 13.1320
step  40, loss = 0.0144
step  80, loss = 0.0097
step 120, loss = 0.0097
step 160, loss = 0.0097
predictions for x=0: [[1.0059681]]


### Key observations

- Parameters are stored in `ParamState` objects, so gradients are a tree keyed
  by state paths (`params.to_flat()` mirrors `state_dict()`).
- `grad` explicitly lists `grad_states`; argument gradients can be included via
  `argnums` (similar to PyTorch's manual `requires_grad`).
- Optimisers work on state trees instead of implicit parameter lists.

## Saving and loading state

In [2]:
state_tree = brainstate.graph.treefy_states(model)
print('stored keys:', list(state_tree.to_flat().keys()))

# Later (or in another process):
restored = LinearModel(1, 1)
brainstate.graph.update_states(restored, state_tree)
print('restored weight:', restored.weight.value)

stored keys: [('bias',), ('weight',)]
restored weight: [[3.0168793]]


Alternatively, you can use ``braintools.file.msgpack_save`` and ``braintools.file.msgpack_load``. 

In [3]:
braintools.file.msgpack_save('example.msgpack', model.states(brainstate.ParamState))

Saving checkpoint into example.msgpack


## Gradients with additional arguments

Below, we take derivatives w.r.t. both model parameters and an explicit scalar.

In [5]:
scale = jnp.array(0.1)

def scaled_loss(alpha, inputs, targets):
    preds = model(inputs)
    mse = jnp.mean((preds - targets) ** 2)
    return mse + alpha * jnp.sum(model.weight.value ** 2)

(grads_state, alpha_grad), loss_val = grad(
    scaled_loss,
    grad_states=params,
    argnums=0,
    return_value=True,
)(scale, x_train, y_train)

print('loss:', float(loss_val))
print('grad w.r.t alpha:', float(alpha_grad))
for path, g in grads_state.items():
    print(path, g.shape)

loss: 0.9198333024978638
grad w.r.t alpha: 9.101560592651367
('bias',) (1,)
('weight',) (1, 1)


## Random numbers

BrainState wraps JAX PRNG keys. Use `brainstate.random.seed` to set the global
seed, or instantiate a `RandomState` for module-specific randomness. Transforms
like `vmap` and `pmap` split keys automatically per batch element.

In [6]:
import brainstate.random as brandom

brandom.seed(42)
rs = brandom.RandomState()
print('single sample:', rs.normal(size=(2,)))

single sample: [ 0.6630465  -0.72396195]


## Debugging and JIT

BrainState leans on JAX's tooling. `brainstate.transform.jit` works on stateful
functions, while `brainstate.transform.make_jaxpr` inspects the computed graph.

In [7]:
from brainstate.transform import make_jaxpr

jaxpr = make_jaxpr(model)
print(jaxpr(jnp.ones((1,))))

({ [34;1mlambda [39;22m; a[35m:f32[1][39m b[35m:f32[1,1][39m c[35m:f32[1][39m. [34;1mlet
    [39;22md[35m:f32[1][39m = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] a b
    e[35m:f32[1][39m = add d c
  [34;1min [39;22m(e, b, c) }, (ParamState(
  value=ShapedArray(float32[1,1])
), ParamState(
  value=ShapedArray(float32[1])
)))


## Summary

- Replace `nn.Module` + `nn.Parameter` with `brainstate.nn.Module` + `ParamState`.
- Use `brainstate.transform.grad`/`jit` instead of PyTorch autograd and scripting.
- Retrieve and update parameter trees via `graph.treefy_states` and
  `graph.update_states`.
- Optimisers in `braintools.optim` mirror the familiar PyTorch API, operating on
  state dictionaries.

With these substitutions most PyTorch training loops can be ported one module at
a time to BrainState.