Skip to content

[Bug]: AttributeError: module 'jax.core' has no attribute 'ClosedJaxpr' when pydot is installed #134

@Routhleck

Description

@Routhleck

Prerequisites

  • I have searched existing issues and discussions to avoid duplicates.
  • I have read CONTRIBUTING.md and followed the reporting guidelines.

Summary

When pydot is installed, simply importing brainstate raises an AttributeError because jax.core.ClosedJaxpr no longer exists in newer versions of JAX.

Steps to Reproduce

pip install pydot
import brainstate  # raises AttributeError immediately

Expected Behavior

No error.

Actual Behavior

AttributeError: module 'jax.core' has no attribute 'ClosedJaxpr'

Environment

No response

Logs & Traceback

Additional Context

Root Cause

In _ir_visualize.py, the type annotation jax.core.ClosedJaxpr is evaluated at import time when pydot is installed, but ClosedJaxpr has been removed from jax.core in newer JAX versions.

# _ir_visualize.py
if pydot_is_installed:
    sub_graph_return = Tuple[
        Union[pydot.Node, pydot.Subgraph],
        ...
    ]
    
    def draw_dot_graph(
        fn: jax.core.ClosedJaxpr,  # ❌ evaluated at import time, raises AttributeError
        ...
    )

Suggested Fix

The type annotation should avoid being evaluated at import time. A compatibility import or string-based annotation would resolve this issue.

Test Case

A test that imports brainstate with pydot installed would help catch this regression.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions