# Introduction to ``State``

In dynamical brain modeling, time-varying state variables are often encountered, such as the membrane potential ``V`` of neurons or the firing rate ``r`` in firing rate models. ``BrainState`` provides the ``State`` data structure, which helps users intuitively define and manage computational states. This tutorial provides a detailed introduction to the usage of ``State``. By following this tutorial, you will learn:
	•	The basic concepts and fundamental usage of ``State`` objects.
	•	How to create ``State`` objects and the usage of its subclasses: ``ShortTermState``, ``LongTermState``, ``HiddenState``, and ``ParamState``.
	•	How to use ``StateTraceStack`` to track the State objects used in a program.

In [8]:
import jax.numpy as jnp

import brainstate 
import brainunit as u

## 1. Basic Concepts and Usage of ``State`` Objects

``State`` is a key data structure in ``BrainState`` used to encapsulate state variables in models. These variables primarily represent values that change over time within the model. A ``State`` can wrap any Python data type, such as integers, floating-point numbers, arrays, ``jax.Array``, or any of these encapsulated in dictionaries or arrays. Unlike native Python data structures, the data within a ``State`` object remains mutable after program compilation.

For example, if a user needs to define a state array, a ``State`` object can be defined as follows:

In [9]:
example = brainstate.State(jnp.ones(10))

example

State(
  value=ShapedArray(float32[10])
)

Furthermore, ``State`` supports arbitrary [PyTree](https://jax.readthedocs.io/en/latest/working-with-pytrees.html), which means users can encapsulate any data structure within a ``State`` object. This allows for convenient state management and computation.

In [10]:
example2 = brainstate.State({'a': jnp.ones(3), 'b': jnp.zeros(4)})

example2

State(
  value={
    'a': ShapedArray(float32[3]),
    'b': ShapedArray(float32[4])
  }
)

Updating state variables is an essential operation. Users can access and modify these data through the ``State.value`` attribute. For example, accessing a state variable:

In [11]:
example.value

Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)

You can see that the returned value is the array we initially passed in. Next, we can modify this array using an assignment operation:

In [12]:
example.value = brainstate.random.random(3)

example

State(
  value=ShapedArray(float32[3])
)

From the output, we can see that the data within the ``State`` object has been successfully modified. The data in a ``State`` object is mutable, meaning users can update the values of state variables at any point during program execution. This feature is crucial for dynamical brain modeling, where state variables often change over time.

Core Features of State Syntax:
- Encapsulation of mutable variables: All quantities that need to change are encapsulated within ``State`` objects, making it easier for users to track and debug model states.
- Immutability of non-state variables: Variables that are not encapsulated in ``State`` objects are immutable and cannot be modified after program compilation.

Important Notes:
1.	Static Data in JIT Compilation: Any data not marked as a state variable will be treated as static during JIT compilation. Modifying static data in a JIT-compiled environment has no effect.
2.  Constraints on Modifying ``State`` Data: Updating the data within a ``State`` object via the ``value`` attribute is subject to certain conditions. The assigned data must have the same type and shape as the original data. Otherwise, an error will be raised.

In [13]:
state = brainstate.ShortTermState(jnp.zeros((2, 3)))

with brainstate.check_state_value_tree():
    state.value = jnp.zeros((2, 3))

    try:
        state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
    except Exception as e:
        print(e)

The given value PyTreeDef((*, *)) does not match with the origin tree structure PyTreeDef(*).


## 2. Subclasses of ``State``

``BrainState`` provides several subclasses of ``State``, allowing users to select the most suitable one based on their needs. Below, we will explain the usage of four subclasses: ``ShortTermState``, ``LongTermState``, ``HiddenState``, and ``ParamState``.

> Note: The subclasses of ``State`` are designed for better management of state variables. While there are no functional differences between these subclasses at the code implementation level, they help users distinguish between different types of state variables in their models. Users can choose the appropriate subclass based on their specific requirements.

### 2.1 ``ShortTermState``

``ShortTermState`` is a subclass of ``State`` designed to encapsulate short-term state variables in models. The data in a ``ShortTermState`` object is updated during every simulation iteration. Users can use ``ShortTermState`` objects to define short-term state variables in their models, such as a neuron’s last spike time (last_spike_time).

These state variables primarily capture instantaneous states in the model. They evolve over time but do not exhibit direct temporal dependencies, which is why they are referred to as short-term state variables.

In [14]:
# Example
short_term_state = brainstate.ShortTermState(jnp.ones(5))

import braintools
# In Neuron model, we can use ShortTermState to record the last spike time of neuron
t_last_spike = brainstate.ShortTermState((braintools.init.param(braintools.init.Constant(-1e7 * u.ms), sizes=(10,))))
t_last_spike

ShortTermState(
  value=~float32[10] * msecond
)

### 2.2 ``LongTermState``

``LongTermState`` is a subclass of ``State`` used to encapsulate long-term state variables in models. The data in a ``LongTermState`` object is updated during every simulation iteration but is retained between iterations. This means long-term state variables preserve historical information over time, where previous state values influence subsequent iterations.

Users can use ``LongTermState`` objects to define long-term state variables in their models. For example, when calculating a moving average, the average value and variance can be defined as long-term state variables.

In [15]:
# Example
long_term_state = brainstate.LongTermState(jnp.ones(5))

# We can use LongTermState to record the running mean of a variable
running_mean = brainstate.LongTermState(jnp.zeros(5))
running_mean

LongTermState(
  value=ShapedArray(float32[5])
)

### 2.3 ``HiddenState``

``HiddenState`` is a subclass of ``State`` designed to encapsulate hidden state variables in models. Similar to ``LongTermState``, the data in a ``HiddenState`` object is updated during every simulation iteration and retained between iterations. Its usage is identical to that of ``LongTermState``.

Users can define hidden state variables in their models using ``HiddenState`` objects. Examples include a neuron’s membrane potential (``V``), synaptic conductance (``g``), and postsynaptic current (``I``).

In [16]:
# Example
hidden_state = brainstate.HiddenState(jnp.ones(5))

# In Neuron model, we can use HiddenState to record the membrane potential of neuron
V = brainstate.HiddenState((braintools.init.param(braintools.init.Constant(-70. * u.mV), sizes=(10,))))
V

HiddenState(
  value=~float32[10] * mvolt
)

### 2.4 ``ParamState``

``ParamState`` is a subclass of ``State`` used to encapsulate trainable parameters in a model. ``ParamState`` objects are primarily used to define parameters in trainable models, such as neural network weights (``w``) and biases (``b``). Users can define parameters in their models using ``ParamState`` objects.

In [17]:
# Example
param_state = brainstate.ParamState(jnp.ones(5))

# In Neural Network model, we can use ParamState to define the weight of the network
weight = brainstate.ParamState((braintools.init.param(braintools.init.Constant(0.1), sizes=(10, 10), batch_size=2)))
weight

ParamState(
  value=ShapedArray(float32[2,10,10], weak_type=True)
)

It is important to emphasize once again that users can directly use the ``State`` object if they choose not to use the four subclasses: ``ShortTermState``, ``LongTermState``, ``HiddenState``, and ``ParamState``. The primary purpose of these subclasses is to help users better manage state variables in the model and distinguish between different types of state variables.

## 3. Using ``StateTraceStack``

``StateTraceStack`` is an important tool in ``BrainState`` for tracking ``State`` objects used within a program. Since ``BrainState`` treats all ``State`` objects as intermediate variables during program compilation, there is no global space to store all ``State`` instances. Instead, the ``State`` objects used within a function are managed and stored temporarily, and they are released once the function execution ends.

Given this, how can we determine which ``State`` objects are used in a specific segment of code? ``StateTraceStack`` offers an excellent solution. Users can utilize ``StateTraceStack`` to view the ``State`` objects used in a program, enabling centralized management of these states.

``StateTraceStack`` can be used as a context manager. Here’s an example: Suppose we define a linear layer, and during the execution of the layer’s call, we use ``StateTraceStack`` to record the ``State`` information. We can then output the read or modified State information through different functions.

In [18]:
class Linear(brainstate.graph.Node):
    def __init__(self, din: int, dout: int):
        self.din, self.dout = din, dout
        self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
        self.b = brainstate.ParamState(jnp.zeros((dout,)))
        self.y = brainstate.HiddenState(jnp.zeros((dout,)))

    def __call__(self, x):
        self.y.value = x @ self.w.value + self.b.value
        return self.y.value
    
model = Linear(2, 5)

with brainstate.StateTraceStack() as stack:
    model(brainstate.random.rand(2))
    states_to_be_read = [st for st in stack.get_read_states()]
    states_values_to_be_read = [st for st in stack.get_read_state_values()] 
    states_to_be_written = [st for st in stack.get_write_states()]
    states_values_to_be_written = [st for st in stack.get_write_state_values()]

We can use four methods of ``StateTraceStack``: ``get_read_states``, ``get_read_state_values``, ``get_write_states``, and ``get_write_state_values`` to retrieve the ``State`` objects used in the program. This allows us to better track the ``State`` objects used in the program, facilitating centralized management of these ``State`` objects.

First, the ``get_read_states`` method returns the State objects that are read during program execution:

In [19]:
states_to_be_read

[ParamState(
   value=ShapedArray(float32[2,5])
 ),
 ParamState(
   value=ShapedArray(float32[5])
 )]

Next, the ``get_read_state_values`` method returns the ``values`` stored in the ``State`` objects that were read during program execution:

In [20]:
states_values_to_be_read

[Array([[0.4205774 , 0.17630506, 0.29401565, 0.5609126 , 0.66944456],
        [0.46394718, 0.8783361 , 0.64317834, 0.43470597, 0.64010286]],      dtype=float32),
 Array([0., 0., 0., 0., 0.], dtype=float32)]

The ``get_write_states`` method returns the ``State`` objects that were written to during program execution:

In [21]:
states_to_be_written

[RandomState([3852626694 3095418825]),
 HiddenState(
   value=ShapedArray(float32[5])
 )]

The ``get_write_state_values`` method returns the values stored in the ``State`` objects that were written to during program execution:

In [22]:
states_values_to_be_written

[Array([3852626694, 3095418825], dtype=uint32),
 Array([0.30106878, 0.1906452 , 0.24051404, 0.38418475, 0.46994978],      dtype=float32)]