# BrainState Typing System


This notebook introduces the type utilities in `brainstate.typing`.
You will learn how to annotate arrays, PyTrees, random seeds, and helper
structures so that static checkers and collaborators can understand your
code more easily.

Topics covered:

- Size/shape/axis aliases used in array APIs.
- `Array` / `ArrayLike` for expressing tensor expectations.
- `PyTree` annotations and path filters for tree utilities.
- Data type helpers (`DType`, `DTypeLike`, `SupportsDType`).
- Random key, sentinel, and filter helper types.


In [1]:
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np

from brainstate.typing import (
    Array,
    ArrayLike,
    Axes,
    DType,
    DTypeLike,
    Filter,
    Key,
    Missing,
    PathParts,
    Predicate,
    PyTree,
    SeedOrKey,
    Shape,
    Size,
    SupportsDType,
)


## Shapes, sizes, and axes

`Size`, `Shape`, and `Axes` help you document functions that expect
specific tensor dimensions. They are thin aliases around Python sequences
but communicating intent through annotations is valuable to readers and
tooling.


In [2]:
def normalise_batch(batch: ArrayLike, shape: Shape, along: Axes = 0) -> jax.Array:
    """Reshape `batch` then standardise along the given axes."""
    array = jnp.asarray(batch).reshape(tuple(shape))
    mean = jnp.mean(array, axis=along, keepdims=True)
    std = jnp.maximum(jnp.std(array, axis=along, keepdims=True), 1e-6)
    return (array - mean) / std

example = normalise_batch(jnp.arange(12.0), shape=(3, 4), along=0)
example


Array([[-1.2247448, -1.2247448, -1.2247448, -1.2247448],
       [ 0.       ,  0.       ,  0.       ,  0.       ],
       [ 1.2247448,  1.2247448,  1.2247448,  1.2247448]], dtype=float32)

## Array annotations

Use `Array[...]` to describe shape expectations and `ArrayLike` when a
function accepts anything convertible to a JAX array. These annotations are
informative for readers and static type checkers alike.


In [3]:
Matrix = Array["rows, cols"]
Vector = Array["cols"]

def affine_transform(x: Matrix, weight: Array["cols, features"], bias: Vector) -> Array["rows, features"]:
    return x @ weight + bias

x = jnp.ones((2, 3))
w = jnp.arange(6.0).reshape(3, 2)
b = jnp.array([0.5, -0.5])
affine_transform(x, w, b)


Array([[6.5, 8.5],
       [6.5, 8.5]], dtype=float32)

You can still accept flexible data by annotating parameters as `ArrayLike`.
The conversion to `jnp.asarray` happens inside the function, keeping the
        signature expressive yet ergonomic.


In [4]:
def sum_energy(signal: ArrayLike) -> float:
    arr = jnp.asarray(signal)
    return float(jnp.sum(arr ** 2))

print(sum_energy([1, 2, 3]))
print(sum_energy(np.float32(1.5)))


14.0
2.25


> `ArrayLike` also covers `brainunit.Quantity` objects, so unit-aware
        tensors can pass through the same APIs without losing type information.


## Annotating PyTrees

`PyTree` acts like `typing.Any`, but it documents the expected leaf type
        (and optionally structure). That improves readability when writing
        utilities that operate on nested containers.


In [5]:
def tree_l2_norm(tree: PyTree[jax.Array]) -> float:
    leaves, _ = jax.tree_util.tree_flatten(tree)
    total = sum(float(jnp.sum(jnp.square(jnp.asarray(leaf)))) for leaf in leaves)
    return float(total)

nested = {"encoder": jnp.ones((2, 2)), "decoder": [jnp.arange(3.0)]}
tree_l2_norm(nested)


9.0

### Working with paths and filters

`PathParts`, `Predicate`, and `Filter` describe how to select parts of a
        PyTree. The snippet below collects leaves whose path ends with `"weight"`.


In [6]:
def walk(tree: Any, predicate: Predicate, path: PathParts = ()) -> list[tuple[PathParts, Any]]:
    matches: list[tuple[PathParts, Any]] = []
    if predicate(path, tree):
        matches.append((path, tree))
    if isinstance(tree, dict):
        for key, value in tree.items():
            matches.extend(walk(value, predicate, path + (key,)))
    elif isinstance(tree, (list, tuple)):
        for idx, value in enumerate(tree):
            matches.extend(walk(value, predicate, path + (idx,)))
    return matches

model = {
    "dense1": {"weight": jnp.ones((3, 3)), "bias": jnp.zeros(3)},
    "dense2": {"weight": jnp.eye(3), "bias": jnp.ones(3)},
}

weight_filter: Predicate = lambda path, value: path and path[-1] == "weight"
for found_path, value in walk(model, weight_filter):
    print(found_path, value.shape)


('dense1', 'weight') (3, 3)
('dense2', 'weight') (3, 3)


## Data type helpers

`DType` names a concrete NumPy dtype, while `DTypeLike` accepts any object
        that can be coerced into one. Implementing the `SupportsDType` protocol
        lets custom containers participate too.


In [7]:
class TensorView(SupportsDType):
    def __init__(self, array: jax.Array):
        self._array = array

    @property
    def dtype(self) -> DType:
        return self._array.dtype

def zeros_like(shape: Shape, dtype: DTypeLike) -> jax.Array:
    return jnp.zeros(shape, dtype=dtype)

print(zeros_like((2, 2), np.float32))
print(zeros_like((1, 3), TensorView(jnp.ones(3))))


[[0. 0.]
 [0. 0.]]
[[0. 0. 0.]]


## Random seeds and keys

`SeedOrKey` lists the accepted random sources (`int`, JAX key, or NumPy key).
Normalising the input inside your function keeps call sites ergonomic.


In [8]:
def sample_normal(key: SeedOrKey, shape: Shape) -> jax.Array:
    if isinstance(key, int):
        key = jax.random.PRNGKey(key)
    elif isinstance(key, np.ndarray):
        key = jnp.asarray(key, dtype=jnp.uint32)
    return jax.random.normal(key, shape)

print(sample_normal(0, (2,)))
print(sample_normal(jax.random.PRNGKey(1), (2,)))


[1.6226422 2.0252647]
[-0.15443718  0.08470728]


## Keys and sentinels

`Key` is a protocol for path components. `Missing` is a sentinel object you can
        use when `None` is a meaningful value.


In [9]:
_MISSING = Missing()

def resolve_config(name: Key, *, output_dir: str | Missing = _MISSING) -> str:
    if output_dir is _MISSING:
        return f'/tmp/{name}'
    return str(output_dir)

print(resolve_config('experiment-A'))
print(resolve_config('experiment-B', output_dir=None))


/tmp/experiment-A
None


## Summary

BrainState's typing helpers build on standard Python typing to describe arrays,
        PyTrees, dtypes, random keys, and structural filters. Applying them consistently
        makes complex scientific code easier to navigate and verify.
