# Brainstate Tutorial

This tutorial demonstrates how to build stateful neural systems with [brainstate](https://github.com/chaobrain/brainstate). The walkthrough adapts several scripts from the `examples/` directory and shows how to compose modules, stage transformations with JAX, and handle checkpoints in practice.

## Getting Ready

Install the core packages before running the notebook. On a CPU-only machine you can start with:

```bash
pip install brainstate braintools jax[cpu] matplotlib orbax-checkpoint
```

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import brainstate

## Functional API: Working with States and Graphs

This section mirrors `examples/001_functional_api.py`. We build an `MLP` node that tracks internal call counts, split its states into parameter and auxiliary collections, and train it with `jax.jit`-compiled functions.

In [None]:
rng = np.random.default_rng(0)
X = np.linspace(0.0, 1.0, 100)[:, None]
Y = 0.8 * X ** 2 + 0.1 + rng.normal(0.0, 0.1, size=X.shape)


def dataset(batch_size: int):
    while True:
        idx = rng.choice(len(X), size=batch_size, replace=True)
        yield X[idx], Y[idx]


class Linear(brainstate.nn.Module):
    def __init__(self, din: int, dout: int):
        super().__init__()
        self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
        self.b = brainstate.ParamState(jnp.zeros((dout,)))

    def __call__(self, x):
        return x @ self.w.value + self.b.value


class CallCount(brainstate.State):
    pass


class RegressionMLP(brainstate.graph.Node):
    def __init__(self, din: int, dhidden: int, dout: int):
        self.count = CallCount(jnp.array(0))
        self.linear1 = Linear(din, dhidden)
        self.linear2 = Linear(dhidden, dout)

    def __call__(self, x):
        self.count.value += 1
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        return x

We separate the graph definition from its states using `treefy_split`, define training/evaluation steps, and then merge the updated states back into a full model for logging.

In [None]:
graphdef, params, counts = brainstate.graph.treefy_split(
    RegressionMLP(din=1, dhidden=32, dout=1), brainstate.ParamState, CallCount
)


@jax.jit
def train_step(param_states, count_states, batch):
    x, y = batch

    def loss_fn(pstates):
        model = brainstate.graph.treefy_merge(graphdef, pstates, count_states)
        y_pred = model(x)
        new_counts = brainstate.graph.treefy_states(model, CallCount)
        loss = jnp.mean((y - y_pred) ** 2)
        return loss, new_counts

    grads, count_states = jax.grad(loss_fn, has_aux=True)(param_states)
    param_states = jax.tree.map(lambda w, g: w - 0.1 * g, param_states, grads)
    return param_states, count_states


@jax.jit
def eval_step(param_states, count_states, batch):
    x, y = batch
    model = brainstate.graph.treefy_merge(graphdef, param_states, count_states)
    y_pred = model(x)
    loss = jnp.mean((y - y_pred) ** 2)
    return {'loss': loss}


for step, batch in zip(range(1000), dataset(32)):
    params, counts = train_step(params, counts, batch)
    if step % 200 == 0:
        logs = eval_step(params, counts, (X, Y))
        print(f"step: {step}, loss: {logs['loss']:.4f}")

model = brainstate.graph.treefy_merge(graphdef, params, counts)
print(f"call count: {model.count.value}")

y_pred = model(X)
plt.scatter(X, Y, color='steelblue', label='data')
plt.plot(X, y_pred, color='black', label='fit')
plt.legend()
plt.show()

## Lifted Transforms and Optimizers

Here we rework the regression example following `examples/002_lifted_transforms.py`. The model stays entirely inside a `brainstate.nn.Module`, and `brainstate.transform` decorators lift JAX transformations so that the state collections are updated automatically. We pair the model with `braintools.optim.SGD` to handle parameter updates.

In [None]:
import braintools

class LiftedMLP(brainstate.nn.Module):
    def __init__(self, din: int, dhidden: int, dout: int):
        super().__init__()
        self.count = CallCount(jnp.array(0))
        self.linear1 = Linear(din, dhidden)
        self.linear2 = Linear(dhidden, dout)

    def __call__(self, x):
        self.count.value += 1
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        return x

model2 = LiftedMLP(1, 32, 1)
optimizer = braintools.optim.SGD(1e-3)
optimizer.register_trainable_weights(model2.states(brainstate.ParamState))


@brainstate.transform.jit
def train_step_lifted(batch):
    x, y = batch

    def loss_fn():
        preds = model2(x)
        return jnp.mean((y - preds) ** 2)

    grads = brainstate.transform.grad(loss_fn, optimizer.param_states.to_pytree())()
    optimizer.update(grads)


@brainstate.transform.jit
def eval_step_lifted(batch):
    x, y = batch
    preds = model2(x)
    loss = jnp.mean((y - preds) ** 2)
    return {'loss': loss}

for step, batch in zip(range(1000), dataset(32)):
    train_step_lifted(batch)
    if step % 200 == 0:
        logs = eval_step_lifted((X, Y))
        print(f"[lifted] step: {step}, loss: {logs['loss']:.4f}")

print(f"lifted call count: {model2.count.value}")
plt.scatter(X, Y, color='steelblue', label='data')
plt.plot(X, model2(X), color='darkorange', label='lifted fit')
plt.legend()
plt.show()

## Saving and Loading Checkpoints

Checkpointing follows `examples/005_save_load_checkpoints.py`. We snapshot the full state tree with `orbax.checkpoint`, restore it into an abstractly initialised model, and resume inference without writing custom serialization logic.

In [None]:
import os
from tempfile import TemporaryDirectory

import orbax.checkpoint as orbax


class CheckpointMLP(brainstate.nn.Module):
    def __init__(self, din: int, dmid: int, dout: int):
        super().__init__()
        self.dense1 = brainstate.nn.Linear(din, dmid)
        self.dense2 = brainstate.nn.Linear(dmid, dout)

    def __call__(self, x):
        x = self.dense1(x)
        x = jax.nn.relu(x)
        x = self.dense2(x)
        return x


def create_model(seed: int):
    brainstate.random.seed(seed)
    return CheckpointMLP(10, 20, 30)


def create_and_save(seed: int, path: str):
    model = create_model(seed)
    state_tree = brainstate.graph.treefy_states(model)
    checkpointer = orbax.PyTreeCheckpointer()
    checkpointer.save(os.path.join(path, 'state'), state_tree)


def load_model(path: str) -> CheckpointMLP:
    model = brainstate.transform.abstract_init(lambda: create_model(0))
    state_tree = brainstate.graph.treefy_states(model)
    checkpointer = orbax.PyTreeCheckpointer()
    state_tree = checkpointer.restore(os.path.join(path, 'state'), item=state_tree)
    brainstate.graph.update_states(model, state_tree)
    return model


with TemporaryDirectory() as tmpdir:
    create_and_save(42, tmpdir)
    restored = load_model(tmpdir)
    y = restored(jnp.ones((1, 10)))
    print(restored)
    print(y)

## Where to Go Next

Explore the rest of the `examples/` directory for domain-specific workflows:

- `examples/100_hh_neuron_model.py` shows how to simulate detailed Hodgkin-Huxley neurons.
- `examples/203_brainscale_for_snns.py` demonstrates brain-inspired training loops for spiking neural networks.
- `examples/300_integrator_rnn.py` introduces rate-based recurrent models with integrator dynamics.

The [Brainstate documentation](https://brainstate.readthedocs.io/) expands on graph manipulation, transforms, and interoperability across the BrainX ecosystem.