Refactor graph ops, update JAX/Python requirements, improve tests#138
Refactor graph ops, update JAX/Python requirements, improve tests#138chaoming0625 merged 10 commits intomainfrom
Conversation
… enhance error callback functionality
…rts and functions for clarity
… requirements.txt
…ntation build settings
Reviewer's GuideRefactors the graph operation and conversion utilities for better type-safety and JAX compatibility, simplifies make_jaxpr by dropping the custom implementation in favor of jax.make_jaxpr, tightens Python/JAX version requirements, and substantially rewrites and unskips the graph operation test suite to validate the new behavior and threading support. Sequence diagram for graph flattening and unflattening with shared referencessequenceDiagram
actor User
participant SplitContext
participant RefIndex as RefMap
participant GraphOp as flatten
participant Impl as _get_node_impl
participant GFlat as _graph_flatten
participant MergeContext
participant IndexRef as index_ref(dict)
participant GUnflat as _graph_unflatten
User->>SplitContext: split_context()
activate SplitContext
SplitContext->>RefIndex: create RefMap
SplitContext-->>User: (SplitContext, RefIndex)
deactivate SplitContext
User->>GraphOp: flatten(root_node, ref_index=RefIndex)
activate GraphOp
GraphOp->>GFlat: _graph_flatten(path=(), ref_index, flatted_state_mapping, node, treefy_state)
activate GFlat
GFlat->>Impl: _get_node_impl(node)
Impl-->>GFlat: GraphNodeImpl or PyTreeNodeImpl
alt node already in RefIndex
GFlat-->>GraphOp: NodeRef
else new node
GFlat->>RefIndex: register node -> index
GFlat->>Impl: flatten(node)
Impl-->>GFlat: values, metadata
GFlat->>GFlat: recurse into subgraphs and leaves
GFlat-->>GraphOp: NodeDef(root)
end
GraphOp-->>User: GraphDef, NestedDict.from_flat(flatted_state_mapping)
deactivate GraphOp
User->>MergeContext: merge_context()
activate MergeContext
MergeContext-->>User: (MergeContext, IndexRef)
deactivate MergeContext
User->>GUnflat: _graph_unflatten(GraphDef, state_mapping, IndexRef, index_ref_cache)
activate GUnflat
alt graph_def is NodeRef
GUnflat-->>User: existing node from IndexRef
else graph_def is NodeDef
GUnflat->>Impl: get_node_impl_for_type(graph_def.type)
Impl-->>GUnflat: GraphNodeImpl or PyTreeNodeImpl
alt GraphNodeImpl
GUnflat->>Impl: create_empty(metadata) or reuse from index_ref_cache
Impl-->>GUnflat: node
GUnflat->>IndexRef: register node at graph_def.index
end
GUnflat->>GUnflat: _get_children(graph_def, state_mapping,...)
alt GraphNodeImpl
GUnflat->>Impl: init(node, children.items())
else PyTreeNodeImpl
GUnflat->>Impl: unflatten(children.items(), metadata)
end
GUnflat-->>User: reconstructed node
end
deactivate GUnflat
Architecture/flow diagram for graph_to_tree and tree_to_graph interaction with JAX pytreesflowchart LR
subgraph GraphSide[Graph-based objects]
GNode[Graph nodes\nNode]
GDef[GraphDef / NodeDef / NodeRef]
GState[GraphStateMapping / NestedDict]
end
subgraph JAXSide[JAX pytree world]
PTIn[Input pytree\nmay contain graph nodes]
NodeStatesPT[NodeStates pytrees]
PTOut[Reconstructed pytree]
end
subgraph SplitMergeCtx[Graph contexts]
SC[SplitContext\nref_index: RefMap]
MC[MergeContext\nindex_ref: dict]
end
PTIn -->|graph_to_tree| GT(GraphToTree)
GT --> SC
SC -->|treefy_split| GDef
SC -->|treefy_split| GState
GT --> NodeStatesPT
NodeStatesPT -->|tree_to_graph| TG(TreeToGraph)
TG --> MC
MC -->|treefy_merge| GNode
GNode --> PTOut
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
@sourcery-ai title |
There was a problem hiding this comment.
Hey - I've found 2 issues
Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments
### Comment 1
<location path="brainstate/graph/_operation.py" line_range="217" />
<code_context>
-class HashableMapping(Mapping[HA, HB], Hashable):
- def __init__(self, mapping: Union[Mapping[HA, HB], Iterable[tuple[HA, HB]]]) -> None:
+class HashableMapping(Mapping[HA, HB]):
+ """An immutable, hashable mapping."""
+
</code_context>
<issue_to_address>
**suggestion:** Consider keeping HA/HB constrained to Hashable to reflect the true runtime requirements of HashableMapping.
Dropping the Hashable bound on HA/HB while still computing the hash via `hash(frozenset(self._mapping.items()))` means type checkers no longer enforce that keys/values are hashable, even though runtime still requires it and will fail if they are not. Restoring the Hashable constraint keeps the static types aligned with the actual requirements without changing behavior.
Suggested implementation:
```python
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, MutableMapping, Sequence
from typing import Any, Generic, TypeVar
```
```python
import jax
import numpy as np
HA = TypeVar("HA", bound=Hashable)
HB = TypeVar("HB", bound=Hashable)
class HashableMapping(Mapping[HA, HB]):
"""An immutable, hashable mapping."""
```
1. If `HA` and `HB` are already defined elsewhere in this file, replace those existing `TypeVar` definitions with the new bounded versions instead of adding new ones.
2. Ensure the `HashableMapping` class definition in the file uses `Mapping[HA, HB]` (as shown) and that any previous inheritance from `Hashable` is removed or intentionally replaced, since the type-level hashability is now enforced via the bounds on `HA` and `HB`.
</issue_to_address>
### Comment 2
<location path="brainstate/graph/_convert.py" line_range="190-195" />
<code_context>
- pass
- find_states = states(index_ref._mapping)
+ # Build a dict mirroring RefMap's content via the public API, then extract
+ # State objects from it. We must not access the private ._mapping attribute.
+ public_map = {id(k): (k, v) for k, v in index_ref.items()}
+ find_states = states(public_map)
pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out, find_states
</code_context>
<issue_to_address>
**suggestion (performance):** The `find_states` computation is still unused and now adds extra work; consider removing it entirely.
This block builds `public_map` and calls `states`, but the resulting `find_states` is never used (same as the prior `states(index_ref._mapping)` call). Since it just adds extra traversal with no effect, consider removing it, or if it’s intended for validation, make that purpose explicit by using or propagating the result/error.
```suggestion
# NOTE: `find_states` is currently unused by callers; avoid extra traversal
# work by not computing it. Preserve the return signature for compatibility.
find_states = None
pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out, find_states
```
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
|
|
||
| class HashableMapping(Mapping[HA, HB], Hashable): | ||
| def __init__(self, mapping: Union[Mapping[HA, HB], Iterable[tuple[HA, HB]]]) -> None: | ||
| class HashableMapping(Mapping[HA, HB]): |
There was a problem hiding this comment.
suggestion: Consider keeping HA/HB constrained to Hashable to reflect the true runtime requirements of HashableMapping.
Dropping the Hashable bound on HA/HB while still computing the hash via hash(frozenset(self._mapping.items())) means type checkers no longer enforce that keys/values are hashable, even though runtime still requires it and will fail if they are not. Restoring the Hashable constraint keeps the static types aligned with the actual requirements without changing behavior.
Suggested implementation:
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, MutableMapping, Sequence
from typing import Any, Generic, TypeVarimport jax
import numpy as np
HA = TypeVar("HA", bound=Hashable)
HB = TypeVar("HB", bound=Hashable)
class HashableMapping(Mapping[HA, HB]):
"""An immutable, hashable mapping."""- If
HAandHBare already defined elsewhere in this file, replace those existingTypeVardefinitions with the new bounded versions instead of adding new ones. - Ensure the
HashableMappingclass definition in the file usesMapping[HA, HB](as shown) and that any previous inheritance fromHashableis removed or intentionally replaced, since the type-level hashability is now enforced via the bounds onHAandHB.
| # Build a dict mirroring RefMap's content via the public API, then extract | ||
| # State objects from it. We must not access the private ._mapping attribute. | ||
| public_map = {id(k): (k, v) for k, v in index_ref.items()} | ||
| find_states = states(public_map) | ||
| pytree_out = jax.tree.unflatten(treedef, leaves_out) | ||
| return pytree_out, find_states |
There was a problem hiding this comment.
suggestion (performance): The find_states computation is still unused and now adds extra work; consider removing it entirely.
This block builds public_map and calls states, but the resulting find_states is never used (same as the prior states(index_ref._mapping) call). Since it just adds extra traversal with no effect, consider removing it, or if it’s intended for validation, make that purpose explicit by using or propagating the result/error.
| # Build a dict mirroring RefMap's content via the public API, then extract | |
| # State objects from it. We must not access the private ._mapping attribute. | |
| public_map = {id(k): (k, v) for k, v in index_ref.items()} | |
| find_states = states(public_map) | |
| pytree_out = jax.tree.unflatten(treedef, leaves_out) | |
| return pytree_out, find_states | |
| # NOTE: `find_states` is currently unused by callers; avoid extra traversal | |
| # work by not computing it. Preserve the return signature for compatibility. | |
| find_states = None | |
| pytree_out = jax.tree.unflatten(treedef, leaves_out) | |
| return pytree_out, find_states |
|
fix #134 |
Summary by Sourcery
Refine the graph operation and transformation infrastructure, modernize typing and JAX version handling, improve error messages and docstrings, and significantly expand and stabilize the associated tests.
Bug Fixes:
Enhancements:
Build:
Documentation:
Tests: