Skip to content

Simplify JAX compat: use jax.make_jaxpr and aval helpers#137

Merged
chaoming0625 merged 8 commits intomainfrom
update
Mar 7, 2026
Merged

Simplify JAX compat: use jax.make_jaxpr and aval helpers#137
chaoming0625 merged 8 commits intomainfrom
update

Conversation

@chaoming0625
Copy link
Member

@chaoming0625 chaoming0625 commented Mar 7, 2026

Summary by Sourcery

Simplify JAX integration by removing a custom make_jaxpr implementation and aligning with newer JAX APIs, while centralizing version-dependent utilities in the compatibility layer.

Enhancements:

  • Always use jax.make_jaxpr in StatefulFunction instead of a custom _make_jaxpr shim and drop IR optimization plumbing.
  • Move get_aval and mapped_aval usage in scan utilities to the compatibility module to handle JAX version differences consistently.
  • Adjust compatibility imports to source ClosedJaxpr, jaxpr helpers, and mapped_aval from version-appropriate JAX modules and remove the custom extend_axis_env_nd context manager.
  • Simplify the fun_name helper by removing redundant documentation comments.
  • Ensure debug error callbacks in JIT paths are ordered by enabling ordered=True on jax.debug.callback.

Tests:

  • Remove extend_axis_env_nd-specific tests from the compatibility test suite since the helper is no longer provided.

@sourcery-ai
Copy link
Contributor

sourcery-ai bot commented Mar 7, 2026

Reviewer's Guide

This PR simplifies JAX version handling and compatibility utilities by removing a custom internal make_jaxpr implementation, centralizing axis/aval helpers in _compatible_import, updating call sites to use those helpers, and tightening debug callback behavior ordering.

Sequence diagram for updated StatefulFunction.make_jaxpr path

sequenceDiagram
    actor User
    participant StatefulFunction
    participant jax_make_jaxpr as JaxMakeJaxpr
    participant wrapped_fun_to_eval as WrappedFun

    User->>StatefulFunction: make_jaxpr(*args, **kwargs)
    StatefulFunction->>StatefulFunction: build cache_key, static_kwargs
    StatefulFunction->>WrappedFun: functools.partial(_wrapped_fun_to_eval, cache_key, static_kwargs)
    StatefulFunction->>JaxMakeJaxpr: jax.make_jaxpr(partial_fun, static_argnums, axis_env, return_shape=True)(*args, **dyn_kwargs)
    JaxMakeJaxpr-->>StatefulFunction: jaxpr, (out_shapes, state_shapes)
    StatefulFunction->>StatefulFunction: cache jaxpr, out_shapes, state_shapes
    StatefulFunction-->>User: stateful jaxpr representation
Loading

File-Level Changes

Change Details Files
Use JAX’s built-in make_jaxpr unconditionally and remove the custom internal _make_jaxpr implementation.
  • Drop the version-conditional fallback to a local _make_jaxpr and always call jax.make_jaxpr with return_shape=True in StatefulFunction.make_jaxpr.
  • Remove the entire internal _make_jaxpr implementation and its helper _flatten_fun/_check_callable, along with related imports like transformation_with_aux, ExitStack, extend_axis_env_nd, safe_zip, unzip2, and wrap_init from _make_jaxpr.py.
  • Clean up now-unused IR optimization plumbing (ir_optimizations argument path) that relied on the custom _make_jaxpr.
brainstate/transform/_make_jaxpr.py
Refactor compatibility shims to import mapped_aval across JAX versions and remove the custom extend_axis_env_nd shim and its tests.
  • Add mapped_aval to the public compatibility export list and implement a version-conditional import (jax.core.mapped_aval for older versions, jax.extend.core.mapped_aval for newer ones).
  • Reorganize ClosedJaxpr/Primitive/jaxpr_as_fun imports so that extend_axis_env_nd is taken from jax.core only for older versions and no longer wrapped by a custom contextmanager for newer versions.
  • Delete the custom extend_axis_env_nd contextmanager based on trace_ctx and remove its associated unit test that validated its behavior.
  • Remove now-unneeded fun_name docstring comments to reduce noise while keeping the helper behavior unchanged.
brainstate/_compatible_import.py
brainstate/_compatible_import_test.py
Update scan utilities to use compatibility-layer helpers for aval computation instead of directly referencing jax.core.
  • Replace direct usage of jax.core.get_aval with the compat get_aval helper when computing xs_avals in scan and checkpointed_scan.
  • Replace direct usage of jax.core.mapped_aval with the compat mapped_aval helper when computing x_avals in scan and checkpointed_scan.
brainstate/transform/_loop_collect_return.py
Ensure deterministic ordering for error callbacks under JIT by using ordered debug callbacks.
  • Pass ordered=True to jax.debug.callback in the JIT true branch of _error_if to guarantee ordered error reporting under JIT execution.
brainstate/transform/_error_if.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625
Copy link
Member Author

@sourcery-ai title

Copy link
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've left some high level feedback:

  • Now that _make_jaxpr and the ir_optimizations handling are removed in favor of jax.make_jaxpr, consider cleaning up any remaining references/configuration around self.ir_optimizations to avoid dead or misleading options.
  • The change to jax.debug.callback(err_fun, *args, **kwargs, ordered=True) assumes the ordered keyword is available in all supported JAX versions; if older versions are still supported, you may want to guard this with a version check or fall back when ordered is not accepted.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- Now that `_make_jaxpr` and the `ir_optimizations` handling are removed in favor of `jax.make_jaxpr`, consider cleaning up any remaining references/configuration around `self.ir_optimizations` to avoid dead or misleading options.
- The change to `jax.debug.callback(err_fun, *args, **kwargs, ordered=True)` assumes the `ordered` keyword is available in all supported JAX versions; if older versions are still supported, you may want to guard this with a version check or fall back when `ordered` is not accepted.

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

@sourcery-ai sourcery-ai bot changed the title Update Simplify JAX compat: use jax.make_jaxpr and aval helpers Mar 7, 2026
@chaoming0625 chaoming0625 merged commit cc1b84d into main Mar 7, 2026
4 of 7 checks passed
@chaoming0625 chaoming0625 deleted the update branch March 7, 2026 18:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant