Skip to content

Support custom EnvironmentState instances in environ helpers#118

Merged
chaoming0625 merged 2 commits intomainfrom
environ
Dec 13, 2025
Merged

Support custom EnvironmentState instances in environ helpers#118
chaoming0625 merged 2 commits intomainfrom
environ

Conversation

@chaoming0625
Copy link
Copy Markdown
Collaborator

@chaoming0625 chaoming0625 commented Dec 13, 2025

Summary by Sourcery

Add support for operating on custom EnvironmentState instances across environment management utilities while keeping JAX configuration tied to the global environment.

Enhancements:

  • Export the EnvironmentState class from the environ module for external use.
  • Extend reset, context, get, all, pop, set, get_dt, set_precision, get_precision, dftype, ditype, dutype, dctype, tolerance, and behavior registration utilities to accept an optional EnvironmentState, enabling multiple independent environment instances.
  • Ensure JAX precision and related global configuration changes only apply when modifying the global EnvironmentState, not custom environments.
  • Improve documentation and examples for environment helpers to demonstrate usage with custom EnvironmentState instances.

…ntState in reset, context, and various access functions
@sourcery-ai
Copy link
Copy Markdown
Contributor

sourcery-ai Bot commented Dec 13, 2025

Reviewer's Guide

Extend environment management API to support multiple independent EnvironmentState instances by threading an optional env parameter through core accessors (reset, context, get/set, dtype helpers, callbacks, etc.) while ensuring JAX configuration is only mutated for the global environment and updating docs accordingly.

Sequence diagram for set_precision with optional EnvironmentState

sequenceDiagram
    actor User
    participant environ_module
    participant EnvironmentState as env
    participant _ENV_STATE as global_env
    participant JaxConfig

    User->>environ_module: set_precision(precision, env=None_or_custom)
    alt env is None
        environ_module->>global_env: use _ENV_STATE as env
    else env is provided
        environ_module->>EnvironmentState: use provided env
    end

    environ_module->>environ_module: _validate_precision(precision)

    alt env is global_env
        environ_module->>JaxConfig: _set_jax_precision(precision)
    else env is custom EnvironmentState
        note over environ_module,JaxConfig: JAX config is not modified
    end

    environ_module->>EnvironmentState: env.settings[PRECISION] = precision

    alt PRECISION in env.functions
        environ_module->>EnvironmentState: env.functions[PRECISION](precision)
    else no callback
        environ_module-->>User: return
    end

    environ_module-->>User: precision updated in selected env
Loading

Sequence diagram for context manager with custom EnvironmentState

sequenceDiagram
    actor User
    participant environ_module
    participant EnvironmentState as env
    participant _ENV_STATE as global_env
    participant JaxConfig

    User->>environ_module: context(env=custom_or_None, **kwargs)
    alt env is None
        environ_module->>global_env: env = _ENV_STATE
    else env is provided
        environ_module->>EnvironmentState: use provided env
    end

    alt PRECISION in kwargs
        environ_module->>environ_module: original_precision = _get_precision(env)
        environ_module->>environ_module: _validate_precision(kwargs[PRECISION])
        alt env is global_env
            environ_module->>JaxConfig: _set_jax_precision(kwargs[PRECISION])
        else env is custom EnvironmentState
            note over environ_module,JaxConfig: JAX precision unchanged
        end
    end

    loop for each key,value in kwargs
        environ_module->>EnvironmentState: env.contexts[key].append(value)
        alt key in env.functions
            environ_module->>EnvironmentState: env.functions[key](value)
        end
    end

    environ_module->>environ_module: snapshot = all(env)
    environ_module-->>User: yield snapshot

    User-->>environ_module: exit context

    loop for each key in kwargs
        environ_module->>EnvironmentState: env.contexts[key].pop()
        alt key in env.functions
            environ_module->>environ_module: prev_value = get(key, default=None, env)
            alt prev_value is not None
                environ_module->>EnvironmentState: env.functions[key](prev_value)
            end
        end
    end

    alt original_precision is not None and env is global_env
        environ_module->>JaxConfig: _set_jax_precision(original_precision)
    end

    environ_module-->>User: context exited
Loading

Class diagram for EnvironmentState and environment API

classDiagram
    class EnvironmentState {
        +Dict~str,Any~ settings
        +Dict~str,List[Any]~ contexts
        +Dict~str,Callable[Any,None]~ functions
        +Dict~str,Lock~ locks
        +__post_init__()
    }

    class _ENV_STATE {
        <<singleton>>
    }

    class environ_module {
        +reset(env: EnvironmentState) void
        +context(env: EnvironmentState, kwargs: Dict_str_Any) ContextManager_Dict_str_Any
        +get(key: str, default: Any, desc: str, env: EnvironmentState) Any
        +all(env: EnvironmentState) Dict_str_Any
        +pop(key: str, default: Any, env: EnvironmentState) Any
        +set(platform: PlatformType, host_device_count: int, precision: PrecisionType, dt: float, env: EnvironmentState, kwargs: Dict_str_Any) void
        +get_dt(env: EnvironmentState) float
        +set_precision(precision: PrecisionType, env: EnvironmentState) void
        +get_precision(env: EnvironmentState) int
        +dftype(env: EnvironmentState) DTypeLike
        +ditype(env: EnvironmentState) DTypeLike
        +dutype(env: EnvironmentState) DTypeLike
        +dctype(env: EnvironmentState) DTypeLike
        +tolerance(env: EnvironmentState) jnp_ndarray
        +register_default_behavior(key: str, behavior: Callable_Any_None, replace_if_exist: bool, env: EnvironmentState) void
        +unregister_default_behavior(key: str, env: EnvironmentState) bool
        +list_registered_behaviors(env: EnvironmentState) List_str
        +_get_precision(env: EnvironmentState) PrecisionType
    }

    class JaxConfig {
        +_set_jax_precision(precision: PrecisionType) void
    }

    _ENV_STATE --|> EnvironmentState
    environ_module --> EnvironmentState : uses env_or__ENV_STATE
    environ_module --> JaxConfig : calls_when_env_is__ENV_STATE
    EnvironmentState o--> Lock
    EnvironmentState o--> Callable
Loading

Flow diagram for reset behavior with optional EnvironmentState

flowchart TD
    A["reset(env=None_or_custom)"] --> B["env is None?"]
    B -->|"Yes"| C["target_env = _ENV_STATE"]
    B -->|"No"| D["target_env = env"]

    C --> E["Is target_env _ENV_STATE?"]
    D --> E

    E -->|"Yes (global env)"| F["_ENV_STATE = new EnvironmentState()
Reapply default precision via _set_jax_precision(DEFAULT_PRECISION)"]
    E -->|"No (custom env)"| G["Clear env.settings, env.contexts, env.functions
Set env.settings[PRECISION] = DEFAULT_PRECISION"]

    F --> H["Warn: environment has been reset"]
    G --> H

    H --> I["Return"]
Loading

File-Level Changes

Change Details Files
Allow reset() to operate on either the global EnvironmentState or a provided custom EnvironmentState instance.
  • Add optional keyword-only env parameter to reset() and update docstring with custom-environment usage example.
  • If env is None or matches the global _ENV_STATE, reinitialize _ENV_STATE and reapply default JAX precision as before.
  • If env is a custom EnvironmentState, clear its settings/contexts/functions and reinitialize its precision to DEFAULT_PRECISION without touching global JAX config.
  • Keep warning behavior about environment reset unchanged.
brainstate/environ.py
Thread an optional env parameter through environment context management and accessors so that they operate on a specific EnvironmentState rather than always on the global singleton.
  • Add optional keyword-only env parameter to context(), get(), all(), pop(), set(), get_dt(), set_precision(), get_precision(), _get_precision(), dftype(), ditype(), dutype(), dctype(), tolerance(), register_default_behavior(), unregister_default_behavior(), and list_registered_behaviors().
  • Default env parameters to the global _ENV_STATE when None, preserving existing behavior for callers that do not pass env.
  • Update internal uses of _ENV_STATE (locks, contexts, settings, functions) in these functions to instead use the resolved env object, including callback registration and invocation.
  • Ensure platform and host_device_count remain global/JAX-wide and not per-environment, preserving their existing semantics.
brainstate/environ.py
Ensure JAX configuration (precision and related dtype/tolerance behavior) is only updated when mutating the global environment, while still allowing custom EnvironmentState instances to track their own precision.
  • In context(), set(), and set_precision(), guard calls to _set_jax_precision() so they only execute when env is the global _ENV_STATE.
  • Use env-aware helpers (get_precision, _get_precision) for computing precision-driven types and tolerance, so custom env instances can override precision independently of the global one.
  • Maintain raw precision retrieval via _get_precision(env=...) and wire all dtype helpers and tolerance() through this env-aware path.
brainstate/environ.py
Expand documentation and examples to cover usage with custom EnvironmentState instances across the API surface.
  • Add EnvironmentState to the module’s all export list.
  • Augment docstrings for reset(), context(), get(), all(), pop(), set(), get_dt(), set_precision(), get_precision(), dftype(), ditype(), dutype(), dctype(), tolerance(), register_default_behavior(), unregister_default_behavior(), and list_registered_behaviors() with env parameter descriptions and custom-environment code examples.
  • Clarify notes regarding when JAX configuration is affected (only for global env) in relevant docstrings such as context(), set(), and set_precision().
brainstate/environ.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625
Copy link
Copy Markdown
Collaborator Author

@sourcery-ai title

@sourcery-ai sourcery-ai Bot changed the title Enhance environment management by adding support for custom EnvironmentState in reset, context, and various access functions Support custom EnvironmentState instances in environ helpers Dec 13, 2025
Copy link
Copy Markdown
Contributor

@sourcery-ai sourcery-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there - I've reviewed your changes - here's some feedback:

  • The custom-env branch of reset() clears settings, contexts, and functions but does not restore the default keys/structure established in EnvironmentState.__post_init__, so subsequent calls like context() that assume env.contexts[key]/env.locks[key] exist may raise KeyError; consider reinitializing the object (e.g. by re-running __post_init__ or constructing a new EnvironmentState) rather than manually clearing the dicts.
  • When resetting a custom environment, only PRECISION is restored to a default value while any other defaults that the global _ENV_STATE would get on construction (e.g. dt or other standard settings) are left unset; if you intend reset(env=custom_env) to mirror the global reset behavior, you may want to populate all default settings rather than only precision.
  • Several functions gate JAX configuration updates on env is _ENV_STATE; this identity check can become surprising if callers hold onto an old global environment reference across reset(), so it might be worth explicitly documenting that JAX config is tied only to the current global _ENV_STATE object, not to any EnvironmentState instance that used to be global.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- The custom-env branch of `reset()` clears `settings`, `contexts`, and `functions` but does not restore the default keys/structure established in `EnvironmentState.__post_init__`, so subsequent calls like `context()` that assume `env.contexts[key]`/`env.locks[key]` exist may raise `KeyError`; consider reinitializing the object (e.g. by re-running `__post_init__` or constructing a new `EnvironmentState`) rather than manually clearing the dicts.
- When resetting a custom environment, only `PRECISION` is restored to a default value while any other defaults that the global `_ENV_STATE` would get on construction (e.g. `dt` or other standard settings) are left unset; if you intend `reset(env=custom_env)` to mirror the global reset behavior, you may want to populate all default settings rather than only precision.
- Several functions gate JAX configuration updates on `env is _ENV_STATE`; this identity check can become surprising if callers hold onto an old global environment reference across `reset()`, so it might be worth explicitly documenting that JAX config is tied only to the current global `_ENV_STATE` object, not to any `EnvironmentState` instance that used to be global.

## Individual Comments

### Comment 1
<location> `brainstate/environ.py:232-238` </location>
<code_context>
+        _ENV_STATE = EnvironmentState()
+        # Re-apply default precision to JAX
+        _set_jax_precision(DEFAULT_PRECISION)
+    else:
+        # Reset the custom env by clearing its state
+        env.settings.clear()
+        env.contexts.clear()
+        env.functions.clear()
+        # Re-initialize with default precision
+        env.settings[PRECISION] = DEFAULT_PRECISION

     warnings.warn(
</code_context>

<issue_to_address>
**issue (bug_risk):** Custom env reset does not reapply other default state configured in EnvironmentState.__post_init__

In the non-global branch, `reset` clears `settings`, `contexts`, and `functions` and only restores the precision setting. Any other defaults established in `EnvironmentState.__post_init__` (e.g., additional `settings` keys, callbacks, or other fields) will not be re-applied, so custom envs can end up in a different default state than the global env. It would be safer to reuse the same initialization path as construction/`__post_init__` (e.g., via a helper or re-instantiation) so both global and custom envs are reset consistently.
</issue_to_address>

### Comment 2
<location> `brainstate/environ.py:356-357` </location>
<code_context>
+    - When using a custom env, JAX config is only updated if env is the global environment
     """
+    # Use global state if no env provided
+    if env is None:
+        env = _ENV_STATE
+
     # Validate restricted parameters
</code_context>

<issue_to_address>
**suggestion (bug_risk):** Using identity checks against _ENV_STATE to decide global behavior can be surprising after env resets

Several call sites (`context`, `set`, `set_precision`, `tolerance`) branch on `env is _ENV_STATE` to decide whether to update JAX global config. After `reset()` rebinds `_ENV_STATE`, any previously captured `EnvironmentState` instance effectively becomes a "custom" env and stops updating JAX config, even if callers still treat it as the global env. If this change in behavior isn’t intentional, consider representing "global" via an explicit flag on `EnvironmentState` or a helper like `_is_global_env(env)` that you can keep consistent when `_ENV_STATE` is replaced.

Suggested implementation:

```python
    - Settings are restored in reverse order when exiting
    - Thread-safe: each thread maintains its own context stack
    - When using a custom env, JAX config is only updated if env is the global environment
    """
    # Treat `env is None` as "use the current global environment".
    # This avoids tying global behavior to a specific EnvironmentState instance,
    # which can be surprising after `_ENV_STATE` is replaced by `reset()`.
    is_global_env = env is None
    if env is None:
        env = _ENV_STATE

```

```python
    # Handle precision changes (only update JAX config for global env)
    original_precision = None
    if PRECISION in kwargs:

```

To fully apply this pattern across the file and remove reliance on `env is _ENV_STATE` identity checks:

1. In this function (the one containing the snippet), replace any remaining usages of `env is _ENV_STATE` in branches that decide whether to update JAX global config with `is_global_env`.
2. In other call sites you mentioned (`context`, `set`, `set_precision`, `tolerance`), follow the same pattern:
   - At the start of each function, derive `is_global_env = env is None` and, if needed for reading/writing state, set `env = _ENV_STATE` when `env is None`.
   - Replace all `env is _ENV_STATE` checks with `is_global_env`.
3. If you currently expose `_ENV_STATE` to callers, prefer to have callers pass `env=None` when they intend to operate on the global environment; otherwise, those explicit `EnvironmentState` instances are correctly treated as custom environments and will not update JAX global config after a reset.
</issue_to_address>

### Comment 3
<location> `brainstate/environ.py:228-211` </location>
<code_context>
-    _ENV_STATE = EnvironmentState()
-    # Re-apply default precision
-    _set_jax_precision(DEFAULT_PRECISION)
+    if env is None or env is _ENV_STATE:
+        _ENV_STATE = EnvironmentState()
+        # Re-apply default precision to JAX
+        _set_jax_precision(DEFAULT_PRECISION)
+    else:
+        # Reset the custom env by clearing its state
</code_context>

<issue_to_address>
**suggestion:** Reset warning message may be misleading when only a custom env is reset

The `warnings.warn` below always says "Environment has been reset to default settings", even when only a custom `env` is reset and the global `_ENV_STATE` (and its JAX config) are unchanged. Consider adjusting the message (or branching it) so it clearly indicates whether the global or just the provided environment was reset.
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment thread brainstate/environ.py
Comment on lines +232 to +238
else:
# Reset the custom env by clearing its state
env.settings.clear()
env.contexts.clear()
env.functions.clear()
# Re-initialize with default precision
env.settings[PRECISION] = DEFAULT_PRECISION
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Custom env reset does not reapply other default state configured in EnvironmentState.post_init

In the non-global branch, reset clears settings, contexts, and functions and only restores the precision setting. Any other defaults established in EnvironmentState.__post_init__ (e.g., additional settings keys, callbacks, or other fields) will not be re-applied, so custom envs can end up in a different default state than the global env. It would be safer to reuse the same initialization path as construction/__post_init__ (e.g., via a helper or re-instantiation) so both global and custom envs are reset consistently.

Comment thread brainstate/environ.py
Comment on lines +356 to +357
if env is None:
env = _ENV_STATE
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Using identity checks against _ENV_STATE to decide global behavior can be surprising after env resets

Several call sites (context, set, set_precision, tolerance) branch on env is _ENV_STATE to decide whether to update JAX global config. After reset() rebinds _ENV_STATE, any previously captured EnvironmentState instance effectively becomes a "custom" env and stops updating JAX config, even if callers still treat it as the global env. If this change in behavior isn’t intentional, consider representing "global" via an explicit flag on EnvironmentState or a helper like _is_global_env(env) that you can keep consistent when _ENV_STATE is replaced.

Suggested implementation:

    - Settings are restored in reverse order when exiting
    - Thread-safe: each thread maintains its own context stack
    - When using a custom env, JAX config is only updated if env is the global environment
    """
    # Treat `env is None` as "use the current global environment".
    # This avoids tying global behavior to a specific EnvironmentState instance,
    # which can be surprising after `_ENV_STATE` is replaced by `reset()`.
    is_global_env = env is None
    if env is None:
        env = _ENV_STATE
    # Handle precision changes (only update JAX config for global env)
    original_precision = None
    if PRECISION in kwargs:

To fully apply this pattern across the file and remove reliance on env is _ENV_STATE identity checks:

  1. In this function (the one containing the snippet), replace any remaining usages of env is _ENV_STATE in branches that decide whether to update JAX global config with is_global_env.
  2. In other call sites you mentioned (context, set, set_precision, tolerance), follow the same pattern:
    • At the start of each function, derive is_global_env = env is None and, if needed for reading/writing state, set env = _ENV_STATE when env is None.
    • Replace all env is _ENV_STATE checks with is_global_env.
  3. If you currently expose _ENV_STATE to callers, prefer to have callers pass env=None when they intend to operate on the global environment; otherwise, those explicit EnvironmentState instances are correctly treated as custom environments and will not update JAX global config after a reset.

Comment thread brainstate/environ.py
@@ -201,14 +209,33 @@ def reset() -> None:
>>> env.reset()
>>> print(env.get('custom_param', default=None)) # None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Reset warning message may be misleading when only a custom env is reset

The warnings.warn below always says "Environment has been reset to default settings", even when only a custom env is reset and the global _ENV_STATE (and its JAX config) are unchanged. Consider adjusting the message (or branching it) so it clearly indicates whether the global or just the provided environment was reset.

@chaoming0625 chaoming0625 merged commit 63a394d into main Dec 13, 2025
6 checks passed
@chaoming0625 chaoming0625 deleted the environ branch December 13, 2025 07:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant