# State Management in BrainState

In dynamical brain modeling, time-varying state variables are often encountered, such as the membrane potential `V` of neurons or the firing rate `r` in firing rate models. **BrainState** provides the `State` data structure, which helps users intuitively define and manage computational states.

This tutorial provides a detailed introduction to state management in BrainState. By following this tutorial, you will learn:

- The basic concepts and fundamental usage of `State` objects
- How to create `State` objects and use its subclasses: `ShortTermState`, `LongTermState`, `HiddenState`, and `ParamState`
- State and JAX PyTree compatibility
- How to use `StateTraceStack` to track State objects in your programs
- Advanced state management patterns with `StateDictManager`

In [None]:
import jax.numpy as jnp
import brainstate as bst

## 1. Basic Concepts and Usage of State Objects

`State` is a key data structure in **BrainState** used to encapsulate state variables in models. These variables primarily represent values that change over time within the model.

### Why States?

JAX is built on functional programming principles, which means:
- All data is immutable by default
- Functions cannot have side effects
- State must be explicitly threaded through computations

This creates a challenge for neural network programming, where we naturally think in terms of mutable states (weights, neuron voltages, etc.). **BrainState's `State`** solves this by:

‚úÖ Providing a mutable interface for state variables  
‚úÖ Automatically managing state updates during JAX transformations  
‚úÖ Maintaining compatibility with JAX's functional paradigm  

### Creating States

A `State` can wrap any Python data type, such as integers, floating-point numbers, arrays, `jax.Array`, or any of these encapsulated in dictionaries or lists. Unlike native Python data structures, the data within a `State` object remains mutable after program compilation.

In [None]:
# Create a simple State with an array
example = bst.State(jnp.ones(10))
example

### States and PyTrees

`State` supports arbitrary [PyTree](https://jax.readthedocs.io/en/latest/working-with-pytrees.html) structures, which means you can encapsulate complex nested data structures within a `State` object. This is particularly useful for models with hierarchical state representations.

In [None]:
# State can hold complex PyTree structures
example2 = bst.State({'a': jnp.ones(3), 'b': jnp.zeros(4)})
example2

In [None]:
# State can also hold nested structures
complex_state = bst.State({
    'neurons': {
        'V': jnp.zeros(100),
        'u': jnp.zeros(100)
    },
    'synapses': {
        'g': jnp.zeros((100, 100)),
        'weights': jnp.ones((100, 100)) * 0.1
    }
})
print("Complex state structure:")
print(complex_state)

### Accessing and Updating States

Users can access and modify state data through the `State.value` attribute.

In [None]:
# Access the state value
print("Current value:", example.value)

In [None]:
# Update the state value
example.value = bst.random.random(3)
print("Updated state:")
example

### Core Features of State

**‚úÖ Mutable after compilation**: State values can be updated even in JIT-compiled functions

**‚úÖ Type and shape safety**: States enforce consistent types and shapes

**‚úÖ Integration with JAX**: Works seamlessly with JAX transformations

### Important Notes

‚ö†Ô∏è **Static Data in JIT Compilation**: Any data not marked as a state variable will be treated as static during JIT compilation. Modifying static data in a JIT-compiled environment has no effect.

‚ö†Ô∏è **Constraints on Modifying State Data**: When updating via the `value` attribute, the assigned data must have the same PyTree structure as the original. The shape and dtype should generally match, though some flexibility is allowed.

In [None]:
# Demonstrate tree structure checking
state = bst.ShortTermState(jnp.zeros((2, 3)))

with bst.check_state_value_tree():
    # This works - same tree structure
    state.value = jnp.zeros((2, 3))
    print("‚úì Successfully updated state with matching structure")
    
    # This fails - different tree structure
    try:
        state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
    except Exception as e:
        print(f"‚úó Error: {e}")

## 2. Subclasses of State

**BrainState** provides several subclasses of `State` to help organize different types of state variables in your models. While these subclasses are functionally identical to the base `State` class, they serve as semantic markers that:

- üìù Improve code readability
- üîç Enable selective filtering (e.g., finding all trainable parameters)
- üéØ Clarify the role of each state variable

### Overview of State Types

| State Type | Purpose | Examples |
|------------|---------|----------|
| `ParamState` | Trainable parameters | Weights, biases |
| `HiddenState` | Hidden activations | Membrane potentials, RNN hidden states |
| `ShortTermState` | Transient states | Last spike time, current input |
| `LongTermState` | Persistent states | Running averages, momentum |

### 2.1 ParamState - Trainable Parameters

`ParamState` is used for trainable parameters in neural networks. These are the values that get updated during training via gradient descent.

In [None]:
# Example: Neural network parameters
weight = bst.ParamState(bst.random.randn(10, 10) * 0.1)
bias = bst.ParamState(jnp.zeros(10))

print("Weight:")
print(weight)
print("\nBias:")
print(bias)

### 2.2 HiddenState - Hidden Activations

`HiddenState` encapsulates hidden activation variables in models. These states are updated during every simulation iteration and retained between iterations, representing the internal dynamics of the model.

In [None]:
# Example: Neuron membrane potential
V = bst.HiddenState(jnp.full(10, -70.0))  # Resting potential

# Example: RNN hidden state
h = bst.HiddenState(jnp.zeros((32, 128)))  # (batch_size, hidden_dim)

print("Membrane potential:")
print(V)
print("\nRNN hidden state:")
print(h)

### 2.3 ShortTermState - Transient States

`ShortTermState` is designed for short-term, transient state variables. These states capture instantaneous values that may not carry long-term dependencies.

In [None]:
# Example: Last spike time
t_last_spike = bst.ShortTermState(jnp.full(10, -1e7))  # Very old time

# Example: Current input
current_input = bst.ShortTermState(jnp.zeros(10))

print("Last spike times:")
print(t_last_spike)
print("\nCurrent input:")
print(current_input)

### 2.4 LongTermState - Persistent States

`LongTermState` is used for long-term state variables that accumulate information over many iterations. These are commonly used for statistics tracking and optimization algorithms.

In [None]:
# Example: Running mean for batch normalization
running_mean = bst.LongTermState(jnp.zeros(64))
running_var = bst.LongTermState(jnp.ones(64))

# Example: Optimizer momentum
momentum = bst.LongTermState(jnp.zeros((100, 100)))

print("Running mean:")
print(running_mean)
print("\nMomentum:")
print(momentum)

### Practical Example: LIF Neuron Model

Let's see how different state types work together in a realistic model:

In [None]:
class LIFNeuron(bst.graph.Node):
    """Leaky Integrate-and-Fire neuron model."""
    
    def __init__(self, n_neurons, tau=10.0, V_th=1.0, V_reset=0.0):
        super().__init__()
        self.tau = tau
        self.V_th = V_th
        self.V_reset = V_reset
        
        # Hidden state: membrane potential (evolves continuously)
        self.V = bst.HiddenState(jnp.full(n_neurons, V_reset))
        
        # Short-term state: refractory period counter
        self.t_last_spike = bst.ShortTermState(jnp.full(n_neurons, -1e7))
        
        # Parameters: input weights
        self.w_in = bst.ParamState(bst.random.randn(n_neurons, n_neurons) * 0.1)
    
    def __call__(self, I_ext, t):
        # Membrane potential dynamics
        dV = (-self.V.value + I_ext) / self.tau
        self.V.value = self.V.value + dV
        
        # Spike generation
        spike = self.V.value >= self.V_th
        
        # Reset
        self.V.value = jnp.where(spike, self.V_reset, self.V.value)
        self.t_last_spike.value = jnp.where(spike, t, self.t_last_spike.value)
        
        return spike

# Create and test the neuron
neuron = LIFNeuron(n_neurons=5)
print("Initial state:")
print(f"V: {neuron.V.value}")

# Simulate
for t in range(20):
    I_ext = jnp.ones(5) * 0.2  # External current
    spikes = neuron(I_ext, t)
    if jnp.any(spikes):
        print(f"t={t}: Spikes at neurons {jnp.where(spikes)[0]}")

## 3. State Tracking with StateTraceStack

`StateTraceStack` is a powerful debugging and introspection tool that tracks which `State` objects are accessed during program execution.

### Why Track States?

- üîç **Debugging**: Understand which states are being read/written
- üìä **Profiling**: Identify state access patterns
- üéØ **Selective updates**: Apply operations only to specific state types
- üß™ **Testing**: Verify expected state interactions

### Basic Usage

In [None]:
class Linear(bst.graph.Node):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.w = bst.ParamState(bst.random.randn(d_in, d_out) * 0.1)
        self.b = bst.ParamState(jnp.zeros(d_out))
        self.y = bst.HiddenState(jnp.zeros(d_out))
    
    def __call__(self, x):
        self.y.value = x @ self.w.value + self.b.value
        return self.y.value

model = Linear(2, 5)

# Track state access
with bst.StateTraceStack() as stack:
    output = model(bst.random.randn(2))
    
    # Get accessed states
    read_states = list(stack.get_read_states())
    write_states = list(stack.get_write_states())

print(f"States read: {len(read_states)}")
print(f"States written: {len(write_states)}")

### Inspecting State Access

`StateTraceStack` provides four main methods:

- `get_read_states()`: Returns State objects that were read
- `get_read_state_values()`: Returns the values of read states
- `get_write_states()`: Returns State objects that were written
- `get_write_state_values()`: Returns the values of written states

In [None]:
# Inspect read states
print("=== Read States ===")
for i, state in enumerate(read_states):
    print(f"{i+1}. {type(state).__name__}: shape={state.value.shape}")

In [None]:
# Inspect written states
print("=== Written States ===")
for i, state in enumerate(write_states):
    print(f"{i+1}. {type(state).__name__}: shape={state.value.shape if hasattr(state.value, 'shape') else 'N/A'}")

## 4. Advanced: StateDictManager

`StateDictManager` provides utilities for managing collections of states, particularly useful for:

- Collecting all states of a specific type (e.g., all `ParamState`s)
- Saving and loading model checkpoints
- Freezing/unfreezing parameters

### Finding States in a Model

In [None]:
# Create a model with various state types
class MultiStateModel(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.weight = bst.ParamState(jnp.ones((5, 5)))
        self.bias = bst.ParamState(jnp.zeros(5))
        self.activation = bst.HiddenState(jnp.zeros(5))
        self.running_mean = bst.LongTermState(jnp.zeros(5))

model = MultiStateModel()

# Use StateDictManager to find states
from brainstate.util import DictManager

# This is typically done internally, but here's how it works:
print("Model structure:")
print(model)

## Summary

In this tutorial, you learned:

‚úÖ **States** provide mutable variables compatible with JAX  
‚úÖ Different **state types** serve different purposes:  
  - `ParamState` for trainable parameters  
  - `HiddenState` for hidden activations  
  - `ShortTermState` for transient states  
  - `LongTermState` for persistent states  
‚úÖ **StateTraceStack** tracks state access for debugging  
‚úÖ States support **PyTree structures** for complex data  

### Best Practices

1. üéØ Use specific state types (`ParamState`, etc.) rather than generic `State`
2. üìù Keep state updates simple and explicit
3. üîç Use `StateTraceStack` for debugging unexpected behavior
4. ‚ö†Ô∏è Remember: only `State` values are mutable; regular variables are static

### Next Steps

Continue with:
- **Random Number Generation** - Learn about stateful random number generation
- **Neural Network Modules** - Build complex models using states
- **Program Transformations** - Use states with JIT, grad, and vmap