# Projections

Projections are `brainpy.state` 's mechanism for connecting neural populations.
They implement the **Communication-Synapse-Output (Comm-Syn-Out)** architecture,
which separates connectivity, synaptic dynamics, and output computation into modular components.


This guide provides a comprehensive understanding of projections in `brainpy.state`.


## Overview

### What are Projections?

A **projection** connects a presynaptic population to a postsynaptic population through:

1. **Communication (Comm)**: How spikes propagate through connections
2. **Synapse (Syn)**: Temporal filtering and synaptic dynamics
3. **Output (Out)**: How synaptic currents affect postsynaptic neurons


**Key benefits:**

- Modular design (swap components independently)
- Biologically realistic (separate connectivity and dynamics)
- Efficient (optimized sparse operations)
- Flexible (combine components in different ways)


### The Comm-Syn-Out Architecture

In [22]:
import brainstate
import braintools
import brainunit as u
import numpy as np

import brainpy

In [24]:
brainstate.environ.set(dt=0.1 * u.ms)

```text
Presynaptic        Communication         Synapse          Output        Postsynaptic
Population    ──►  (Connectivity)  ──►  (Dynamics)  ──►  (Current) ──►  Population

Spikes        ──►  Weight matrix   ──►  g(t)        ──►  I_syn     ──►  Neurons
                   Sparse/Dense         Expon/Alpha     CUBA/COBA
```

**Flow:**

1. Presynaptic spikes arrive
2. Communication: Spikes propagate through connectivity matrix
3. Synapse: Temporal dynamics filter the signal
4. Output: Convert to current/conductance
5. Postsynaptic neurons receive input



### Types of Projections

BrainPy provides two main projection types:

**AlignPostProj**
   - Align synaptic states with postsynaptic neurons
   - Most common for standard neural networks
   - Efficient memory layout

**AlignPreProj**
   - Align synaptic states with presynaptic neurons
   - Useful for certain learning rules
   - Different memory organization

For most use cases, use `AlignPostProj`.

## Communication Layer

The Communication layer defines **how spikes propagate** through connections.

### Dense Connectivity

All neurons potentially connected (though weights may be zero).

**Use case:** Small networks, fully connected layers

In [232]:
# Dense linear transformation
comm = brainstate.nn.Linear(
    100,  # in_size
    50,  # out_size
    w_init=braintools.init.KaimingNormal(),
    b_init=None  # No bias for synapses
)

**Characteristics:**

- Memory: O(n_pre × n_post)
- Computation: Full matrix multiplication
- Best for: Small networks, fully connected architectures

### Sparse Connectivity

Only a subset of connections exist (biologically realistic).

**Use case:** Large networks, biological connectivity patterns

#### Event-Based Fixed Probability

Connect neurons with fixed probability.

In [233]:
# Sparse random connectivity (2% connection probability)
comm = brainstate.nn.EventFixedProb(
    1000,  # pre_size
    800,  # post_size
    conn_num=0.02,  # 2% connectivity
    conn_weight=0.5  # Synaptic weight (unitless for event-based)
)

**Characteristics:**

- Memory: O(n_pre × n_post × prob)
- Computation: Only active connections
- Best for: Large-scale networks, biological models

#### Event-Based All-to-All

All neurons connected (but stored sparsely).

In [234]:
# All-to-all sparse (event-driven)
comm = brainstate.nn.AllToAll(
    100,  # pre_size
    100,  # post_size
    0.3  # Unitless weight
)

#### Event-Based One-to-One

One-to-one mapping (same size populations).

In [235]:
size = 100
weight = 1.0

# One-to-one connections
comm = brainstate.nn.OneToOne(
    size,
    weight  # Unitless weight
)

**Use case:** Feedforward pathways, identity mappings


## Synapse Layer

The Synapse layer defines **temporal dynamics** of synaptic transmission.

### Exponential Synapse

Single exponential decay (most common).

**Dynamics:**


$$
\tau \frac{dg}{dt} = -g + \sum_k \delta(t - t_k)
$$

**Implementation:**

In [236]:
# Exponential synapse with 5ms time constant
syn = brainpy.state.Expon(
    in_size=100,  # Postsynaptic population size
    tau=5.0 * u.ms  # Decay time constant
)

**Characteristics:**

- Single time constant
- Fast computation
- Good for most applications

**When to use:** Default choice for most models

### Alpha Synapse

Dual exponential with rise and decay.

**Dynamics:**


$$
\tau \frac{dg}{dt} = -g + h \\
\tau \frac{dh}{dt} = -h + \sum_k \delta(t - t_k)
$$
**Implementation:**

In [237]:
# Alpha synapse
syn = brainpy.state.Alpha(
    in_size=100,
    tau=10.0 * u.ms  # Characteristic time
)

**Characteristics:**

- Realistic rise time
- Smoother response
- Slightly slower computation

**When to use:** When rise time matters, more biological realism

### NMDA Synapse

Voltage-dependent NMDA receptors.

**Dynamics:**


$$
g_{NMDA} = \frac{g}{1 + \eta [Mg^{2+}] e^{-\gamma V}}
$$
**Implementation:**

In [238]:
# NMDA receptor
syn = brainpy.state.BioNMDA(
    in_size=100,
    T_dur=100.0 * u.ms,  # Slow decay
    T=2.0 * u.ms,  # Fast rise
    alpha1=0.5 / u.mM,  # Mg²⁺ sensitivity
    g_initializer=1.2 * u.mM  # Mg²⁺ concentration
)

**Characteristics:**

- Voltage-dependent
- Slow kinetics
- Important for plasticity

**When to use:** Long-term potentiation, working memory models

### AMPA Synapse

Fast glutamatergic transmission.

In [11]:
# AMPA receptor (fast excitation)
syn = brainpy.state.AMPA(
    in_size=100,
    beta=0.5 / u.ms,  # Fast decay (~2ms)
)

**When to use:** Fast excitatory transmission

### GABA Synapse

Inhibitory transmission.

**GABAa (fast):**

In [14]:
# GABAa receptor (fast inhibition)
syn = brainpy.state.GABAa(
    in_size=100,
    beta=0.16 / u.ms,  # ~6ms decay
)

**GABAb (slow):**

In [15]:
# GABAb receptor (slow inhibition)
syn = brainpy.state.GABAa(
    in_size=100,
    T_dur=150.0 * u.ms,  # Very slow
    T=3.5 * u.ms
)

**When to use:**
- GABAa: Fast inhibition, cortical networks
- GABAb: Slow inhibition, rhythm generation

### Custom Synapses

Create custom synaptic dynamics by subclassing `Synapse`.

In [16]:
class DoubleExpSynapse(brainpy.state.Synapse):
    """Custom synapse with two time constants."""

    def __init__(self, size, tau_fast=2 * u.ms, tau_slow=10 * u.ms, **kwargs):
        super().__init__(size, **kwargs)
        self.tau_fast = tau_fast
        self.tau_slow = tau_slow

        # State variables
        self.g_fast = brainstate.ShortTermState(jnp.zeros(size))
        self.g_slow = brainstate.ShortTermState(jnp.zeros(size))

    def reset_state(self, batch_size=None):
        shape = self.varshape if batch_size is None else (batch_size, *self.varshape)
        self.g_fast.value = jnp.zeros(shape)
        self.g_slow.value = jnp.zeros(shape)

    def update(self, x):
        dt = brainstate.environ.get_dt()

        # Fast component
        dg_fast = -self.g_fast.value / self.tau_fast.to_decimal(u.ms)
        self.g_fast.value += dg_fast * dt.to_decimal(u.ms) + x * 0.7

        # Slow component
        dg_slow = -self.g_slow.value / self.tau_slow.to_decimal(u.ms)
        self.g_slow.value += dg_slow * dt.to_decimal(u.ms) + x * 0.3

        return self.g_fast.value + self.g_slow.value

## Output Layer

The Output layer defines **how synaptic conductance affects neurons**.

### CUBA (Current-Based)

Synaptic conductance directly becomes current.

**Model:**


$$
I_{syn} = g_{syn}
$$
**Implementation:**

In [17]:
# Define population sizes
pre_size = 100
post_size = 50

# Define connectivity parameters
conn_num = 0.1
conn_weight = 0.5

comm = brainstate.nn.EventFixedProb(
    pre_size, post_size, conn_num, conn_weight
)

**Characteristics:**

- Simple and fast
- No voltage dependence
- Good for rate-based models

**When to use:**
- Abstract models
- When voltage dependence not important
- Faster computation needed

### COBA (Conductance-Based)

Synaptic conductance with reversal potential.

**Model:**


$$
I_{syn} = g_{syn} (E_{syn} - V_{post})
$$
**Implementation:**

In [18]:
# Excitatory conductance-based
out_exc = brainpy.state.COBA(E=0.0 * u.mV)

# Inhibitory conductance-based
out_inh = brainpy.state.COBA(E=-80.0 * u.mV)

**Characteristics:**

- Voltage-dependent
- Biologically realistic
- Self-limiting (saturates near reversal)

**When to use:**
- Biologically detailed models
- When voltage dependence matters
- Shunting inhibition needed

### MgBlock (NMDA)

Voltage-dependent magnesium block for NMDA.

In [19]:
# NMDA with Mg²⁺ block
out_nmda = brainpy.state.MgBlock(
    E=0.0 * u.mV,
    cc_Mg=1.2 * u.mM,
    alpha=0.062 / u.mV,
    beta=3.57
)

**When to use:** NMDA receptors, voltage-dependent plasticity

## Complete Projection Examples

### Example 1: Simple Feedforward

In [27]:
# Create populations
pre = brainpy.state.LIF(100, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)
post = brainpy.state.LIF(50, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)

# Create projection: 100 → 50 neurons
proj = brainpy.state.AlignPostProj(
    comm=brainstate.nn.EventFixedProb(
        100,  # pre_size
        50,  # post_size
        conn_num=0.1,  # 10% connectivity
        conn_weight=0.5 * u.mS  # Weight
    ),
    syn=brainpy.state.Expon(
        in_size=50,  # Postsynaptic size
        tau=5.0 * u.ms
    ),
    out=brainpy.state.CUBA(),
    post=post  # Postsynaptic population
)

# Initialize
brainstate.nn.init_all_states([pre, post, proj])


# Simulate
def step(t, i, inp):
    with brainstate.environ.context(t=t, i=i):
        # Update neurons
        pre(inp)

        # Get presynaptic spikes
        pre_spikes = pre.get_spike()

        # Update projection
        proj(pre_spikes)

        post(0.0 * u.nA)  # Projection provides input

        return pre.get_spike(), post.get_spike()


indices = np.arange(1000)
times = indices * brainstate.environ.get_dt()
inputs = brainstate.random.uniform(30., 50., indices.shape) * u.nA
_ = brainstate.transform.for_loop(step, times, indices, inputs)

(Array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 Array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32))

### Example 2: Excitatory-Inhibitory Network

In [32]:
class EINetwork(brainstate.nn.Module):
    def __init__(self, n_exc=800, n_inh=200):
        super().__init__()

        # Populations
        self.E = brainpy.state.LIF(n_exc, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=15 * u.ms)
        self.I = brainpy.state.LIF(n_inh, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)

        # E → E projection (AMPA, excitatory)
        self.E2E = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_exc, n_exc, conn_num=0.02, conn_weight=0.6 * u.mS),
            syn=brainpy.state.Expon(n_exc, tau=2. * u.ms),
            out=brainpy.state.COBA(E=0.0 * u.mV),
            post=self.E
        )

        # E → I projection (AMPA, excitatory)
        self.E2I = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_exc, n_inh, conn_num=0.02, conn_weight=0.6 * u.mS),
            syn=brainpy.state.Expon(n_inh, tau=2. * u.ms),
            out=brainpy.state.COBA(E=0.0 * u.mV),
            post=self.I
        )

        # I → E projection (GABAa, inhibitory)
        self.I2E = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_inh, n_exc, conn_num=0.02, conn_weight=6.7 * u.mS),
            syn=brainpy.state.Expon(n_exc, tau=6. * u.ms),
            out=brainpy.state.COBA(E=-80.0 * u.mV),
            post=self.E
        )

        # I → I projection (GABAa, inhibitory)
        self.I2I = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_inh, n_inh, conn_num=0.02, conn_weight=6.7 * u.mS),
            syn=brainpy.state.Expon(n_inh, tau=6. * u.ms),
            out=brainpy.state.COBA(E=-80.0 * u.mV),
            post=self.I
        )

    def update(self, i, inp_e, inp_i):
        t = brainstate.environ.get_dt() * i
        with brainstate.environ.context(t=t, i=i):
            # Get spikes BEFORE updating neurons
            spk_e = self.E.get_spike()
            spk_i = self.I.get_spike()

            # Update all projections
            self.E2E(spk_e)
            self.E2I(spk_e)
            self.I2E(spk_i)
            self.I2I(spk_i)

            # Update neurons (projections provide synaptic input)
            self.E(inp_e)
            self.I(inp_i)

            return spk_e, spk_i


net = EINetwork()
brainstate.nn.init_all_states(net)
_ = brainstate.transform.for_loop(net.update, indices, inputs, inputs)

### Example 3: Multi-Timescale Synapses

Combine AMPA (fast) and NMDA (slow) for realistic excitation.

In [248]:
class DualExcitatory(brainstate.nn.Module):
    """E → E with both AMPA and NMDA."""

    def __init__(self, n_pre=100, n_post=100):
        super().__init__()

        self.post = brainpy.state.LIF(n_post, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)

        # Fast AMPA component
        self.ampa_proj = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_pre, n_post, conn_num=0.1, conn_weight=0.3 * u.mS),
            syn=brainpy.state.AMPA(n_post, tau=2.0 * u.ms),
            out=brainpy.state.COBA(E=0.0 * u.mV),
            post=self.post
        )

        # Slow NMDA component
        self.nmda_proj = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_pre, n_post, conn_num=0.1, conn_weight=0.3 * u.mS),
            syn=brainpy.state.NMDA(n_post, tau_decay=100.0 * u.ms, tau_rise=2.0 * u.ms),
            out=brainpy.state.MgBlock(E=0.0 * u.mV, cc_Mg=1.2 * u.mM),
            post=self.post
        )

    def update(self, t, i, pre_spikes):
        with brainstate.environ.context(t=t, i=i):
            # Both projections share same presynaptic spikes
            self.ampa_proj(pre_spikes)
            self.nmda_proj(pre_spikes)

            # Post receives combined input
            self.post(0.0 * u.nA)

            return self.post.get_spike()

### Example 4: Delay Projections

Add synaptic delays to projections.

In [36]:


# To implement delay, use a separate Delay module
delay_time = 5.0 * u.ms


# Create a network with delay
class DelayedProjection(brainstate.nn.Module):
    def __init__(self, pre_size, post_size):
        super().__init__()

        # Define post_neurons for demonstration
        self.post = brainpy.state.LIF(100, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)
        self.delay = self.post.output_delay(delay_time)

        # Standard projection
        self.proj = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(pre_size, post_size, conn_num=0.1, conn_weight=0.5 * u.mS),
            syn=brainpy.state.Expon(post_size, tau=5.0 * u.ms),
            out=brainpy.state.CUBA(),
            post=self.post
        )

    def update(self, inp=0. * u.nA):
        # Retrieve delayed spikes
        delayed_spikes = self.delay()
        # Update projection with delayed spikes
        self.proj(delayed_spikes)
        self.post(inp)
        # Store current spikes in delay buffer
        self.delay(self.post.get_spike())

    def step_run(self, i, inp):
        t = brainstate.environ.get_dt() * i
        with brainstate.environ.context(t=t, i=i):
            # Update post neurons
            self.update(inp)
            return self.post.get_spike()


net = DelayedProjection(100, 100)
brainstate.nn.init_all_states(net)
_ = brainstate.transform.for_loop(net.step_run, indices, inputs)