# Parallelisation

`brainstate.transform.pmap` mirrors `jax.pmap` while keeping BrainState `State`
objects consistent across devices. This notebook explains how to configure the
API, how random states behave under device parallelism, and how `pmap` reuses
the same `StatefulMapping` infrastructure as `vmap`.

In [1]:
import jax
import jax.numpy as jnp

import brainstate
from brainstate.transform import pmap
from brainstate.util.filter import OfType

## Configuring devices

For CPU-only demonstrations we can provision multiple devices per host by
setting `jax_num_cpu_devices` before importing JAX. (If you are running on GPU
or TPU you can skip this cell; the environment will report the hardware devices
it already sees.)

In [2]:
jax.config.update('jax_num_cpu_devices', 8)
print('local device count:', jax.local_device_count())

local device count: 8


## 1. Core arguments of `pmap`

`pmap` accepts the same signature as `jax.pmap` plus BrainState-specific
keywords (`state_in_axes`, `state_out_axes`, `unexpected_out_state_mapping`).
Use `axis_name` to enable collectives and `devices` / `backend` when you want to
pin the computation to specific hardware.

In [3]:
class Affine(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.weight = brainstate.ParamState(jnp.ones((size,)))

    def __call__(self, delta):
        self.weight.value = self.weight.value + delta
        return self.weight.value


model = Affine(size=jax.local_device_count())
axis_name = 'devices'

pmapped_update = pmap(
    model,
    axis_name=axis_name,
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.ParamState)},
    state_out_axes={0: OfType(brainstate.ParamState)},
)

# Each device receives a different delta vector
per_device_delta = jnp.arange(jax.local_device_count() * 4.).reshape(jax.local_device_count(), 4)
updated = pmapped_update(per_device_delta)
print('updated shape:', updated.shape)
print('final weights:', model.weight.value)

updated shape: (8, 4)
final weights: [[ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]
 [ 9. 10. 11. 12.]
 [13. 14. 15. 16.]
 [17. 18. 19. 20.]
 [21. 22. 23. 24.]
 [25. 26. 27. 28.]
 [29. 30. 31. 32.]]


### axis_size and devices

`axis_size` is inferred from the device list if possible. It is useful when you
want to simulate a smaller logical mesh than the number of physical devices.
`devices` lets you provide an explicit list of JAX devices to map over.

In [4]:
logical_devices = jax.devices()[:2]
model = Affine(size=len(logical_devices))

pairwise_update = pmap(
    model,
    axis_name='pair',
    in_axes=0,
    out_axes=0,
    devices=logical_devices,
    state_in_axes={0: OfType(brainstate.ParamState)},
    state_out_axes={0: OfType(brainstate.ParamState)},
)

deltas = jnp.stack([jnp.ones((4,)), -jnp.ones((4,))], axis=0)
pairwise_update(deltas)
print('weights after pairwise update:', model.weight.value)

weights after pairwise update: [[2. 2. 2. 2.]
 [0. 0. 0. 0.]]


### Handling static arguments and donation

Most `jax.pmap` flags pass straight through: `static_broadcasted_argnums` keeps
an argument constant across devices, while `donate_argnums` can improve memory
usage by letting the compiler reuse buffers.

In [5]:
@pmap(axis_name=axis_name, in_axes=(0, None), out_axes=0)
def add_with_scale(delta, scale):
    return delta + scale

add_with_scale(jnp.arange(jax.local_device_count()), 0.5)

Array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5], dtype=float32, weak_type=True)

## 2. Random-number semantics

As with `vmap`, BrainState splits `RandomState` keys automatically so that each
device sees a different stream. This makes stochastic simulations reproducible
without manual key management.

In [6]:
rand_state = brainstate.random.RandomState(0)

@pmap(
    axis_name='devices',
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.random.RandomState)},
    state_out_axes={0: OfType(brainstate.random.RandomState)},
)
def sample_normal(scale):
    return brainstate.random.normal(0.0, scale)

per_device_scales = jnp.linspace(1.0, 2.0, jax.local_device_count())
sample_normal(per_device_scales)

Array([ 1.23822   , -0.2782504 , -1.9162552 , -0.21000428,  0.41403982,
       -0.7870412 , -1.6281602 , -1.1573448 ], dtype=float32)

If you need identical keys on all devices, use `jax.random` explicitly and mark
the key input as static (`in_axes=None`).

In [7]:
shared_key = jax.random.PRNGKey(0)

@pmap(axis_name='devices', in_axes=(None, 0), out_axes=0)
def sample_shared(key, scale):
    return jax.random.normal(key, ()) * scale

sample_shared(shared_key, per_device_scales)

Array([1.6226422, 1.8544483, 2.0862544, 2.3180604, 2.5498662, 2.7816722,
       3.0134785, 3.2452843], dtype=float32)

## 3. Relationship to `StatefulMapping`

`pmap` creates a `StatefulMapping` under the hood, just like `vmap`. The wrapper
analyzes state usage, constructs IR for the batched computation, and restores
state values after every parallel execution.

In [8]:
parallel_mapping = pmap(
    model,
    axis_name='devices',
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.ParamState)},
    state_out_axes={0: OfType(brainstate.ParamState)},
)

print(type(parallel_mapping))
print('origin fun:', parallel_mapping.origin_fun)
print('state_in_axes:', parallel_mapping.state_in_axes)

<class 'brainstate.transform.StatefulMapping'>
origin fun: Affine(
  weight=ParamState(
    value=ShapedArray(float32[2,4])
  )
)
state_in_axes: {0: OfType(<class 'brainstate.ParamState'>)}


Advanced users can construct `StatefulMapping` directly, selecting their own
mapping primitive. Below we recreate the earlier example but pass an explicit
`jax.pmap` with custom donation settings.

In [9]:
from brainstate.transform import StatefulMapping

model = Affine(size=jax.local_device_count())

custom_pmap = StatefulMapping(
    model,
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.ParamState)},
    state_out_axes={0: OfType(brainstate.ParamState)},
    axis_name='devices',
    mapping_fn=lambda fun, *a, **kw: jax.pmap(fun, donate_argnums=(0,), *a, **kw),
)

custom_pmap(jnp.ones((jax.local_device_count(), 4)))

model.weight.value

Array([[2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.]], dtype=float32)

## Summary

- `brainstate.transform.pmap` supports the full `jax.pmap` interface and adds
  state-specific controls via `state_in_axes`, `state_out_axes`, and
  `unexpected_out_state_mapping`.
- Random states are split automatically so each device receives its own key.
  Use `jax.random` with `in_axes=None` to broadcast a shared key instead.
- Like `vmap`, `pmap` returns a `StatefulMapping` that identifies state axis
  mappings and compiles the computation into a state-aware IR.