Skip to content

Refactor graph ops, update JAX/Python requirements, improve tests#138

Merged
chaoming0625 merged 10 commits intomainfrom
update
Mar 8, 2026
Merged

Refactor graph ops, update JAX/Python requirements, improve tests#138
chaoming0625 merged 10 commits intomainfrom
update

Conversation

@chaoming0625
Copy link
Member

@chaoming0625 chaoming0625 commented Mar 8, 2026

Summary by Sourcery

Refine the graph operation and transformation infrastructure, modernize typing and JAX version handling, improve error messages and docstrings, and significantly expand and stabilize the associated tests.

Bug Fixes:

  • Fix misuse of internal RefMap storage by accessing its items via the public mapping interface when scanning for states.
  • Ensure debug callbacks used in error_if are ordered to avoid non‑deterministic behavior.
  • Correct filter validation and error paths in graph splitting utilities to provide more precise exceptions.

Enhancements:

  • Simplify and clarify graph operation internals with stricter type hints, reduced duplication, and more focused docstrings across graph utilities and node helpers.
  • Modernize JAX integration by consolidating make_jaxpr handling, updating compatible imports (including mapped_aval), and refining IR visualization type usage.
  • Tighten graph context management for split/merge operations using thread‑local stacks and clearer context helpers, and adjust graph/tree conversion helpers for safer RefMap usage and alias checking.

Build:

  • Raise the minimum supported Python version to 3.11 and require JAX >= 0.6.0 across core, extras, and documentation tooling.
  • Update ReadTheDocs configuration to build documentation with Python 3.13.

Documentation:

  • Streamline and update docstrings throughout graph operations, node helpers, conversion utilities, and compatibility helpers to better explain behavior while removing obsolete or overly verbose sections.

Tests:

  • Unskip and substantially rewrite graph operation tests to rely on assertions instead of prints, cover iteration, graph manipulation, helper utilities, threading, and complex aliasing scenarios.
  • Adjust compatibility and module‑structure tests to match the updated compatible import surface (e.g., removal of extend_axis_env_nd).

@sourcery-ai
Copy link
Contributor

sourcery-ai bot commented Mar 8, 2026

Reviewer's Guide

Refactors the graph operation and conversion utilities for better type-safety and JAX compatibility, simplifies make_jaxpr by dropping the custom implementation in favor of jax.make_jaxpr, tightens Python/JAX version requirements, and substantially rewrites and unskips the graph operation test suite to validate the new behavior and threading support.

Sequence diagram for graph flattening and unflattening with shared references

sequenceDiagram
    actor User
    participant SplitContext
    participant RefIndex as RefMap
    participant GraphOp as flatten
    participant Impl as _get_node_impl
    participant GFlat as _graph_flatten
    participant MergeContext
    participant IndexRef as index_ref(dict)
    participant GUnflat as _graph_unflatten

    User->>SplitContext: split_context()
    activate SplitContext
    SplitContext->>RefIndex: create RefMap
    SplitContext-->>User: (SplitContext, RefIndex)
    deactivate SplitContext

    User->>GraphOp: flatten(root_node, ref_index=RefIndex)
    activate GraphOp
    GraphOp->>GFlat: _graph_flatten(path=(), ref_index, flatted_state_mapping, node, treefy_state)
    activate GFlat
    GFlat->>Impl: _get_node_impl(node)
    Impl-->>GFlat: GraphNodeImpl or PyTreeNodeImpl

    alt node already in RefIndex
        GFlat-->>GraphOp: NodeRef
    else new node
        GFlat->>RefIndex: register node -> index
        GFlat->>Impl: flatten(node)
        Impl-->>GFlat: values, metadata
        GFlat->>GFlat: recurse into subgraphs and leaves
        GFlat-->>GraphOp: NodeDef(root)
    end

    GraphOp-->>User: GraphDef, NestedDict.from_flat(flatted_state_mapping)
    deactivate GraphOp

    User->>MergeContext: merge_context()
    activate MergeContext
    MergeContext-->>User: (MergeContext, IndexRef)
    deactivate MergeContext

    User->>GUnflat: _graph_unflatten(GraphDef, state_mapping, IndexRef, index_ref_cache)
    activate GUnflat
    alt graph_def is NodeRef
        GUnflat-->>User: existing node from IndexRef
    else graph_def is NodeDef
        GUnflat->>Impl: get_node_impl_for_type(graph_def.type)
        Impl-->>GUnflat: GraphNodeImpl or PyTreeNodeImpl
        alt GraphNodeImpl
            GUnflat->>Impl: create_empty(metadata) or reuse from index_ref_cache
            Impl-->>GUnflat: node
            GUnflat->>IndexRef: register node at graph_def.index
        end
        GUnflat->>GUnflat: _get_children(graph_def, state_mapping,...)
        alt GraphNodeImpl
            GUnflat->>Impl: init(node, children.items())
        else PyTreeNodeImpl
            GUnflat->>Impl: unflatten(children.items(), metadata)
        end
        GUnflat-->>User: reconstructed node
    end
    deactivate GUnflat
Loading

Architecture/flow diagram for graph_to_tree and tree_to_graph interaction with JAX pytrees

flowchart LR
    subgraph GraphSide[Graph-based objects]
        GNode[Graph nodes\nNode]
        GDef[GraphDef / NodeDef / NodeRef]
        GState[GraphStateMapping / NestedDict]
    end

    subgraph JAXSide[JAX pytree world]
        PTIn[Input pytree\nmay contain graph nodes]
        NodeStatesPT[NodeStates pytrees]
        PTOut[Reconstructed pytree]
    end

    subgraph SplitMergeCtx[Graph contexts]
        SC[SplitContext\nref_index: RefMap]
        MC[MergeContext\nindex_ref: dict]
    end

    PTIn -->|graph_to_tree| GT(GraphToTree)
    GT --> SC
    SC -->|treefy_split| GDef
    SC -->|treefy_split| GState
    GT --> NodeStatesPT

    NodeStatesPT -->|tree_to_graph| TG(TreeToGraph)
    TG --> MC
    MC -->|treefy_merge| GNode
    GNode --> PTOut
Loading

File-Level Changes

Change Details Files
Modernize typing, error handling, and documentation in graph operations while keeping behavior and APIs largely intact.
  • Simplified generic type parameters (Node→N), loosened hashable bounds, and clarified type aliases for nodes and indices.
  • Condensed or replaced large docstrings with short, focused descriptions and added explicit return type annotations and runtime type checks (raising TypeError/ValueError instead of using asserts).
  • Refined NodeImpl hierarchy and registration (GraphNodeImpl/PyTreeNodeImpl) while preserving behavior; adjusted helper functions like _get_node_impl, is_node_type, and iter* utilities.
  • Tightened state and node update logic in _graph_update_dynamic, _graph_pop, _get_children, and related helpers, improving validation of immutable vs mutable nodes and leaf types without changing external APIs.
  • Simplified graphdef, clone, flatten/unflatten, treefy_split/merge/states, and nodes/states utilities to be thinner wrappers with clearer semantics and more precise errors.
brainstate/graph/_operation.py
Improve graph node base implementation and registration helpers with leaner behavior-focused utilities.
  • Reduced Node and helper function docstrings to concise descriptions while keeping behavior the same.
  • Ensured Node.init_subclass consistently registers subclasses via register_graph_node_type.
  • Simplified _node_flatten, _node_set_key, _node_pop_key, _node_create_empty, and _node_clear to minimal implementations focused on attribute management and State/TreefyState handling.
brainstate/graph/_node.py
Strengthen graph context management for split/merge operations with clearer semantics and typing.
  • Documented GraphContext as a thread-local container for split/merge stacks and converted attributes to typed lists.
  • Clarified SplitContext.treefy_split and MergeContext.treefy_merge signatures and behavior, returning merged/unflattened nodes directly.
  • Simplified split_context and merge_context context managers, ensuring stacks are pushed/popped reliably and exposing the active context and index maps.
brainstate/graph/_context.py
Refine graph-to-tree conversion, alias checking, and NodeStates wrapper for JAX pytrees.
  • Relaxed KeyEntry type bound and updated broadcast_prefix and graph_to_tree/tree_to_graph signatures to use modern typing and clearer docstrings.
  • Reimplemented check_consistent_aliasing to accumulate node/prefix mismatches more readably and tightened its error messages.
  • Adjusted NodeStates factory methods and default split function signatures, documenting their use in tree conversions.
  • Avoided depending on RefMap internals in graph_to_tree by reconstructing a public mapping before calling states() and simplified tree_to_graph to just unflatten the transformed leaves.
brainstate/graph/_convert.py
Simplify make_jaxpr internals by dropping the custom tracer path and relying on jax.make_jaxpr, while updating JAX compatibility helpers and loop transforms.
  • In StatefulFunction.make_jaxpr, removed the version-dependent call path and custom _make_jaxpr implementation, always calling jax.make_jaxpr with static_argnums and axis_env.
  • Deleted the internal _make_jaxpr implementation and helper utilities no longer needed (e.g., _check_callable, _flatten_fun) from _make_jaxpr.py.
  • Extended _compatible_import to expose mapped_aval (with version-dependent import) and jaxpr-related aliases without extend_axis_env_nd, simplifying version handling.
  • Updated scan and checkpointed_scan to use get_aval and mapped_aval from the compatibility layer instead of jax.core.get_aval/mapped_aval directly.
brainstate/transform/_make_jaxpr.py
brainstate/_compatible_import.py
brainstate/transform/_loop_collect_return.py
Align JAX and Python version requirements and build configs with the new minimum versions.
  • Raised project requires-python to >=3.11 and removed Python 3.10 classifier from pyproject.toml; updated docs build to use Python 3.13.
  • Bumped all JAX-related dependencies in optional extras and requirements.txt to require jax>=0.6.0 (for cpu, cuda, tpu, testing).
pyproject.toml
requirements.txt
.readthedocs.yml
Update IR visualization code to use compatibility-layer JAX core types and drop no-longer-exported APIs from tests.
  • Replaced direct jax.core imports in _ir_visualize with imports from brainstate._compatible_import (Var, ClosedJaxpr, Jaxpr, JaxprEqn, Literal, DropVar) and adjusted type annotations and isinstance checks accordingly.
  • Removed tests that referenced extend_axis_env_nd and all expectations that no longer hold, keeping the rest of the compatibility tests intact.
brainstate/transform/_ir_visualize.py
brainstate/_compatible_import_test.py
Refine error_if JIT behavior to ensure error callbacks are ordered under JIT.
  • Pass ordered=True to jax.debug.callback inside _err_jit_true_branch so side-effecting error callbacks preserve execution ordering.
brainstate/transform/_error_if.py
Major overhaul of graph operation tests to remove skips, reduce dependencies, and improve coverage of new behavior.
  • Removed the module-level pytest.skip and unused imports (braintools, brainpy) and reorganized tests into sections (iter tests, graph utilities, RefMap, helper functions, node registration, HashableMapping, NodeDef/NodeRef, graphdef/clone, nodes, Static, error handling, integration, threading).
  • Rewrote iterator tests to assert counts and hierarchy behavior instead of printing, and added tests for nested modules, shared-node traversal, and allowed_hierarchy semantics.
  • Updated graph operation tests to use jax.nn.relu, brainstate.nn.Linear, ParamState, ShortTermState, and pop_states behavior in ways aligned with the refactored code (e.g., no brainpy LIF).
  • Adjusted RefMap, helper, registration, HashableMapping, NodeDef/NodeRef, Static, and error-handling tests to validate the new APIs, error messages, and type checks, including removal of extend_axis_env_nd references and aligning expectations with new behavior.
  • Kept and updated threading tests to ensure treefy_split remains thread-safe under the new context and graph machinery.
brainstate/graph/_operation_test.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625
Copy link
Member Author

@sourcery-ai title

@sourcery-ai sourcery-ai bot changed the title Update Refactor graph ops, update JAX/Python requirements, improve tests Mar 8, 2026
Copy link
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 2 issues

Prompt for AI Agents
Please address the comments from this code review:

## Individual Comments

### Comment 1
<location path="brainstate/graph/_operation.py" line_range="217" />
<code_context>

-class HashableMapping(Mapping[HA, HB], Hashable):
-    def __init__(self, mapping: Union[Mapping[HA, HB], Iterable[tuple[HA, HB]]]) -> None:
+class HashableMapping(Mapping[HA, HB]):
+    """An immutable, hashable mapping."""
+
</code_context>
<issue_to_address>
**suggestion:** Consider keeping HA/HB constrained to Hashable to reflect the true runtime requirements of HashableMapping.

Dropping the Hashable bound on HA/HB while still computing the hash via `hash(frozenset(self._mapping.items()))` means type checkers no longer enforce that keys/values are hashable, even though runtime still requires it and will fail if they are not. Restoring the Hashable constraint keeps the static types aligned with the actual requirements without changing behavior.

Suggested implementation:

```python
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, MutableMapping, Sequence
from typing import Any, Generic, TypeVar

```

```python
import jax
import numpy as np


HA = TypeVar("HA", bound=Hashable)
HB = TypeVar("HB", bound=Hashable)


class HashableMapping(Mapping[HA, HB]):
    """An immutable, hashable mapping."""

```

1. If `HA` and `HB` are already defined elsewhere in this file, replace those existing `TypeVar` definitions with the new bounded versions instead of adding new ones.
2. Ensure the `HashableMapping` class definition in the file uses `Mapping[HA, HB]` (as shown) and that any previous inheritance from `Hashable` is removed or intentionally replaced, since the type-level hashability is now enforced via the bounds on `HA` and `HB`.
</issue_to_address>

### Comment 2
<location path="brainstate/graph/_convert.py" line_range="190-195" />
<code_context>
-        pass

-    find_states = states(index_ref._mapping)
+    # Build a dict mirroring RefMap's content via the public API, then extract
+    # State objects from it.  We must not access the private ._mapping attribute.
+    public_map = {id(k): (k, v) for k, v in index_ref.items()}
+    find_states = states(public_map)
     pytree_out = jax.tree.unflatten(treedef, leaves_out)
     return pytree_out, find_states
</code_context>
<issue_to_address>
**suggestion (performance):** The `find_states` computation is still unused and now adds extra work; consider removing it entirely.

This block builds `public_map` and calls `states`, but the resulting `find_states` is never used (same as the prior `states(index_ref._mapping)` call). Since it just adds extra traversal with no effect, consider removing it, or if it’s intended for validation, make that purpose explicit by using or propagating the result/error.

```suggestion
    # NOTE: `find_states` is currently unused by callers; avoid extra traversal
    # work by not computing it. Preserve the return signature for compatibility.
    find_states = None
    pytree_out = jax.tree.unflatten(treedef, leaves_out)
    return pytree_out, find_states
```
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.


class HashableMapping(Mapping[HA, HB], Hashable):
def __init__(self, mapping: Union[Mapping[HA, HB], Iterable[tuple[HA, HB]]]) -> None:
class HashableMapping(Mapping[HA, HB]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider keeping HA/HB constrained to Hashable to reflect the true runtime requirements of HashableMapping.

Dropping the Hashable bound on HA/HB while still computing the hash via hash(frozenset(self._mapping.items())) means type checkers no longer enforce that keys/values are hashable, even though runtime still requires it and will fail if they are not. Restoring the Hashable constraint keeps the static types aligned with the actual requirements without changing behavior.

Suggested implementation:

from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, MutableMapping, Sequence
from typing import Any, Generic, TypeVar
import jax
import numpy as np


HA = TypeVar("HA", bound=Hashable)
HB = TypeVar("HB", bound=Hashable)


class HashableMapping(Mapping[HA, HB]):
    """An immutable, hashable mapping."""
  1. If HA and HB are already defined elsewhere in this file, replace those existing TypeVar definitions with the new bounded versions instead of adding new ones.
  2. Ensure the HashableMapping class definition in the file uses Mapping[HA, HB] (as shown) and that any previous inheritance from Hashable is removed or intentionally replaced, since the type-level hashability is now enforced via the bounds on HA and HB.

Comment on lines +190 to 195
# Build a dict mirroring RefMap's content via the public API, then extract
# State objects from it. We must not access the private ._mapping attribute.
public_map = {id(k): (k, v) for k, v in index_ref.items()}
find_states = states(public_map)
pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out, find_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): The find_states computation is still unused and now adds extra work; consider removing it entirely.

This block builds public_map and calls states, but the resulting find_states is never used (same as the prior states(index_ref._mapping) call). Since it just adds extra traversal with no effect, consider removing it, or if it’s intended for validation, make that purpose explicit by using or propagating the result/error.

Suggested change
# Build a dict mirroring RefMap's content via the public API, then extract
# State objects from it. We must not access the private ._mapping attribute.
public_map = {id(k): (k, v) for k, v in index_ref.items()}
find_states = states(public_map)
pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out, find_states
# NOTE: `find_states` is currently unused by callers; avoid extra traversal
# work by not computing it. Preserve the return signature for compatibility.
find_states = None
pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out, find_states

@chaoming0625
Copy link
Member Author

fix #134

@chaoming0625 chaoming0625 merged commit dbbcf4b into main Mar 8, 2026
4 of 7 checks passed
@chaoming0625 chaoming0625 deleted the update branch March 8, 2026 17:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant