# JIT Compilation in BrainState

`brainstate.transform.jit` extends `jax.jit` with state tracking and extra control
surfaces. This guide highlights how BrainState JIT differs from plain JAX JIT,
when to prefer each API, and how to decompose modules with
`brainstate.graph.treefy_split` or `brainstate.graph.treefy_states`.

In [1]:
import jax
import jax.numpy as jnp

import brainstate

## Why BrainState JIT?

`brainstate.transform.jit` understands `State` objects and automatically wires
read/write traces into the compiled function. The returned object is a
`JittedFunction` with helper methods such as `compile`, `clear_cache`, and
`origin_fun`. Pure functions still work, but stateful modules are first-class
citizens.

In [2]:
@brainstate.transform.jit
def softplus(x: jax.Array) -> jax.Array:
    return jnp.log1p(jnp.exp(-jnp.abs(x))) + jnp.maximum(x, 0)

xs = jnp.linspace(-5.0, 5.0, 7)
softplus(xs)

Array([0.00671535, 0.03505242, 0.17300805, 0.69314724, 1.839675  ,
       3.368386  , 5.0067153 ], dtype=float32)

Subsequent calls reuse the compiled executable. If you disable JIT globally
(`jax.config.jax_disable_jit = True`), BrainState falls back to the original
Python implementation automatically.

In [12]:
with jax.disable_jit():
    softplus(xs * 2.0)

## Stateful modules with zero boilerplate

BrainState keeps modules stateful inside compiled code. Below, a running-mean
tracker updates hidden state at each call without any manual intervention.

In [4]:
class RunningMean(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.sum = brainstate.HiddenState(jnp.array(0.0))
        self.count = brainstate.HiddenState(jnp.array(0))

    def __call__(self, batch: jax.Array) -> jax.Array:
        self.sum.value += jnp.sum(batch)
        self.count.value += batch.size
        return self.sum.value / self.count.value


tracker = RunningMean()

@brainstate.transform.jit
def update_running_mean(batch: jax.Array) -> jax.Array:
    return tracker(batch)

for step in range(3):
    data = jnp.arange(4.0) + step
    print(f'step {step}: mean={float(update_running_mean(data)):.2f}')

float(tracker.sum.value), int(tracker.count.value)

step 0: mean=1.50
step 1: mean=2.00
step 2: mean=2.50


(30.0, 12)

The hidden states remain in sync because BrainState records and replays the
state updates around the compiled executable.

## Extra controls exposed by `JittedFunction`

Unlike bare `jax.jit`, BrainState's wrapper exposes runtime helpers. You can
precompile executables or drop cached traces explicitly.

In [5]:
softplus.compile(jnp.ones((4,)))
softplus(jnp.ones((4,)))

Array([1.3132617, 1.3132617, 1.3132617, 1.3132617], dtype=float32)

In [6]:
softplus.clear_cache()
softplus(jnp.linspace(-1.0, 1.0, 5))

Array([0.3132617, 0.474077 , 0.6931472, 0.974077 , 1.3132617], dtype=float32)

## Working directly with `jax.jit`

If you prefer raw JAX primitives you can still make modules jit-friendly by
splitting them into pure stateless functions. `brainstate.graph.treefy_split`
returns a `GraphDef` plus one or more state trees that you must thread manually.

In [7]:
model = RunningMean()

graph_def, hidden_state_tree = brainstate.graph.treefy_split(model, brainstate.HiddenState)


def running_mean_stateless(state_tree, batch):
    module = brainstate.graph.treefy_merge(graph_def, state_tree)
    out = module(batch)
    new_state_tree = brainstate.graph.treefy_states(module, brainstate.HiddenState)
    return out, new_state_tree


jax_jitted = jax.jit(running_mean_stateless)

state_tree = hidden_state_tree
for step in range(3):
    batch = jnp.arange(4.0) + step
    mean, state_tree = jax_jitted(state_tree, batch)
    print(f'step {step}: mean={float(mean):.2f}')

int(state_tree['count'].value), float(state_tree['sum'].value)

step 0: mean=1.50
step 1: mean=2.00
step 2: mean=2.50


(12, 30.0)

The JAX version works, but you are responsible for threading state containers and
reconstructing modules yourself.

## `treefy_split` vs `treefy_states`

Both helpers live in `brainstate.graph` but serve different purposes:

- **`treefy_split`** → returns `(graph_def, state_tree1, state_tree2, ...)`. Use
  it when you need to rebuild modules (e.g. JAX interop or serialising complete
  graphs).
- **`treefy_states`** → returns one or more state trees without the graph
  definition. It's the lightweight choice when you only need a PyTree of
  parameters for optimisation or checkpointing.


See also [BrainState Graph and Node System](../utilities/01_graph_operations.ipynb) for more details of how to use these interfaces.

In [8]:
class TinyLinear(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = brainstate.ParamState(jnp.array([[1.0]]))
        self.bias = brainstate.ParamState(jnp.array([0.0]))

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.weight.value + self.bias.value


lin = TinyLinear()

# Split into graph + states (useful for reconstruction / JAX interop)
lin_graph, param_tree, other_states = brainstate.graph.treefy_split(
    lin, brainstate.ParamState, ...,
)
print('treefy_split param paths:', list(param_tree.to_flat().keys()))

# Fetch only the parameter tree (perfect for gradient updates)
params_only = brainstate.graph.treefy_states(lin, brainstate.ParamState)
print('treefy_states param paths:', list(params_only.to_flat().keys()))

treefy_split param paths: [('bias',), ('weight',)]
treefy_states param paths: [('bias',), ('weight',)]


In [10]:
# Example: compute gradients w.r.t. ParamState using brainstate.transform.grad
def mse_loss(params, others, x):
    lin_recovered = brainstate.graph.treefy_merge(lin_graph, params, others)
    pred = lin_recovered(x)
    target = 2.0 * x + 1.0
    return jnp.mean((pred - target) ** 2)

loss_grad = jax.value_and_grad(mse_loss)

(loss_value, grads) = loss_grad(param_tree, other_states, jnp.array([[0.0], [1.0]]))
print('loss:', float(loss_value))
for path, g in grads.items():
    print('grad', path, g)

loss: 2.5
grad bias TreefyState(
  type=<class 'brainstate.ParamState'>,
  value=Array([-3.], dtype=float32),
  tag=None
)
grad weight TreefyState(
  type=<class 'brainstate.ParamState'>,
  value=Array([[-2.]], dtype=float32),
  tag=None
)


`treefy_states` drops directly into optimisation pipelines: you obtain a PyTree
keyed by parameter paths without carrying the `GraphDef` unless you plan to
reconstruct the module elsewhere.

## Static arguments still apply

Static-argument handling mirrors `jax.jit`. The example below specialises the
compiled program by polynomial degree.

In [11]:
@brainstate.transform.jit(static_argnums=1)
def polynomial_series(x: jax.Array, degree: int) -> jax.Array:
    powers = [x ** (i + 1) for i in range(degree)]
    coeffs = jnp.arange(1, degree + 1, dtype=x.dtype)
    return jnp.tensordot(coeffs, jnp.stack(powers, axis=0), axes=1)


p1 = polynomial_series(jnp.array([1.0, 2.0]), 3)
p2 = polynomial_series(jnp.array([1.0, 2.0]), 3)
p3 = polynomial_series(jnp.array([1.0, 2.0]), 4)
print(p1, p2, p3)

[ 6. 34.] [ 6. 34.] [10. 98.]


## Which API should you choose?

| Scenario | `brainstate.transform.jit` | `jax.jit` |
| -------- | -------------------------- | --------- |
| Stateful BrainState modules | ✅ Zero boilerplate | ⚠️ Requires `treefy_split` and manual state threading |
| Pure stateless functions | ✅ Works (with helper methods) | ✅ Often the leanest choice |
| Need `compile()` / `clear_cache()` | ✅ Built-in | ❌ Not available |
| Custom sharding / device placement | ✅ Same signature as `jax.jit` | ✅ |

`treefy_split` is the workhorse when you need a `GraphDef` for reconstruction or
JAX interop. `treefy_states` is the light option for extracting parameter
PyTrees, for example before calling `brainstate.transform.grad` or saving a
checkpoint.

## Summary

- `brainstate.transform.jit` tracks BrainState `State` objects automatically and
  returns a `JittedFunction` with extra controls.
- `jax.jit` still works, but you must explicitly split and merge module state.
- `graph.treefy_split` produces `(graph_def, state_tree1, state_tree2, …)` for
  reconstruction; `graph.treefy_states` returns just the requested state trees.
- Choose the interface that matches your workflow: use BrainState JIT for
  module-centric code, drop down to JAX primitives when integrating with other
  systems.