# Overview

`brainpy.state` introduces a modern, state-based architecture built on top of `brainstate`. This overview will help you understand the key concepts and design philosophy.

## What's New

`brainpy.state` has been completely rewritten to provide:

- **State-based programming**: Built on `brainstate` for efficient state management
- **Modular architecture**: Clear separation of concerns (communication, dynamics, outputs)
- **Physical units**: Integration with `brainunit` for scientifically accurate simulations
- **Modern API**: Cleaner, more intuitive interfaces
- **Better performance**: Optimized JIT compilation and memory management

## Key Architectural Components

`brainpy.state` is organized around several core concepts:

### 1. State Management

Everything in `brainpy.state` revolves around **states**. States are variables that persist across time steps:

- `brainstate.State`: Base state container
- `brainstate.ParamState`: Trainable parameters
- `brainstate.ShortTermState`: Temporary variables

States enable:

- Automatic differentiation for training
- Efficient memory management
- Batching and parallelization

### 2. Neurons

Neurons are the fundamental computational units:

In [40]:
import brainpy
import brainunit as u

# Create a population of 100 LIF neurons
neurons = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)

Key neuron models:

- `brainpy.state.IF`: Integrate-and-Fire
- `brainpy.state.LIF`: Leaky Integrate-and-Fire
- `brainpy.state.LIFRef`: LIF with refractory period
- `brainpy.state.ALIF`: Adaptive LIF

### 3. Synapses

Synapses model the dynamics of neural connections:

In [41]:
# Exponential synapse
synapse = brainpy.state.Expon(100, tau=5*u.ms)

# Alpha synapse (more realistic)
synapse = brainpy.state.Alpha(100, tau=5*u.ms)

Synapse models:

- `brainpy.state.Expon`: Single exponential decay
- `brainpy.state.Alpha`: Double exponential (alpha function)
- `brainpy.state.AMPA`: Excitatory receptor dynamics
- `brainpy.state.GABAa`: Inhibitory receptor dynamics

### 4. Projections

Projections connect neural populations:

In [42]:
import brainstate

N_pre=100
N_post=50
prob=0.1
weight=0.5

projection = brainpy.state.AlignPostProj(
    comm=brainstate.nn.EventFixedProb(N_pre, N_post, prob, weight),
    syn=brainpy.state.Expon.desc(N_post, tau=5*u.ms),
    out=brainpy.state.CUBA.desc(),
    post=neurons
)

The projection architecture separates:

- **Communication**: How spikes are transmitted (connectivity, weights)
- **Synaptic dynamics**: How synapses respond (temporal filtering)
- **Output mechanism**: How synaptic currents affect neurons (CUBA/COBA)

### 5. Networks

Networks combine neurons and projections:

In [43]:
class EINet(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.E = brainpy.state.LIF(800)
        self.I = brainpy.state.LIF(200)
        self.E2E = brainpy.state.AlignPostProj(...)
        self.E2I = brainpy.state.AlignPostProj(...)
        # ... more projections

    def update(self, input):
        # Define network dynamics
        pass

## Computational Model

### Time-Stepped Simulation

BrainPy uses discrete time steps for simulation:

In [45]:
# Set simulation time step
brainstate.environ.set(dt=0.1 * u.ms)

# Create a simple neuron for demonstration
neurons = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)

# Initialize all states
brainstate.nn.init_all_states(neurons)

# Run simulation
def step(t, i):
    with brainstate.environ.context(t=t, i=i):
        # Provide input current to the neurons
        neurons.update(5 * u.nA)
        return neurons.get_spike()

times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())
indices = u.math.arange(times.size)
results = brainstate.transform.for_loop(step, times, indices)

### JIT Compilation

BrainPy leverages JAX for Just-In-Time compilation:

In [46]:
# Create a simple network for demonstration
network = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)
brainstate.nn.init_all_states(network)

# Define input current
input_current = 5 * u.nA

# JIT-compiled simulation function
@brainstate.transform.jit
def simulate(t, i):
    with brainstate.environ.context(t=t, i=i):
        network.update(input_current)
        return network.get_spike()

# First call compiles, subsequent calls are fast
times = u.math.arange(0*u.ms, 100*u.ms, brainstate.environ.get_dt())
indices = u.math.arange(times.size)
result = brainstate.transform.for_loop(simulate, times, indices)

Benefits:

- Near-C performance
- Automatic GPU/TPU dispatch
- Optimized memory usage

### Physical Units

``brainpy.state`` integrates `brainunit` for scientific accuracy:

In [47]:
import brainunit as u

# Define parameters with units
tau = 10 * u.ms
V_threshold = -50 * u.mV
current = 5 * u.nA

# Units are checked automatically
neurons = brainpy.state.LIF(100, tau=tau, V_th=V_threshold)

This prevents unit-related bugs and makes code self-documenting.

## Training and Learning

``brainpy.state``  supports gradient-based training:

In [50]:
import braintools

# Create a simple network for training
net = brainpy.state.LIF(10, tau=10*u.ms, V_th=-50*u.mV)
brainstate.nn.init_all_states(net)

# Define optimizer
optimizer = braintools.optim.Adam(lr=1e-3)
optimizer.register_trainable_weights(net.states(brainstate.ParamState))

# Prepare dummy data for demonstration
num_steps = 100
inputs = u.math.ones((num_steps,)) * 5 * u.nA
targets = u.math.zeros((num_steps, 10))  # dummy target

# Define loss function
def loss_fn():
    def step(t, i, inp):
        with brainstate.environ.context(t=t, i=i):
            net.update(inp)
            return net.spike.value
    
    times = u.math.arange(0*u.ms, num_steps*brainstate.environ.get_dt(), brainstate.environ.get_dt())
    indices = u.math.arange(times.size)
    predictions = brainstate.transform.for_loop(step, times, indices, inputs)
    # Simple MSE loss
    return u.math.mean((predictions.astype(float) - targets) ** 2)

# Training step
@brainstate.transform.jit
def train_step():
    grads, loss_value = brainstate.transform.grad(
        loss_fn,
        net.states(brainstate.ParamState),
        return_value=True
    )()
    optimizer.update(grads)
    return loss_value

Key features:

- Surrogate gradients for spiking neurons
- Automatic differentiation
- Various optimizers (Adam, SGD, etc.)

## Ecosystem Components

`brainpy.state` is part of a larger ecosystem:

### brainstate

The foundation for state management and compilation:

- State-based IR construction
- JIT compilation
- Program augmentation (batching, etc.)

### brainunit

Physical units system:

- SI units support
- Automatic unit checking
- Unit conversions

### braintools

Utilities and tools:

- Optimizers (`braintools.optim`)
- Initialization (`braintools.init`)
- Metrics and losses (`braintools.metric`)
- Surrogate gradients (`braintools.surrogate`)
- Visualization (`braintools.visualize`)

## Design Philosophy

`brainpy.state` follows these principles:

1. **Explicit over implicit**: Clear, readable code
2. **Modular composition**: Build complex models from simple components
3. **Performance by default**: JIT compilation and optimization built-in
4. **Scientific accuracy**: Physical units and biologically realistic models
5. **Extensibility**: Easy to add custom components

## Next Steps

Now that you understand the core concepts:

- Try the [5-minute tutorial](5min-tutorial.ipynb) to get hands-on experience
- Read the detailed [core concepts](../core-concepts/index.rst) documentation
- Explore [basic tutorials](../tutorials/index.rst) to learn each component
- Check out the [examples gallery](../examples/gallery.rst) for real-world models