# Overview

``brainpy.state`` represents a complete architectural redesign built on top of the ``brainstate`` framework. This document explains the design principles and architectural components that make ``brainpy.state`` powerful and flexible.

## Design Philosophy

``brainpy.state``  is built around several core principles:

**State-Based Programming**
   All dynamical variables are managed as explicit states, enabling automatic differentiation, efficient compilation, and clear data flow.

**Modular Composition**
   Complex models are built by composing simple, reusable components. Each component has a well-defined interface and responsibility.

**Scientific Accuracy**
   Integration with ``brainunit`` ensures physical correctness and prevents unit-related errors.

**Performance by Default**
   JIT compilation and optimization are built into the framework, not an afterthought.

**Extensibility**
   Adding new neuron models, synapse types, or learning rules is straightforward and follows clear patterns.

## Architectural Layers

brainpy.state  is organized into several layers:

```text
┌─────────────────────────────────────────┐
│         User Models & Networks          │  ← Your code
├─────────────────────────────────────────┤
│      BrainPy Components Layer           │  ← Neurons, Synapses, Projections
├─────────────────────────────────────────┤
│       BrainState Framework              │  ← State management, compilation
├─────────────────────────────────────────┤
│       JAX + XLA Backend                 │  ← JIT compilation, autodiff
└─────────────────────────────────────────┘
```

### 1. JAX + XLA Backend

The foundation layer provides:

- Just-In-Time (JIT) compilation
- Automatic differentiation
- Hardware acceleration (CPU/GPU/TPU)
- Functional transformations (vmap, grad, etc.)

### 2. BrainState Framework

Built on JAX, ``brainstate`` provides:

- State management system
- Module composition
- Compilation and optimization
- Program transformations (for_loop, etc.)

### 3. BrainPy Components

High-level neuroscience-specific components:

- Neuron models (LIF, ALIF, etc.)
- Synapse models (Expon, Alpha, etc.)
- Projection architectures
- Learning rules and plasticity

### 4. User Models

Your custom networks and experiments built using BrainPy components.

## State Management System

### The Foundation: ``brainstate.State``

Everything in ``brainpy.state``  revolves around **states**:

In [1]:
import brainpy
import brainstate
import braintools
import brainunit as u
import jax.numpy as jnp

# Create a state
voltage = brainstate.State(0.0)  # Single value
weights = brainstate.State([[0.1, 0.2], [0.3, 0.4]])  # Matrix

States are special containers that:

- Track their values across time
- Support automatic differentiation
- Enable efficient compilation
- Handle batching automatically

### State Types

BrainPy uses different state types for different purposes:

**ParamState** - Trainable Parameters
   Used for weights, time constants, and other trainable parameters.

In [2]:
class MyNeuron(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.tau = brainstate.ParamState(10.0)  # Trainable
        self.weight = brainstate.ParamState([[0.1, 0.2]])

**ShortTermState** - Temporary Variables
   Used for membrane potentials, synaptic currents, and other dynamics.

In [3]:
class MyNeuron(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.V = brainstate.ShortTermState(jnp.zeros(size))  # Dynamic
        self.spike = brainstate.ShortTermState(jnp.zeros(size))

### State Initialization

States can be initialized with various strategies:

In [4]:
# Define example size and shape
size = 100  # Number of neurons
shape = (100, 50)  # Weight matrix shape

# Constant initialization
V = brainstate.ShortTermState(
    braintools.init.Constant(-65.0, unit=u.mV)(size)
)

# Normal distribution
V = brainstate.ShortTermState(
    braintools.init.Normal(-65.0, 5.0, unit=u.mV)(size)
)

# Uniform distribution
weights = brainstate.ParamState(
    braintools.init.Uniform(0.0, 1.0)(shape)
)

## Module System

### Base Class: brainstate.nn.Module

All BrainPy components inherit from ``brainstate.nn.Module``:

In [5]:
class MyComponent(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        # Initialize states
        self.state1 = brainstate.ShortTermState(jnp.zeros(size))
        self.param1 = brainstate.ParamState(jnp.ones(size))

    def update(self, input):
        # Define dynamics
        pass

Benefits of Module:

- Automatic state registration
- Nested module support
- State collection and filtering
- Serialization support

### Module Composition

Modules can contain other modules:

```python

class Network(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.neurons = brainpy.state.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
        self.synapse = brainpy.state.Expon(100, tau=5*u.ms)
        self.projection = brainpy.state.AlignPostProj(...)  # Example - requires more setup

    def update(self, input):
        # Compose behavior
        self.projection(spikes)  # Example
        self.neurons(input)

```


## Component Architecture

### Neurons

Neurons model the dynamics of neural populations:

In [6]:
class Neuron(brainstate.nn.Module):
    def __init__(self, size, **kwargs):
        super().__init__()
        # Membrane potential
        self.V = brainstate.ShortTermState(jnp.zeros(size))
        # Spike output
        self.spike = brainstate.ShortTermState(jnp.zeros(size))

    def update(self, input_current):
        # Update membrane potential
        # Generate spikes
        pass

Key responsibilities:

- Maintain membrane potential
- Generate spikes when threshold is crossed
- Reset after spiking
- Integrate input currents

### Synapses

Synapses model temporal filtering of spike trains:

In [7]:
class Synapse(brainstate.nn.Module):
    def __init__(self, size, tau, **kwargs):
        super().__init__()
        # Synaptic conductance/current
        self.g = brainstate.ShortTermState(jnp.zeros(size))
        self.tau = tau

    def update(self, spike_input):
        # Update synaptic variable
        # Return filtered output
        pass

Key responsibilities:

- Filter spike inputs temporally
- Model synaptic dynamics (exponential, alpha, etc.)
- Provide smooth currents to postsynaptic neurons

### Projections: The Comm-Syn-Out Pattern

Projections connect populations using a three-stage architecture:

```text
Presynaptic Spikes → [Comm] → [Syn] → [Out] → Postsynaptic Neurons
                      │         │       │
                  Connectivity  │    Current
                  & Weights   Dynamics  Injection
```

**Communication (Comm)**
   Handles spike transmission, connectivity, and weights.

In [8]:
# Define population sizes
pre_size = 100
post_size = 50

# Define prob and weight
prob=0.1
weight=0.5

comm = brainstate.nn.EventFixedProb(
    pre_size, post_size, prob, weight
)

**Synaptic Dynamics (Syn)**
   Temporal filtering of transmitted spikes.

In [9]:
post_size = 50  # Postsynaptic population size

syn = brainpy.state.Expon.desc(post_size, tau=5*u.ms)

**Output Mechanism (Out)**
   How synaptic variables affect postsynaptic neurons.

In [10]:
# Current-based output
out = brainpy.state.CUBA.desc()  

# Or conductance-based output
out = brainpy.state.COBA.desc(E=0*u.mV)

**Complete Projection**

In [11]:
# Define postsynaptic neurons
postsynaptic_neurons = brainpy.state.LIF(50, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)

# Create complete projection
projection = brainpy.state.AlignPostProj(
    comm=comm,
    syn=syn,
    out=out,
    post=postsynaptic_neurons
)

This separation provides:

- Clear responsibility boundaries
- Easy component swapping
- Reusable building blocks
- Better testing and debugging

## Compilation and Execution

### Time-Stepped Simulation

BrainPy uses discrete time steps:

In [12]:
# Example: create a simple network
class SimpleNetwork(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.neurons = brainpy.state.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
    
    def update(self, t, i):
        # Generate constant input current
        inp = jnp.ones(100) * 5.0 * u.nA
        with brainstate.environ.context(t=t, i=i):
            self.neurons(inp)
            return self.neurons.get_spike()

network = SimpleNetwork()
brainstate.nn.init_all_states(network)

# Set global time step
brainstate.environ.set(dt=0.1 * u.ms)

# Define simulation duration
times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())
indices = u.math.arange(times.size)

# Run simulation
results = brainstate.transform.for_loop(
    network.update,
    times,
    indices,
    pbar=brainstate.transform.ProgressBar(10)
)

  0%|          | 0/10000 [00:00<?, ?it/s]

### JIT Compilation

Functions are compiled for performance:

In [13]:
# Create example input
input_example = jnp.ones(100) * 2.0 * u.nA

@brainstate.transform.jit
def simulate_step(t, i, input_current):
    with brainstate.environ.context(t=t, i=i):
        return network.update(t, i)

# First call: compile
result = simulate_step(0.0*u.ms, 0, input_example)  # Slow (compilation)

# Subsequent calls: fast
result = simulate_step(0.1*u.ms, 1, input_example)  # Fast (compiled)

Compilation benefits:

- 10-100x speedup over Python
- Automatic GPU/TPU dispatch
- Memory optimization
- Fusion of operations

### Gradient Computation

For training, gradients are computed automatically:

In [14]:
# Example: Define mock functions for demonstration
def compute_loss(predictions, targets):
    return jnp.mean((predictions.astype(float) - targets) ** 2)

# Mock targets
num_steps = 100
targets = jnp.zeros((num_steps, 100))

def loss_fn():
    # Run network for multiple timesteps
    def step(t, i):
        with brainstate.environ.context(t=t, i=i):
            return network.update(t, i)
    
    times = u.math.arange(0*u.ms, num_steps*brainstate.environ.get_dt(), brainstate.environ.get_dt())
    indices = u.math.arange(times.size)
    predictions = brainstate.transform.for_loop(step, times, indices)
    return compute_loss(predictions, targets)

# Get trainable parameters
params = network.states(brainstate.ParamState)

# Compute gradients
if len(params) > 0:
    optimizer = braintools.optim.Adam(lr=1e-3)
    grads, loss = brainstate.transform.grad(
        loss_fn,
        grad_states=params,
        return_value=True
    )()
    print(f"Loss: {loss}")
    # Update parameters with optimizer (if defined)
    optimizer.update(grads)
else:
    # If no trainable parameters, just compute loss
    loss = loss_fn()
    print(f"Loss (no trainable params): {loss}")

Loss (no trainable params): 0.0


## Physical Units System

### Integration with brainunit

``brainpy.state`` integrates ``brainunit`` for scientific accuracy:

In [15]:
# Define with units
tau = 10 * u.ms
threshold = -50 * u.mV
current = 5 * u.nA

# Units are checked automatically
neuron = brainpy.state.LIF(100, tau=tau, V_th=threshold)

Benefits:

- Prevents unit errors (e.g., ms vs s)
- Self-documenting code
- Automatic unit conversions
- Scientific correctness

### Unit Operations

In [16]:
# Arithmetic with units
total_time = 100 * u.ms + 0.5 * u.second  # → 600 ms

# Unit conversion
time_in_seconds = (100 * u.ms).to_decimal(u.second)  # → 0.1

# Unit checking (automatic in BrainPy operations)
voltage = -65 * u.mV
current = 2 * u.nA
resistance = voltage / current  # Automatically gives MΩ

## Ecosystem Integration

``brainpy.state`` integrates tightly with its ecosystem:

### braintools

Utilities and tools:

In [17]:
# Optimizers
optimizer = braintools.optim.Adam(lr=1e-3)

# Initializers
init = braintools.init.KaimingNormal()

# Surrogate gradients
spike_fn = braintools.surrogate.ReluGrad()

# Metrics (example with dummy data)
# pred = jnp.array([0.1, 0.9])
# target = jnp.array([0, 1])
# loss = braintools.metric.cross_entropy(pred, target)

### brainunit

Physical units:

In [18]:
# All standard SI units
time = 10 * u.ms
voltage = -65 * u.mV
current = 2 * u.nA

### brainstate

Core framework (used automatically):

In [19]:
import brainstate

# Module system
class Net(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        pass

# Compilation
@brainstate.transform.jit
def fn():
    return 0

# Transformations
# result = brainstate.transform.for_loop(...)

## Data Flow Example

Here's how data flows through a typical ``brainpy.state`` simulation:

In [20]:
# 1. Define network
class EINetwork(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.E = brainpy.state.LIF(800, V_rest=-65*u.mV, V_th=-50*u.mV, tau=15*u.ms)
        self.I = brainpy.state.LIF(200, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
        
        # Example projections (simplified - full setup requires more code)
        # self.E2E = brainpy.state.AlignPostProj(...)
        # self.E2I = brainpy.state.AlignPostProj(...)
        # self.I2E = brainpy.state.AlignPostProj(...)
        # self.I2I = brainpy.state.AlignPostProj(...)

    def update(self, input):
        # Get spikes from last time step
        e_spikes = self.E.get_spike()
        i_spikes = self.I.get_spike()

        # Update projections (spikes → synaptic currents)
        # self.E2E(e_spikes)  # Updates E2E.syn.g
        # self.E2I(e_spikes)
        # self.I2E(i_spikes)
        # self.I2I(i_spikes)

        # Update neurons (currents → new V and spikes)
        self.E(input[:800] if len(input) >= 800 else input)
        self.I(input[800:] if len(input) > 800 else jnp.zeros(200) * u.nA)

        return e_spikes, i_spikes

# 2. Initialize
net = EINetwork()
brainstate.nn.init_all_states(net)

# 3. Compile
@brainstate.transform.jit
def step(input):
    return net.update(input)

# 4. Simulate (commented out for quick execution)
# times = u.math.arange(0*u.ms, 1000*u.ms, 0.1*u.ms)
# results = brainstate.transform.for_loop(step, times)

State Flow:

```text
Time t:
┌──────────────────────────────────────────┐
│  States at t-1:                          │
│    E.V[t-1], E.spike[t-1]               │
│    I.V[t-1], I.spike[t-1]               │
│    E2E.syn.g[t-1], ...                  │
└──────────────────────────────────────────┘
                ↓
┌──────────────────────────────────────────┐
│  Projection Updates:                     │
│    E2E.syn.g[t] = f(g[t-1], E.spike[t-1])│
│    ... (other projections)               │
└──────────────────────────────────────────┘
                ↓
┌──────────────────────────────────────────┐
│  Neuron Updates:                         │
│    E.V[t] = f(V[t-1], Σ currents[t])   │
│    E.spike[t] = E.V[t] >= V_th          │
│    ... (other neurons)                   │
└──────────────────────────────────────────┘
                ↓
Time t+1...
```

## Performance Considerations

### Memory Management

- States are preallocated
- In-place updates when possible
- Efficient batching support
- Automatic garbage collection

### Compilation Strategy

- Compile simulation loops
- Batch operations when possible
- Use ``for_loop`` for long sequences
- Leverage JAX's XLA optimization

### Hardware Acceleration

- Automatic GPU dispatch for large arrays
- TPU support for massive parallelism
- Efficient CPU fallback for small problems

## Summary

``brainpy.state`` 's architecture provides:

✅ **Clear Abstractions**: Neurons, synapses, and projections with well-defined roles

✅ **State Management**: Explicit, efficient handling of dynamical variables

✅ **Modularity**: Compose complex models from simple components

✅ **Performance**: JIT compilation and hardware acceleration

✅ **Scientific Accuracy**: Integrated physical units

✅ **Extensibility**: Easy to add custom components

✅ **Modern Design**: Built on proven frameworks (JAX, brainstate)

## Next Steps

- Learn about specific components: [neurons](neurons.ipynb), [synapses](synapses.ipynb), [projections](projections.ipynb)
- Understand state management in depth: [state-management](state-management.ipynb)
- See practical examples in the [tutorials](../tutorials/basic/01-lif-neuron.ipynb)
- Explore the ecosystem: [brainstate docs](https://brainstate.readthedocs.io/)