-
Notifications
You must be signed in to change notification settings - Fork 2
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Prerequisites
- I have searched existing issues and discussions to avoid duplicates.
- I have read
CONTRIBUTING.mdand followed the reporting guidelines.
Summary
When pydot is installed, simply importing brainstate raises an AttributeError because jax.core.ClosedJaxpr no longer exists in newer versions of JAX.
Steps to Reproduce
pip install pydot
import brainstate # raises AttributeError immediatelyExpected Behavior
No error.
Actual Behavior
AttributeError: module 'jax.core' has no attribute 'ClosedJaxpr'Environment
No response
Logs & Traceback
Additional Context
Root Cause
In _ir_visualize.py, the type annotation jax.core.ClosedJaxpr is evaluated at import time when pydot is installed, but ClosedJaxpr has been removed from jax.core in newer JAX versions.
# _ir_visualize.py
if pydot_is_installed:
sub_graph_return = Tuple[
Union[pydot.Node, pydot.Subgraph],
...
]
def draw_dot_graph(
fn: jax.core.ClosedJaxpr, # ❌ evaluated at import time, raises AttributeError
...
)Suggested Fix
The type annotation should avoid being evaluated at import time. A compatibility import or string-based annotation would resolve this issue.
Test Case
A test that imports brainstate with pydot installed would help catch this regression.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working