Support custom EnvironmentState instances in environ helpers#118
Support custom EnvironmentState instances in environ helpers#118chaoming0625 merged 2 commits intomainfrom
Conversation
…ntState in reset, context, and various access functions
Reviewer's GuideExtend environment management API to support multiple independent EnvironmentState instances by threading an optional env parameter through core accessors (reset, context, get/set, dtype helpers, callbacks, etc.) while ensuring JAX configuration is only mutated for the global environment and updating docs accordingly. Sequence diagram for set_precision with optional EnvironmentStatesequenceDiagram
actor User
participant environ_module
participant EnvironmentState as env
participant _ENV_STATE as global_env
participant JaxConfig
User->>environ_module: set_precision(precision, env=None_or_custom)
alt env is None
environ_module->>global_env: use _ENV_STATE as env
else env is provided
environ_module->>EnvironmentState: use provided env
end
environ_module->>environ_module: _validate_precision(precision)
alt env is global_env
environ_module->>JaxConfig: _set_jax_precision(precision)
else env is custom EnvironmentState
note over environ_module,JaxConfig: JAX config is not modified
end
environ_module->>EnvironmentState: env.settings[PRECISION] = precision
alt PRECISION in env.functions
environ_module->>EnvironmentState: env.functions[PRECISION](precision)
else no callback
environ_module-->>User: return
end
environ_module-->>User: precision updated in selected env
Sequence diagram for context manager with custom EnvironmentStatesequenceDiagram
actor User
participant environ_module
participant EnvironmentState as env
participant _ENV_STATE as global_env
participant JaxConfig
User->>environ_module: context(env=custom_or_None, **kwargs)
alt env is None
environ_module->>global_env: env = _ENV_STATE
else env is provided
environ_module->>EnvironmentState: use provided env
end
alt PRECISION in kwargs
environ_module->>environ_module: original_precision = _get_precision(env)
environ_module->>environ_module: _validate_precision(kwargs[PRECISION])
alt env is global_env
environ_module->>JaxConfig: _set_jax_precision(kwargs[PRECISION])
else env is custom EnvironmentState
note over environ_module,JaxConfig: JAX precision unchanged
end
end
loop for each key,value in kwargs
environ_module->>EnvironmentState: env.contexts[key].append(value)
alt key in env.functions
environ_module->>EnvironmentState: env.functions[key](value)
end
end
environ_module->>environ_module: snapshot = all(env)
environ_module-->>User: yield snapshot
User-->>environ_module: exit context
loop for each key in kwargs
environ_module->>EnvironmentState: env.contexts[key].pop()
alt key in env.functions
environ_module->>environ_module: prev_value = get(key, default=None, env)
alt prev_value is not None
environ_module->>EnvironmentState: env.functions[key](prev_value)
end
end
end
alt original_precision is not None and env is global_env
environ_module->>JaxConfig: _set_jax_precision(original_precision)
end
environ_module-->>User: context exited
Class diagram for EnvironmentState and environment APIclassDiagram
class EnvironmentState {
+Dict~str,Any~ settings
+Dict~str,List[Any]~ contexts
+Dict~str,Callable[Any,None]~ functions
+Dict~str,Lock~ locks
+__post_init__()
}
class _ENV_STATE {
<<singleton>>
}
class environ_module {
+reset(env: EnvironmentState) void
+context(env: EnvironmentState, kwargs: Dict_str_Any) ContextManager_Dict_str_Any
+get(key: str, default: Any, desc: str, env: EnvironmentState) Any
+all(env: EnvironmentState) Dict_str_Any
+pop(key: str, default: Any, env: EnvironmentState) Any
+set(platform: PlatformType, host_device_count: int, precision: PrecisionType, dt: float, env: EnvironmentState, kwargs: Dict_str_Any) void
+get_dt(env: EnvironmentState) float
+set_precision(precision: PrecisionType, env: EnvironmentState) void
+get_precision(env: EnvironmentState) int
+dftype(env: EnvironmentState) DTypeLike
+ditype(env: EnvironmentState) DTypeLike
+dutype(env: EnvironmentState) DTypeLike
+dctype(env: EnvironmentState) DTypeLike
+tolerance(env: EnvironmentState) jnp_ndarray
+register_default_behavior(key: str, behavior: Callable_Any_None, replace_if_exist: bool, env: EnvironmentState) void
+unregister_default_behavior(key: str, env: EnvironmentState) bool
+list_registered_behaviors(env: EnvironmentState) List_str
+_get_precision(env: EnvironmentState) PrecisionType
}
class JaxConfig {
+_set_jax_precision(precision: PrecisionType) void
}
_ENV_STATE --|> EnvironmentState
environ_module --> EnvironmentState : uses env_or__ENV_STATE
environ_module --> JaxConfig : calls_when_env_is__ENV_STATE
EnvironmentState o--> Lock
EnvironmentState o--> Callable
Flow diagram for reset behavior with optional EnvironmentStateflowchart TD
A["reset(env=None_or_custom)"] --> B["env is None?"]
B -->|"Yes"| C["target_env = _ENV_STATE"]
B -->|"No"| D["target_env = env"]
C --> E["Is target_env _ENV_STATE?"]
D --> E
E -->|"Yes (global env)"| F["_ENV_STATE = new EnvironmentState()
Reapply default precision via _set_jax_precision(DEFAULT_PRECISION)"]
E -->|"No (custom env)"| G["Clear env.settings, env.contexts, env.functions
Set env.settings[PRECISION] = DEFAULT_PRECISION"]
F --> H["Warn: environment has been reset"]
G --> H
H --> I["Return"]
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
@sourcery-ai title |
There was a problem hiding this comment.
Hey there - I've reviewed your changes - here's some feedback:
- The custom-env branch of
reset()clearssettings,contexts, andfunctionsbut does not restore the default keys/structure established inEnvironmentState.__post_init__, so subsequent calls likecontext()that assumeenv.contexts[key]/env.locks[key]exist may raiseKeyError; consider reinitializing the object (e.g. by re-running__post_init__or constructing a newEnvironmentState) rather than manually clearing the dicts. - When resetting a custom environment, only
PRECISIONis restored to a default value while any other defaults that the global_ENV_STATEwould get on construction (e.g.dtor other standard settings) are left unset; if you intendreset(env=custom_env)to mirror the global reset behavior, you may want to populate all default settings rather than only precision. - Several functions gate JAX configuration updates on
env is _ENV_STATE; this identity check can become surprising if callers hold onto an old global environment reference acrossreset(), so it might be worth explicitly documenting that JAX config is tied only to the current global_ENV_STATEobject, not to anyEnvironmentStateinstance that used to be global.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The custom-env branch of `reset()` clears `settings`, `contexts`, and `functions` but does not restore the default keys/structure established in `EnvironmentState.__post_init__`, so subsequent calls like `context()` that assume `env.contexts[key]`/`env.locks[key]` exist may raise `KeyError`; consider reinitializing the object (e.g. by re-running `__post_init__` or constructing a new `EnvironmentState`) rather than manually clearing the dicts.
- When resetting a custom environment, only `PRECISION` is restored to a default value while any other defaults that the global `_ENV_STATE` would get on construction (e.g. `dt` or other standard settings) are left unset; if you intend `reset(env=custom_env)` to mirror the global reset behavior, you may want to populate all default settings rather than only precision.
- Several functions gate JAX configuration updates on `env is _ENV_STATE`; this identity check can become surprising if callers hold onto an old global environment reference across `reset()`, so it might be worth explicitly documenting that JAX config is tied only to the current global `_ENV_STATE` object, not to any `EnvironmentState` instance that used to be global.
## Individual Comments
### Comment 1
<location> `brainstate/environ.py:232-238` </location>
<code_context>
+ _ENV_STATE = EnvironmentState()
+ # Re-apply default precision to JAX
+ _set_jax_precision(DEFAULT_PRECISION)
+ else:
+ # Reset the custom env by clearing its state
+ env.settings.clear()
+ env.contexts.clear()
+ env.functions.clear()
+ # Re-initialize with default precision
+ env.settings[PRECISION] = DEFAULT_PRECISION
warnings.warn(
</code_context>
<issue_to_address>
**issue (bug_risk):** Custom env reset does not reapply other default state configured in EnvironmentState.__post_init__
In the non-global branch, `reset` clears `settings`, `contexts`, and `functions` and only restores the precision setting. Any other defaults established in `EnvironmentState.__post_init__` (e.g., additional `settings` keys, callbacks, or other fields) will not be re-applied, so custom envs can end up in a different default state than the global env. It would be safer to reuse the same initialization path as construction/`__post_init__` (e.g., via a helper or re-instantiation) so both global and custom envs are reset consistently.
</issue_to_address>
### Comment 2
<location> `brainstate/environ.py:356-357` </location>
<code_context>
+ - When using a custom env, JAX config is only updated if env is the global environment
"""
+ # Use global state if no env provided
+ if env is None:
+ env = _ENV_STATE
+
# Validate restricted parameters
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Using identity checks against _ENV_STATE to decide global behavior can be surprising after env resets
Several call sites (`context`, `set`, `set_precision`, `tolerance`) branch on `env is _ENV_STATE` to decide whether to update JAX global config. After `reset()` rebinds `_ENV_STATE`, any previously captured `EnvironmentState` instance effectively becomes a "custom" env and stops updating JAX config, even if callers still treat it as the global env. If this change in behavior isn’t intentional, consider representing "global" via an explicit flag on `EnvironmentState` or a helper like `_is_global_env(env)` that you can keep consistent when `_ENV_STATE` is replaced.
Suggested implementation:
```python
- Settings are restored in reverse order when exiting
- Thread-safe: each thread maintains its own context stack
- When using a custom env, JAX config is only updated if env is the global environment
"""
# Treat `env is None` as "use the current global environment".
# This avoids tying global behavior to a specific EnvironmentState instance,
# which can be surprising after `_ENV_STATE` is replaced by `reset()`.
is_global_env = env is None
if env is None:
env = _ENV_STATE
```
```python
# Handle precision changes (only update JAX config for global env)
original_precision = None
if PRECISION in kwargs:
```
To fully apply this pattern across the file and remove reliance on `env is _ENV_STATE` identity checks:
1. In this function (the one containing the snippet), replace any remaining usages of `env is _ENV_STATE` in branches that decide whether to update JAX global config with `is_global_env`.
2. In other call sites you mentioned (`context`, `set`, `set_precision`, `tolerance`), follow the same pattern:
- At the start of each function, derive `is_global_env = env is None` and, if needed for reading/writing state, set `env = _ENV_STATE` when `env is None`.
- Replace all `env is _ENV_STATE` checks with `is_global_env`.
3. If you currently expose `_ENV_STATE` to callers, prefer to have callers pass `env=None` when they intend to operate on the global environment; otherwise, those explicit `EnvironmentState` instances are correctly treated as custom environments and will not update JAX global config after a reset.
</issue_to_address>
### Comment 3
<location> `brainstate/environ.py:228-211` </location>
<code_context>
- _ENV_STATE = EnvironmentState()
- # Re-apply default precision
- _set_jax_precision(DEFAULT_PRECISION)
+ if env is None or env is _ENV_STATE:
+ _ENV_STATE = EnvironmentState()
+ # Re-apply default precision to JAX
+ _set_jax_precision(DEFAULT_PRECISION)
+ else:
+ # Reset the custom env by clearing its state
</code_context>
<issue_to_address>
**suggestion:** Reset warning message may be misleading when only a custom env is reset
The `warnings.warn` below always says "Environment has been reset to default settings", even when only a custom `env` is reset and the global `_ENV_STATE` (and its JAX config) are unchanged. Consider adjusting the message (or branching it) so it clearly indicates whether the global or just the provided environment was reset.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| else: | ||
| # Reset the custom env by clearing its state | ||
| env.settings.clear() | ||
| env.contexts.clear() | ||
| env.functions.clear() | ||
| # Re-initialize with default precision | ||
| env.settings[PRECISION] = DEFAULT_PRECISION |
There was a problem hiding this comment.
issue (bug_risk): Custom env reset does not reapply other default state configured in EnvironmentState.post_init
In the non-global branch, reset clears settings, contexts, and functions and only restores the precision setting. Any other defaults established in EnvironmentState.__post_init__ (e.g., additional settings keys, callbacks, or other fields) will not be re-applied, so custom envs can end up in a different default state than the global env. It would be safer to reuse the same initialization path as construction/__post_init__ (e.g., via a helper or re-instantiation) so both global and custom envs are reset consistently.
| if env is None: | ||
| env = _ENV_STATE |
There was a problem hiding this comment.
suggestion (bug_risk): Using identity checks against _ENV_STATE to decide global behavior can be surprising after env resets
Several call sites (context, set, set_precision, tolerance) branch on env is _ENV_STATE to decide whether to update JAX global config. After reset() rebinds _ENV_STATE, any previously captured EnvironmentState instance effectively becomes a "custom" env and stops updating JAX config, even if callers still treat it as the global env. If this change in behavior isn’t intentional, consider representing "global" via an explicit flag on EnvironmentState or a helper like _is_global_env(env) that you can keep consistent when _ENV_STATE is replaced.
Suggested implementation:
- Settings are restored in reverse order when exiting
- Thread-safe: each thread maintains its own context stack
- When using a custom env, JAX config is only updated if env is the global environment
"""
# Treat `env is None` as "use the current global environment".
# This avoids tying global behavior to a specific EnvironmentState instance,
# which can be surprising after `_ENV_STATE` is replaced by `reset()`.
is_global_env = env is None
if env is None:
env = _ENV_STATE # Handle precision changes (only update JAX config for global env)
original_precision = None
if PRECISION in kwargs:To fully apply this pattern across the file and remove reliance on env is _ENV_STATE identity checks:
- In this function (the one containing the snippet), replace any remaining usages of
env is _ENV_STATEin branches that decide whether to update JAX global config withis_global_env. - In other call sites you mentioned (
context,set,set_precision,tolerance), follow the same pattern:- At the start of each function, derive
is_global_env = env is Noneand, if needed for reading/writing state, setenv = _ENV_STATEwhenenv is None. - Replace all
env is _ENV_STATEchecks withis_global_env.
- At the start of each function, derive
- If you currently expose
_ENV_STATEto callers, prefer to have callers passenv=Nonewhen they intend to operate on the global environment; otherwise, those explicitEnvironmentStateinstances are correctly treated as custom environments and will not update JAX global config after a reset.
| @@ -201,14 +209,33 @@ def reset() -> None: | |||
| >>> env.reset() | |||
| >>> print(env.get('custom_param', default=None)) # None | |||
|
|
|||
There was a problem hiding this comment.
suggestion: Reset warning message may be misleading when only a custom env is reset
The warnings.warn below always says "Environment has been reset to default settings", even when only a custom env is reset and the global _ENV_STATE (and its JAX config) are unchanged. Consider adjusting the message (or branching it) so it clearly indicates whether the global or just the provided environment was reset.
Summary by Sourcery
Add support for operating on custom EnvironmentState instances across environment management utilities while keeping JAX configuration tied to the global environment.
Enhancements: