diff --git a/.github/workflows/greetings.yml b/.github/workflows/greetings.yml
index 69598583..6112fe20 100644
--- a/.github/workflows/greetings.yml
+++ b/.github/workflows/greetings.yml
@@ -18,7 +18,7 @@ jobs:
uses: actions/first-interaction@v3
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
- issue-message: >
+ issue_message: >
๐ Welcome to BrainPy! Thank you for opening your first issue.
@@ -36,7 +36,7 @@ jobs:
Happy coding! ๐ง โจ
- pr-message: >
+ pr_message: >
๐ Congratulations on opening your first pull request in BrainPy! Thank you for
your contribution!
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
deleted file mode 100644
index 17b3adb1..00000000
--- a/.github/workflows/stale.yml
+++ /dev/null
@@ -1,58 +0,0 @@
-name: Close Stale Issues and PRs
-
-on:
- schedule:
- - cron: '0 0 * * *' # Run daily at midnight UTC
- workflow_dispatch: # Allow manual triggers
-
-permissions:
- issues: write
- pull-requests: write
-
-jobs:
- stale:
- runs-on: ubuntu-latest
- steps:
- - name: Mark/Close Stale Issues and PRs
- uses: actions/stale@v10
- with:
- repo-token: ${{ secrets.GITHUB_TOKEN }}
-
- # Issue settings
- days-before-issue-stale: 90
- days-before-issue-close: 14
- stale-issue-message: >
- This issue has been automatically marked as stale because it has not had
- recent activity in the last 90 days. It will be closed in 14 days if no
- further activity occurs. If this issue is still relevant, please comment
- to keep it open. Thank you for your contributions!
- close-issue-message: >
- This issue was automatically closed because it has been stale for 14 days
- with no activity. If you believe this is still relevant, please feel free
- to reopen it or create a new issue.
- stale-issue-label: 'stale'
-
- # PR settings
- days-before-pr-stale: 60
- days-before-pr-close: 14
- stale-pr-message: >
- This pull request has been automatically marked as stale because it has not had
- recent activity in the last 60 days. It will be closed in 14 days if no
- further activity occurs. If this PR is still relevant, please comment or push
- new commits to keep it open. Thank you for your contributions!
- close-pr-message: >
- This pull request was automatically closed because it has been stale for 14 days
- with no activity. If you'd like to continue working on this, please feel free
- to reopen it or create a new PR.
- stale-pr-label: 'stale'
-
- # Exemptions
- exempt-issue-labels: 'pinned,security,good first issue,help wanted,enhancement'
- exempt-pr-labels: 'pinned,work-in-progress,blocked'
- exempt-all-milestones: true
-
- # Operations per run (to avoid rate limits)
- operations-per-run: 100
-
- # Remove stale label when updated
- remove-stale-when-updated: true
diff --git a/.gitignore b/.gitignore
index 3ed5f76d..b3192417 100644
--- a/.gitignore
+++ b/.gitignore
@@ -236,3 +236,5 @@ cython_debug/
/docs_version3/_build/
/docs_version3/_static/logos/
/docs_version3/changelog.md
+/examples_version2/dynamics_training/data/
+/docs/
diff --git a/README.md b/README.md
index 90128845..14d0ca11 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-
+
@@ -16,8 +16,9 @@
BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation. It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.
-- **Website (documentation and APIs)**: https://brainpy.readthedocs.io/
- **Source**: https://github.com/brainpy/BrainPy
+- **Documentation**: https://brainpy.readthedocs.io/
+- **Documentation**: https://brainpy-v2.readthedocs.io/
- **Bug reports**: https://github.com/brainpy/BrainPy/issues
- **Ecosystem**: https://brainmodeling.readthedocs.io/
diff --git a/docs/conf.py b/docs/conf.py
index 70a03338..37382e33 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -15,10 +15,7 @@
import shutil
import sys
-# ่ฆไฟ็็ๆไปถ/ๆไปถๅคนๅ่กจ
keep_files = {'highlight_test_lexer.py', 'conf.py', 'make.bat', 'Makefile'}
-
-# ้ๅๅฝๅ็ฎๅฝ
for item in os.listdir('.'):
if item not in keep_files:
path = os.path.join('.', item)
@@ -34,21 +31,19 @@
if build_version == 'v2':
shutil.copytree(
os.path.join(os.path.dirname(__file__), '../docs_version2'),
- os.path.join(os.path.dirname(__file__), ),
+ os.path.join(os.path.dirname(__file__)),
dirs_exist_ok=True
)
else:
shutil.copytree(
os.path.join(os.path.dirname(__file__), '../docs_version3'),
- os.path.join(os.path.dirname(__file__), ),
+ os.path.join(os.path.dirname(__file__)),
dirs_exist_ok=True
)
sys.path.insert(0, os.path.abspath('./'))
sys.path.insert(0, os.path.abspath('../'))
-import brainpy
-
shutil.copytree('../images/', './_static/logos/', dirs_exist_ok=True)
shutil.copyfile('../changelog.md', './changelog.md')
@@ -62,7 +57,8 @@
fix_ipython2_lexer_in_notebooks(os.path.dirname(os.path.abspath(__file__)))
-# The full version, including alpha/beta/rc tags
+import brainpy
+
release = brainpy.__version__
# -- General configuration ---------------------------------------------------
diff --git a/docs_version2/index.rst b/docs_version2/index.rst
index 2bd495d0..946947ea 100644
--- a/docs_version2/index.rst
+++ b/docs_version2/index.rst
@@ -127,7 +127,7 @@ Learn more
.. card:: :material-regular:`settings;2em` Examples
:class-card: sd-text-black sd-bg-light
- :link: https://brainpy-examples.readthedocs.io/
+ :link: https://brainpy-v2.readthedocs.io/projects/examples/
.. grid-item::
:columns: 6 6 6 4
diff --git a/docs_version3/api/index.rst b/docs_version3/api/index.rst
new file mode 100644
index 00000000..e23b9284
--- /dev/null
+++ b/docs_version3/api/index.rst
@@ -0,0 +1,204 @@
+API Reference
+=============
+
+Complete API reference for BrainPy 3.0.
+
+.. note::
+ BrainPy 3.0 is built on top of `brainstate `_,
+ `brainunit `_, and `braintools `_.
+
+Organization
+------------
+
+The API is organized into the following categories:
+
+.. grid:: 1 2 2 2
+
+ .. grid-item-card:: :material-regular:`psychology;2em` Neurons
+ :link: neurons.html
+
+ Spiking neuron models (LIF, ALIF, Izhikevich, etc.)
+
+ .. grid-item-card:: :material-regular:`timeline;2em` Synapses
+ :link: synapses.html
+
+ Synaptic dynamics (Expon, Alpha, AMPA, GABA, NMDA)
+
+ .. grid-item-card:: :material-regular:`account_tree;2em` Projections
+ :link: projections.html
+
+ Connect neural populations (AlignPostProj, AlignPreProj)
+
+ .. grid-item-card:: :material-regular:`hub;2em` Networks
+ :link: networks.html
+
+ Network building blocks and utilities
+
+ .. grid-item-card:: :material-regular:`school;2em` Training
+ :link: training.html
+
+ Gradient-based learning utilities
+
+ .. grid-item-card:: :material-regular:`input;2em` Input/Output
+ :link: input-output.html
+
+ Input generation and output processing
+
+Quick Reference
+---------------
+
+**Most commonly used classes:**
+
+Neurons
+~~~~~~~
+
+.. code-block:: python
+
+ import brainpy as bp
+
+ # Leaky Integrate-and-Fire
+ bp.LIF(size, V_rest, V_th, V_reset, tau, R, ...)
+
+ # Adaptive LIF
+ bp.ALIF(size, V_rest, V_th, V_reset, tau, tau_w, a, b, ...)
+
+ # Izhikevich
+ bp.Izhikevich(size, a, b, c, d, ...)
+
+Synapses
+~~~~~~~~
+
+.. code-block:: python
+
+ # Exponential
+ bp.Expon.desc(size, tau)
+
+ # Alpha
+ bp.Alpha.desc(size, tau)
+
+ # AMPA receptor
+ bp.AMPA.desc(size, tau)
+
+ # GABA_a receptor
+ bp.GABAa.desc(size, tau)
+
+Projections
+~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Standard projection
+ bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_pre, n_post, prob, weight),
+ syn=bp.Expon.desc(n_post, tau),
+ out=bp.COBA.desc(E),
+ post=post_neurons
+ )
+
+Networks
+~~~~~~~~
+
+.. code-block:: python
+
+ # Module base class
+ class MyNetwork(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ # ... define components
+
+ def update(self, x):
+ # ... network dynamics
+ return output
+
+Training
+~~~~~~~~
+
+.. code-block:: python
+
+ import braintools
+
+ # Optimizer
+ optimizer = braintools.optim.Adam(lr=1e-3)
+
+ # Gradients
+ grads = brainstate.transform.grad(loss_fn, params)(...)
+
+ # Update
+ optimizer.update(grads)
+
+Documentation Sections
+----------------------
+
+.. toctree::
+ :maxdepth: 2
+
+ neurons
+ synapses
+ projections
+ networks
+ training
+ input-output
+
+Import Structure
+----------------
+
+BrainPy uses a clear import hierarchy:
+
+.. code-block:: python
+
+ import brainpy as bp # Core BrainPy
+ import brainstate # State management and modules
+ import brainunit as u # Physical units
+ import braintools # Training utilities
+
+ # Neurons and synapses
+ neuron = bp.LIF(100, ...)
+ synapse = bp.Expon.desc(100, tau=5*u.ms)
+
+ # State management
+ state = brainstate.ShortTermState(...)
+ brainstate.nn.init_all_states(net)
+
+ # Units
+ current = 2.0 * u.nA
+ voltage = -65 * u.mV
+ time = 10 * u.ms
+
+ # Training
+ optimizer = braintools.optim.Adam(lr=1e-3)
+ loss = braintools.metric.softmax_cross_entropy(...)
+
+Type Conventions
+----------------
+
+**States:**
+
+- ``ShortTermState`` - Temporary dynamics (V, g, spikes)
+- ``ParamState`` - Learnable parameters (weights, biases)
+- ``LongTermState`` - Persistent statistics
+
+**Units:**
+
+All physical quantities use ``brainunit``:
+
+- Voltage: ``u.mV``
+- Current: ``u.nA``, ``u.pA``
+- Time: ``u.ms``, ``u.second``
+- Conductance: ``u.mS``, ``u.nS``
+- Concentration: ``u.mM``
+
+**Shapes:**
+
+- Single trial: ``(n_neurons,)``
+- Batched: ``(batch_size, n_neurons)``
+- Connectivity: ``(n_pre, n_post)``
+
+See Also
+--------
+
+**External Documentation:**
+
+- `BrainState Documentation `_ - State management
+- `BrainUnit Documentation `_ - Physical units
+- `BrainTools Documentation `_ - Training utilities
+- `JAX Documentation `_ - Underlying computation
diff --git a/docs_version3/api/input-output.rst b/docs_version3/api/input-output.rst
new file mode 100644
index 00000000..575f58ac
--- /dev/null
+++ b/docs_version3/api/input-output.rst
@@ -0,0 +1,288 @@
+Input and Output
+================
+
+Utilities for generating inputs and processing outputs.
+
+Input Encoding
+--------------
+
+Poisson Spike Trains
+~~~~~~~~~~~~~~~~~~~~
+
+Generate Poisson-distributed spikes:
+
+.. code-block:: python
+
+ def poisson_input(rates, dt):
+ """Generate Poisson spike train.
+
+ Args:
+ rates: Firing rates in Hz (array)
+ dt: Time step (Quantity[ms])
+
+ Returns:
+ Binary spike array
+ """
+ probs = rates * dt.to_decimal(u.second)
+ return (brainstate.random.rand(*rates.shape) < probs).astype(float)
+
+ # Usage
+ rates = jnp.ones(100) * 50 # 50 Hz
+ spikes = poisson_input(rates, dt=0.1*u.ms)
+
+Rate Coding
+~~~~~~~~~~~
+
+Encode values as firing rates:
+
+.. code-block:: python
+
+ def rate_encode(values, max_rate, dt):
+ """Encode values as spike rates.
+
+ Args:
+ values: Values to encode [0, 1]
+ max_rate: Maximum firing rate (Quantity[Hz])
+ dt: Time step (Quantity[ms])
+ """
+ rates = values * max_rate.to_decimal(u.Hz)
+ probs = rates * dt.to_decimal(u.second)
+ return (brainstate.random.rand(len(values)) < probs).astype(float)
+
+ # Usage
+ pixel_values = jnp.array([0.2, 0.8, 0.5, ...]) # Normalized pixels
+ spikes = rate_encode(pixel_values, max_rate=100*u.Hz, dt=0.1*u.ms)
+
+Population Coding
+~~~~~~~~~~~~~~~~~
+
+Encode with population of tuned neurons:
+
+.. code-block:: python
+
+ def population_encode(value, n_neurons, pref_values, sigma, max_rate, dt):
+ """Population coding with Gaussian tuning curves.
+
+ Args:
+ value: Value to encode (scalar)
+ n_neurons: Number of neurons in population
+ pref_values: Preferred values of neurons
+ sigma: Tuning width
+ max_rate: Maximum firing rate
+ dt: Time step
+ """
+ # Gaussian tuning curves
+ responses = jnp.exp(-0.5 * ((value - pref_values) / sigma)**2)
+ rates = responses * max_rate.to_decimal(u.Hz)
+ probs = rates * dt.to_decimal(u.second)
+ return (brainstate.random.rand(n_neurons) < probs).astype(float)
+
+ # Usage
+ pref_values = jnp.linspace(0, 1, 20)
+ spikes = population_encode(
+ value=0.5,
+ n_neurons=20,
+ pref_values=pref_values,
+ sigma=0.1,
+ max_rate=100*u.Hz,
+ dt=0.1*u.ms
+ )
+
+Temporal Contrast
+~~~~~~~~~~~~~~~~~
+
+Encode based on image gradients (event cameras):
+
+.. code-block:: python
+
+ def temporal_contrast_encode(image, prev_image, threshold=0.1, polarity=True):
+ """Encode based on temporal contrast.
+
+ Args:
+ image: Current image
+ prev_image: Previous image
+ threshold: Change threshold
+ polarity: If True, separate ON/OFF channels
+
+ Returns:
+ Spike events
+ """
+ diff = image - prev_image
+
+ if polarity:
+ on_spikes = (diff > threshold).astype(float)
+ off_spikes = (diff < -threshold).astype(float)
+ return on_spikes, off_spikes
+ else:
+ spikes = (jnp.abs(diff) > threshold).astype(float)
+ return spikes
+
+Output Decoding
+---------------
+
+Population Vector Decoding
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ def population_decode(spike_counts, pref_values):
+ """Decode value from population spikes.
+
+ Args:
+ spike_counts: Spike counts from population
+ pref_values: Preferred values of neurons
+
+ Returns:
+ Decoded value
+ """
+ total_activity = jnp.sum(spike_counts)
+ if total_activity > 0:
+ decoded = jnp.sum(spike_counts * pref_values) / total_activity
+ return decoded
+ return 0.0
+
+ # Usage
+ spike_counts = jnp.array([5, 12, 20, 15, 3, ...]) # From 20 neurons
+ pref_values = jnp.linspace(0, 1, 20)
+ decoded_value = population_decode(spike_counts, pref_values)
+
+Spike Count
+~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Count total spikes over time window
+ spike_count = jnp.sum(spike_history, axis=0)
+
+ # Firing rate (Hz)
+ duration = n_steps * dt.to_decimal(u.second)
+ firing_rate = spike_count / duration
+
+Readout Layer
+~~~~~~~~~~~~~
+
+Use ``bp.Readout`` for trainable spike-to-output conversion:
+
+.. code-block:: python
+
+ readout = bp.Readout(n_neurons, n_outputs)
+
+ # Accumulate over time
+ def run_and_readout(net, inputs, n_steps):
+ brainstate.nn.init_all_states(net)
+
+ outputs = []
+ for t in range(n_steps):
+ net(inputs)
+ spikes = net.get_spike()
+ output = readout(spikes)
+ outputs.append(output)
+
+ # Sum over time for classification
+ logits = jnp.sum(jnp.array(outputs), axis=0)
+ return logits
+
+State Recording
+---------------
+
+Record Activity
+~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Record states during simulation
+ V_history = []
+ spike_history = []
+
+ for t in range(n_steps):
+ neuron(input_current)
+
+ V_history.append(neuron.V.value.copy())
+ spike_history.append(neuron.spike.value.copy())
+
+ V_history = jnp.array(V_history)
+ spike_history = jnp.array(spike_history)
+
+Spike Raster Plot
+~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import matplotlib.pyplot as plt
+
+ # Get spike times and neuron indices
+ times, neurons = jnp.where(spike_history > 0)
+
+ # Plot
+ plt.figure(figsize=(12, 6))
+ plt.scatter(times * 0.1, neurons, s=1, c='black')
+ plt.xlabel('Time (ms)')
+ plt.ylabel('Neuron index')
+ plt.title('Spike Raster')
+ plt.show()
+
+Firing Rate Over Time
+~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Population firing rate
+ firing_rate = jnp.mean(spike_history, axis=1) * (1000 / dt.to_decimal(u.ms))
+
+ plt.figure(figsize=(12, 4))
+ plt.plot(times, firing_rate)
+ plt.xlabel('Time (ms)')
+ plt.ylabel('Population Rate (Hz)')
+ plt.show()
+
+Complete Example
+----------------
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+ import jax.numpy as jnp
+
+ # Setup
+ n_input = 784 # MNIST pixels
+ n_hidden = 100
+ n_output = 10
+ dt = 0.1 * u.ms
+ brainstate.environ.set(dt=dt)
+
+ # Network
+ class EncoderDecoderSNN(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.hidden = bp.LIF(n_hidden, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ self.readout = bp.Readout(n_hidden, n_output)
+
+ def update(self, x):
+ self.hidden(x)
+ return self.readout(self.hidden.get_spike())
+
+ net = EncoderDecoderSNN()
+ brainstate.nn.init_all_states(net)
+
+ # Input encoding (rate coding)
+ image = jnp.random.rand(784) # Normalized image
+ encoded = rate_encode(image, max_rate=100*u.Hz, dt=dt) * 2.0 * u.nA
+
+ # Simulate
+ outputs = []
+ for t in range(100):
+ output = net(encoded)
+ outputs.append(output)
+
+ # Output decoding
+ logits = jnp.sum(jnp.array(outputs), axis=0)
+ prediction = jnp.argmax(logits)
+
+See Also
+--------
+
+- :doc:`../tutorials/basic/04-input-output` - Input/output tutorial
+- :doc:`../tutorials/advanced/05-snn-training` - Training with encoded inputs
+- :doc:`neurons` - Neuron models
diff --git a/docs_version3/api/networks.rst b/docs_version3/api/networks.rst
new file mode 100644
index 00000000..97d45b57
--- /dev/null
+++ b/docs_version3/api/networks.rst
@@ -0,0 +1,122 @@
+Network Components
+==================
+
+Building blocks for neural networks.
+
+.. currentmodule:: brainstate.nn
+
+Module System
+-------------
+
+Module
+~~~~~~
+
+.. class:: Module
+
+ Base class for all network components.
+
+ **Key Methods:**
+
+ .. method:: update(*args, **kwargs)
+
+ Forward pass / simulation step.
+
+ .. method:: reset_state(batch_size=None)
+
+ Reset component state.
+
+ .. method:: states(state_type=None)
+
+ Get states of specific type.
+
+ :param state_type: ParamState, ShortTermState, or LongTermState
+ :returns: Dictionary of states
+
+ **Example:**
+
+ .. code-block:: python
+
+ class MyNetwork(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.neurons = bp.LIF(100, ...)
+ self.weights = brainstate.ParamState(jnp.ones((100, 100)))
+
+ def update(self, x):
+ self.neurons(x)
+ return self.neurons.get_spike()
+
+State Initialization
+--------------------
+
+init_all_states
+~~~~~~~~~~~~~~~
+
+.. function:: init_all_states(module, batch_size=None)
+
+ Initialize all states in a module hierarchy.
+
+ **Parameters:**
+
+ - ``module`` - Network module
+ - ``batch_size`` (int or None) - Optional batch dimension
+
+ **Example:**
+
+ .. code-block:: python
+
+ net = MyNetwork()
+ brainstate.nn.init_all_states(net) # Single trial
+ brainstate.nn.init_all_states(net, batch_size=32) # Batched
+
+Readout Layers
+--------------
+
+Readout
+~~~~~~~
+
+.. class:: Readout(in_size, out_size)
+
+ Convert spikes to continuous outputs.
+
+ **Example:**
+
+ .. code-block:: python
+
+ readout = bp.Readout(n_hidden, n_output)
+
+ # Usage
+ spikes = hidden_neurons.get_spike()
+ logits = readout(spikes)
+
+Linear
+~~~~~~
+
+.. class:: Linear(in_size, out_size, w_init=None, b_init=None)
+
+ Fully connected linear layer.
+
+ **Parameters:**
+
+ - ``in_size`` (int) - Input dimension
+ - ``out_size`` (int) - Output dimension
+ - ``w_init`` - Weight initializer
+ - ``b_init`` - Bias initializer (None for no bias)
+
+ **Example:**
+
+ .. code-block:: python
+
+ fc = brainstate.nn.Linear(
+ 100, 50,
+ w_init=brainstate.init.KaimingNormal()
+ )
+
+ output = fc(input_data)
+
+See Also
+--------
+
+- :doc:`../core-concepts/architecture` - Architecture overview
+- :doc:`../core-concepts/state-management` - State system
+- :doc:`../tutorials/basic/03-network-connections` - Network tutorial
diff --git a/docs_version3/api/neurons.rst b/docs_version3/api/neurons.rst
new file mode 100644
index 00000000..deceae0f
--- /dev/null
+++ b/docs_version3/api/neurons.rst
@@ -0,0 +1,439 @@
+Neuron Models
+=============
+
+Spiking neuron models in BrainPy.
+
+.. currentmodule:: brainpy
+
+Base Class
+----------
+
+Neuron
+~~~~~~
+
+.. class:: Neuron(size, **kwargs)
+
+ Base class for all neuron models.
+
+ All neuron models inherit from this class and implement the ``update()`` method
+ for their specific dynamics.
+
+ **Parameters:**
+
+ - ``size`` (int) - Number of neurons in the population
+ - ``**kwargs`` - Additional keyword arguments
+
+ **Key Methods:**
+
+ .. method:: update(x)
+
+ Update neuron dynamics for one time step.
+
+ :param x: Input current with units (e.g., ``2.0 * u.nA``)
+ :type x: Array with brainunit
+ :returns: Updated state (typically membrane potential)
+
+ .. method:: get_spike()
+
+ Get current spike state.
+
+ :returns: Binary spike indicator (1 = spike, 0 = no spike)
+ :rtype: Array of shape ``(size,)`` or ``(batch_size, size)``
+
+ .. method:: reset_state(batch_size=None)
+
+ Reset neuron state for new trial.
+
+ :param batch_size: Optional batch dimension
+ :type batch_size: int or None
+
+ **Common Attributes:**
+
+ - ``V`` (ShortTermState) - Membrane potential
+ - ``spike`` (ShortTermState) - Spike indicator
+ - ``size`` (int) - Number of neurons
+
+ **Example:**
+
+ .. code-block:: python
+
+ # Subclass to create custom neuron
+ class CustomNeuron(bp.Neuron):
+ def __init__(self, size, tau=10*u.ms):
+ super().__init__(size)
+ self.tau = tau
+ self.V = brainstate.ShortTermState(jnp.zeros(size))
+
+ def update(self, x):
+ # Custom dynamics
+ pass
+
+Integrate-and-Fire Models
+--------------------------
+
+IF
+~~
+
+.. class:: IF(size, V_rest=-65*u.mV, V_th=-50*u.mV, V_reset=-65*u.mV, **kwargs)
+
+ Basic Integrate-and-Fire neuron.
+
+ **Model:**
+
+ .. math::
+
+ \\frac{dV}{dt} = I_{ext}
+
+ Spikes when :math:`V \\geq V_{th}`, then resets to :math:`V_{reset}`.
+
+ **Parameters:**
+
+ - ``size`` (int) - Number of neurons
+ - ``V_rest`` (Quantity[mV]) - Resting potential (default: -65 mV)
+ - ``V_th`` (Quantity[mV]) - Spike threshold (default: -50 mV)
+ - ``V_reset`` (Quantity[mV]) - Reset potential (default: -65 mV)
+
+ **States:**
+
+ - ``V`` (ShortTermState) - Membrane potential [mV]
+ - ``spike`` (ShortTermState) - Spike indicator [0 or 1]
+
+ **Example:**
+
+ .. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+
+ neuron = bp.IF(100, V_rest=-65*u.mV, V_th=-50*u.mV)
+ brainstate.nn.init_all_states(neuron)
+
+ # Simulate
+ for t in range(1000):
+ inp = brainstate.random.rand(100) * 2.0 * u.nA
+ neuron(inp)
+ spikes = neuron.get_spike()
+
+LIF
+~~~
+
+.. class:: LIF(size, V_rest=-65*u.mV, V_th=-50*u.mV, V_reset=-65*u.mV, tau=10*u.ms, R=1*u.ohm, **kwargs)
+
+ Leaky Integrate-and-Fire neuron.
+
+ **Model:**
+
+ .. math::
+
+ \\tau \\frac{dV}{dt} = -(V - V_{rest}) + R I_{ext}
+
+ Most commonly used spiking neuron model.
+
+ **Parameters:**
+
+ - ``size`` (int) - Number of neurons
+ - ``V_rest`` (Quantity[mV]) - Resting potential (default: -65 mV)
+ - ``V_th`` (Quantity[mV]) - Spike threshold (default: -50 mV)
+ - ``V_reset`` (Quantity[mV]) - Reset potential (default: -65 mV)
+ - ``tau`` (Quantity[ms]) - Membrane time constant (default: 10 ms)
+ - ``R`` (Quantity[ohm]) - Input resistance (default: 1 ฮฉ)
+
+ **States:**
+
+ - ``V`` (ShortTermState) - Membrane potential [mV]
+ - ``spike`` (ShortTermState) - Spike indicator [0 or 1]
+
+ **Example:**
+
+ .. code-block:: python
+
+ # Standard LIF
+ neuron = bp.LIF(
+ size=100,
+ V_rest=-65*u.mV,
+ V_th=-50*u.mV,
+ V_reset=-65*u.mV,
+ tau=10*u.ms
+ )
+
+ brainstate.nn.init_all_states(neuron)
+
+ # Compute F-I curve
+ currents = u.math.linspace(0, 5, 20) * u.nA
+ rates = []
+
+ for I in currents:
+ brainstate.nn.init_all_states(neuron)
+ spike_count = 0
+ for _ in range(1000):
+ neuron(jnp.ones(100) * I)
+ spike_count += jnp.sum(neuron.get_spike())
+ rate = spike_count / (1000 * 0.1 * 1e-3) / 100 # Hz
+ rates.append(rate)
+
+ **See Also:**
+
+ - :doc:`../core-concepts/neurons` - Detailed LIF guide
+ - :doc:`../tutorials/basic/01-lif-neuron` - LIF tutorial
+
+LIFRef
+~~~~~~
+
+.. class:: LIFRef(size, V_rest=-65*u.mV, V_th=-50*u.mV, V_reset=-65*u.mV, tau=10*u.ms, tau_ref=2*u.ms, **kwargs)
+
+ LIF with refractory period.
+
+ **Model:**
+
+ .. math::
+
+ \\tau \\frac{dV}{dt} = -(V - V_{rest}) + R I_{ext} \\quad \\text{(if not refractory)}
+
+ After spike, neuron is unresponsive for ``tau_ref`` milliseconds.
+
+ **Parameters:**
+
+ - ``size`` (int) - Number of neurons
+ - ``V_rest`` (Quantity[mV]) - Resting potential (default: -65 mV)
+ - ``V_th`` (Quantity[mV]) - Spike threshold (default: -50 mV)
+ - ``V_reset`` (Quantity[mV]) - Reset potential (default: -65 mV)
+ - ``tau`` (Quantity[ms]) - Membrane time constant (default: 10 ms)
+ - ``tau_ref`` (Quantity[ms]) - Refractory period (default: 2 ms)
+
+ **States:**
+
+ - ``V`` (ShortTermState) - Membrane potential [mV]
+ - ``spike`` (ShortTermState) - Spike indicator [0 or 1]
+ - ``t_last_spike`` (ShortTermState) - Time since last spike [ms]
+
+ **Example:**
+
+ .. code-block:: python
+
+ neuron = bp.LIFRef(
+ size=100,
+ tau=10*u.ms,
+ tau_ref=2*u.ms # 2ms refractory period
+ )
+
+Adaptive Models
+---------------
+
+ALIF
+~~~~
+
+.. class:: ALIF(size, V_rest=-65*u.mV, V_th=-50*u.mV, V_reset=-65*u.mV, tau=10*u.ms, tau_w=100*u.ms, a=0*u.nA, b=0.5*u.nA, **kwargs)
+
+ Adaptive Leaky Integrate-and-Fire neuron.
+
+ **Model:**
+
+ .. math::
+
+ \\tau \\frac{dV}{dt} &= -(V - V_{rest}) + R I_{ext} - R w \\\\
+ \\tau_w \\frac{dw}{dt} &= a(V - V_{rest}) - w
+
+ After spike: :math:`w \\rightarrow w + b`
+
+ Implements spike-frequency adaptation through adaptation current :math:`w`.
+
+ **Parameters:**
+
+ - ``size`` (int) - Number of neurons
+ - ``V_rest`` (Quantity[mV]) - Resting potential (default: -65 mV)
+ - ``V_th`` (Quantity[mV]) - Spike threshold (default: -50 mV)
+ - ``V_reset`` (Quantity[mV]) - Reset potential (default: -65 mV)
+ - ``tau`` (Quantity[ms]) - Membrane time constant (default: 10 ms)
+ - ``tau_w`` (Quantity[ms]) - Adaptation time constant (default: 100 ms)
+ - ``a`` (Quantity[nA]) - Subthreshold adaptation (default: 0 nA)
+ - ``b`` (Quantity[nA]) - Spike-triggered adaptation (default: 0.5 nA)
+
+ **States:**
+
+ - ``V`` (ShortTermState) - Membrane potential [mV]
+ - ``w`` (ShortTermState) - Adaptation current [nA]
+ - ``spike`` (ShortTermState) - Spike indicator [0 or 1]
+
+ **Example:**
+
+ .. code-block:: python
+
+ # Adapting neuron
+ neuron = bp.ALIF(
+ size=100,
+ tau=10*u.ms,
+ tau_w=100*u.ms, # Slow adaptation
+ a=0.1*u.nA, # Subthreshold coupling
+ b=0.5*u.nA # Spike-triggered jump
+ )
+
+ # Constant input โ decreasing firing rate
+ brainstate.nn.init_all_states(neuron)
+ rates = []
+
+ for t in range(2000):
+ neuron(jnp.ones(100) * 5.0 * u.nA)
+ if t % 100 == 0:
+ rate = jnp.mean(neuron.get_spike())
+ rates.append(rate)
+ # rates will decrease over time due to adaptation
+
+Izhikevich
+~~~~~~~~~~
+
+.. class:: Izhikevich(size, a=0.02, b=0.2, c=-65*u.mV, d=8*u.mV, **kwargs)
+
+ Izhikevich neuron model.
+
+ **Model:**
+
+ .. math::
+
+ \\frac{dV}{dt} &= 0.04 V^2 + 5V + 140 - u + I \\\\
+ \\frac{du}{dt} &= a(bV - u)
+
+ If :math:`V \\geq 30`, then :math:`V \\rightarrow c, u \\rightarrow u + d`
+
+ Can reproduce many different firing patterns by varying parameters.
+
+ **Parameters:**
+
+ - ``size`` (int) - Number of neurons
+ - ``a`` (float) - Time scale of recovery variable (default: 0.02)
+ - ``b`` (float) - Sensitivity of recovery to V (default: 0.2)
+ - ``c`` (Quantity[mV]) - After-spike reset value of V (default: -65 mV)
+ - ``d`` (Quantity[mV]) - After-spike increment of u (default: 8 mV)
+
+ **States:**
+
+ - ``V`` (ShortTermState) - Membrane potential [mV]
+ - ``u`` (ShortTermState) - Recovery variable [mV]
+ - ``spike`` (ShortTermState) - Spike indicator [0 or 1]
+
+ **Common Parameter Sets:**
+
+ .. code-block:: python
+
+ # Regular spiking
+ neuron_rs = bp.Izhikevich(100, a=0.02, b=0.2, c=-65*u.mV, d=8*u.mV)
+
+ # Intrinsically bursting
+ neuron_ib = bp.Izhikevich(100, a=0.02, b=0.2, c=-55*u.mV, d=4*u.mV)
+
+ # Chattering
+ neuron_ch = bp.Izhikevich(100, a=0.02, b=0.2, c=-50*u.mV, d=2*u.mV)
+
+ # Fast spiking
+ neuron_fs = bp.Izhikevich(100, a=0.1, b=0.2, c=-65*u.mV, d=2*u.mV)
+
+ **Example:**
+
+ .. code-block:: python
+
+ neuron = bp.Izhikevich(100, a=0.02, b=0.2, c=-65*u.mV, d=8*u.mV)
+ brainstate.nn.init_all_states(neuron)
+
+ for t in range(1000):
+ inp = brainstate.random.rand(100) * 15.0 * u.nA
+ neuron(inp)
+
+Exponential Models
+------------------
+
+ExpIF
+~~~~~
+
+.. class:: ExpIF(size, V_rest=-65*u.mV, V_th=-50*u.mV, V_reset=-65*u.mV, tau=10*u.ms, delta_T=2*u.mV, **kwargs)
+
+ Exponential Integrate-and-Fire neuron.
+
+ **Model:**
+
+ .. math::
+
+ \\tau \\frac{dV}{dt} = -(V - V_{rest}) + \\Delta_T e^{\\frac{V - V_{th}}{\\Delta_T}} + R I_{ext}
+
+ Features exponential spike generation.
+
+ **Parameters:**
+
+ - ``size`` (int) - Number of neurons
+ - ``V_rest`` (Quantity[mV]) - Resting potential (default: -65 mV)
+ - ``V_th`` (Quantity[mV]) - Spike threshold (default: -50 mV)
+ - ``V_reset`` (Quantity[mV]) - Reset potential (default: -65 mV)
+ - ``tau`` (Quantity[ms]) - Membrane time constant (default: 10 ms)
+ - ``delta_T`` (Quantity[mV]) - Spike slope factor (default: 2 mV)
+
+AdExIF
+~~~~~~
+
+.. class:: AdExIF(size, V_rest=-65*u.mV, V_th=-50*u.mV, V_reset=-65*u.mV, tau=10*u.ms, tau_w=100*u.ms, delta_T=2*u.mV, a=0*u.nA, b=0.5*u.nA, **kwargs)
+
+ Adaptive Exponential Integrate-and-Fire neuron.
+
+ Combines exponential spike generation with adaptation.
+
+ **Parameters:**
+
+ Similar to ExpIF plus ALIF adaptation parameters (``tau_w``, ``a``, ``b``).
+
+Usage Patterns
+--------------
+
+**Creating Neuron Populations:**
+
+.. code-block:: python
+
+ # Single population
+ neurons = bp.LIF(1000, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+
+ # Multiple populations with different parameters
+ E_neurons = bp.LIF(800, tau=15*u.ms) # Excitatory: slower
+ I_neurons = bp.LIF(200, tau=10*u.ms) # Inhibitory: faster
+
+**Batched Simulation:**
+
+.. code-block:: python
+
+ neuron = bp.LIF(100, ...)
+ brainstate.nn.init_all_states(neuron, batch_size=32)
+
+ # Input shape: (32, 100)
+ inp = brainstate.random.rand(32, 100) * 2.0 * u.nA
+ neuron(inp)
+
+ # Output shape: (32, 100)
+ spikes = neuron.get_spike()
+
+**Custom Neurons:**
+
+.. code-block:: python
+
+ class CustomLIF(bp.Neuron):
+ def __init__(self, size, tau=10*u.ms):
+ super().__init__(size)
+ self.tau = tau
+ self.V = brainstate.ShortTermState(jnp.zeros(size))
+ self.spike = brainstate.ShortTermState(jnp.zeros(size))
+
+ def reset_state(self, batch_size=None):
+ shape = self.size if batch_size is None else (batch_size, self.size)
+ self.V.value = jnp.zeros(shape)
+ self.spike.value = jnp.zeros(shape)
+
+ def update(self, I):
+ # Custom dynamics
+ pass
+
+ def get_spike(self):
+ return self.spike.value
+
+See Also
+--------
+
+- :doc:`../core-concepts/neurons` - Detailed neuron model guide
+- :doc:`../tutorials/basic/01-lif-neuron` - LIF neuron tutorial
+- :doc:`../how-to-guides/custom-components` - Creating custom neurons
+- :doc:`synapses` - Synaptic models
+- :doc:`projections` - Connecting neurons
diff --git a/docs_version3/api/projections.rst b/docs_version3/api/projections.rst
new file mode 100644
index 00000000..0477c77f
--- /dev/null
+++ b/docs_version3/api/projections.rst
@@ -0,0 +1,157 @@
+Projections
+===========
+
+Connect neural populations with the Comm-Syn-Out architecture.
+
+.. currentmodule:: brainpy
+
+Projection Classes
+------------------
+
+AlignPostProj
+~~~~~~~~~~~~~
+
+.. class:: AlignPostProj(comm, syn, out, post, **kwargs)
+
+ Standard projection aligning synaptic states with postsynaptic neurons.
+
+ **Parameters:**
+
+ - ``comm`` - Communication layer (connectivity)
+ - ``syn`` - Synapse dynamics
+ - ``out`` - Output computation
+ - ``post`` - Postsynaptic neuron population
+
+ **Example:**
+
+ .. code-block:: python
+
+ proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(100, 50, prob=0.1, weight=0.5*u.mS),
+ syn=bp.Expon.desc(50, tau=5*u.ms),
+ out=bp.COBA.desc(E=0*u.mV),
+ post=post_neurons
+ )
+
+ # Usage
+ pre_spikes = pre_neurons.get_spike()
+ proj(pre_spikes)
+
+AlignPreProj
+~~~~~~~~~~~~
+
+.. class:: AlignPreProj(comm, syn, out, post, **kwargs)
+
+ Projection aligning synaptic states with presynaptic neurons.
+
+ Used for certain learning rules that require presynaptic alignment.
+
+Communication Layers
+--------------------
+
+From ``brainstate.nn``:
+
+EventFixedProb
+~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ comm = brainstate.nn.EventFixedProb(
+ pre_size,
+ post_size,
+ prob=0.1, # Connection probability
+ weight=0.5*u.mS # Synaptic weight
+ )
+
+Sparse connectivity with fixed connection probability.
+
+EventAll2All
+~~~~~~~~~~~~
+
+.. code-block:: python
+
+ comm = brainstate.nn.EventAll2All(
+ pre_size,
+ post_size,
+ weight=0.5*u.mS
+ )
+
+All-to-all connectivity (event-driven).
+
+EventOne2One
+~~~~~~~~~~~~
+
+.. code-block:: python
+
+ comm = brainstate.nn.EventOne2One(
+ size,
+ weight=0.5*u.mS
+ )
+
+One-to-one connections (same size populations).
+
+Linear
+~~~~~~
+
+.. code-block:: python
+
+ comm = brainstate.nn.Linear(
+ in_size,
+ out_size,
+ w_init=brainstate.init.KaimingNormal()
+ )
+
+Dense linear transformation (for small networks).
+
+Complete Examples
+-----------------
+
+**E โ E Excitatory:**
+
+.. code-block:: python
+
+ E2E = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_exc, n_exc, prob=0.02, weight=0.6*u.mS),
+ syn=bp.AMPA.desc(n_exc, tau=2*u.ms),
+ out=bp.COBA.desc(E=0*u.mV),
+ post=E_neurons
+ )
+
+**I โ E Inhibitory:**
+
+.. code-block:: python
+
+ I2E = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_inh, n_exc, prob=0.02, weight=6.7*u.mS),
+ syn=bp.GABAa.desc(n_exc, tau=6*u.ms),
+ out=bp.COBA.desc(E=-80*u.mV),
+ post=E_neurons
+ )
+
+**Multi-timescale (AMPA + NMDA):**
+
+.. code-block:: python
+
+ # Fast AMPA
+ ampa_proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_pre, n_post, prob=0.1, weight=0.3*u.mS),
+ syn=bp.AMPA.desc(n_post, tau=2*u.ms),
+ out=bp.COBA.desc(E=0*u.mV),
+ post=post_neurons
+ )
+
+ # Slow NMDA
+ nmda_proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_pre, n_post, prob=0.1, weight=0.3*u.mS),
+ syn=bp.NMDA.desc(n_post, tau_decay=100*u.ms),
+ out=bp.MgBlock.desc(E=0*u.mV),
+ post=post_neurons
+ )
+
+See Also
+--------
+
+- :doc:`../core-concepts/projections` - Complete projection guide
+- :doc:`../tutorials/basic/03-network-connections` - Network tutorial
+- :doc:`neurons` - Neuron models
+- :doc:`synapses` - Synapse models
diff --git a/docs_version3/api/synapses.rst b/docs_version3/api/synapses.rst
new file mode 100644
index 00000000..181971d4
--- /dev/null
+++ b/docs_version3/api/synapses.rst
@@ -0,0 +1,352 @@
+Synapse Models
+==============
+
+Synaptic dynamics models in BrainPy.
+
+.. currentmodule:: brainpy
+
+Base Class
+----------
+
+Synapse
+~~~~~~~
+
+.. class:: Synapse(size, **kwargs)
+
+ Base class for all synapse models.
+
+ **Parameters:**
+
+ - ``size`` (int) - Number of post-synaptic neurons
+ - ``**kwargs`` - Additional keyword arguments
+
+ **Key Methods:**
+
+ .. method:: update(x)
+
+ Update synaptic dynamics.
+
+ :param x: Pre-synaptic input (typically spike indicator)
+ :returns: Synaptic conductance/current
+
+ .. method:: reset_state(batch_size=None)
+
+ Reset synaptic state.
+
+ **Descriptor Pattern:**
+
+ Synapses use the ``.desc()`` class method for use in projections:
+
+ .. code-block:: python
+
+ syn = bp.Expon.desc(size=100, tau=5*u.ms)
+
+Simple Synapses
+---------------
+
+Delta
+~~~~~
+
+.. class:: Delta
+
+ Instantaneous synaptic transmission (no dynamics).
+
+ .. math::
+
+ g(t) = \\sum_k \\delta(t - t_k)
+
+ **Usage:**
+
+ .. code-block:: python
+
+ syn = bp.Delta.desc(100)
+
+Expon
+~~~~~
+
+.. class:: Expon
+
+ Exponential synapse (single time constant).
+
+ .. math::
+
+ \\tau \\frac{dg}{dt} = -g + \\sum_k \\delta(t - t_k)
+
+ **Parameters:**
+
+ - ``size`` (int) - Population size
+ - ``tau`` (Quantity[ms]) - Time constant (default: 5 ms)
+
+ **Example:**
+
+ .. code-block:: python
+
+ syn = bp.Expon.desc(size=100, tau=5*u.ms)
+
+Alpha
+~~~~~
+
+.. class:: Alpha
+
+ Alpha function synapse (rise + decay).
+
+ .. math::
+
+ \\tau \\frac{dg}{dt} &= -g + h \\\\
+ \\tau \\frac{dh}{dt} &= -h + \\sum_k \\delta(t - t_k)
+
+ **Parameters:**
+
+ - ``size`` (int) - Population size
+ - ``tau`` (Quantity[ms]) - Characteristic time (default: 10 ms)
+
+ **Example:**
+
+ .. code-block:: python
+
+ syn = bp.Alpha.desc(size=100, tau=10*u.ms)
+
+DualExponential
+~~~~~~~~~~~~~~~
+
+.. class:: DualExponential
+
+ Biexponential synapse with separate rise/decay.
+
+ **Parameters:**
+
+ - ``size`` (int) - Population size
+ - ``tau_rise`` (Quantity[ms]) - Rise time constant
+ - ``tau_decay`` (Quantity[ms]) - Decay time constant
+
+Receptor Models
+---------------
+
+AMPA
+~~~~
+
+.. class:: AMPA
+
+ AMPA receptor (fast excitatory).
+
+ **Parameters:**
+
+ - ``size`` (int) - Population size
+ - ``tau`` (Quantity[ms]) - Time constant (default: 2 ms)
+
+ **Example:**
+
+ .. code-block:: python
+
+ syn = bp.AMPA.desc(size=100, tau=2*u.ms)
+
+NMDA
+~~~~
+
+.. class:: NMDA
+
+ NMDA receptor (slow excitatory, voltage-dependent).
+
+ **Parameters:**
+
+ - ``size`` (int) - Population size
+ - ``tau_rise`` (Quantity[ms]) - Rise time (default: 2 ms)
+ - ``tau_decay`` (Quantity[ms]) - Decay time (default: 100 ms)
+ - ``a`` (Quantity[1/mM]) - Mgยฒโบ sensitivity (default: 0.5/mM)
+ - ``cc_Mg`` (Quantity[mM]) - Mgยฒโบ concentration (default: 1.2 mM)
+
+ **Example:**
+
+ .. code-block:: python
+
+ syn = bp.NMDA.desc(
+ size=100,
+ tau_rise=2*u.ms,
+ tau_decay=100*u.ms
+ )
+
+GABAa
+~~~~~
+
+.. class:: GABAa
+
+ GABA_A receptor (fast inhibitory).
+
+ **Parameters:**
+
+ - ``size`` (int) - Population size
+ - ``tau`` (Quantity[ms]) - Time constant (default: 6 ms)
+
+ **Example:**
+
+ .. code-block:: python
+
+ syn = bp.GABAa.desc(size=100, tau=6*u.ms)
+
+GABAb
+~~~~~
+
+.. class:: GABAb
+
+ GABA_B receptor (slow inhibitory).
+
+ **Parameters:**
+
+ - ``size`` (int) - Population size
+ - ``tau_rise`` (Quantity[ms]) - Rise time (default: 3.5 ms)
+ - ``tau_decay`` (Quantity[ms]) - Decay time (default: 150 ms)
+
+Short-Term Plasticity
+---------------------
+
+STD
+~~~
+
+.. class:: STD
+
+ Short-term depression.
+
+ **Parameters:**
+
+ - ``size`` (int) - Population size
+ - ``tau`` (Quantity[ms]) - Synaptic time constant
+ - ``tau_d`` (Quantity[ms]) - Depression recovery time
+ - ``U`` (float) - Utilization fraction (0-1)
+
+STF
+~~~
+
+.. class:: STF
+
+ Short-term facilitation.
+
+ **Parameters:**
+
+ - ``size`` (int) - Population size
+ - ``tau`` (Quantity[ms]) - Synaptic time constant
+ - ``tau_f`` (Quantity[ms]) - Facilitation time constant
+ - ``U`` (float) - Baseline utilization
+
+STP
+~~~
+
+.. class:: STP
+
+ Combined short-term plasticity (depression + facilitation).
+
+ **Parameters:**
+
+ - ``size`` (int) - Population size
+ - ``tau`` (Quantity[ms]) - Synaptic time constant
+ - ``tau_d`` (Quantity[ms]) - Depression time constant
+ - ``tau_f`` (Quantity[ms]) - Facilitation time constant
+ - ``U`` (float) - Baseline utilization
+
+Output Models
+-------------
+
+CUBA
+~~~~
+
+.. class:: CUBA
+
+ Current-based synaptic output.
+
+ .. math::
+
+ I_{syn} = g_{syn}
+
+ **Usage:**
+
+ .. code-block:: python
+
+ out = bp.CUBA.desc()
+
+COBA
+~~~~
+
+.. class:: COBA
+
+ Conductance-based synaptic output.
+
+ .. math::
+
+ I_{syn} = g_{syn} (E_{syn} - V_{post})
+
+ **Parameters:**
+
+ - ``E`` (Quantity[mV]) - Reversal potential
+
+ **Example:**
+
+ .. code-block:: python
+
+ # Excitatory
+ out_exc = bp.COBA.desc(E=0*u.mV)
+
+ # Inhibitory
+ out_inh = bp.COBA.desc(E=-80*u.mV)
+
+MgBlock
+~~~~~~~
+
+.. class:: MgBlock
+
+ Voltage-dependent magnesium block (for NMDA).
+
+ **Parameters:**
+
+ - ``E`` (Quantity[mV]) - Reversal potential
+ - ``cc_Mg`` (Quantity[mM]) - Mgยฒโบ concentration
+ - ``alpha`` (Quantity[1/mV]) - Voltage sensitivity
+ - ``beta`` (float) - Voltage offset
+
+Usage in Projections
+---------------------
+
+**Standard pattern:**
+
+.. code-block:: python
+
+ proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_pre, n_post, prob=0.1, weight=0.5*u.mS),
+ syn=bp.Expon.desc(n_post, tau=5*u.ms), # Synapse
+ out=bp.COBA.desc(E=0*u.mV), # Output
+ post=post_neurons
+ )
+
+**Receptor-specific:**
+
+.. code-block:: python
+
+ # Fast excitation (AMPA)
+ ampa_proj = bp.AlignPostProj(
+ comm=...,
+ syn=bp.AMPA.desc(n_post, tau=2*u.ms),
+ out=bp.COBA.desc(E=0*u.mV),
+ post=post_neurons
+ )
+
+ # Slow excitation (NMDA)
+ nmda_proj = bp.AlignPostProj(
+ comm=...,
+ syn=bp.NMDA.desc(n_post, tau_decay=100*u.ms),
+ out=bp.MgBlock.desc(E=0*u.mV),
+ post=post_neurons
+ )
+
+ # Fast inhibition (GABA_A)
+ gaba_proj = bp.AlignPostProj(
+ comm=...,
+ syn=bp.GABAa.desc(n_post, tau=6*u.ms),
+ out=bp.COBA.desc(E=-80*u.mV),
+ post=post_neurons
+ )
+
+See Also
+--------
+
+- :doc:`../core-concepts/synapses` - Detailed synapse guide
+- :doc:`../tutorials/basic/02-synapse-models` - Synapse tutorial
+- :doc:`../tutorials/advanced/06-synaptic-plasticity` - Plasticity tutorial
+- :doc:`projections` - Projection API
diff --git a/docs_version3/api/training.rst b/docs_version3/api/training.rst
new file mode 100644
index 00000000..94e8619f
--- /dev/null
+++ b/docs_version3/api/training.rst
@@ -0,0 +1,221 @@
+Training Utilities
+==================
+
+Tools for training spiking neural networks.
+
+.. currentmodule:: braintools
+
+Optimizers
+----------
+
+From ``braintools.optim``:
+
+Adam
+~~~~
+
+.. class:: optim.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-8)
+
+ Adam optimizer.
+
+ **Parameters:**
+
+ - ``learning_rate`` (float) - Learning rate
+ - ``beta1`` (float) - First moment decay
+ - ``beta2`` (float) - Second moment decay
+ - ``eps`` (float) - Numerical stability
+
+ **Methods:**
+
+ .. method:: register_trainable_weights(params)
+
+ Register parameters to optimize.
+
+ .. method:: update(grads)
+
+ Update parameters with gradients.
+
+ **Example:**
+
+ .. code-block:: python
+
+ optimizer = braintools.optim.Adam(learning_rate=1e-3)
+ params = net.states(brainstate.ParamState)
+ optimizer.register_trainable_weights(params)
+
+ # Training loop
+ grads = brainstate.transform.grad(loss_fn, params)(...)
+ optimizer.update(grads)
+
+SGD
+~~~
+
+.. class:: optim.SGD(learning_rate=0.01, momentum=0.0)
+
+ Stochastic gradient descent with momentum.
+
+RMSprop
+~~~~~~~
+
+.. class:: optim.RMSprop(learning_rate=0.001, decay=0.9, eps=1e-8)
+
+ RMSprop optimizer.
+
+Gradient Computation
+--------------------
+
+From ``brainstate.transform``:
+
+grad
+~~~~
+
+.. function:: transform.grad(fun, argnums=0, has_aux=False, return_value=False)
+
+ Compute gradients of a function.
+
+ **Parameters:**
+
+ - ``fun`` - Function to differentiate
+ - ``argnums`` - Which arguments to differentiate
+ - ``has_aux`` - Whether function returns auxiliary data
+ - ``return_value`` - Also return function value
+
+ **Example:**
+
+ .. code-block:: python
+
+ def loss_fn(params, net, X, y):
+ output = net(X)
+ return jnp.mean((output - y)**2)
+
+ params = net.states(brainstate.ParamState)
+
+ # Get gradients
+ grads = brainstate.transform.grad(loss_fn, params)(net, X, y)
+
+ # Get gradients and loss
+ grads, loss = brainstate.transform.grad(
+ loss_fn, params, return_value=True
+ )(net, X, y)
+
+value_and_grad
+~~~~~~~~~~~~~~
+
+.. function:: transform.value_and_grad(fun, argnums=0)
+
+ Compute both value and gradient (more efficient than separate calls).
+
+Surrogate Gradients
+-------------------
+
+From ``braintools.surrogate``:
+
+ReluGrad
+~~~~~~~~
+
+.. class:: surrogate.ReluGrad(alpha=1.0)
+
+ ReLU surrogate gradient for spike function.
+
+ **Example:**
+
+ .. code-block:: python
+
+ neuron = bp.LIF(
+ 100,
+ spike_fun=braintools.surrogate.ReluGrad()
+ )
+
+sigmoid
+~~~~~~~
+
+.. function:: surrogate.sigmoid(alpha=4.0)
+
+ Sigmoid surrogate gradient.
+
+slayer_grad
+~~~~~~~~~~~
+
+.. function:: surrogate.slayer_grad(alpha=4.0)
+
+ SLAYER/SuperSpike surrogate gradient.
+
+Loss Functions
+--------------
+
+From ``braintools.metric``:
+
+softmax_cross_entropy_with_integer_labels
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. function:: metric.softmax_cross_entropy_with_integer_labels(logits, labels)
+
+ Cross-entropy loss for classification.
+
+ **Parameters:**
+
+ - ``logits`` - Network outputs (batch_size, num_classes)
+ - ``labels`` - Integer labels (batch_size,)
+
+ **Returns:**
+
+ Loss per example (batch_size,)
+
+ **Example:**
+
+ .. code-block:: python
+
+ logits = net(inputs) # (32, 10)
+ labels = jnp.array([0, 1, 2, ...]) # (32,)
+
+ loss = braintools.metric.softmax_cross_entropy_with_integer_labels(
+ logits, labels
+ ).mean()
+
+Training Workflow
+-----------------
+
+**Complete example:**
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import braintools
+
+ # 1. Create network
+ net = TrainableSNN()
+ brainstate.nn.init_all_states(net, batch_size=32)
+
+ # 2. Create optimizer
+ optimizer = braintools.optim.Adam(learning_rate=1e-3)
+ params = net.states(brainstate.ParamState)
+ optimizer.register_trainable_weights(params)
+
+ # 3. Define loss function
+ def loss_fn(params, net, X, y):
+ brainstate.nn.init_all_states(net)
+ logits = run_network(net, X) # Simulate and accumulate
+ loss = braintools.metric.softmax_cross_entropy_with_integer_labels(
+ logits, y
+ ).mean()
+ return loss
+
+ # 4. Training loop
+ for epoch in range(num_epochs):
+ for X_batch, y_batch in data_loader:
+ # Compute gradients
+ grads, loss = brainstate.transform.grad(
+ loss_fn, params, return_value=True
+ )(net, X_batch, y_batch)
+
+ # Update parameters
+ optimizer.update(grads)
+
+ print(f"Loss: {loss:.4f}")
+
+See Also
+--------
+
+- :doc:`../tutorials/advanced/05-snn-training` - SNN training tutorial
+- :doc:`../how-to-guides/save-load-models` - Model checkpointing
+- :doc:`../how-to-guides/gpu-tpu-usage` - GPU acceleration
diff --git a/docs_version3/core-concepts/architecture.rst b/docs_version3/core-concepts/architecture.rst
new file mode 100644
index 00000000..9b6f8108
--- /dev/null
+++ b/docs_version3/core-concepts/architecture.rst
@@ -0,0 +1,609 @@
+Architecture Overview
+====================
+
+BrainPy 3.0 represents a complete architectural redesign built on top of the ``brainstate`` framework. This document explains the design principles and architectural components that make BrainPy 3.0 powerful and flexible.
+
+Design Philosophy
+-----------------
+
+BrainPy 3.0 is built around several core principles:
+
+**State-Based Programming**
+ All dynamical variables are managed as explicit states, enabling automatic differentiation, efficient compilation, and clear data flow.
+
+**Modular Composition**
+ Complex models are built by composing simple, reusable components. Each component has a well-defined interface and responsibility.
+
+**Scientific Accuracy**
+ Integration with ``brainunit`` ensures physical correctness and prevents unit-related errors.
+
+**Performance by Default**
+ JIT compilation and optimization are built into the framework, not an afterthought.
+
+**Extensibility**
+ Adding new neuron models, synapse types, or learning rules is straightforward and follows clear patterns.
+
+Architectural Layers
+--------------------
+
+BrainPy 3.0 is organized into several layers:
+
+.. code-block:: text
+
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+ โ User Models & Networks โ โ Your code
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
+ โ BrainPy Components Layer โ โ Neurons, Synapses, Projections
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
+ โ BrainState Framework โ โ State management, compilation
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
+ โ JAX + XLA Backend โ โ JIT compilation, autodiff
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+
+1. JAX + XLA Backend
+~~~~~~~~~~~~~~~~~~~~
+
+The foundation layer provides:
+
+- Just-In-Time (JIT) compilation
+- Automatic differentiation
+- Hardware acceleration (CPU/GPU/TPU)
+- Functional transformations (vmap, grad, etc.)
+
+2. BrainState Framework
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+Built on JAX, ``brainstate`` provides:
+
+- State management system
+- Module composition
+- Compilation and optimization
+- Program transformations (for_loop, etc.)
+
+3. BrainPy Components
+~~~~~~~~~~~~~~~~~~~~~
+
+High-level neuroscience-specific components:
+
+- Neuron models (LIF, ALIF, etc.)
+- Synapse models (Expon, Alpha, etc.)
+- Projection architectures
+- Learning rules and plasticity
+
+4. User Models
+~~~~~~~~~~~~~~
+
+Your custom networks and experiments built using BrainPy components.
+
+State Management System
+-----------------------
+
+The Foundation: brainstate.State
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Everything in BrainPy 3.0 revolves around **states**:
+
+.. code-block:: python
+
+ import brainstate
+
+ # Create a state
+ voltage = brainstate.State(0.0) # Single value
+ weights = brainstate.State([[0.1, 0.2], [0.3, 0.4]]) # Matrix
+
+States are special containers that:
+
+- Track their values across time
+- Support automatic differentiation
+- Enable efficient compilation
+- Handle batching automatically
+
+State Types
+~~~~~~~~~~~
+
+BrainPy uses different state types for different purposes:
+
+**ParamState** - Trainable Parameters
+ Used for weights, time constants, and other trainable parameters.
+
+ .. code-block:: python
+
+ class MyNeuron(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.tau = brainstate.ParamState(10.0) # Trainable
+ self.weight = brainstate.ParamState([[0.1, 0.2]])
+
+**ShortTermState** - Temporary Variables
+ Used for membrane potentials, synaptic currents, and other dynamics.
+
+ .. code-block:: python
+
+ class MyNeuron(brainstate.nn.Module):
+ def __init__(self, size):
+ super().__init__()
+ self.V = brainstate.ShortTermState(jnp.zeros(size)) # Dynamic
+ self.spike = brainstate.ShortTermState(jnp.zeros(size))
+
+State Initialization
+~~~~~~~~~~~~~~~~~~~~
+
+States can be initialized with various strategies:
+
+.. code-block:: python
+
+ import braintools
+ import brainunit as u
+
+ # Constant initialization
+ V = brainstate.ShortTermState(
+ braintools.init.Constant(-65.0, unit=u.mV)(size)
+ )
+
+ # Normal distribution
+ V = brainstate.ShortTermState(
+ braintools.init.Normal(-65.0, 5.0, unit=u.mV)(size)
+ )
+
+ # Uniform distribution
+ weights = brainstate.ParamState(
+ braintools.init.Uniform(0.0, 1.0)(shape)
+ )
+
+Module System
+-------------
+
+Base Class: brainstate.nn.Module
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+All BrainPy components inherit from ``brainstate.nn.Module``:
+
+.. code-block:: python
+
+ class MyComponent(brainstate.nn.Module):
+ def __init__(self, size):
+ super().__init__()
+ # Initialize states
+ self.state1 = brainstate.ShortTermState(...)
+ self.param1 = brainstate.ParamState(...)
+
+ def update(self, input):
+ # Define dynamics
+ pass
+
+Benefits of Module:
+
+- Automatic state registration
+- Nested module support
+- State collection and filtering
+- Serialization support
+
+Module Composition
+~~~~~~~~~~~~~~~~~~
+
+Modules can contain other modules:
+
+.. code-block:: python
+
+ import brainpy
+
+ class Network(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.neurons = brainpy.LIF(100) # Neuron module
+ self.synapse = brainpy.Expon(100) # Synapse module
+ self.projection = brainpy.AlignPostProj(...) # Projection module
+
+ def update(self, input):
+ # Compose behavior
+ self.projection(spikes)
+ self.neurons(input)
+
+Component Architecture
+----------------------
+
+Neurons
+~~~~~~~
+
+Neurons model the dynamics of neural populations:
+
+.. code-block:: python
+
+ class Neuron(brainstate.nn.Module):
+ def __init__(self, size, ...):
+ super().__init__()
+ # Membrane potential
+ self.V = brainstate.ShortTermState(...)
+ # Spike output
+ self.spike = brainstate.ShortTermState(...)
+
+ def update(self, input_current):
+ # Update membrane potential
+ # Generate spikes
+ pass
+
+Key responsibilities:
+
+- Maintain membrane potential
+- Generate spikes when threshold is crossed
+- Reset after spiking
+- Integrate input currents
+
+Synapses
+~~~~~~~~
+
+Synapses model temporal filtering of spike trains:
+
+.. code-block:: python
+
+ class Synapse(brainstate.nn.Module):
+ def __init__(self, size, tau, ...):
+ super().__init__()
+ # Synaptic conductance/current
+ self.g = brainstate.ShortTermState(...)
+
+ def update(self, spike_input):
+ # Update synaptic variable
+ # Return filtered output
+ pass
+
+Key responsibilities:
+
+- Filter spike inputs temporally
+- Model synaptic dynamics (exponential, alpha, etc.)
+- Provide smooth currents to postsynaptic neurons
+
+Projections: The Comm-Syn-Out Pattern
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Projections connect populations using a three-stage architecture:
+
+.. code-block:: text
+
+ Presynaptic Spikes โ [Comm] โ [Syn] โ [Out] โ Postsynaptic Neurons
+ โ โ โ
+ Connectivity โ Current
+ & Weights Dynamics Injection
+
+**Communication (Comm)**
+ Handles spike transmission, connectivity, and weights.
+
+ .. code-block:: python
+
+ comm = brainstate.nn.EventFixedProb(
+ pre_size, post_size, prob=0.1, weight=0.5
+ )
+
+**Synaptic Dynamics (Syn)**
+ Temporal filtering of transmitted spikes.
+
+ .. code-block:: python
+
+ syn = brainpy.Expon.desc(post_size, tau=5*u.ms)
+
+**Output Mechanism (Out)**
+ How synaptic variables affect postsynaptic neurons.
+
+ .. code-block:: python
+
+ out = brainpy.CUBA.desc() # Current-based
+ # or
+ out = brainpy.COBA.desc() # Conductance-based
+
+**Complete Projection**
+
+.. code-block:: python
+
+ projection = brainpy.AlignPostProj(
+ comm=comm,
+ syn=syn,
+ out=out,
+ post=postsynaptic_neurons
+ )
+
+This separation provides:
+
+- Clear responsibility boundaries
+- Easy component swapping
+- Reusable building blocks
+- Better testing and debugging
+
+Compilation and Execution
+--------------------------
+
+Time-Stepped Simulation
+~~~~~~~~~~~~~~~~~~~~~~~
+
+BrainPy uses discrete time steps:
+
+.. code-block:: python
+
+ import brainunit as u
+
+ # Set global time step
+ brainstate.environ.set(dt=0.1 * u.ms)
+
+ # Define simulation duration
+ times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())
+
+ # Run simulation
+ results = brainstate.transform.for_loop(
+ network.update,
+ times,
+ pbar=brainstate.transform.ProgressBar(10)
+ )
+
+JIT Compilation
+~~~~~~~~~~~~~~~
+
+Functions are compiled for performance:
+
+.. code-block:: python
+
+ @brainstate.compile.jit
+ def simulate_step(input):
+ return network.update(input)
+
+ # First call: compile
+ result = simulate_step(input) # Slow (compilation)
+
+ # Subsequent calls: fast
+ result = simulate_step(input) # Fast (compiled)
+
+Compilation benefits:
+
+- 10-100x speedup over Python
+- Automatic GPU/TPU dispatch
+- Memory optimization
+- Fusion of operations
+
+Gradient Computation
+~~~~~~~~~~~~~~~~~~~~
+
+For training, gradients are computed automatically:
+
+.. code-block:: python
+
+ def loss_fn():
+ predictions = network.forward(inputs)
+ return compute_loss(predictions, targets)
+
+ # Compute gradients
+ grads, loss = brainstate.transform.grad(
+ loss_fn,
+ network.states(brainstate.ParamState),
+ return_value=True
+ )()
+
+ # Update parameters
+ optimizer.update(grads)
+
+Physical Units System
+---------------------
+
+Integration with brainunit
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+BrainPy 3.0 integrates ``brainunit`` for scientific accuracy:
+
+.. code-block:: python
+
+ import brainunit as u
+
+ # Define with units
+ tau = 10 * u.ms
+ threshold = -50 * u.mV
+ current = 5 * u.nA
+
+ # Units are checked automatically
+ neuron = brainpy.LIF(100, tau=tau, V_th=threshold)
+
+Benefits:
+
+- Prevents unit errors (e.g., ms vs s)
+- Self-documenting code
+- Automatic unit conversions
+- Scientific correctness
+
+Unit Operations
+~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Arithmetic with units
+ total_time = 100 * u.ms + 0.5 * u.second # โ 600 ms
+
+ # Unit conversion
+ time_in_seconds = (100 * u.ms).to_decimal(u.second) # โ 0.1
+
+ # Unit checking (automatic in BrainPy operations)
+ voltage = -65 * u.mV
+ current = 2 * u.nA
+ resistance = voltage / current # Automatically gives Mฮฉ
+
+Ecosystem Integration
+---------------------
+
+BrainPy 3.0 integrates tightly with its ecosystem:
+
+braintools
+~~~~~~~~~~
+
+Utilities and tools:
+
+.. code-block:: python
+
+ import braintools
+
+ # Optimizers
+ optimizer = braintools.optim.Adam(lr=1e-3)
+
+ # Initializers
+ init = braintools.init.KaimingNormal()
+
+ # Surrogate gradients
+ spike_fn = braintools.surrogate.ReluGrad()
+
+ # Metrics
+ loss = braintools.metric.cross_entropy(pred, target)
+
+brainunit
+~~~~~~~~~
+
+Physical units:
+
+.. code-block:: python
+
+ import brainunit as u
+
+ # All standard SI units
+ time = 10 * u.ms
+ voltage = -65 * u.mV
+ current = 2 * u.nA
+
+brainstate
+~~~~~~~~~~
+
+Core framework (used automatically):
+
+.. code-block:: python
+
+ import brainstate
+
+ # Module system
+ class Net(brainstate.nn.Module): ...
+
+ # Compilation
+ @brainstate.compile.jit
+ def fn(): ...
+
+ # Transformations
+ result = brainstate.transform.for_loop(...)
+
+Data Flow Example
+-----------------
+
+Here's how data flows through a typical BrainPy 3.0 simulation:
+
+.. code-block:: python
+
+ # 1. Define network
+ class EINetwork(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.E = brainpy.LIF(800) # States: V, spike
+ self.I = brainpy.LIF(200) # States: V, spike
+ self.E2E = brainpy.AlignPostProj(...) # States: g (in synapse)
+ self.E2I = brainpy.AlignPostProj(...)
+ self.I2E = brainpy.AlignPostProj(...)
+ self.I2I = brainpy.AlignPostProj(...)
+
+ def update(self, input):
+ # Get spikes from last time step
+ e_spikes = self.E.get_spike()
+ i_spikes = self.I.get_spike()
+
+ # Update projections (spikes โ synaptic currents)
+ self.E2E(e_spikes) # Updates E2E.syn.g
+ self.E2I(e_spikes)
+ self.I2E(i_spikes)
+ self.I2I(i_spikes)
+
+ # Update neurons (currents โ new V and spikes)
+ self.E(input) # Updates E.V and E.spike
+ self.I(input) # Updates I.V and I.spike
+
+ return e_spikes, i_spikes
+
+ # 2. Initialize
+ net = EINetwork()
+ brainstate.nn.init_all_states(net)
+
+ # 3. Compile
+ @brainstate.compile.jit
+ def step(input):
+ return net.update(input)
+
+ # 4. Simulate
+ times = u.math.arange(0*u.ms, 1000*u.ms, 0.1*u.ms)
+ results = brainstate.transform.for_loop(step, times)
+
+State Flow:
+
+.. code-block:: text
+
+ Time t:
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+ โ States at t-1: โ
+ โ E.V[t-1], E.spike[t-1] โ
+ โ I.V[t-1], I.spike[t-1] โ
+ โ E2E.syn.g[t-1], ... โ
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+ โ
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+ โ Projection Updates: โ
+ โ E2E.syn.g[t] = f(g[t-1], E.spike[t-1])โ
+ โ ... (other projections) โ
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+ โ
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+ โ Neuron Updates: โ
+ โ E.V[t] = f(V[t-1], ฮฃ currents[t]) โ
+ โ E.spike[t] = E.V[t] >= V_th โ
+ โ ... (other neurons) โ
+ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+ โ
+ Time t+1...
+
+Performance Considerations
+--------------------------
+
+Memory Management
+~~~~~~~~~~~~~~~~~
+
+- States are preallocated
+- In-place updates when possible
+- Efficient batching support
+- Automatic garbage collection
+
+Compilation Strategy
+~~~~~~~~~~~~~~~~~~~~
+
+- Compile simulation loops
+- Batch operations when possible
+- Use ``for_loop`` for long sequences
+- Leverage JAX's XLA optimization
+
+Hardware Acceleration
+~~~~~~~~~~~~~~~~~~~~~
+
+- Automatic GPU dispatch for large arrays
+- TPU support for massive parallelism
+- Efficient CPU fallback for small problems
+
+Summary
+-------
+
+BrainPy 3.0's architecture provides:
+
+โ
**Clear Abstractions**: Neurons, synapses, and projections with well-defined roles
+
+โ
**State Management**: Explicit, efficient handling of dynamical variables
+
+โ
**Modularity**: Compose complex models from simple components
+
+โ
**Performance**: JIT compilation and hardware acceleration
+
+โ
**Scientific Accuracy**: Integrated physical units
+
+โ
**Extensibility**: Easy to add custom components
+
+โ
**Modern Design**: Built on proven frameworks (JAX, brainstate)
+
+Next Steps
+----------
+
+- Learn about specific components: :doc:`neurons`, :doc:`synapses`, :doc:`projections`
+- Understand state management in depth: :doc:`state-management`
+- See practical examples: :doc:`../tutorials/basic/01-lif-neuron`
+- Explore the ecosystem: `brainstate docs `_
diff --git a/docs_version3/core-concepts/neurons.rst b/docs_version3/core-concepts/neurons.rst
new file mode 100644
index 00000000..952f3492
--- /dev/null
+++ b/docs_version3/core-concepts/neurons.rst
@@ -0,0 +1,598 @@
+Neurons
+=======
+
+Neurons are the fundamental computational units in BrainPy 3.0. This document explains how neurons work, what models are available, and how to use and create them.
+
+Overview
+--------
+
+In BrainPy 3.0, neurons model the dynamics of neural populations. Each neuron model:
+
+- Maintains **membrane potential** (voltage)
+- Integrates **input currents**
+- Generates **spikes** when threshold is crossed
+- **Resets** after spiking (various strategies)
+
+All neuron models inherit from the base ``Neuron`` class and follow consistent interfaces.
+
+Basic Usage
+-----------
+
+Creating Neurons
+~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import brainpy
+ import brainunit as u
+
+ # Create a population of 100 LIF neurons
+ neurons = brainpy.LIF(
+ size=100,
+ V_rest=-65. * u.mV,
+ V_th=-50. * u.mV,
+ V_reset=-65. * u.mV,
+ tau=10. * u.ms
+ )
+
+Initializing States
+~~~~~~~~~~~~~~~~~~~
+
+Before simulation, initialize neuron states:
+
+.. code-block:: python
+
+ import brainstate
+
+ # Initialize all states to default values
+ brainstate.nn.init_all_states(neurons)
+
+ # Or with specific batch size
+ brainstate.nn.init_all_states(neurons, batch_size=32)
+
+Running Neurons
+~~~~~~~~~~~~~~~
+
+Update neurons by calling them with input current:
+
+.. code-block:: python
+
+ # Single time step
+ input_current = 2.0 * u.nA
+ neurons(input_current)
+
+ # Access results
+ voltage = neurons.V.value # Membrane potential
+ spikes = neurons.get_spike() # Spike output
+
+Available Neuron Models
+-----------------------
+
+IF (Integrate-and-Fire)
+~~~~~~~~~~~~~~~~~~~~~~~
+
+The simplest spiking neuron model.
+
+**Mathematical Model:**
+
+.. math::
+
+ \\tau \\frac{dV}{dt} = -V + R \\cdot I(t)
+
+**Spike condition:** If :math:`V \\geq V_{th}`, emit spike and reset.
+
+**Example:**
+
+.. code-block:: python
+
+ neuron = brainpy.IF(
+ size=100,
+ V_rest=0. * u.mV,
+ V_th=1. * u.mV,
+ V_reset=0. * u.mV,
+ tau=20. * u.ms,
+ R=1. * u.ohm
+ )
+
+**Parameters:**
+
+- ``size``: Number of neurons
+- ``V_rest``: Resting potential
+- ``V_th``: Spike threshold
+- ``V_reset``: Reset potential after spike
+- ``tau``: Membrane time constant
+- ``R``: Input resistance
+
+**Use cases:**
+
+- Simple rate coding
+- Fast simulations
+- Theoretical studies
+
+LIF (Leaky Integrate-and-Fire)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The most commonly used spiking neuron model.
+
+**Mathematical Model:**
+
+.. math::
+
+ \\tau \\frac{dV}{dt} = -(V - V_{rest}) + R \\cdot I(t)
+
+**Spike condition:** If :math:`V \\geq V_{th}`, emit spike and reset.
+
+**Example:**
+
+.. code-block:: python
+
+ neuron = brainpy.LIF(
+ size=100,
+ V_rest=-65. * u.mV,
+ V_th=-50. * u.mV,
+ V_reset=-65. * u.mV,
+ tau=10. * u.ms,
+ R=1. * u.ohm,
+ V_initializer=braintools.init.Normal(-65., 5., unit=u.mV)
+ )
+
+**Parameters:**
+
+All IF parameters, plus:
+
+- ``V_initializer``: How to initialize membrane potential
+
+**Key Features:**
+
+- Leak toward resting potential
+- Realistic temporal integration
+- Well-studied dynamics
+
+**Use cases:**
+
+- General spiking neural networks
+- Cortical neuron modeling
+- Learning and training
+
+LIFRef (LIF with Refractory Period)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+LIF neuron with absolute refractory period.
+
+**Mathematical Model:**
+
+Same as LIF, but after spiking:
+
+- Neuron is "frozen" for refractory period
+- No integration during refractory period
+- More biologically realistic
+
+**Example:**
+
+.. code-block:: python
+
+ neuron = brainpy.LIFRef(
+ size=100,
+ V_rest=-65. * u.mV,
+ V_th=-50. * u.mV,
+ V_reset=-65. * u.mV,
+ tau=10. * u.ms,
+ tau_ref=2. * u.ms, # Refractory period
+ R=1. * u.ohm
+ )
+
+**Additional Parameters:**
+
+- ``tau_ref``: Refractory period duration
+
+**Key Features:**
+
+- Absolute refractory period
+- Prevents immediate re-firing
+- More realistic spike timing
+
+**Use cases:**
+
+- Precise temporal coding
+- Biological realism
+- Rate regulation
+
+ALIF (Adaptive Leaky Integrate-and-Fire)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+LIF with spike-frequency adaptation.
+
+**Mathematical Model:**
+
+.. math::
+
+ \\tau \\frac{dV}{dt} &= -(V - V_{rest}) - R \\cdot w + R \\cdot I(t)
+
+ \\tau_w \\frac{dw}{dt} &= -w
+
+When spike occurs: :math:`w \\leftarrow w + \\beta`
+
+**Example:**
+
+.. code-block:: python
+
+ neuron = brainpy.ALIF(
+ size=100,
+ V_rest=-65. * u.mV,
+ V_th=-50. * u.mV,
+ V_reset=-65. * u.mV,
+ tau=10. * u.ms,
+ tau_w=200. * u.ms, # Adaptation time constant
+ beta=0.01, # Adaptation strength
+ R=1. * u.ohm
+ )
+
+**Additional Parameters:**
+
+- ``tau_w``: Adaptation time constant
+- ``beta``: Adaptation increment per spike
+
+**Key Features:**
+
+- Spike-frequency adaptation
+- Reduced firing with sustained input
+- More complex dynamics
+
+**Use cases:**
+
+- Cortical neuron modeling
+- Sensory adaptation
+- Complex temporal patterns
+
+Reset Modes
+-----------
+
+BrainPy supports different reset behaviors after spiking:
+
+Soft Reset (Default)
+~~~~~~~~~~~~~~~~~~~~
+
+Subtract threshold from membrane potential:
+
+.. math::
+
+ V \\leftarrow V - V_{th}
+
+.. code-block:: python
+
+ neuron = brainpy.LIF(..., spk_reset='soft')
+
+**Properties:**
+
+- Preserves extra charge above threshold
+- Allows rapid re-firing
+- Common in machine learning
+
+Hard Reset
+~~~~~~~~~~
+
+Reset to fixed potential:
+
+.. math::
+
+ V \\leftarrow V_{reset}
+
+.. code-block:: python
+
+ neuron = brainpy.LIF(..., spk_reset='hard')
+
+**Properties:**
+
+- Discards extra charge
+- More biologically realistic
+- Prevents immediate re-firing
+
+Choosing Reset Mode
+~~~~~~~~~~~~~~~~~~~~
+
+- **Soft reset**: Machine learning, rate coding, fast oscillations
+- **Hard reset**: Biological modeling, temporal coding, realism
+
+Spike Functions
+---------------
+
+For training spiking neural networks, use surrogate gradients:
+
+.. code-block:: python
+
+ import braintools
+
+ neuron = brainpy.LIF(
+ size=100,
+ ...,
+ spk_fun=braintools.surrogate.ReluGrad()
+ )
+
+Available surrogate functions:
+
+- ``ReluGrad()``: ReLU-like gradient
+- ``SigmoidGrad()``: Sigmoid-like gradient
+- ``GaussianGrad()``: Gaussian-like gradient
+- ``SuperSpike()``: SuperSpike surrogate
+
+See :doc:`../tutorials/advanced/03-snn-training` for training details.
+
+Advanced Features
+-----------------
+
+Initialization Strategies
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Different ways to initialize membrane potential:
+
+.. code-block:: python
+
+ import braintools
+
+ # Constant initialization
+ neuron = brainpy.LIF(
+ size=100,
+ V_initializer=braintools.init.Constant(-65., unit=u.mV)
+ )
+
+ # Normal distribution
+ neuron = brainpy.LIF(
+ size=100,
+ V_initializer=braintools.init.Normal(-65., 5., unit=u.mV)
+ )
+
+ # Uniform distribution
+ neuron = brainpy.LIF(
+ size=100,
+ V_initializer=braintools.init.Uniform(-70., -60., unit=u.mV)
+ )
+
+Accessing Neuron States
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Membrane potential (with units)
+ voltage = neuron.V.value # Quantity with units
+
+ # Spike output (binary or real-valued)
+ spikes = neuron.get_spike()
+
+ # Access underlying array (without units)
+ voltage_array = neuron.V.value.to_decimal(u.mV)
+
+Batched Simulation
+~~~~~~~~~~~~~~~~~~
+
+Simulate multiple trials in parallel:
+
+.. code-block:: python
+
+ # Initialize with batch dimension
+ brainstate.nn.init_all_states(neuron, batch_size=32)
+
+ # Input shape: (batch_size,) or (batch_size, size)
+ input_current = jnp.ones((32, 100)) * 2.0 * u.nA
+ neuron(input_current)
+
+ # Output shape: (batch_size, size)
+ spikes = neuron.get_spike()
+
+Complete Example
+----------------
+
+Here's a complete example simulating a LIF neuron:
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+ import matplotlib.pyplot as plt
+
+ # Set time step
+ brainstate.environ.set(dt=0.1 * u.ms)
+
+ # Create neuron
+ neuron = brainpy.LIF(
+ size=1,
+ V_rest=-65. * u.mV,
+ V_th=-50. * u.mV,
+ V_reset=-65. * u.mV,
+ tau=10. * u.ms,
+ spk_reset='hard'
+ )
+
+ # Initialize
+ brainstate.nn.init_all_states(neuron)
+
+ # Simulation parameters
+ duration = 200. * u.ms
+ dt = brainstate.environ.get_dt()
+ times = u.math.arange(0. * u.ms, duration, dt)
+
+ # Input current (step input)
+ def get_input(t):
+ return 2.0 * u.nA if t > 50*u.ms else 0.0 * u.nA
+
+ # Run simulation
+ voltages = []
+ spikes = []
+
+ for t in times:
+ neuron(get_input(t))
+ voltages.append(neuron.V.value)
+ spikes.append(neuron.get_spike())
+
+ # Plot results
+ voltages = u.math.asarray(voltages)
+ times_plot = times.to_decimal(u.ms)
+ voltages_plot = voltages.to_decimal(u.mV)
+
+ plt.figure(figsize=(10, 4))
+ plt.plot(times_plot, voltages_plot)
+ plt.axhline(y=-50, color='r', linestyle='--', label='Threshold')
+ plt.xlabel('Time (ms)')
+ plt.ylabel('Membrane Potential (mV)')
+ plt.title('LIF Neuron Response')
+ plt.legend()
+ plt.grid(True, alpha=0.3)
+ plt.show()
+
+Creating Custom Neurons
+------------------------
+
+You can create custom neuron models by inheriting from ``Neuron``:
+
+.. code-block:: python
+
+ import brainstate
+ from brainpy._base import Neuron
+
+ class MyNeuron(Neuron):
+ def __init__(self, size, tau, V_th, **kwargs):
+ super().__init__(size, **kwargs)
+
+ # Store parameters
+ self.tau = tau
+ self.V_th = V_th
+
+ # Initialize states
+ self.V = brainstate.ShortTermState(
+ braintools.init.Constant(0., unit=u.mV)(size)
+ )
+ self.spike = brainstate.ShortTermState(
+ jnp.zeros(size)
+ )
+
+ def update(self, x):
+ # Get time step
+ dt = brainstate.environ.get_dt()
+
+ # Update membrane potential (custom dynamics)
+ dV = (-self.V.value + x) / self.tau * dt
+ V_new = self.V.value + dV
+
+ # Check for spikes
+ spike = (V_new >= self.V_th).astype(float)
+
+ # Reset
+ V_new = jnp.where(spike > 0, 0. * u.mV, V_new)
+
+ # Update states
+ self.V.value = V_new
+ self.spike.value = spike
+
+ return spike
+
+ def get_spike(self):
+ return self.spike.value
+
+Usage:
+
+.. code-block:: python
+
+ neuron = MyNeuron(size=100, tau=10*u.ms, V_th=1*u.mV)
+ brainstate.nn.init_all_states(neuron)
+ neuron(input_current)
+
+Performance Tips
+----------------
+
+1. **Use JIT compilation** for repeated simulations:
+
+ .. code-block:: python
+
+ @brainstate.compile.jit
+ def simulate_step(input):
+ neuron(input)
+ return neuron.V.value
+
+2. **Batch multiple trials** for parallelism:
+
+ .. code-block:: python
+
+ brainstate.nn.init_all_states(neuron, batch_size=100)
+
+3. **Use appropriate data types**:
+
+ .. code-block:: python
+
+ # Float32 is usually sufficient and faster
+ brainstate.environ.set(dtype=jnp.float32)
+
+4. **Preallocate arrays** when recording:
+
+ .. code-block:: python
+
+ n_steps = len(times)
+ voltages = jnp.zeros((n_steps, neuron.size))
+
+Common Patterns
+---------------
+
+Rate Coding
+~~~~~~~~~~~
+
+Neurons encoding information in firing rate:
+
+.. code-block:: python
+
+ neuron = brainpy.LIF(100, tau=10*u.ms, spk_reset='soft')
+ # Use soft reset for higher firing rates
+
+Temporal Coding
+~~~~~~~~~~~~~~~
+
+Neurons encoding information in spike timing:
+
+.. code-block:: python
+
+ neuron = brainpy.LIFRef(
+ 100,
+ tau=10*u.ms,
+ tau_ref=2*u.ms,
+ spk_reset='hard'
+ )
+ # Use refractory period for precise timing
+
+Burst Firing
+~~~~~~~~~~~~
+
+Neurons with bursting behavior:
+
+.. code-block:: python
+
+ neuron = brainpy.ALIF(
+ 100,
+ tau=10*u.ms,
+ tau_w=200*u.ms,
+ beta=0.01,
+ spk_reset='soft'
+ )
+ # Adaptation creates bursting patterns
+
+Summary
+-------
+
+Neurons in BrainPy 3.0:
+
+โ
**Multiple models**: IF, LIF, LIFRef, ALIF
+
+โ
**Physical units**: All parameters with proper units
+
+โ
**Flexible reset**: Soft or hard reset modes
+
+โ
**Training-ready**: Surrogate gradients for learning
+
+โ
**High performance**: JIT compilation and batching
+
+โ
**Extensible**: Easy to create custom models
+
+Next Steps
+----------
+
+- Learn about :doc:`synapses` to connect neurons
+- Explore :doc:`projections` for network connectivity
+- Follow :doc:`../tutorials/basic/01-lif-neuron` for hands-on practice
+- See :doc:`../examples/classical-networks/ei-balanced` for network examples
diff --git a/docs_version3/core-concepts/projections.rst b/docs_version3/core-concepts/projections.rst
new file mode 100644
index 00000000..a367da04
--- /dev/null
+++ b/docs_version3/core-concepts/projections.rst
@@ -0,0 +1,894 @@
+Projections: Connecting Neural Populations
+==========================================
+
+Projections are BrainPy's mechanism for connecting neural populations. They implement the **Communication-Synapse-Output (Comm-Syn-Out)** architecture, which separates connectivity, synaptic dynamics, and output computation into modular components.
+
+This guide provides a comprehensive understanding of projections in BrainPy 3.0.
+
+.. contents:: Table of Contents
+ :local:
+ :depth: 2
+
+Overview
+--------
+
+What are Projections?
+~~~~~~~~~~~~~~~~~~~~~
+
+A **projection** connects a presynaptic population to a postsynaptic population through:
+
+1. **Communication (Comm)**: How spikes propagate through connections
+2. **Synapse (Syn)**: Temporal filtering and synaptic dynamics
+3. **Output (Out)**: How synaptic currents affect postsynaptic neurons
+
+**Key benefits:**
+
+- Modular design (swap components independently)
+- Biologically realistic (separate connectivity and dynamics)
+- Efficient (optimized sparse operations)
+- Flexible (combine components in different ways)
+
+The Comm-Syn-Out Architecture
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: text
+
+ Presynaptic Communication Synapse Output Postsynaptic
+ Population โโโบ (Connectivity) โโโบ (Dynamics) โโโบ (Current) โโโบ Population
+
+ Spikes โโโบ Weight matrix โโโบ g(t) โโโบ I_syn โโโบ Neurons
+ Sparse/Dense Expon/Alpha CUBA/COBA
+
+**Flow:**
+
+1. Presynaptic spikes arrive
+2. Communication: Spikes propagate through connectivity matrix
+3. Synapse: Temporal dynamics filter the signal
+4. Output: Convert to current/conductance
+5. Postsynaptic neurons receive input
+
+Types of Projections
+~~~~~~~~~~~~~~~~~~~~~
+
+BrainPy provides two main projection types:
+
+**AlignPostProj**
+ - Align synaptic states with postsynaptic neurons
+ - Most common for standard neural networks
+ - Efficient memory layout
+
+**AlignPreProj**
+ - Align synaptic states with presynaptic neurons
+ - Useful for certain learning rules
+ - Different memory organization
+
+For most use cases, use ``AlignPostProj``.
+
+Communication Layer
+-------------------
+
+The Communication layer defines **how spikes propagate** through connections.
+
+Dense Connectivity
+~~~~~~~~~~~~~~~~~~
+
+All neurons potentially connected (though weights may be zero).
+
+**Use case:** Small networks, fully connected layers
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+
+ # Dense linear transformation
+ comm = brainstate.nn.Linear(
+ in_size=100, # Presynaptic neurons
+ out_size=50, # Postsynaptic neurons
+ w_init=brainstate.init.KaimingNormal(),
+ b_init=None # No bias for synapses
+ )
+
+**Characteristics:**
+
+- Memory: O(n_pre ร n_post)
+- Computation: Full matrix multiplication
+- Best for: Small networks, fully connected architectures
+
+Sparse Connectivity
+~~~~~~~~~~~~~~~~~~~
+
+Only a subset of connections exist (biologically realistic).
+
+**Use case:** Large networks, biological connectivity patterns
+
+Event-Based Fixed Probability
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Connect neurons with fixed probability.
+
+.. code-block:: python
+
+ # Sparse random connectivity (2% connection probability)
+ comm = brainstate.nn.EventFixedProb(
+ pre_size=1000,
+ post_size=800,
+ prob=0.02, # 2% connectivity
+ weight=0.5 * u.mS # Synaptic weight
+ )
+
+**Characteristics:**
+
+- Memory: O(n_pre ร n_post ร prob)
+- Computation: Only active connections
+- Best for: Large-scale networks, biological models
+
+Event-Based All-to-All
+^^^^^^^^^^^^^^^^^^^^^^^
+
+All neurons connected (but stored sparsely).
+
+.. code-block:: python
+
+ # All-to-all sparse (event-driven)
+ comm = brainstate.nn.EventAll2All(
+ pre_size=100,
+ post_size=100,
+ weight=0.3 * u.mS
+ )
+
+Event-Based One-to-One
+^^^^^^^^^^^^^^^^^^^^^^^
+
+One-to-one mapping (same size populations).
+
+.. code-block:: python
+
+ # One-to-one connections
+ comm = brainstate.nn.EventOne2One(
+ size=100,
+ weight=1.0 * u.mS
+ )
+
+**Use case:** Feedforward pathways, identity mappings
+
+Comparison Table
+~~~~~~~~~~~~~~~~
+
+.. list-table:: Communication Layer Options
+ :header-rows: 1
+ :widths: 20 20 20 20 20
+
+ * - Type
+ - Memory
+ - Speed
+ - Use Case
+ - Example
+ * - Linear (Dense)
+ - High (O(nยฒ))
+ - Fast (optimized)
+ - Small networks
+ - Fully connected
+ * - EventFixedProb
+ - Low (O(nยฒp))
+ - Very fast
+ - Large networks
+ - Cortical connectivity
+ * - EventAll2All
+ - Medium
+ - Fast
+ - Medium networks
+ - Recurrent layers
+ * - EventOne2One
+ - Minimal (O(n))
+ - Fastest
+ - Feedforward
+ - Sensory pathways
+
+Synapse Layer
+-------------
+
+The Synapse layer defines **temporal dynamics** of synaptic transmission.
+
+Exponential Synapse
+~~~~~~~~~~~~~~~~~~~
+
+Single exponential decay (most common).
+
+**Dynamics:**
+
+.. math::
+
+ \tau \frac{dg}{dt} = -g + \sum_k \delta(t - t_k)
+
+**Implementation:**
+
+.. code-block:: python
+
+ # Exponential synapse with 5ms time constant
+ syn = bp.Expon.desc(
+ size=100, # Postsynaptic population size
+ tau=5.0 * u.ms # Decay time constant
+ )
+
+**Characteristics:**
+
+- Single time constant
+- Fast computation
+- Good for most applications
+
+**When to use:** Default choice for most models
+
+Alpha Synapse
+~~~~~~~~~~~~~
+
+Dual exponential with rise and decay.
+
+**Dynamics:**
+
+.. math::
+
+ \tau \frac{dg}{dt} = -g + h
+
+ \tau \frac{dh}{dt} = -h + \sum_k \delta(t - t_k)
+
+**Implementation:**
+
+.. code-block:: python
+
+ # Alpha synapse
+ syn = bp.Alpha.desc(
+ size=100,
+ tau=10.0 * u.ms # Characteristic time
+ )
+
+**Characteristics:**
+
+- Realistic rise time
+- Smoother response
+- Slightly slower computation
+
+**When to use:** When rise time matters, more biological realism
+
+NMDA Synapse
+~~~~~~~~~~~~
+
+Voltage-dependent NMDA receptors.
+
+**Dynamics:**
+
+.. math::
+
+ g_{NMDA} = \frac{g}{1 + \eta [Mg^{2+}] e^{-\gamma V}}
+
+**Implementation:**
+
+.. code-block:: python
+
+ # NMDA receptor
+ syn = bp.NMDA.desc(
+ size=100,
+ tau_decay=100.0 * u.ms, # Slow decay
+ tau_rise=2.0 * u.ms, # Fast rise
+ a=0.5 / u.mM, # Mgยฒโบ sensitivity
+ cc_Mg=1.2 * u.mM # Mgยฒโบ concentration
+ )
+
+**Characteristics:**
+
+- Voltage-dependent
+- Slow kinetics
+- Important for plasticity
+
+**When to use:** Long-term potentiation, working memory models
+
+AMPA Synapse
+~~~~~~~~~~~~
+
+Fast glutamatergic transmission.
+
+.. code-block:: python
+
+ # AMPA receptor (fast excitation)
+ syn = bp.AMPA.desc(
+ size=100,
+ tau=2.0 * u.ms # Fast decay (~2ms)
+ )
+
+**When to use:** Fast excitatory transmission
+
+GABA Synapse
+~~~~~~~~~~~~
+
+Inhibitory transmission.
+
+**GABAa (fast):**
+
+.. code-block:: python
+
+ # GABAa receptor (fast inhibition)
+ syn = bp.GABAa.desc(
+ size=100,
+ tau=6.0 * u.ms # ~6ms decay
+ )
+
+**GABAb (slow):**
+
+.. code-block:: python
+
+ # GABAb receptor (slow inhibition)
+ syn = bp.GABAb.desc(
+ size=100,
+ tau_decay=150.0 * u.ms, # Very slow
+ tau_rise=3.5 * u.ms
+ )
+
+**When to use:**
+- GABAa: Fast inhibition, cortical networks
+- GABAb: Slow inhibition, rhythm generation
+
+Custom Synapses
+~~~~~~~~~~~~~~~
+
+Create custom synaptic dynamics by subclassing ``Synapse``.
+
+.. code-block:: python
+
+ class DoubleExpSynapse(bp.Synapse):
+ """Custom synapse with two time constants."""
+
+ def __init__(self, size, tau_fast=2*u.ms, tau_slow=10*u.ms, **kwargs):
+ super().__init__(size, **kwargs)
+ self.tau_fast = tau_fast
+ self.tau_slow = tau_slow
+
+ # State variables
+ self.g_fast = brainstate.ShortTermState(jnp.zeros(size))
+ self.g_slow = brainstate.ShortTermState(jnp.zeros(size))
+
+ def reset_state(self, batch_size=None):
+ shape = self.size if batch_size is None else (batch_size, self.size)
+ self.g_fast.value = jnp.zeros(shape)
+ self.g_slow.value = jnp.zeros(shape)
+
+ def update(self, x):
+ dt = brainstate.environ.get_dt()
+
+ # Fast component
+ dg_fast = -self.g_fast.value / self.tau_fast.to_decimal(u.ms)
+ self.g_fast.value += dg_fast * dt.to_decimal(u.ms) + x * 0.7
+
+ # Slow component
+ dg_slow = -self.g_slow.value / self.tau_slow.to_decimal(u.ms)
+ self.g_slow.value += dg_slow * dt.to_decimal(u.ms) + x * 0.3
+
+ return self.g_fast.value + self.g_slow.value
+
+Output Layer
+------------
+
+The Output layer defines **how synaptic conductance affects neurons**.
+
+CUBA (Current-Based)
+~~~~~~~~~~~~~~~~~~~~
+
+Synaptic conductance directly becomes current.
+
+**Model:**
+
+.. math::
+
+ I_{syn} = g_{syn}
+
+**Implementation:**
+
+.. code-block:: python
+
+ # Current-based output
+ out = bp.CUBA.desc()
+
+**Characteristics:**
+
+- Simple and fast
+- No voltage dependence
+- Good for rate-based models
+
+**When to use:**
+- Abstract models
+- When voltage dependence not important
+- Faster computation needed
+
+COBA (Conductance-Based)
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Synaptic conductance with reversal potential.
+
+**Model:**
+
+.. math::
+
+ I_{syn} = g_{syn} (E_{syn} - V_{post})
+
+**Implementation:**
+
+.. code-block:: python
+
+ # Excitatory conductance-based
+ out_exc = bp.COBA.desc(E=0.0 * u.mV)
+
+ # Inhibitory conductance-based
+ out_inh = bp.COBA.desc(E=-80.0 * u.mV)
+
+**Characteristics:**
+
+- Voltage-dependent
+- Biologically realistic
+- Self-limiting (saturates near reversal)
+
+**When to use:**
+- Biologically detailed models
+- When voltage dependence matters
+- Shunting inhibition needed
+
+MgBlock (NMDA)
+~~~~~~~~~~~~~~
+
+Voltage-dependent magnesium block for NMDA.
+
+.. code-block:: python
+
+ # NMDA with Mgยฒโบ block
+ out_nmda = bp.MgBlock.desc(
+ E=0.0 * u.mV,
+ cc_Mg=1.2 * u.mM,
+ alpha=0.062 / u.mV,
+ beta=3.57
+ )
+
+**When to use:** NMDA receptors, voltage-dependent plasticity
+
+Complete Projection Examples
+-----------------------------
+
+Example 1: Simple Feedforward
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+
+ # Create populations
+ pre = bp.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ post = bp.LIF(50, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+
+ # Create projection: 100 โ 50 neurons
+ proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(
+ pre_size=100,
+ post_size=50,
+ prob=0.1, # 10% connectivity
+ weight=0.5 * u.mS
+ ),
+ syn=bp.Expon.desc(
+ size=50, # Postsynaptic size
+ tau=5.0 * u.ms
+ ),
+ out=bp.CUBA.desc(),
+ post=post # Postsynaptic population
+ )
+
+ # Initialize
+ brainstate.nn.init_all_states([pre, post, proj])
+
+ # Simulate
+ def step(inp):
+ # Get presynaptic spikes
+ pre_spikes = pre.get_spike()
+
+ # Update projection
+ proj(pre_spikes)
+
+ # Update neurons
+ pre(inp)
+ post(0.0 * u.nA) # Projection provides input
+
+ return pre.get_spike(), post.get_spike()
+
+Example 2: Excitatory-Inhibitory Network
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class EINetwork(brainstate.nn.Module):
+ def __init__(self, n_exc=800, n_inh=200):
+ super().__init__()
+
+ # Populations
+ self.E = bp.LIF(n_exc, V_rest=-65*u.mV, V_th=-50*u.mV, tau=15*u.ms)
+ self.I = bp.LIF(n_inh, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+
+ # E โ E projection (AMPA, excitatory)
+ self.E2E = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_exc, n_exc, prob=0.02, weight=0.6*u.mS),
+ syn=bp.AMPA.desc(n_exc, tau=2.0*u.ms),
+ out=bp.COBA.desc(E=0.0*u.mV),
+ post=self.E
+ )
+
+ # E โ I projection (AMPA, excitatory)
+ self.E2I = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_exc, n_inh, prob=0.02, weight=0.6*u.mS),
+ syn=bp.AMPA.desc(n_inh, tau=2.0*u.ms),
+ out=bp.COBA.desc(E=0.0*u.mV),
+ post=self.I
+ )
+
+ # I โ E projection (GABAa, inhibitory)
+ self.I2E = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_inh, n_exc, prob=0.02, weight=6.7*u.mS),
+ syn=bp.GABAa.desc(n_exc, tau=6.0*u.ms),
+ out=bp.COBA.desc(E=-80.0*u.mV),
+ post=self.E
+ )
+
+ # I โ I projection (GABAa, inhibitory)
+ self.I2I = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_inh, n_inh, prob=0.02, weight=6.7*u.mS),
+ syn=bp.GABAa.desc(n_inh, tau=6.0*u.ms),
+ out=bp.COBA.desc(E=-80.0*u.mV),
+ post=self.I
+ )
+
+ def update(self, inp_e, inp_i):
+ # Get spikes BEFORE updating neurons
+ spk_e = self.E.get_spike()
+ spk_i = self.I.get_spike()
+
+ # Update all projections
+ self.E2E(spk_e)
+ self.E2I(spk_e)
+ self.I2E(spk_i)
+ self.I2I(spk_i)
+
+ # Update neurons (projections provide synaptic input)
+ self.E(inp_e)
+ self.I(inp_i)
+
+ return spk_e, spk_i
+
+Example 3: Multi-Timescale Synapses
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Combine AMPA (fast) and NMDA (slow) for realistic excitation.
+
+.. code-block:: python
+
+ class DualExcitatory(brainstate.nn.Module):
+ """E โ E with both AMPA and NMDA."""
+
+ def __init__(self, n_pre=100, n_post=100):
+ super().__init__()
+
+ self.post = bp.LIF(n_post, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+
+ # Fast AMPA component
+ self.ampa_proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_pre, n_post, prob=0.1, weight=0.3*u.mS),
+ syn=bp.AMPA.desc(n_post, tau=2.0*u.ms),
+ out=bp.COBA.desc(E=0.0*u.mV),
+ post=self.post
+ )
+
+ # Slow NMDA component
+ self.nmda_proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_pre, n_post, prob=0.1, weight=0.3*u.mS),
+ syn=bp.NMDA.desc(n_post, tau_decay=100.0*u.ms, tau_rise=2.0*u.ms),
+ out=bp.MgBlock.desc(E=0.0*u.mV, cc_Mg=1.2*u.mM),
+ post=self.post
+ )
+
+ def update(self, pre_spikes):
+ # Both projections share same presynaptic spikes
+ self.ampa_proj(pre_spikes)
+ self.nmda_proj(pre_spikes)
+
+ # Post receives combined input
+ self.post(0.0 * u.nA)
+
+ return self.post.get_spike()
+
+Advanced Topics
+---------------
+
+Delay Projections
+~~~~~~~~~~~~~~~~~
+
+Add synaptic delays to projections.
+
+.. code-block:: python
+
+ # Projection with 5ms synaptic delay
+ proj_delayed = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(100, 100, prob=0.1, weight=0.5*u.mS),
+ syn=bp.Expon.desc(100, tau=5.0*u.ms),
+ out=bp.CUBA.desc(),
+ post=post_neurons,
+ delay=5.0 * u.ms # Synaptic delay
+ )
+
+**Use cases:**
+- Biologically realistic transmission delays
+- Axonal conduction delays
+- Synchronization studies
+
+Heterogeneous Weights
+~~~~~~~~~~~~~~~~~~~~~~
+
+Different weights for different connections.
+
+.. code-block:: python
+
+ import jax.numpy as jnp
+
+ # Custom weight matrix
+ n_pre, n_post = 100, 50
+ weights = jnp.abs(brainstate.random.randn(n_pre, n_post)) * 0.5 * u.mS
+
+ # Sparse with heterogeneous weights
+ comm = brainstate.nn.EventJitFPHomoLinear(
+ num_in=n_pre,
+ num_out=n_post,
+ prob=0.1,
+ weight=weights # Heterogeneous
+ )
+
+Learning Synapses
+~~~~~~~~~~~~~~~~~
+
+Combine with plasticity (see :doc:`../tutorials/advanced/06-synaptic-plasticity`).
+
+.. code-block:: python
+
+ # Projection with learnable weights
+ class PlasticProjection(brainstate.nn.Module):
+ def __init__(self, n_pre, n_post):
+ super().__init__()
+
+ # Initialize weights as parameters
+ self.weights = brainstate.ParamState(
+ jnp.ones((n_pre, n_post)) * 0.5 * u.mS
+ )
+
+ self.proj = bp.AlignPostProj(
+ comm=CustomComm(self.weights), # Use learnable weights
+ syn=bp.Expon.desc(n_post, tau=5.0*u.ms),
+ out=bp.CUBA.desc(),
+ post=post_neurons
+ )
+
+ def update_weights(self, dw):
+ """Update weights based on learning rule."""
+ self.weights.value += dw
+
+Best Practices
+--------------
+
+Choosing Communication Type
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Use EventFixedProb when:**
+- Large networks (>1000 neurons)
+- Sparse connectivity (<10%)
+- Biological models
+
+**Use Linear when:**
+- Small networks (<1000 neurons)
+- Fully connected layers
+- Training with gradients
+
+**Use EventOne2One when:**
+- Same-size populations
+- Feedforward pathways
+- Identity mappings
+
+Choosing Synapse Type
+~~~~~~~~~~~~~~~~~~~~~~
+
+**Use Expon when:**
+- Default choice for most models
+- Fast computation needed
+- Simple dynamics sufficient
+
+**Use Alpha when:**
+- Rise time is important
+- More biological realism
+- Smoother responses
+
+**Use AMPA/NMDA/GABA when:**
+- Specific receptor types matter
+- Pharmacological studies
+- Detailed biological models
+
+Choosing Output Type
+~~~~~~~~~~~~~~~~~~~~~
+
+**Use CUBA when:**
+- Abstract models
+- Training with gradients
+- Speed is critical
+
+**Use COBA when:**
+- Biological realism needed
+- Voltage dependence matters
+- Shunting inhibition required
+
+Performance Tips
+~~~~~~~~~~~~~~~~
+
+1. **Sparse over Dense:** Use sparse connectivity for large networks
+2. **Batch initialization:** Initialize all modules together
+3. **JIT compile:** Wrap simulation loop with ``@brainstate.compile.jit``
+4. **Appropriate precision:** Use float32 unless high precision needed
+5. **Minimize communication:** Group projections with same connectivity
+
+Common Patterns
+~~~~~~~~~~~~~~~
+
+**Pattern 1: Dale's Principle**
+
+Neurons are either excitatory OR inhibitory (not both).
+
+.. code-block:: python
+
+ # Separate excitatory and inhibitory populations
+ E = bp.LIF(800, ...) # Excitatory
+ I = bp.LIF(200, ...) # Inhibitory
+
+ # E always excitatory (E=0mV)
+ # I always inhibitory (E=-80mV)
+
+**Pattern 2: Balanced Networks**
+
+Excitation balanced by inhibition.
+
+.. code-block:: python
+
+ # Strong inhibition to balance excitation
+ w_exc = 0.6 * u.mS
+ w_inh = 6.7 * u.mS # ~10ร stronger
+
+ # More E neurons than I (4:1 ratio)
+ n_exc = 800
+ n_inh = 200
+
+**Pattern 3: Recurrent Loops**
+
+Self-connections for persistent activity.
+
+.. code-block:: python
+
+ # Excitatory recurrence (working memory)
+ E2E = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_exc, n_exc, prob=0.02, weight=0.5*u.mS),
+ syn=bp.Expon.desc(n_exc, tau=5*u.ms),
+ out=bp.COBA.desc(E=0*u.mV),
+ post=E
+ )
+
+Troubleshooting
+---------------
+
+Issue: Spikes not propagating
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptoms:** Postsynaptic neurons don't receive input
+
+**Solutions:**
+
+1. Check spike timing: Call ``get_spike()`` BEFORE updating
+2. Verify connectivity: Check ``prob`` and ``weight``
+3. Check update order: Projections before neurons
+
+.. code-block:: python
+
+ # CORRECT order
+ spk = pre.get_spike() # Get spikes from previous step
+ proj(spk) # Update projection
+ pre(inp) # Update neurons
+
+ # WRONG order
+ pre(inp) # Update first
+ spk = pre.get_spike() # Then get spikes (too late!)
+ proj(spk)
+
+Issue: Network silent or exploding
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptoms:** No activity or runaway firing
+
+**Solutions:**
+
+1. Balance E/I weights (I should be ~10ร stronger)
+2. Check reversal potentials (E=0mV, I=-80mV)
+3. Verify threshold and reset values
+4. Add external input
+
+.. code-block:: python
+
+ # Balanced weights
+ w_exc = 0.5 * u.mS
+ w_inh = 5.0 * u.mS # Strong inhibition
+
+ # Proper reversal potentials
+ out_exc = bp.COBA.desc(E=0.0 * u.mV)
+ out_inh = bp.COBA.desc(E=-80.0 * u.mV)
+
+Issue: Slow simulation
+~~~~~~~~~~~~~~~~~~~~~~
+
+**Solutions:**
+
+1. Use sparse connectivity (EventFixedProb)
+2. Use JIT compilation
+3. Use CUBA instead of COBA (if appropriate)
+4. Reduce connectivity or neurons
+
+.. code-block:: python
+
+ # Fast configuration
+ @brainstate.compile.jit
+ def simulate_step(net, inp):
+ return net(inp)
+
+ # Sparse connectivity
+ comm = brainstate.nn.EventFixedProb(1000, 1000, prob=0.02, ...)
+
+Further Reading
+---------------
+
+- :doc:`../tutorials/basic/03-network-connections` - Network connections tutorial
+- :doc:`architecture` - Overall BrainPy architecture
+- :doc:`synapses` - Detailed synapse models
+- :doc:`../tutorials/advanced/06-synaptic-plasticity` - Learning in projections
+- :doc:`../tutorials/advanced/07-large-scale-simulations` - Scaling projections
+
+Summary
+-------
+
+**Key takeaways:**
+
+โ
Projections use Comm-Syn-Out architecture
+
+โ
Communication: Dense (Linear) or Sparse (EventFixedProb)
+
+โ
Synapse: Temporal dynamics (Expon, Alpha, AMPA, GABA, NMDA)
+
+โ
Output: Current-based (CUBA) or Conductance-based (COBA)
+
+โ
Choose components based on scale, realism, and performance needs
+
+โ
Follow Dale's principle and balanced E/I patterns
+
+โ
Get spikes BEFORE updating for correct propagation
+
+**Quick reference:**
+
+.. code-block:: python
+
+ # Standard projection template
+ proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_pre, n_post, prob=0.1, weight=0.5*u.mS),
+ syn=bp.Expon.desc(n_post, tau=5.0*u.ms),
+ out=bp.COBA.desc(E=0.0*u.mV),
+ post=post_neurons
+ )
+
+ # Usage in network
+ def update(self):
+ spk = self.pre.get_spike() # Get spikes first
+ self.proj(spk) # Update projection
+ self.pre(inp) # Update neurons
+ self.post(0*u.nA)
diff --git a/docs_version3/core-concepts/state-management.rst b/docs_version3/core-concepts/state-management.rst
new file mode 100644
index 00000000..32385351
--- /dev/null
+++ b/docs_version3/core-concepts/state-management.rst
@@ -0,0 +1,1014 @@
+State Management: The Foundation of BrainPy 3.0
+===============================================
+
+State management is the core architectural change in BrainPy 3.0. Understanding states is essential for using BrainPy effectively. This guide provides comprehensive coverage of the state system built on ``brainstate``.
+
+.. contents:: Table of Contents
+ :local:
+ :depth: 2
+
+Overview
+--------
+
+What is State?
+~~~~~~~~~~~~~~
+
+**State** is any variable that persists across function calls and can change over time. In neural simulations:
+
+- Membrane potentials
+- Synaptic conductances
+- Spike trains
+- Learnable weights
+- Temporary buffers
+
+**Key insight:** BrainPy 3.0 makes states **explicit** rather than implicit. Every stateful variable is declared and tracked.
+
+Why Explicit State Management?
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Problems with implicit state (BrainPy 2.x):**
+
+- Hard to track what changes when
+- Difficult to serialize/checkpoint
+- Unclear initialization procedures
+- Conflicts with JAX functional programming
+
+**Benefits of explicit state (BrainPy 3.0):**
+
+โ
Clear variable lifecycle
+
+โ
Easy checkpointing and loading
+
+โ
Functional programming compatible
+
+โ
Better debugging and introspection
+
+โ
Automatic differentiation support
+
+โ
Type safety and validation
+
+The State Hierarchy
+~~~~~~~~~~~~~~~~~~~~
+
+BrainPy uses different state types for different purposes:
+
+.. code-block:: text
+
+ State (base class)
+ โ
+ โโโ ParamState โ Learnable parameters (weights, biases)
+ โโโ ShortTermState โ Temporary dynamics (V, g, spikes)
+ โโโ LongTermState โ Persistent but non-learnable (statistics)
+
+Each type has different semantics and handling:
+
+- **ParamState**: Updated by optimizers, saved in checkpoints
+- **ShortTermState**: Reset each trial, not saved
+- **LongTermState**: Saved but not trained
+
+State Types
+-----------
+
+ParamState: Learnable Parameters
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Use for:** Weights, biases, trainable parameters
+
+**Characteristics:**
+
+- Updated by gradient descent
+- Saved in model checkpoints
+- Persistent across trials
+- Registered with optimizers
+
+**Example:**
+
+.. code-block:: python
+
+ import brainstate
+ import jax.numpy as jnp
+
+ class LinearLayer(brainstate.nn.Module):
+ def __init__(self, in_size, out_size):
+ super().__init__()
+
+ # Learnable weight matrix
+ self.W = brainstate.ParamState(
+ brainstate.random.randn(in_size, out_size) * 0.01
+ )
+
+ # Learnable bias vector
+ self.b = brainstate.ParamState(
+ jnp.zeros(out_size)
+ )
+
+ def update(self, x):
+ # Use parameters in computation
+ return jnp.dot(x, self.W.value) + self.b.value
+
+ # Access all parameters
+ layer = LinearLayer(100, 50)
+ params = layer.states(brainstate.ParamState)
+ # Returns: {'W': ParamState(...), 'b': ParamState(...)}
+
+**Common uses:**
+
+- Synaptic weights
+- Neural biases
+- Time constants (if learning them)
+- Connectivity matrices (if plastic)
+
+ShortTermState: Temporary Dynamics
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Use for:** Variables that reset each trial
+
+**Characteristics:**
+
+- Reset at trial start
+- Not saved in checkpoints
+- Represent current dynamics
+- Fastest state type
+
+**Example:**
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainunit as u
+
+ class LIFNeuron(brainstate.nn.Module):
+ def __init__(self, size):
+ super().__init__()
+
+ self.size = size
+ self.V_rest = -65.0 * u.mV
+ self.V_th = -50.0 * u.mV
+
+ # Membrane potential (resets each trial)
+ self.V = brainstate.ShortTermState(
+ jnp.ones(size) * self.V_rest.to_decimal(u.mV)
+ )
+
+ # Spike indicator (resets each trial)
+ self.spike = brainstate.ShortTermState(
+ jnp.zeros(size)
+ )
+
+ def reset_state(self, batch_size=None):
+ """Called at trial start."""
+ if batch_size is None:
+ self.V.value = jnp.ones(self.size) * self.V_rest.to_decimal(u.mV)
+ self.spike.value = jnp.zeros(self.size)
+ else:
+ self.V.value = jnp.ones((batch_size, self.size)) * self.V_rest.to_decimal(u.mV)
+ self.spike.value = jnp.zeros((batch_size, self.size))
+
+ def update(self, I):
+ # Update membrane potential
+ # ... (LIF dynamics)
+ self.V.value = new_V
+ self.spike.value = new_spike
+
+**Common uses:**
+
+- Membrane potentials
+- Synaptic conductances
+- Spike indicators
+- Refractory counters
+- Temporary buffers
+
+LongTermState: Persistent Non-Learnable
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Use for:** Statistics, counters, persistent metadata
+
+**Characteristics:**
+
+- Not reset each trial
+- Saved in checkpoints
+- Not updated by optimizers
+- Accumulates over time
+
+**Example:**
+
+.. code-block:: python
+
+ class NeuronWithStatistics(brainstate.nn.Module):
+ def __init__(self, size):
+ super().__init__()
+
+ self.V = brainstate.ShortTermState(jnp.zeros(size))
+
+ # Running spike count (persists across trials)
+ self.total_spikes = brainstate.LongTermState(
+ jnp.zeros(size, dtype=jnp.int32)
+ )
+
+ # Running average firing rate
+ self.avg_rate = brainstate.LongTermState(
+ jnp.zeros(size)
+ )
+
+ def update(self, I):
+ # ... update dynamics ...
+
+ # Accumulate statistics
+ self.total_spikes.value += self.spike.value.astype(jnp.int32)
+
+**Common uses:**
+
+- Spike counters
+- Running averages
+- Homeostatic variables
+- Simulation metadata
+- Custom statistics
+
+State Initialization
+--------------------
+
+Automatic Initialization
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+BrainPy provides ``init_all_states()`` for automatic initialization.
+
+**Basic usage:**
+
+.. code-block:: python
+
+ import brainstate
+
+ # Create network
+ net = MyNetwork()
+
+ # Initialize all states (single trial)
+ brainstate.nn.init_all_states(net)
+
+ # Initialize with batch dimension
+ brainstate.nn.init_all_states(net, batch_size=32)
+
+**What it does:**
+
+1. Finds all modules in the hierarchy
+2. Calls ``reset_state()`` on each module
+3. Handles nested structures automatically
+4. Sets up batch dimensions if requested
+
+**Example with network:**
+
+.. code-block:: python
+
+ class EINetwork(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.E = bp.LIF(800, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ self.I = bp.LIF(200, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ # ... projections ...
+
+ net = EINetwork()
+
+ # This initializes E, I, and all projections
+ brainstate.nn.init_all_states(net, batch_size=10)
+
+Manual Initialization
+~~~~~~~~~~~~~~~~~~~~~
+
+For custom initialization, override ``reset_state()``.
+
+.. code-block:: python
+
+ class CustomNeuron(brainstate.nn.Module):
+ def __init__(self, size, V_init_range=(-70, -60)):
+ super().__init__()
+ self.size = size
+ self.V_init_range = V_init_range
+
+ self.V = brainstate.ShortTermState(jnp.zeros(size))
+
+ def reset_state(self, batch_size=None):
+ """Custom initialization: random voltage in range."""
+
+ # Generate random initial voltages
+ low, high = self.V_init_range
+ if batch_size is None:
+ init_V = brainstate.random.uniform(low, high, size=self.size)
+ else:
+ init_V = brainstate.random.uniform(low, high, size=(batch_size, self.size))
+
+ self.V.value = init_V
+
+**Best practices:**
+
+- Always check ``batch_size`` parameter
+- Handle both single and batched cases
+- Initialize all ShortTermStates
+- Don't initialize ParamStates (they're learnable)
+- Don't initialize LongTermStates (they persist)
+
+Initializers for Parameters
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Use ``brainstate.init`` for parameter initialization.
+
+.. code-block:: python
+
+ import brainstate.init as init
+
+ class Network(brainstate.nn.Module):
+ def __init__(self, in_size, out_size):
+ super().__init__()
+
+ # Xavier/Glorot initialization
+ self.W1 = brainstate.ParamState(
+ init.XavierNormal()(shape=(in_size, 100))
+ )
+
+ # Kaiming/He initialization (for ReLU)
+ self.W2 = brainstate.ParamState(
+ init.KaimingNormal()(shape=(100, out_size))
+ )
+
+ # Zero initialization
+ self.b = brainstate.ParamState(
+ init.Constant(0.0)(shape=(out_size,))
+ )
+
+ # Orthogonal initialization (for RNNs)
+ self.W_rec = brainstate.ParamState(
+ init.Orthogonal()(shape=(100, 100))
+ )
+
+**Available initializers:**
+
+- ``Constant(value)`` - Fill with constant
+- ``Normal(mean, std)`` - Gaussian distribution
+- ``Uniform(low, high)`` - Uniform distribution
+- ``XavierNormal()`` - Xavier/Glorot normal
+- ``XavierUniform()`` - Xavier/Glorot uniform
+- ``KaimingNormal()`` - He normal (for ReLU)
+- ``KaimingUniform()`` - He uniform
+- ``Orthogonal()`` - Orthogonal matrix (for RNNs)
+- ``Identity()`` - Identity matrix
+
+State Access and Manipulation
+------------------------------
+
+Reading State Values
+~~~~~~~~~~~~~~~~~~~~
+
+Access the current value with ``.value``.
+
+.. code-block:: python
+
+ neuron = bp.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ brainstate.nn.init_all_states(neuron)
+
+ # Read current membrane potential
+ current_V = neuron.V.value
+
+ # Read shape
+ print(current_V.shape) # (100,)
+
+ # Read specific neurons
+ V_neuron_0 = neuron.V.value[0]
+
+Writing State Values
+~~~~~~~~~~~~~~~~~~~~
+
+Update state by assigning to ``.value``.
+
+.. code-block:: python
+
+ # Set new value (entire array)
+ neuron.V.value = jnp.ones(100) * -60.0
+
+ # Update subset
+ neuron.V.value = neuron.V.value.at[0:10].set(-55.0)
+
+ # Increment
+ neuron.V.value = neuron.V.value + 0.1
+
+**Important:** Always assign to ``.value``, not the state object itself!
+
+.. code-block:: python
+
+ # CORRECT
+ neuron.V.value = new_V
+
+ # WRONG (creates new object, doesn't update state)
+ neuron.V = new_V
+
+Collecting States
+~~~~~~~~~~~~~~~~~
+
+Get all states of a specific type from a module.
+
+.. code-block:: python
+
+ # Get all parameters
+ params = net.states(brainstate.ParamState)
+ # Returns: dict with parameter names as keys
+
+ # Get all short-term states
+ short_term = net.states(brainstate.ShortTermState)
+
+ # Get all states (any type)
+ all_states = net.states()
+
+**Example:**
+
+.. code-block:: python
+
+ class SimpleNet(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.W = brainstate.ParamState(jnp.ones((10, 10)))
+ self.V = brainstate.ShortTermState(jnp.zeros(10))
+
+ net = SimpleNet()
+
+ params = net.states(brainstate.ParamState)
+ # {'W': ParamState(...)}
+
+ states = net.states(brainstate.ShortTermState)
+ # {'V': ShortTermState(...)}
+
+State in Training
+-----------------
+
+Gradient Computation
+~~~~~~~~~~~~~~~~~~~~
+
+Use ``brainstate.transform.grad()`` to compute gradients w.r.t. parameters.
+
+.. code-block:: python
+
+ def loss_fn(params, net, X, y):
+ """Loss function parameterized by params."""
+ # params is automatically used by net
+ output = net(X)
+ return jnp.mean((output - y) ** 2)
+
+ # Get parameters
+ params = net.states(brainstate.ParamState)
+
+ # Compute gradients
+ grads = brainstate.transform.grad(loss_fn, params)(net, X, y)
+
+ # grads has same structure as params
+ # grads = {'W': gradient_for_W, 'b': gradient_for_b, ...}
+
+**Key points:**
+
+- Gradients computed only for ParamState
+- ShortTermState treated as constants
+- Gradient structure matches parameter structure
+
+Optimizer Updates
+~~~~~~~~~~~~~~~~~
+
+Register parameters with optimizer and update.
+
+.. code-block:: python
+
+ import braintools
+
+ # Create optimizer
+ optimizer = braintools.optim.Adam(learning_rate=1e-3)
+
+ # Register trainable parameters
+ params = net.states(brainstate.ParamState)
+ optimizer.register_trainable_weights(params)
+
+ # Training loop
+ for epoch in range(num_epochs):
+ for batch in data_loader:
+ X, y = batch
+
+ # Compute gradients
+ grads = brainstate.transform.grad(
+ loss_fn,
+ params,
+ return_value=False
+ )(net, X, y)
+
+ # Update parameters
+ optimizer.update(grads)
+
+**The optimizer automatically:**
+
+- Updates all registered parameters
+- Applies learning rate
+- Handles momentum/adaptive rates
+- Maintains optimizer state (momentum buffers, etc.)
+
+State Persistence
+~~~~~~~~~~~~~~~~~
+
+Training doesn't reset ShortTermState between batches (unless you do it manually).
+
+.. code-block:: python
+
+ # Training with state reset each example
+ for X, y in data_loader:
+ # Reset dynamics for new example
+ brainstate.nn.init_all_states(net)
+
+ # Forward pass (dynamics evolve)
+ output = net(X)
+
+ # Backward pass
+ grads = compute_grads(...)
+ optimizer.update(grads)
+
+ # Training with persistent state (e.g., RNN)
+ for X, y in data_loader:
+ # Don't reset - state carries over
+ output = net(X)
+ grads = compute_grads(...)
+ optimizer.update(grads)
+
+Batching
+--------
+
+Batch Dimensions
+~~~~~~~~~~~~~~~~
+
+States can have a batch dimension for parallel trials.
+
+**Single trial:**
+
+.. code-block:: python
+
+ neuron = bp.LIF(100, ...) # 100 neurons
+ brainstate.nn.init_all_states(neuron)
+ # neuron.V.value.shape = (100,)
+
+**Batched trials:**
+
+.. code-block:: python
+
+ neuron = bp.LIF(100, ...) # 100 neurons
+ brainstate.nn.init_all_states(neuron, batch_size=32)
+ # neuron.V.value.shape = (32, 100)
+
+**Usage:**
+
+.. code-block:: python
+
+ # Input also needs batch dimension
+ inp = brainstate.random.rand(32, 100) * 2.0 * u.nA
+
+ # Update operates on all batches in parallel
+ neuron(inp)
+
+ # Output has batch dimension
+ spikes = neuron.get_spike() # shape: (32, 100)
+
+Benefits of Batching
+~~~~~~~~~~~~~~~~~~~~
+
+**1. Parallelism:** GPU processes all batches simultaneously
+
+**2. Statistical averaging:** Reduce noise in gradients
+
+**3. Exploration:** Try different initial conditions
+
+**4. Efficiency:** Amortize compilation cost
+
+**Example: Parameter sweep with batching**
+
+.. code-block:: python
+
+ # Test 10 different input currents in parallel
+ batch_size = 10
+ neuron = bp.LIF(100, ...)
+ brainstate.nn.init_all_states(neuron, batch_size=batch_size)
+
+ # Different input for each batch
+ currents = jnp.linspace(0, 5, batch_size).reshape(-1, 1) * u.nA
+ inp = jnp.broadcast_to(currents, (batch_size, 100))
+
+ # Simulate
+ for _ in range(1000):
+ neuron(inp)
+
+ # Analyze each trial separately
+ spike_counts = jnp.sum(neuron.spike.value, axis=1) # (10,)
+
+Checkpointing and Serialization
+--------------------------------
+
+Saving Models
+~~~~~~~~~~~~~
+
+Save model state to disk.
+
+.. code-block:: python
+
+ import pickle
+
+ # Get all states to save
+ state_dict = {
+ 'params': net.states(brainstate.ParamState),
+ 'long_term': net.states(brainstate.LongTermState),
+ 'epoch': current_epoch,
+ 'optimizer_state': optimizer.state_dict() # If applicable
+ }
+
+ # Save to file
+ with open('checkpoint.pkl', 'wb') as f:
+ pickle.dump(state_dict, f)
+
+**Note:** Don't save ShortTermState (it resets each trial).
+
+Loading Models
+~~~~~~~~~~~~~~
+
+Restore model state from disk.
+
+.. code-block:: python
+
+ # Load checkpoint
+ with open('checkpoint.pkl', 'rb') as f:
+ state_dict = pickle.load(f)
+
+ # Create fresh model
+ net = MyNetwork()
+ brainstate.nn.init_all_states(net)
+
+ # Restore parameters
+ params = state_dict['params']
+ for name, param_state in params.items():
+ # Find corresponding parameter in net
+ # and copy value
+ net_params = net.states(brainstate.ParamState)
+ if name in net_params:
+ net_params[name].value = param_state.value
+
+ # Restore long-term states similarly
+ # ...
+
+ # Restore optimizer if continuing training
+ optimizer.load_state_dict(state_dict['optimizer_state'])
+
+Best Practices for Checkpointing
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**1. Save regularly during training**
+
+.. code-block:: python
+
+ if epoch % save_interval == 0:
+ save_checkpoint(net, optimizer, epoch, path)
+
+**2. Keep multiple checkpoints**
+
+.. code-block:: python
+
+ # Save with epoch number
+ save_path = f'checkpoint_epoch_{epoch}.pkl'
+
+**3. Save best model separately**
+
+.. code-block:: python
+
+ if val_loss < best_val_loss:
+ best_val_loss = val_loss
+ save_checkpoint(net, optimizer, epoch, 'best_model.pkl')
+
+**4. Include metadata**
+
+.. code-block:: python
+
+ state_dict = {
+ 'params': ...,
+ 'epoch': epoch,
+ 'best_val_loss': best_val_loss,
+ 'config': model_config, # Hyperparameters
+ 'timestamp': datetime.now()
+ }
+
+Common Patterns
+---------------
+
+Pattern 1: Resetting Between Trials
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Simulate multiple trials
+ for trial in range(num_trials):
+ # Reset dynamics
+ brainstate.nn.init_all_states(net)
+
+ # Run trial
+ for t in range(trial_length):
+ inp = get_input(trial, t)
+ output = net(inp)
+ record(output)
+
+Pattern 2: Accumulating Statistics
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class NeuronWithStats(brainstate.nn.Module):
+ def __init__(self, size):
+ super().__init__()
+ self.V = brainstate.ShortTermState(jnp.zeros(size))
+
+ # Accumulate across trials
+ self.total_spikes = brainstate.LongTermState(
+ jnp.zeros(size, dtype=jnp.int32)
+ )
+ self.n_steps = brainstate.LongTermState(0)
+
+ def update(self, I):
+ # ... dynamics ...
+
+ # Accumulate
+ self.total_spikes.value += self.spike.value.astype(jnp.int32)
+ self.n_steps.value += 1
+
+ def get_firing_rate(self):
+ """Average firing rate across all trials."""
+ dt = brainstate.environ.get_dt()
+ total_time = self.n_steps.value * dt.to_decimal(u.second)
+ return self.total_spikes.value / total_time
+
+Pattern 3: Conditional Updates
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class AdaptiveNeuron(brainstate.nn.Module):
+ def __init__(self, size):
+ super().__init__()
+ self.V = brainstate.ShortTermState(jnp.zeros(size))
+ self.threshold = brainstate.ParamState(jnp.ones(size) * (-50.0))
+
+ def update(self, I):
+ # Dynamics
+ # ...
+
+ # Homeostatic threshold adaptation
+ spike_rate = compute_spike_rate(self.spike.value)
+
+ # Adjust threshold based on activity
+ target_rate = 5.0 # Hz
+ adjustment = 0.01 * (spike_rate - target_rate)
+
+ # Update learnable threshold
+ self.threshold.value -= adjustment
+
+Pattern 4: Hierarchical States
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class HierarchicalNetwork(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ # Submodules have their own states
+ self.layer1 = MyLayer(100, 50)
+ self.layer2 = MyLayer(50, 10)
+
+ def update(self, x):
+ # Each layer manages its own states
+ h1 = self.layer1(x)
+ h2 = self.layer2(h1)
+ return h2
+
+ net = HierarchicalNetwork()
+
+ # Collect ALL states from hierarchy
+ all_params = net.states(brainstate.ParamState)
+ # Includes params from layer1 AND layer2
+
+ # Initialize ALL states in hierarchy
+ brainstate.nn.init_all_states(net)
+ # Calls reset_state() on net, layer1, and layer2
+
+Advanced Topics
+---------------
+
+Custom State Types
+~~~~~~~~~~~~~~~~~~
+
+Create custom state types for specialized needs.
+
+.. code-block:: python
+
+ class RandomState(brainstate.State):
+ """State that re-randomizes on reset."""
+
+ def __init__(self, shape, low=0.0, high=1.0):
+ super().__init__(jnp.zeros(shape))
+ self.shape = shape
+ self.low = low
+ self.high = high
+
+ def reset(self):
+ """Re-randomize on reset."""
+ self.value = brainstate.random.uniform(
+ self.low, self.high, size=self.shape
+ )
+
+State Sharing
+~~~~~~~~~~~~~
+
+Share state between modules (use with caution).
+
+.. code-block:: python
+
+ class SharedState(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ # Shared weight matrix
+ shared_W = brainstate.ParamState(jnp.ones((100, 100)))
+
+ self.module1 = ModuleA(shared_W)
+ self.module2 = ModuleB(shared_W)
+
+ # module1 and module2 both modify the same weights
+
+**When to use:** Siamese networks, weight tying, parameter sharing
+
+**Caution:** Makes dependencies implicit, harder to debug
+
+State Inspection
+~~~~~~~~~~~~~~~~
+
+Debug by inspecting state values.
+
+.. code-block:: python
+
+ # Print all parameter shapes
+ params = net.states(brainstate.ParamState)
+ for name, state in params.items():
+ print(f"{name}: {state.value.shape}")
+
+ # Check for NaN values
+ for name, state in params.items():
+ if jnp.any(jnp.isnan(state.value)):
+ print(f"NaN detected in {name}!")
+
+ # Compute statistics
+ V_values = neuron.V.value
+ print(f"V range: [{V_values.min():.2f}, {V_values.max():.2f}]")
+ print(f"V mean: {V_values.mean():.2f}")
+
+Troubleshooting
+---------------
+
+Issue: States not updating
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptoms:** Values stay constant
+
+**Solutions:**
+
+1. Assign to ``.value``, not the state itself
+2. Check you're updating the right variable
+3. Verify update function is called
+
+.. code-block:: python
+
+ # WRONG
+ self.V = new_V # Creates new object!
+
+ # CORRECT
+ self.V.value = new_V # Updates state
+
+Issue: Batch dimension errors
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptoms:** Shape mismatch errors
+
+**Solutions:**
+
+1. Initialize with ``batch_size`` parameter
+2. Ensure inputs have batch dimension
+3. Check ``reset_state()`` handles batching
+
+.. code-block:: python
+
+ # Initialize with batching
+ brainstate.nn.init_all_states(net, batch_size=32)
+
+ # Input needs batch dimension
+ inp = jnp.zeros((32, 100)) # (batch, neurons)
+
+Issue: Gradients are None
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptoms:** No gradients for parameters
+
+**Solutions:**
+
+1. Ensure parameters are ``ParamState``
+2. Check parameters are used in loss computation
+3. Verify gradient function call
+
+.. code-block:: python
+
+ # Parameters must be ParamState
+ self.W = brainstate.ParamState(init_W) # Correct
+
+ # Compute gradients for parameters only
+ params = net.states(brainstate.ParamState)
+ grads = brainstate.transform.grad(loss_fn, params)(...)
+
+Issue: Memory leak during training
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptoms:** Memory grows over time
+
+**Solutions:**
+
+1. Don't accumulate history in Python lists
+2. Clear unnecessary references
+3. Use ``jnp.array`` operations (not Python append)
+
+.. code-block:: python
+
+ # BAD - accumulates in Python memory
+ history = []
+ for t in range(10000):
+ output = net(inp)
+ history.append(output) # Memory leak!
+
+ # GOOD - use fixed-size buffer or don't store
+ for t in range(10000):
+ output = net(inp)
+ # Process immediately, don't store
+
+Further Reading
+---------------
+
+- :doc:`architecture` - Overall BrainPy architecture
+- :doc:`neurons` - Neuron models and their states
+- :doc:`synapses` - Synapse models and their states
+- :doc:`../tutorials/advanced/05-snn-training` - Training with states
+- BrainState documentation: https://brainstate.readthedocs.io/
+
+Summary
+-------
+
+**Key takeaways:**
+
+โ
**Three state types:**
+ - ``ParamState``: Learnable parameters
+ - ``ShortTermState``: Temporary dynamics
+ - ``LongTermState``: Persistent statistics
+
+โ
**Initialization:**
+ - Use ``brainstate.nn.init_all_states(module)``
+ - Implement ``reset_state()`` for custom logic
+ - Handle batch dimensions
+
+โ
**Access:**
+ - Read/write with ``.value``
+ - Collect with ``.states(StateType)``
+ - Never assign to state object directly
+
+โ
**Training:**
+ - Gradients computed for ``ParamState``
+ - Register with optimizer
+ - Update with ``optimizer.update(grads)``
+
+โ
**Checkpointing:**
+ - Save ``ParamState`` and ``LongTermState``
+ - Don't save ``ShortTermState``
+ - Include metadata and optimizer state
+
+**Quick reference:**
+
+.. code-block:: python
+
+ # Define states
+ class MyModule(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.W = brainstate.ParamState(init_W) # Learnable
+ self.V = brainstate.ShortTermState(init_V) # Resets
+ self.count = brainstate.LongTermState(init_c) # Persists
+
+ def reset_state(self, batch_size=None):
+ """Initialize ShortTermState."""
+ shape = self.size if batch_size is None else (batch_size, self.size)
+ self.V.value = jnp.zeros(shape)
+
+ # Initialize
+ brainstate.nn.init_all_states(module, batch_size=32)
+
+ # Access
+ params = module.states(brainstate.ParamState)
+ module.V.value = new_V
+
+ # Train
+ grads = brainstate.transform.grad(loss, params)(...)
+ optimizer.update(grads)
diff --git a/docs_version3/core-concepts/synapses.rst b/docs_version3/core-concepts/synapses.rst
new file mode 100644
index 00000000..9f36d4d2
--- /dev/null
+++ b/docs_version3/core-concepts/synapses.rst
@@ -0,0 +1,642 @@
+Synapses
+========
+
+Synapses model the temporal dynamics of neural connections in BrainPy 3.0. This document explains how synapses work, what models are available, and how to use them effectively.
+
+Overview
+--------
+
+Synapses provide temporal filtering of spike trains, transforming discrete spikes into continuous currents or conductances. They model:
+
+- **Postsynaptic potentials** (PSPs)
+- **Temporal integration** of spike trains
+- **Synaptic dynamics** (rise and decay)
+
+In BrainPy's architecture, synapses are part of the projection system:
+
+.. code-block:: text
+
+ Spikes โ [Connectivity] โ [Synapse] โ [Output] โ Neurons
+ โ
+ Temporal filtering
+
+Basic Usage
+-----------
+
+Creating Synapses
+~~~~~~~~~~~~~~~~~
+
+Synapses are typically created as part of projections:
+
+.. code-block:: python
+
+ import brainpy
+ import brainunit as u
+
+ # Create synapse descriptor
+ syn = brainpy.Expon.desc(
+ size=100, # Number of synapses
+ tau=5. * u.ms # Time constant
+ )
+
+ # Use in projection
+ projection = brainpy.AlignPostProj(
+ comm=...,
+ syn=syn, # Synapse here
+ out=...,
+ post=neurons
+ )
+
+Synapse Lifecycle
+~~~~~~~~~~~~~~~~~
+
+1. **Creation**: Define synapse with `.desc()` method
+2. **Integration**: Include in projection
+3. **Update**: Called automatically by projection
+4. **Access**: Read synaptic variables as needed
+
+.. code-block:: python
+
+ # During simulation
+ projection(presynaptic_spikes) # Updates synapse internally
+
+ # Access synaptic variable
+ synaptic_current = projection.syn.g.value
+
+Available Synapse Models
+------------------------
+
+Expon (Single Exponential)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The simplest and most commonly used synapse model.
+
+**Mathematical Model:**
+
+.. math::
+
+ \\tau \\frac{dg}{dt} = -g
+
+When spike arrives: :math:`g \\leftarrow g + 1`
+
+**Impulse Response:**
+
+.. math::
+
+ g(t) = \\exp(-t/\\tau)
+
+**Example:**
+
+.. code-block:: python
+
+ syn = brainpy.Expon.desc(
+ size=100,
+ tau=5. * u.ms,
+ g_initializer=braintools.init.Constant(0. * u.mS)
+ )
+
+**Parameters:**
+
+- ``size``: Number of synapses
+- ``tau``: Decay time constant
+- ``g_initializer``: Initial synaptic variable (optional)
+
+**Key Features:**
+
+- Single time constant
+- Fast computation
+- Instantaneous rise
+
+**Use cases:**
+
+- General-purpose modeling
+- Fast simulations
+- When precise kinetics are not critical
+
+**Behavior:**
+
+.. code-block:: python
+
+ # Response to single spike at t=0
+ # g(t) = exp(-t/ฯ)
+ # Fast rise, exponential decay
+
+Alpha Synapse
+~~~~~~~~~~~~~
+
+A more realistic model with non-instantaneous rise time.
+
+**Mathematical Model:**
+
+.. math::
+
+ \\tau \\frac{dh}{dt} &= -h
+
+ \\tau \\frac{dg}{dt} &= -g + h
+
+When spike arrives: :math:`h \\leftarrow h + 1`
+
+**Impulse Response:**
+
+.. math::
+
+ g(t) = \\frac{t}{\\tau} \\exp(-t/\\tau)
+
+**Example:**
+
+.. code-block:: python
+
+ syn = brainpy.Alpha.desc(
+ size=100,
+ tau=5. * u.ms,
+ g_initializer=braintools.init.Constant(0. * u.mS)
+ )
+
+**Parameters:**
+
+Same as Expon, but produces alpha-shaped response.
+
+**Key Features:**
+
+- Smooth rise and fall
+- Biologically realistic
+- Peak at t = ฯ
+
+**Use cases:**
+
+- Biological realism
+- Detailed cortical modeling
+- When kinetics matter
+
+**Behavior:**
+
+.. code-block:: python
+
+ # Response to single spike at t=0
+ # g(t) = (t/ฯ) * exp(-t/ฯ)
+ # Gradual rise to peak at ฯ, then decay
+
+AMPA (Excitatory)
+~~~~~~~~~~~~~~~~~
+
+Models AMPA receptor dynamics for excitatory synapses.
+
+**Mathematical Model:**
+
+Similar to Alpha, but with parameters tuned for AMPA receptors.
+
+**Example:**
+
+.. code-block:: python
+
+ syn = brainpy.AMPA.desc(
+ size=100,
+ tau=2. * u.ms, # Fast AMPA kinetics
+ g_initializer=braintools.init.Constant(0. * u.mS)
+ )
+
+**Key Features:**
+
+- Fast kinetics (ฯ โ 2 ms)
+- Excitatory receptor
+- Biologically parameterized
+
+**Use cases:**
+
+- Excitatory synapses
+- Cortical pyramidal neurons
+- Biological realism
+
+GABAa (Inhibitory)
+~~~~~~~~~~~~~~~~~~
+
+Models GABAa receptor dynamics for inhibitory synapses.
+
+**Mathematical Model:**
+
+Similar to Alpha, but with parameters tuned for GABAa receptors.
+
+**Example:**
+
+.. code-block:: python
+
+ syn = brainpy.GABAa.desc(
+ size=100,
+ tau=10. * u.ms, # Slower GABAa kinetics
+ g_initializer=braintools.init.Constant(0. * u.mS)
+ )
+
+**Key Features:**
+
+- Slower kinetics (ฯ โ 10 ms)
+- Inhibitory receptor
+- Biologically parameterized
+
+**Use cases:**
+
+- Inhibitory synapses
+- GABAergic interneurons
+- Biological realism
+
+Synaptic Variables
+------------------
+
+The Descriptor Pattern
+~~~~~~~~~~~~~~~~~~~~~~~
+
+BrainPy synapses use a descriptor pattern:
+
+.. code-block:: python
+
+ # Create descriptor (not yet instantiated)
+ syn_desc = brainpy.Expon.desc(size=100, tau=5*u.ms)
+
+ # Instantiated within projection
+ projection = brainpy.AlignPostProj(..., syn=syn_desc, ...)
+
+ # Access instantiated synapse
+ actual_synapse = projection.syn
+ g_value = actual_synapse.g.value
+
+Why Descriptors?
+~~~~~~~~~~~~~~~~
+
+- **Deferred instantiation**: Created when needed
+- **Reusability**: Same descriptor for multiple projections
+- **Flexibility**: Configure before instantiation
+
+Accessing Synaptic State
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Within projection
+ projection = brainpy.AlignPostProj(
+ comm=...,
+ syn=brainpy.Expon.desc(100, tau=5*u.ms),
+ out=...,
+ post=neurons
+ )
+
+ # After simulation step
+ synaptic_var = projection.syn.g.value # Current value with units
+
+ # Convert to array for plotting
+ g_array = synaptic_var.to_decimal(u.mS)
+
+Synaptic Dynamics Visualization
+--------------------------------
+
+Comparing Different Models
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+ import matplotlib.pyplot as plt
+ import jax.numpy as jnp
+
+ brainstate.environ.set(dt=0.1 * u.ms)
+
+ # Create different synapses
+ expon = brainpy.Expon(100, tau=5*u.ms)
+ alpha = brainpy.Alpha(100, tau=5*u.ms)
+ ampa = brainpy.AMPA(100, tau=2*u.ms)
+ gaba = brainpy.GABAa(100, tau=10*u.ms)
+
+ # Initialize
+ for syn in [expon, alpha, ampa, gaba]:
+ brainstate.nn.init_all_states(syn)
+
+ # Single spike at t=0
+ spike_input = jnp.zeros(100)
+ spike_input = spike_input.at[0].set(1.0)
+
+ # Simulate
+ times = u.math.arange(0*u.ms, 50*u.ms, 0.1*u.ms)
+ responses = {
+ 'Expon': [],
+ 'Alpha': [],
+ 'AMPA': [],
+ 'GABAa': []
+ }
+
+ for syn, name in zip([expon, alpha, ampa, gaba],
+ ['Expon', 'Alpha', 'AMPA', 'GABAa']):
+ brainstate.nn.init_all_states(syn)
+ for i, t in enumerate(times):
+ if i == 0:
+ syn(spike_input)
+ else:
+ syn(jnp.zeros(100))
+ responses[name].append(syn.g.value[0])
+
+ # Plot
+ plt.figure(figsize=(10, 6))
+ for name, response in responses.items():
+ response_array = u.math.asarray(response)
+ plt.plot(times.to_decimal(u.ms),
+ response_array.to_decimal(u.mS),
+ label=name, linewidth=2)
+
+ plt.xlabel('Time (ms)')
+ plt.ylabel('Synaptic Variable (mS)')
+ plt.title('Comparison of Synapse Models (Single Spike)')
+ plt.legend()
+ plt.grid(True, alpha=0.3)
+ plt.show()
+
+Integration with Projections
+-----------------------------
+
+Complete Example
+~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+
+ # Create neurons
+ pre_neurons = brainpy.LIF(80, V_th=-50*u.mV, tau=10*u.ms)
+ post_neurons = brainpy.LIF(100, V_th=-50*u.mV, tau=10*u.ms)
+
+ # Create projection with exponential synapse
+ projection = brainpy.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(
+ 80, 100, prob=0.1, weight=0.5*u.mS
+ ),
+ syn=brainpy.Expon.desc(100, tau=5*u.ms),
+ out=brainpy.CUBA.desc(),
+ post=post_neurons
+ )
+
+ # Initialize
+ brainstate.nn.init_all_states(pre_neurons)
+ brainstate.nn.init_all_states(post_neurons)
+
+ # Simulation
+ def update(input_current):
+ # Update presynaptic neurons
+ pre_neurons(input_current)
+
+ # Get spikes and propagate through projection
+ spikes = pre_neurons.get_spike()
+ projection(spikes)
+
+ # Update postsynaptic neurons
+ post_neurons(0 * u.nA)
+
+ return post_neurons.get_spike()
+
+ # Run
+ times = u.math.arange(0*u.ms, 100*u.ms, 0.1*u.ms)
+ results = brainstate.transform.for_loop(
+ lambda t: update(2*u.nA),
+ times
+ )
+
+Short-Term Plasticity
+---------------------
+
+Synapses can be combined with short-term plasticity (STP):
+
+.. code-block:: python
+
+ # Create projection with STP
+ projection = brainpy.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(80, 100, prob=0.1, weight=0.5*u.mS),
+ syn=brainpy.STP.desc(
+ brainpy.Expon.desc(100, tau=5*u.ms), # Underlying synapse
+ tau_f=200*u.ms, # Facilitation time constant
+ tau_d=150*u.ms, # Depression time constant
+ U=0.2 # Utilization of synaptic efficacy
+ ),
+ out=brainpy.CUBA.desc(),
+ post=post_neurons
+ )
+
+See :doc:`plasticity` for more details on STP.
+
+Custom Synapses
+---------------
+
+Creating Custom Synapse Models
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+You can create custom synapse models by inheriting from ``Synapse``:
+
+.. code-block:: python
+
+ import brainstate
+ from brainpy._base import Synapse
+
+ class MyCustomSynapse(Synapse):
+ def __init__(self, size, tau1, tau2, **kwargs):
+ super().__init__(size, **kwargs)
+
+ self.tau1 = tau1
+ self.tau2 = tau2
+
+ # Synaptic variable
+ self.g = brainstate.ShortTermState(
+ braintools.init.Constant(0., unit=u.mS)(size)
+ )
+
+ def update(self, spike_input):
+ dt = brainstate.environ.get_dt()
+
+ # Custom dynamics (double exponential)
+ dg = (-self.g.value / self.tau1 +
+ spike_input / self.tau2)
+ self.g.value = self.g.value + dg * dt
+
+ return self.g.value
+
+ @classmethod
+ def desc(cls, size, tau1, tau2, **kwargs):
+ """Descriptor for deferred instantiation."""
+ def create():
+ return cls(size, tau1, tau2, **kwargs)
+ return create
+
+Usage:
+
+.. code-block:: python
+
+ # Create descriptor
+ syn_desc = MyCustomSynapse.desc(
+ size=100,
+ tau1=5*u.ms,
+ tau2=10*u.ms
+ )
+
+ # Use in projection
+ projection = brainpy.AlignPostProj(..., syn=syn_desc, ...)
+
+Choosing the Right Synapse
+---------------------------
+
+Decision Guide
+~~~~~~~~~~~~~~
+
+.. list-table::
+ :header-rows: 1
+ :widths: 20 30 25 25
+
+ * - Model
+ - When to Use
+ - Pros
+ - Cons
+ * - Expon
+ - General purpose, speed
+ - Fast, simple
+ - Unrealistic rise
+ * - Alpha
+ - Biological realism
+ - Realistic kinetics
+ - Slower computation
+ * - AMPA
+ - Excitatory, fast
+ - Biologically accurate
+ - Specific use case
+ * - GABAa
+ - Inhibitory, slow
+ - Biologically accurate
+ - Specific use case
+
+Recommendations
+~~~~~~~~~~~~~~~
+
+**For machine learning / SNNs:**
+ Use ``Expon`` for speed and simplicity.
+
+**For biological modeling:**
+ Use ``Alpha``, ``AMPA``, or ``GABAa`` for realism.
+
+**For cortical networks:**
+ - Excitatory: ``AMPA`` (ฯ โ 2 ms)
+ - Inhibitory: ``GABAa`` (ฯ โ 10 ms)
+
+**For custom dynamics:**
+ Implement custom synapse class.
+
+Performance Considerations
+--------------------------
+
+Computational Cost
+~~~~~~~~~~~~~~~~~~
+
+.. list-table::
+ :header-rows: 1
+ :widths: 25 25 50
+
+ * - Model
+ - Relative Cost
+ - Notes
+ * - Expon
+ - 1x (baseline)
+ - Single state variable
+ * - Alpha
+ - 2x
+ - Two state variables
+ * - AMPA/GABAa
+ - 2x
+ - Similar to Alpha
+
+Optimization Tips
+~~~~~~~~~~~~~~~~~
+
+1. **Use Expon when possible**: Fastest option
+
+2. **Batch operations**: Multiple synapses together
+
+ .. code-block:: python
+
+ # Good: Single projection with 1000 synapses
+ proj = brainpy.AlignPostProj(..., syn=brainpy.Expon.desc(1000, ...))
+
+ # Bad: 1000 separate projections
+ projs = [brainpy.AlignPostProj(..., syn=brainpy.Expon.desc(1, ...))
+ for _ in range(1000)]
+
+3. **JIT compilation**: Always use for simulations
+
+ .. code-block:: python
+
+ @brainstate.compile.jit
+ def step():
+ projection(spikes)
+ neurons(0*u.nA)
+
+Common Patterns
+---------------
+
+Excitatory-Inhibitory Balance
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Excitatory projection (fast)
+ E_proj = brainpy.AlignPostProj(
+ comm=...,
+ syn=brainpy.Expon.desc(post_size, tau=2*u.ms),
+ out=brainpy.CUBA.desc(),
+ post=neurons
+ )
+
+ # Inhibitory projection (slow)
+ I_proj = brainpy.AlignPostProj(
+ comm=...,
+ syn=brainpy.Expon.desc(post_size, tau=10*u.ms),
+ out=brainpy.CUBA.desc(),
+ post=neurons
+ )
+
+Multiple Receptor Types
+~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # AMPA (fast excitatory)
+ ampa_proj = brainpy.AlignPostProj(
+ ..., syn=brainpy.AMPA.desc(size, tau=2*u.ms), ...
+ )
+
+ # NMDA (slow excitatory) - custom
+ nmda_proj = brainpy.AlignPostProj(
+ ..., syn=CustomNMDA.desc(size, tau=100*u.ms), ...
+ )
+
+ # GABAa (fast inhibitory)
+ gaba_proj = brainpy.AlignPostProj(
+ ..., syn=brainpy.GABAa.desc(size, tau=10*u.ms), ...
+ )
+
+Summary
+-------
+
+Synapses in BrainPy 3.0:
+
+โ
**Multiple models**: Expon, Alpha, AMPA, GABAa
+
+โ
**Temporal filtering**: Convert spikes to continuous signals
+
+โ
**Descriptor pattern**: Flexible, reusable configuration
+
+โ
**Integration ready**: Seamless use in projections
+
+โ
**Extensible**: Easy custom synapse models
+
+โ
**Physical units**: Proper unit handling throughout
+
+Next Steps
+----------
+
+- Learn about :doc:`projections` for complete connectivity
+- Explore :doc:`plasticity` for learning rules
+- Follow :doc:`../tutorials/basic/02-synapse-models` for practice
+- See :doc:`../examples/classical-networks/ei-balanced` for network examples
diff --git a/docs_version3/examples/gallery.rst b/docs_version3/examples/gallery.rst
new file mode 100644
index 00000000..3affb2b7
--- /dev/null
+++ b/docs_version3/examples/gallery.rst
@@ -0,0 +1,461 @@
+Examples Gallery
+================
+
+Welcome to the BrainPy 3.0 examples gallery! Here you'll find complete, runnable examples demonstrating various aspects of computational neuroscience modeling.
+
+All examples are available in the `examples_version3/ `_ directory of the BrainPy repository.
+
+Classical Network Models
+-------------------------
+
+These examples reproduce influential models from the computational neuroscience literature.
+
+E-I Balanced Networks
+~~~~~~~~~~~~~~~~~~~~~
+
+**102_EI_net_1996.py** - Van Vreeswijk & Sompolinsky (1996)
+
+Implements the classic excitatory-inhibitory balanced network showing chaotic dynamics.
+
+.. code-block:: python
+
+ # Key features:
+ - 80% excitatory, 20% inhibitory neurons
+ - Random sparse connectivity
+ - Balanced excitation and inhibition
+ - Asynchronous irregular firing
+
+:download:`Download <../../examples_version3/102_EI_net_1996.py>`
+
+**Key Concepts**: E-I balance, network dynamics, sparse connectivity
+
+---
+
+COBA Network (2005)
+~~~~~~~~~~~~~~~~~~~
+
+**103_COBA_2005.py** - Vogels & Abbott (2005)
+
+Conductance-based synaptic integration in balanced networks.
+
+.. code-block:: python
+
+ # Key features:
+ - Conductance-based synapses (COBA)
+ - Reversal potentials
+ - More biologically realistic
+ - Stable asynchronous activity
+
+:download:`Download <../../examples_version3/103_COBA_2005.py>`
+
+**Key Concepts**: COBA synapses, conductance-based models, reversal potentials
+
+---
+
+CUBA Network (2005)
+~~~~~~~~~~~~~~~~~~~
+
+**104_CUBA_2005.py** - Vogels & Abbott (2005)
+
+Current-based synaptic integration (simpler, faster variant).
+
+.. code-block:: python
+
+ # Key features:
+ - Current-based synapses (CUBA)
+ - Faster computation
+ - Widely used for large-scale simulations
+
+:download:`Download <../../examples_version3/104_CUBA_2005.py>`
+
+**Alternative**: `104_CUBA_2005_version2.py <../../examples_version3/104_CUBA_2005_version2.py>`_ - Different parameterization
+
+**Key Concepts**: CUBA synapses, current-based models
+
+---
+
+COBA with Hodgkin-Huxley Neurons (2007)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**106_COBA_HH_2007.py** - Conductance-based network with HH neurons
+
+More detailed neuron model with sodium and potassium channels.
+
+.. code-block:: python
+
+ # Key features:
+ - Hodgkin-Huxley neuron dynamics
+ - Action potential generation
+ - Biophysically detailed
+ - Computationally intensive
+
+:download:`Download <../../examples_version3/106_COBA_HH_2007.py>`
+
+**Key Concepts**: Hodgkin-Huxley model, ion channels, biophysical detail
+
+Oscillations and Rhythms
+-------------------------
+
+Gamma Oscillation (1996)
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**107_gamma_oscillation_1996.py** - Gamma rhythm generation
+
+Interneuron network generating gamma oscillations (30-80 Hz).
+
+.. code-block:: python
+
+ # Key features:
+ - Interneuron-based gamma
+ - Inhibition-based synchrony
+ - Physiologically relevant frequency
+ - Network oscillations
+
+:download:`Download <../../examples_version3/107_gamma_oscillation_1996.py>`
+
+**Key Concepts**: Gamma oscillations, network synchrony, inhibitory networks
+
+---
+
+Synfire Chains (199x)
+~~~~~~~~~~~~~~~~~~~~~
+
+**108_synfire_chains_199.py** - Feedforward activity propagation
+
+Demonstrates reliable spike sequence propagation.
+
+.. code-block:: python
+
+ # Key features:
+ - Feedforward architecture
+ - Reliable spike timing
+ - Wave propagation
+ - Temporal coding
+
+:download:`Download <../../examples_version3/108_synfire_chains_199.py>`
+
+**Key Concepts**: Synfire chains, feedforward networks, spike timing
+
+---
+
+Fast Global Oscillation
+~~~~~~~~~~~~~~~~~~~~~~~
+
+**109_fast_global_oscillation.py** - Ultra-fast network rhythms
+
+High-frequency oscillations (>100 Hz) in inhibitory networks.
+
+.. code-block:: python
+
+ # Key features:
+ - Very fast oscillations (>100 Hz)
+ - Gap junction coupling
+ - Inhibitory synchrony
+ - Pathological rhythms
+
+:download:`Download <../../examples_version3/109_fast_global_oscillation.py>`
+
+**Key Concepts**: Fast oscillations, gap junctions, pathological rhythms
+
+Gamma Oscillation Mechanisms (Susin & Destexhe 2021)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Series of models exploring different gamma generation mechanisms:
+
+**110_Susin_Destexhe_2021_gamma_oscillation_AI.py** - Asynchronous Irregular
+
+.. code-block:: python
+
+ # AI state: No oscillations, irregular firing
+ - Background activity state
+ - Asynchronous firing
+ - No clear rhythm
+
+:download:`Download <../../examples_version3/110_Susin_Destexhe_2021_gamma_oscillation_AI.py>`
+
+---
+
+**111_Susin_Destexhe_2021_gamma_oscillation_CHING.py** - Coherent High-frequency INhibition-based Gamma
+
+.. code-block:: python
+
+ # CHING mechanism
+ - Coherent inhibition
+ - High-frequency gamma
+ - Interneuron synchrony
+
+:download:`Download <../../examples_version3/111_Susin_Destexhe_2021_gamma_oscillation_CHING.py>`
+
+---
+
+**112_Susin_Destexhe_2021_gamma_oscillation_ING.py** - Inhibition-based Gamma
+
+.. code-block:: python
+
+ # ING mechanism
+ - Pure inhibitory network
+ - Gamma through inhibition
+ - Fast synaptic kinetics
+
+:download:`Download <../../examples_version3/112_Susin_Destexhe_2021_gamma_oscillation_ING.py>`
+
+---
+
+**113_Susin_Destexhe_2021_gamma_oscillation_PING.py** - Pyramidal-Interneuron Gamma
+
+.. code-block:: python
+
+ # PING mechanism
+ - E-I loop generates gamma
+ - Most common mechanism
+ - Excitatory-inhibitory interaction
+
+:download:`Download <../../examples_version3/113_Susin_Destexhe_2021_gamma_oscillation_PING.py>`
+
+**Combined**: `Susin_Destexhe_2021_gamma_oscillation.py <../../examples_version3/Susin_Destexhe_2021_gamma_oscillation.py>`_ - All mechanisms
+
+**Key Concepts**: Gamma mechanisms, network states, oscillation generation
+
+Spiking Neural Network Training
+--------------------------------
+
+Supervised Learning with Surrogate Gradients
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**200_surrogate_grad_lif.py** - Basic SNN training (SpyTorch tutorial reproduction)
+
+Trains a simple spiking network using surrogate gradients.
+
+.. code-block:: python
+
+ # Key features:
+ - Surrogate gradient method
+ - LIF neuron training
+ - Simple classification task
+ - Gradient-based learning
+
+:download:`Download <../../examples_version3/200_surrogate_grad_lif.py>`
+
+**Key Concepts**: Surrogate gradients, SNN training, backpropagation through time
+
+---
+
+Fashion-MNIST Classification
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**201_surrogate_grad_lif_fashion_mnist.py** - Image classification with SNNs
+
+Trains a spiking network on Fashion-MNIST dataset.
+
+.. code-block:: python
+
+ # Key features:
+ - Fashion-MNIST dataset
+ - Multi-layer SNN
+ - Spike-based processing
+ - Real-world classification
+
+:download:`Download <../../examples_version3/201_surrogate_grad_lif_fashion_mnist.py>`
+
+**Key Concepts**: Image classification, multi-layer SNNs, practical applications
+
+---
+
+MNIST with Readout Layer
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**202_mnist_lif_readout.py** - MNIST with specialized readout
+
+Uses readout layer for classification.
+
+.. code-block:: python
+
+ # Key features:
+ - MNIST handwritten digits
+ - Specialized readout layer
+ - Spike counting
+ - Classification from spike rates
+
+:download:`Download <../../examples_version3/202_mnist_lif_readout.py>`
+
+**Key Concepts**: Readout layers, spike-based classification, MNIST
+
+Example Categories
+------------------
+
+By Difficulty
+~~~~~~~~~~~~~
+
+**Beginner** (Start here!)
+ - 102_EI_net_1996.py - Simple E-I network
+ - 104_CUBA_2005.py - Current-based synapses
+ - 200_surrogate_grad_lif.py - Basic training
+
+**Intermediate**
+ - 103_COBA_2005.py - Conductance-based synapses
+ - 107_gamma_oscillation_1996.py - Network oscillations
+ - 201_surrogate_grad_lif_fashion_mnist.py - Image classification
+
+**Advanced**
+ - 106_COBA_HH_2007.py - Biophysical detail
+ - 113_Susin_Destexhe_2021_gamma_oscillation_PING.py - Complex mechanisms
+ - Large-scale simulations (coming soon)
+
+By Topic
+~~~~~~~~
+
+**Network Dynamics**
+ - E-I balanced networks (102, 103, 104)
+ - Oscillations (107, 109, 110-113)
+ - Synfire chains (108)
+
+**Synaptic Mechanisms**
+ - CUBA models (104)
+ - COBA models (103, 106)
+ - Different synapse types
+
+**Learning and Training**
+ - Surrogate gradients (200, 201, 202)
+ - Classification tasks
+ - Supervised learning
+
+**Biophysical Models**
+ - Hodgkin-Huxley neurons (106)
+ - Detailed conductances
+ - Realistic parameters
+
+Running Examples
+----------------
+
+All examples can be run directly:
+
+.. code-block:: bash
+
+ # Clone repository
+ git clone https://github.com/brainpy/BrainPy.git
+ cd BrainPy
+
+ # Run an example
+ python examples_version3/102_EI_net_1996.py
+
+Or in Jupyter:
+
+.. code-block:: python
+
+ # In Jupyter notebook
+ %run examples_version3/102_EI_net_1996.py
+
+Requirements
+~~~~~~~~~~~~
+
+Examples require:
+
+- Python 3.10+
+- BrainPy 3.0
+- matplotlib (for visualization)
+- Additional dependencies as noted in examples
+
+Example Structure
+-----------------
+
+Most examples follow this structure:
+
+.. code-block:: python
+
+ # 1. Imports
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+ import matplotlib.pyplot as plt
+
+ # 2. Network definition
+ class MyNetwork(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ # Define components
+
+ def update(self, input):
+ # Define dynamics
+
+ # 3. Setup
+ brainstate.environ.set(dt=0.1 * u.ms)
+ net = MyNetwork()
+ brainstate.nn.init_all_states(net)
+
+ # 4. Simulation
+ times = u.math.arange(0*u.ms, 1000*u.ms, dt)
+ results = brainstate.transform.for_loop(net.update, times)
+
+ # 5. Visualization
+ plt.figure()
+ # ... plotting code ...
+ plt.show()
+
+Contributing Examples
+---------------------
+
+We welcome new examples! To contribute:
+
+1. Fork the BrainPy repository
+2. Add your example to ``examples_version3/``
+3. Follow naming convention: ``NNN_descriptive_name.py``
+4. Include documentation at the top
+5. Submit a pull request
+
+Example Template:
+
+.. code-block:: python
+
+ # Copyright 2024 BrainX Ecosystem Limited.
+ # Licensed under Apache License 2.0
+
+ """
+ Short description of the example.
+
+ This example demonstrates:
+ - Feature 1
+ - Feature 2
+
+ References:
+ - Citation if reproducing paper
+ """
+
+ # Your code here...
+
+Additional Resources
+--------------------
+
+**Tutorials**
+ For step-by-step learning, see :doc:`../tutorials/basic/01-lif-neuron`
+
+**API Documentation**
+ For detailed API reference, see :doc:`../api/neurons`
+
+**Core Concepts**
+ For architectural understanding, see :doc:`../core-concepts/architecture`
+
+**Migration Guide**
+ For updating from 2.x, see :doc:`../migration/migration-guide`
+
+Browse All Examples
+-------------------
+
+View all examples on GitHub:
+
+`BrainPy Examples (Version 3.0) `_
+
+For more extensive examples and notebooks:
+
+`BrainPy Examples Repository `_
+
+Getting Help
+------------
+
+If you have questions about examples:
+
+- Open an issue on GitHub
+- Check existing discussions
+- Read the tutorials
+- Consult the documentation
+
+Happy modeling! ๐ง
diff --git a/docs_version3/how-to-guides/custom-components.rst b/docs_version3/how-to-guides/custom-components.rst
new file mode 100644
index 00000000..c4a1297e
--- /dev/null
+++ b/docs_version3/how-to-guides/custom-components.rst
@@ -0,0 +1,510 @@
+How to Create Custom Components
+================================
+
+This guide shows you how to create custom neurons, synapses, and other components in BrainPy.
+
+.. contents:: Table of Contents
+ :local:
+ :depth: 2
+
+Quick Start
+-----------
+
+**Custom neuron template:**
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+ import jax.numpy as jnp
+
+ class CustomNeuron(bp.Neuron):
+ def __init__(self, size, **kwargs):
+ super().__init__(size, **kwargs)
+
+ # Parameters
+ self.tau = 10.0 * u.ms
+ self.V_th = -50.0 * u.mV
+
+ # States
+ self.V = brainstate.ShortTermState(jnp.zeros(size))
+ self.spike = brainstate.ShortTermState(jnp.zeros(size))
+
+ def reset_state(self, batch_size=None):
+ shape = self.size if batch_size is None else (batch_size, self.size)
+ self.V.value = jnp.zeros(shape)
+ self.spike.value = jnp.zeros(shape)
+
+ def update(self, x):
+ dt = brainstate.environ.get_dt()
+
+ # Dynamics
+ dV = -self.V.value / self.tau.to_decimal(u.ms) + x.to_decimal(u.nA)
+ self.V.value += dV * dt.to_decimal(u.ms)
+
+ # Spike generation
+ self.spike.value = (self.V.value >= self.V_th.to_decimal(u.mV)).astype(float)
+
+ # Reset
+ self.V.value = jnp.where(
+ self.spike.value > 0,
+ 0.0, # Reset voltage
+ self.V.value
+ )
+
+ return self.V.value
+
+ def get_spike(self):
+ return self.spike.value
+
+Custom Neurons
+--------------
+
+Example 1: Adaptive LIF
+~~~~~~~~~~~~~~~~~~~~~~~
+
+**LIF with spike-frequency adaptation:**
+
+.. code-block:: python
+
+ class AdaptiveLIF(bp.Neuron):
+ """LIF neuron with adaptation current."""
+
+ def __init__(self, size, tau=10*u.ms, tau_w=100*u.ms,
+ V_th=-50*u.mV, V_reset=-65*u.mV, a=0.1*u.nA,
+ b=0.5*u.nA, **kwargs):
+ super().__init__(size, **kwargs)
+
+ self.tau = tau
+ self.tau_w = tau_w
+ self.V_th = V_th
+ self.V_reset = V_reset
+ self.a = a # Adaptation coupling
+ self.b = b # Spike-triggered adaptation
+
+ # States
+ self.V = brainstate.ShortTermState(jnp.ones(size) * V_reset.to_decimal(u.mV))
+ self.w = brainstate.ShortTermState(jnp.zeros(size)) # Adaptation current
+ self.spike = brainstate.ShortTermState(jnp.zeros(size))
+
+ def reset_state(self, batch_size=None):
+ shape = self.size if batch_size is None else (batch_size, self.size)
+ self.V.value = jnp.ones(shape) * self.V_reset.to_decimal(u.mV)
+ self.w.value = jnp.zeros(shape)
+ self.spike.value = jnp.zeros(shape)
+
+ def update(self, I_ext):
+ dt = brainstate.environ.get_dt()
+
+ # Membrane potential dynamics
+ dV = (-self.V.value + self.V_reset.to_decimal(u.mV) + I_ext.to_decimal(u.nA) - self.w.value) / self.tau.to_decimal(u.ms)
+ self.V.value += dV * dt.to_decimal(u.ms)
+
+ # Adaptation dynamics
+ dw = (self.a.to_decimal(u.nA) * (self.V.value - self.V_reset.to_decimal(u.mV)) - self.w.value) / self.tau_w.to_decimal(u.ms)
+ self.w.value += dw * dt.to_decimal(u.ms)
+
+ # Spike generation
+ self.spike.value = (self.V.value >= self.V_th.to_decimal(u.mV)).astype(float)
+
+ # Reset and adaptation jump
+ self.V.value = jnp.where(
+ self.spike.value > 0,
+ self.V_reset.to_decimal(u.mV),
+ self.V.value
+ )
+ self.w.value += self.spike.value * self.b.to_decimal(u.nA)
+
+ return self.V.value
+
+ def get_spike(self):
+ return self.spike.value
+
+Example 2: Izhikevich Neuron
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class Izhikevich(bp.Neuron):
+ """Izhikevich neuron model."""
+
+ def __init__(self, size, a=0.02, b=0.2, c=-65*u.mV, d=8*u.mV, **kwargs):
+ super().__init__(size, **kwargs)
+
+ self.a = a
+ self.b = b
+ self.c = c
+ self.d = d
+
+ # States
+ self.V = brainstate.ShortTermState(jnp.ones(size) * c.to_decimal(u.mV))
+ self.u = brainstate.ShortTermState(jnp.zeros(size))
+ self.spike = brainstate.ShortTermState(jnp.zeros(size))
+
+ def reset_state(self, batch_size=None):
+ shape = self.size if batch_size is None else (batch_size, self.size)
+ self.V.value = jnp.ones(shape) * self.c.to_decimal(u.mV)
+ self.u.value = jnp.zeros(shape)
+ self.spike.value = jnp.zeros(shape)
+
+ def update(self, I):
+ dt = brainstate.environ.get_dt()
+
+ # Izhikevich dynamics
+ dV = (0.04 * self.V.value**2 + 5 * self.V.value + 140 - self.u.value + I.to_decimal(u.nA))
+ du = self.a * (self.b * self.V.value - self.u.value)
+
+ self.V.value += dV * dt.to_decimal(u.ms)
+ self.u.value += du * dt.to_decimal(u.ms)
+
+ # Spike and reset
+ self.spike.value = (self.V.value >= 30).astype(float)
+ self.V.value = jnp.where(self.spike.value > 0, self.c.to_decimal(u.mV), self.V.value)
+ self.u.value = jnp.where(self.spike.value > 0, self.u.value + self.d.to_decimal(u.mV), self.u.value)
+
+ return self.V.value
+
+ def get_spike(self):
+ return self.spike.value
+
+Custom Synapses
+---------------
+
+Example: Biexponential Synapse
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class BiexponentialSynapse(bp.Synapse):
+ """Synapse with separate rise and decay."""
+
+ def __init__(self, size, tau_rise=1*u.ms, tau_decay=5*u.ms, **kwargs):
+ super().__init__(size, **kwargs)
+
+ self.tau_rise = tau_rise
+ self.tau_decay = tau_decay
+
+ # States
+ self.h = brainstate.ShortTermState(jnp.zeros(size)) # Rising phase
+ self.g = brainstate.ShortTermState(jnp.zeros(size)) # Decaying phase
+
+ def reset_state(self, batch_size=None):
+ shape = self.size if batch_size is None else (batch_size, self.size)
+ self.h.value = jnp.zeros(shape)
+ self.g.value = jnp.zeros(shape)
+
+ def update(self, x):
+ dt = brainstate.environ.get_dt()
+
+ # Two-stage dynamics
+ dh = -self.h.value / self.tau_rise.to_decimal(u.ms) + x
+ dg = -self.g.value / self.tau_decay.to_decimal(u.ms) + self.h.value
+
+ self.h.value += dh * dt.to_decimal(u.ms)
+ self.g.value += dg * dt.to_decimal(u.ms)
+
+ return self.g.value
+
+Example: NMDA Synapse
+~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class NMDASynapse(bp.Synapse):
+ """NMDA receptor with voltage dependence."""
+
+ def __init__(self, size, tau=100*u.ms, a=0.5/u.mM, Mg=1.0*u.mM, **kwargs):
+ super().__init__(size, **kwargs)
+
+ self.tau = tau
+ self.a = a
+ self.Mg = Mg
+
+ self.g = brainstate.ShortTermState(jnp.zeros(size))
+
+ def reset_state(self, batch_size=None):
+ shape = self.size if batch_size is None else (batch_size, self.size)
+ self.g.value = jnp.zeros(shape)
+
+ def update(self, x, V_post=None):
+ """Update with optional postsynaptic voltage."""
+ dt = brainstate.environ.get_dt()
+
+ # Conductance dynamics
+ dg = -self.g.value / self.tau.to_decimal(u.ms) + x
+ self.g.value += dg * dt.to_decimal(u.ms)
+
+ # Voltage-dependent magnesium block
+ if V_post is not None:
+ mg_block = 1 / (1 + self.Mg.to_decimal(u.mM) * self.a.to_decimal(1/u.mM) * jnp.exp(-0.062 * V_post.to_decimal(u.mV)))
+ return self.g.value * mg_block
+ else:
+ return self.g.value
+
+Custom Learning Rules
+---------------------
+
+Example: Simplified STDP
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class SimpleSTDP(brainstate.nn.Module):
+ """Simplified STDP learning rule."""
+
+ def __init__(self, n_pre, n_post, A_plus=0.01, A_minus=0.01,
+ tau_plus=20*u.ms, tau_minus=20*u.ms):
+ super().__init__()
+
+ self.A_plus = A_plus
+ self.A_minus = A_minus
+ self.tau_plus = tau_plus
+ self.tau_minus = tau_minus
+
+ # Learnable weights
+ self.W = brainstate.ParamState(jnp.ones((n_pre, n_post)) * 0.5)
+
+ # Eligibility traces
+ self.pre_trace = brainstate.ShortTermState(jnp.zeros(n_pre))
+ self.post_trace = brainstate.ShortTermState(jnp.zeros(n_post))
+
+ def reset_state(self, batch_size=None):
+ shape_pre = self.W.value.shape[0] if batch_size is None else (batch_size, self.W.value.shape[0])
+ shape_post = self.W.value.shape[1] if batch_size is None else (batch_size, self.W.value.shape[1])
+ self.pre_trace.value = jnp.zeros(shape_pre)
+ self.post_trace.value = jnp.zeros(shape_post)
+
+ def update(self, pre_spike, post_spike):
+ dt = brainstate.environ.get_dt()
+
+ # Update traces
+ self.pre_trace.value += -self.pre_trace.value / self.tau_plus.to_decimal(u.ms) * dt.to_decimal(u.ms) + pre_spike
+ self.post_trace.value += -self.post_trace.value / self.tau_minus.to_decimal(u.ms) * dt.to_decimal(u.ms) + post_spike
+
+ # Weight updates
+ # LTP: pre spike finds existing post trace
+ dw_ltp = self.A_plus * jnp.outer(pre_spike, self.post_trace.value)
+
+ # LTD: post spike finds existing pre trace
+ dw_ltd = -self.A_minus * jnp.outer(self.pre_trace.value, post_spike)
+
+ # Update weights
+ self.W.value = jnp.clip(self.W.value + dw_ltp + dw_ltd, 0, 1)
+
+ return jnp.dot(pre_spike, self.W.value)
+
+Custom Network Architectures
+-----------------------------
+
+Example: Liquid State Machine
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class LiquidStateMachine(brainstate.nn.Module):
+ """Reservoir computing with spiking neurons."""
+
+ def __init__(self, n_input=100, n_reservoir=1000, n_output=10):
+ super().__init__()
+
+ # Input projection (trainable)
+ self.input_weights = brainstate.ParamState(
+ brainstate.random.randn(n_input, n_reservoir) * 0.1
+ )
+
+ # Reservoir (fixed random recurrent network)
+ self.reservoir = bp.LIF(n_reservoir, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+
+ # Fixed random recurrent weights
+ w_reservoir = brainstate.random.randn(n_reservoir, n_reservoir) * 0.01
+ mask = (brainstate.random.rand(n_reservoir, n_reservoir) < 0.1).astype(float)
+ self.reservoir_weights = w_reservoir * mask # Not a ParamState (fixed)
+
+ # Readout (trainable)
+ self.readout = bp.Readout(n_reservoir, n_output)
+
+ def update(self, x):
+ # Input to reservoir
+ reservoir_input = jnp.dot(x, self.input_weights.value) * u.nA
+
+ # Reservoir recurrence
+ spk = self.reservoir.get_spike()
+ recurrent_input = jnp.dot(spk, self.reservoir_weights) * u.nA
+
+ # Update reservoir
+ self.reservoir(reservoir_input + recurrent_input)
+
+ # Readout from reservoir state
+ output = self.readout(self.reservoir.get_spike())
+
+ return output
+
+Custom Input Encoders
+----------------------
+
+Example: Temporal Contrast Encoder
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class TemporalContrastEncoder(brainstate.nn.Module):
+ """Encode images as spike timing based on contrast."""
+
+ def __init__(self, n_pixels, max_time=100, threshold=0.1):
+ super().__init__()
+ self.n_pixels = n_pixels
+ self.max_time = max_time
+ self.threshold = threshold
+
+ def encode(self, image):
+ """Convert image to spike timing.
+
+ Args:
+ image: Array of pixel values [0, 1]
+
+ Returns:
+ spike_times: When each pixel spikes (or max_time if no spike)
+ """
+ # Higher intensity โ earlier spike
+ spike_times = jnp.where(
+ image > self.threshold,
+ self.max_time * (1 - image), # Invert: bright pixels spike early
+ self.max_time # Below threshold: no spike
+ )
+
+ return spike_times
+
+ def decode_to_spikes(self, spike_times, current_time):
+ """Get spikes at current simulation time."""
+ spikes = (spike_times == current_time).astype(float)
+ return spikes
+
+Best Practices
+--------------
+
+โ
**Inherit from base classes**
+ - ``bp.Neuron`` for neurons
+ - ``bp.Synapse`` for synapses
+ - ``brainstate.nn.Module`` for general components
+
+โ
**Use ShortTermState for dynamics**
+ - Reset each trial
+ - Temporary variables
+
+โ
**Use ParamState for learnable parameters**
+ - Trained by optimizers
+ - Saved in checkpoints
+
+โ
**Implement reset_state()**
+ - Handle batch_size parameter
+ - Initialize all ShortTermStates
+
+โ
**Use physical units**
+ - All parameters with ``brainunit``
+ - Convert for computation with ``.to_decimal()``
+
+โ
**Follow naming conventions**
+ - ``V`` for voltage
+ - ``spike`` for spike indicator
+ - ``g`` for conductance
+ - ``w`` for weights
+
+Testing Custom Components
+--------------------------
+
+.. code-block:: python
+
+ def test_custom_neuron():
+ """Test custom neuron implementation."""
+
+ neuron = CustomNeuron(size=10)
+ brainstate.nn.init_all_states(neuron)
+
+ # Test 1: Initialization
+ assert neuron.V.value.shape == (10,)
+ assert jnp.all(neuron.V.value == 0)
+
+ # Test 2: Response to input
+ strong_input = jnp.ones(10) * 10.0 * u.nA
+ for _ in range(100):
+ neuron(strong_input)
+
+ spike_count = jnp.sum(neuron.spike.value)
+ assert spike_count > 0, "Neuron should spike with strong input"
+
+ # Test 3: Batch dimension
+ brainstate.nn.init_all_states(neuron, batch_size=5)
+ assert neuron.V.value.shape == (5, 10)
+
+ print("โ
Custom neuron tests passed")
+
+ test_custom_neuron()
+
+Complete Example
+----------------
+
+**Putting it all together:**
+
+.. code-block:: python
+
+ # Custom components
+ class MyNeuron(bp.Neuron):
+ # ... (see examples above)
+ pass
+
+ class MySynapse(bp.Synapse):
+ # ... (see examples above)
+ pass
+
+ # Use in network
+ class CustomNetwork(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.pre = MyNeuron(size=100)
+ self.post = MyNeuron(size=50)
+
+ self.projection = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(100, 50, prob=0.1, weight=0.5*u.mS),
+ syn=MySynapse.desc(50), # Use custom synapse
+ out=bp.CUBA.desc(),
+ post=self.post
+ )
+
+ def update(self, inp):
+ spk_pre = self.pre.get_spike()
+ self.projection(spk_pre)
+ self.pre(inp)
+ self.post(0*u.nA)
+ return self.post.get_spike()
+
+ # Use network
+ net = CustomNetwork()
+ brainstate.nn.init_all_states(net)
+
+ for _ in range(100):
+ output = net(input_data)
+
+Summary
+-------
+
+**Component creation checklist:**
+
+.. code-block:: python
+
+ โ
Inherit from bp.Neuron, bp.Synapse, or brainstate.nn.Module
+ โ
Define __init__ with parameters
+ โ
Create states (ShortTermState or ParamState)
+ โ
Implement reset_state(batch_size=None)
+ โ
Implement update() method
+ โ
Use physical units throughout
+ โ
Test with different batch sizes
+
+See Also
+--------
+
+- :doc:`../core-concepts/state-management` - Understanding states
+- :doc:`../core-concepts/neurons` - Built-in neuron models
+- :doc:`../core-concepts/synapses` - Built-in synapse models
+- :doc:`../tutorials/advanced/06-synaptic-plasticity` - Plasticity examples
diff --git a/docs_version3/how-to-guides/debugging-networks.rst b/docs_version3/how-to-guides/debugging-networks.rst
new file mode 100644
index 00000000..2b8624af
--- /dev/null
+++ b/docs_version3/how-to-guides/debugging-networks.rst
@@ -0,0 +1,811 @@
+How to Debug Networks
+=====================
+
+This guide shows you how to identify and fix common issues when developing neural networks with BrainPy.
+
+.. contents:: Table of Contents
+ :local:
+ :depth: 2
+
+Quick Diagnostic Checklist
+---------------------------
+
+When your network isn't working, check these first:
+
+**โ Is the network receiving input?**
+ Print input values, check shapes
+
+**โ Are neurons firing?**
+ Count spikes, check spike rates
+
+**โ Are projections working?**
+ Verify connectivity, check weights
+
+**โ Is update order correct?**
+ Get spikes BEFORE updating neurons
+
+**โ Are states initialized?**
+ Call ``brainstate.nn.init_all_states()``
+
+**โ Are units correct?**
+ All values need physical units (mV, nA, ms)
+
+Common Issues and Solutions
+----------------------------
+
+Issue 1: No Spikes / Silent Network
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptoms:**
+
+- Network produces no spikes
+- All neurons stay at rest potential
+
+**Diagnosis:**
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+
+ neuron = bp.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ brainstate.nn.init_all_states(neuron)
+
+ # Check 1: Is input being provided?
+ inp = brainstate.random.rand(100) * 5.0 * u.nA
+ print("Input range:", inp.min(), "to", inp.max())
+
+ # Check 2: Are neurons updating?
+ V_before = neuron.V.value.copy()
+ neuron(inp)
+ V_after = neuron.V.value
+ print("Voltage changed:", not jnp.allclose(V_before, V_after))
+
+ # Check 3: Are any neurons near threshold?
+ print("Max voltage:", V_after.max())
+ print("Threshold:", neuron.V_th.to_decimal(u.mV))
+ print("Neurons above -55mV:", jnp.sum(V_after > -55))
+
+ # Check 4: Count spikes
+ for i in range(100):
+ neuron(inp)
+ spike_count = jnp.sum(neuron.spike.value)
+ print(f"Spikes in 100 steps: {spike_count}")
+
+**Common Causes:**
+
+1. **Input too weak:**
+
+ .. code-block:: python
+
+ # Too weak
+ inp = brainstate.random.rand(100) * 0.1 * u.nA # Not enough!
+
+ # Better
+ inp = brainstate.random.rand(100) * 5.0 * u.nA # Stronger
+
+2. **Threshold too high:**
+
+ .. code-block:: python
+
+ # Check threshold
+ neuron = bp.LIF(100, V_th=-40*u.mV, ...) # Harder to spike
+ neuron = bp.LIF(100, V_th=-50*u.mV, ...) # Easier to spike
+
+3. **Time constant too large:**
+
+ .. code-block:: python
+
+ # Slow integration
+ neuron = bp.LIF(100, tau=100*u.ms, ...) # Very slow
+
+ # Faster
+ neuron = bp.LIF(100, tau=10*u.ms, ...) # Normal speed
+
+4. **Missing initialization:**
+
+ .. code-block:: python
+
+ neuron = bp.LIF(100, ...)
+ # MUST initialize!
+ brainstate.nn.init_all_states(neuron)
+
+Issue 2: Runaway Activity / Explosion
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptoms:**
+
+- All neurons fire constantly
+- Membrane potentials go to infinity
+- NaN values appear
+
+**Diagnosis:**
+
+.. code-block:: python
+
+ # Check for NaN
+ if jnp.any(jnp.isnan(neuron.V.value)):
+ print("โ NaN detected in membrane potential!")
+
+ # Check for explosion
+ if jnp.any(jnp.abs(neuron.V.value) > 1000):
+ print("โ Membrane potential exploded!")
+
+ # Check spike rate
+ spike_rate = jnp.mean(neuron.spike.value)
+ print(f"Spike rate: {spike_rate*100:.1f}%")
+ if spike_rate > 0.5:
+ print("โ ๏ธ More than 50% of neurons firing every step!")
+
+**Common Causes:**
+
+1. **Excitation-Inhibition imbalance:**
+
+ .. code-block:: python
+
+ # Imbalanced (explosion!)
+ w_exc = 5.0 * u.mS # Too strong
+ w_inh = 1.0 * u.mS # Too weak
+
+ # Balanced
+ w_exc = 0.5 * u.mS
+ w_inh = 5.0 * u.mS # Inhibition ~10ร stronger
+
+2. **Positive feedback loop:**
+
+ .. code-block:: python
+
+ # Check recurrent excitation
+ # E โ E with no inhibition can explode
+
+ # Add inhibition
+ class BalancedNetwork(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.E = bp.LIF(800, ...)
+ self.I = bp.LIF(200, ...)
+
+ self.E2E = ... # Excitatory recurrence
+ self.I2E = ... # MUST have inhibition!
+
+3. **Time step too large:**
+
+ .. code-block:: python
+
+ # Unstable
+ brainstate.environ.set(dt=1.0 * u.ms) # Too large
+
+ # Stable
+ brainstate.environ.set(dt=0.1 * u.ms) # Standard
+
+4. **Wrong reversal potentials:**
+
+ .. code-block:: python
+
+ # WRONG: Inhibition with excitatory reversal
+ out_inh = bp.COBA.desc(E=0*u.mV) # Should be negative!
+
+ # CORRECT
+ out_exc = bp.COBA.desc(E=0*u.mV) # Excitation
+ out_inh = bp.COBA.desc(E=-80*u.mV) # Inhibition
+
+Issue 3: Spikes Not Propagating
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptoms:**
+
+- Presynaptic neurons spike
+- Postsynaptic neurons don't respond
+- Projection seems inactive
+
+**Diagnosis:**
+
+.. code-block:: python
+
+ # Create simple network
+ pre = bp.LIF(10, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ post = bp.LIF(10, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+
+ proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(10, 10, prob=0.5, weight=2.0*u.mS),
+ syn=bp.Expon.desc(10, tau=5*u.ms),
+ out=bp.CUBA.desc(),
+ post=post
+ )
+
+ brainstate.nn.init_all_states([pre, post, proj])
+
+ # Diagnosis
+ for i in range(10):
+ # CRITICAL: Get spikes BEFORE update
+ pre_spikes = pre.get_spike()
+
+ # Strong input to pre
+ pre(brainstate.random.rand(10) * 10.0 * u.nA)
+
+ # Check: Did pre spike?
+ if jnp.sum(pre_spikes) > 0:
+ print(f"Step {i}: {jnp.sum(pre_spikes)} presynaptic spikes")
+
+ # Update projection
+ proj(pre_spikes)
+
+ # Check: Did projection produce current?
+ print(f" Synaptic conductance: {proj.syn.g.value.max():.4f}")
+
+ # Update post
+ post(0*u.nA) # Only synaptic input
+
+ # Check: Did post spike?
+ post_spikes = post.get_spike()
+ print(f" {jnp.sum(post_spikes)} postsynaptic spikes")
+
+**Common Causes:**
+
+1. **Wrong spike timing:**
+
+ .. code-block:: python
+
+ # WRONG: Spikes from current step
+ pre(inp) # Update first
+ spikes = pre.get_spike() # These are NEW spikes
+ proj(spikes) # But projection needs OLD spikes!
+
+ # CORRECT: Spikes from previous step
+ spikes = pre.get_spike() # Get OLD spikes first
+ proj(spikes) # Update projection
+ pre(inp) # Then update neurons
+
+2. **Weak connectivity:**
+
+ .. code-block:: python
+
+ # Too sparse
+ comm = brainstate.nn.EventFixedProb(..., prob=0.01, weight=0.1*u.mS)
+
+ # Stronger
+ comm = brainstate.nn.EventFixedProb(..., prob=0.1, weight=1.0*u.mS)
+
+3. **Missing projection update:**
+
+ .. code-block:: python
+
+ # Forgot to call projection!
+ spk = pre.get_spike()
+ # proj(spk) <- MISSING!
+ post(0*u.nA)
+
+4. **Wrong postsynaptic target:**
+
+ .. code-block:: python
+
+ # Wrong target
+ proj = bp.AlignPostProj(..., post=wrong_population)
+
+ # Correct target
+ proj = bp.AlignPostProj(..., post=correct_population)
+
+Issue 4: Shape Mismatch Errors
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptoms:**
+
+.. code-block:: text
+
+ ValueError: operands could not be broadcast together
+ with shapes (100,) (64, 100)
+
+**Common Causes:**
+
+1. **Batch dimension mismatch:**
+
+ .. code-block:: python
+
+ # Network initialized with batch
+ brainstate.nn.init_all_states(net, batch_size=64)
+ # States shape: (64, 100)
+
+ # But input has no batch
+ inp = jnp.zeros(100) # Shape: (100,) - WRONG!
+
+ # Fix: Add batch dimension
+ inp = jnp.zeros((64, 100)) # Shape: (64, 100) - CORRECT
+
+2. **Forgot batch in initialization:**
+
+ .. code-block:: python
+
+ # Initialized without batch
+ brainstate.nn.init_all_states(net) # Shape: (100,)
+
+ # But providing batched input
+ inp = jnp.zeros((64, 100)) # Shape: (64, 100)
+
+ # Fix: Initialize with batch
+ brainstate.nn.init_all_states(net, batch_size=64)
+
+**Debug shape mismatches:**
+
+.. code-block:: python
+
+ print(f"Input shape: {inp.shape}")
+ print(f"Network state shape: {net.neurons.V.value.shape}")
+ print(f"Expected: Both should have same batch dimension")
+
+Inspection Tools
+----------------
+
+Print State Values
+~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Inspect neuron states
+ neuron = bp.LIF(10, ...)
+ brainstate.nn.init_all_states(neuron)
+
+ print("Membrane potentials:", neuron.V.value)
+ print("Spikes:", neuron.spike.value)
+ print("Shape:", neuron.V.value.shape)
+
+ # Statistics
+ print(f"V range: [{neuron.V.value.min():.2f}, {neuron.V.value.max():.2f}]")
+ print(f"V mean: {neuron.V.value.mean():.2f}")
+ print(f"Spike count: {jnp.sum(neuron.spike.value)}")
+
+Visualize Activity
+~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import matplotlib.pyplot as plt
+ import numpy as np
+
+ # Record activity
+ n_steps = 1000
+ V_history = []
+ spike_history = []
+
+ for i in range(n_steps):
+ neuron(inp)
+ V_history.append(neuron.V.value.copy())
+ spike_history.append(neuron.spike.value.copy())
+
+ V_history = jnp.array(V_history)
+ spike_history = jnp.array(spike_history)
+
+ # Plot membrane potential
+ plt.figure(figsize=(12, 4))
+ plt.plot(V_history[:, 0]) # First neuron
+ plt.xlabel('Time step')
+ plt.ylabel('Membrane Potential (mV)')
+ plt.title('Neuron 0 Membrane Potential')
+ plt.show()
+
+ # Plot raster
+ plt.figure(figsize=(12, 6))
+ times, neurons = jnp.where(spike_history > 0)
+ plt.scatter(times, neurons, s=1, c='black')
+ plt.xlabel('Time step')
+ plt.ylabel('Neuron index')
+ plt.title('Spike Raster')
+ plt.show()
+
+ # Firing rate over time
+ plt.figure(figsize=(12, 4))
+ firing_rate = jnp.mean(spike_history, axis=1) * 1000 / 0.1 # Hz
+ plt.plot(firing_rate)
+ plt.xlabel('Time step')
+ plt.ylabel('Population Rate (Hz)')
+ plt.title('Population Firing Rate')
+ plt.show()
+
+Check Connectivity
+~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # For sparse projections
+ proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(100, 50, prob=0.1, weight=0.5*u.mS),
+ syn=bp.Expon.desc(50, tau=5*u.ms),
+ out=bp.CUBA.desc(),
+ post=post_neurons
+ )
+
+ # Check connection count
+ print(f"Expected connections: {100 * 50 * 0.1:.0f}")
+ # Note: Actual connectivity may vary due to randomness
+
+ # Check weights
+ # (Accessing internal connectivity structure depends on implementation)
+
+Monitor Training
+~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Track loss and metrics
+ train_losses = []
+ val_accuracies = []
+
+ for epoch in range(num_epochs):
+ epoch_losses = []
+
+ for batch in train_loader:
+ loss = train_step(net, batch)
+ epoch_losses.append(float(loss))
+
+ avg_loss = np.mean(epoch_losses)
+ train_losses.append(avg_loss)
+
+ # Validation
+ val_acc = evaluate(net, val_loader)
+ val_accuracies.append(val_acc)
+
+ print(f"Epoch {epoch}: Loss={avg_loss:.4f}, Val Acc={val_acc:.2%}")
+
+ # Check for issues
+ if np.isnan(avg_loss):
+ print("โ NaN loss! Stopping training.")
+ break
+
+ if avg_loss > 10 * train_losses[0]:
+ print("โ ๏ธ Loss exploding!")
+
+ # Plot training curves
+ plt.figure(figsize=(12, 4))
+ plt.subplot(1, 2, 1)
+ plt.plot(train_losses)
+ plt.xlabel('Epoch')
+ plt.ylabel('Loss')
+ plt.title('Training Loss')
+
+ plt.subplot(1, 2, 2)
+ plt.plot(val_accuracies)
+ plt.xlabel('Epoch')
+ plt.ylabel('Accuracy')
+ plt.title('Validation Accuracy')
+ plt.show()
+
+Advanced Debugging
+------------------
+
+Gradient Checking
+~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import braintools
+
+ # Check if gradients are being computed
+ params = net.states(brainstate.ParamState)
+
+ grads = brainstate.transform.grad(
+ loss_fn,
+ params,
+ return_value=True
+ )(net, X, y)
+
+ # Inspect gradients
+ for name, grad in grads.items():
+ grad_norm = jnp.linalg.norm(grad.value.flatten())
+ print(f"{name}: gradient norm = {grad_norm:.6f}")
+
+ if jnp.any(jnp.isnan(grad.value)):
+ print(f" โ NaN in gradient!")
+
+ if grad_norm == 0:
+ print(f" โ ๏ธ Zero gradient - parameter not learning")
+
+ if grad_norm > 1000:
+ print(f" โ ๏ธ Exploding gradient!")
+
+Trace Execution
+~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ def debug_step(net, inp):
+ """Instrumented simulation step."""
+ print(f"\n--- Step Start ---")
+
+ # Before
+ print(f"Input range: [{inp.min():.2f}, {inp.max():.2f}]")
+ print(f"V before: [{net.neurons.V.value.min():.2f}, {net.neurons.V.value.max():.2f}]")
+
+ # Execute
+ output = net(inp)
+
+ # After
+ print(f"V after: [{net.neurons.V.value.min():.2f}, {net.neurons.V.value.max():.2f}]")
+ print(f"Spikes: {jnp.sum(net.neurons.spike.value)}")
+ print(f"Output range: [{output.min():.2f}, {output.max():.2f}]")
+
+ # Checks
+ if jnp.any(jnp.isnan(net.neurons.V.value)):
+ print("โ NaN detected!")
+ import pdb; pdb.set_trace() # Drop into debugger
+
+ print(f"--- Step End ---\n")
+ return output
+
+ # Use for debugging
+ for i in range(10):
+ output = debug_step(net, input_data)
+
+Assertion Checks
+~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class SafeNetwork(brainstate.nn.Module):
+ """Network with built-in checks."""
+
+ def __init__(self, n_neurons=100):
+ super().__init__()
+ self.neurons = bp.LIF(n_neurons, ...)
+
+ def update(self, inp):
+ # Pre-checks
+ assert inp.shape[-1] == 100, f"Wrong input size: {inp.shape}"
+ assert not jnp.any(jnp.isnan(inp)), "NaN in input!"
+ assert not jnp.any(jnp.isinf(inp)), "Inf in input!"
+
+ # Execute
+ self.neurons(inp)
+ output = self.neurons.get_spike()
+
+ # Post-checks
+ assert not jnp.any(jnp.isnan(self.neurons.V.value)), "NaN in membrane potential!"
+ assert jnp.all(jnp.abs(self.neurons.V.value) < 1000), "Voltage explosion!"
+
+ return output
+
+Unit Testing
+~~~~~~~~~~~~
+
+.. code-block:: python
+
+ def test_neuron_spikes():
+ """Test that neuron spikes with strong input."""
+ neuron = bp.LIF(1, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ brainstate.nn.init_all_states(neuron)
+
+ # Strong constant input should cause spiking
+ strong_input = jnp.array([20.0]) * u.nA
+
+ spike_count = 0
+ for _ in range(100):
+ neuron(strong_input)
+ spike_count += int(neuron.spike.value[0])
+
+ assert spike_count > 0, "Neuron didn't spike with strong input!"
+ assert spike_count < 100, "Neuron spiked every step (check reset!)"
+
+ print(f"โ
Neuron test passed ({spike_count} spikes)")
+
+ def test_projection():
+ """Test that projection propagates spikes."""
+ pre = bp.LIF(10, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ post = bp.LIF(10, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+
+ proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(10, 10, prob=1.0, weight=5.0*u.mS), # 100% connectivity
+ syn=bp.Expon.desc(10, tau=5*u.ms),
+ out=bp.CUBA.desc(),
+ post=post
+ )
+
+ brainstate.nn.init_all_states([pre, post, proj])
+
+ # Make pre spike
+ pre(jnp.ones(10) * 20.0 * u.nA)
+
+ # Projection should activate
+ spk = pre.get_spike()
+ assert jnp.sum(spk) > 0, "Pre didn't spike!"
+
+ proj(spk)
+
+ # Check synaptic conductance increased
+ assert proj.syn.g.value.max() > 0, "Synapse didn't activate!"
+
+ print("โ
Projection test passed")
+
+ # Run tests
+ test_neuron_spikes()
+ test_projection()
+
+Debugging Checklist
+-------------------
+
+When your network doesn't work:
+
+**1. Check Initialization**
+
+.. code-block:: python
+
+ โ Called brainstate.nn.init_all_states()?
+ โ Correct batch_size parameter?
+ โ All submodules initialized?
+
+**2. Check Input**
+
+.. code-block:: python
+
+ โ Input shape matches network?
+ โ Input has units (nA, mV, etc.)?
+ โ Input magnitude reasonable?
+ โ Input not all zeros?
+
+**3. Check Neurons**
+
+.. code-block:: python
+
+ โ Threshold reasonable (e.g., -50 mV)?
+ โ Reset potential below threshold?
+ โ Time constant reasonable (5-20 ms)?
+ โ Neurons actually spiking?
+
+**4. Check Projections**
+
+.. code-block:: python
+
+ โ Connectivity probability > 0?
+ โ Weights reasonable magnitude?
+ โ Correct update order (spikes before update)?
+ โ Projection actually called?
+
+**5. Check Balance**
+
+.. code-block:: python
+
+ โ Inhibition stronger than excitation (~10ร)?
+ โ Reversal potentials correct (E=0, I=-80)?
+ โ E/I ratio appropriate (4:1)?
+
+**6. Check Training**
+
+.. code-block:: python
+
+ โ Loss decreasing?
+ โ Gradients non-zero?
+ โ No NaN in gradients?
+ โ Learning rate appropriate?
+
+Common Error Messages
+---------------------
+
+"operands could not be broadcast"
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Meaning:** Shape mismatch
+
+**Fix:** Check batch dimensions
+
+.. code-block:: python
+
+ print(f"Shapes: {x.shape} vs {y.shape}")
+
+"RESOURCE_EXHAUSTED: Out of memory"
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Meaning:** GPU/CPU memory full
+
+**Fix:** Reduce batch size or network size
+
+.. code-block:: python
+
+ # Reduce batch
+ brainstate.nn.init_all_states(net, batch_size=16) # Instead of 64
+
+"Concrete value required"
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Meaning:** JIT can't handle dynamic values
+
+**Fix:** Use static shapes
+
+.. code-block:: python
+
+ # Dynamic (bad for JIT)
+ n = len(data) # Changes each call
+
+ # Static (good for JIT)
+ n = 100 # Fixed value
+
+"Invalid device"
+~~~~~~~~~~~~~~~~
+
+**Meaning:** Trying to use unavailable device
+
+**Fix:** Check available devices
+
+.. code-block:: python
+
+ import jax
+ print(jax.devices())
+
+Best Practices
+--------------
+
+โ
**Test small first** - Debug with 10 neurons before scaling to 10,000
+
+โ
**Visualize early** - Plot activity to see problems immediately
+
+โ
**Check incrementally** - Test each component before combining
+
+โ
**Use assertions** - Catch problems early with runtime checks
+
+โ
**Print liberally** - Add diagnostic prints during development
+
+โ
**Keep backups** - Save working versions before major changes
+
+โ
**Start simple** - Begin with minimal network, add complexity gradually
+
+โ
**Write tests** - Unit test individual components
+
+โ **Don't debug by guessing** - Use systematic diagnosis
+
+โ **Don't skip initialization** - Always call init_all_states
+
+โ **Don't ignore warnings** - They often indicate real problems
+
+Summary
+-------
+
+**Debugging workflow:**
+
+1. **Identify symptom** (no spikes, explosion, etc.)
+2. **Isolate component** (neurons, projections, input)
+3. **Inspect state** (print values, plot activity)
+4. **Form hypothesis** (what might be wrong?)
+5. **Test fix** (make one change at a time)
+6. **Verify** (ensure problem solved)
+
+**Quick diagnostic code:**
+
+.. code-block:: python
+
+ # Comprehensive diagnostic
+ def diagnose_network(net, inp):
+ print("=== Network Diagnostic ===")
+
+ # Input
+ print(f"Input shape: {inp.shape}")
+ print(f"Input range: [{inp.min():.2f}, {inp.max():.2f}]")
+
+ # States
+ if hasattr(net, 'neurons'):
+ V = net.neurons.V.value
+ print(f"Voltage shape: {V.shape}")
+ print(f"Voltage range: [{V.min():.2f}, {V.max():.2f}]")
+
+ # Simulation
+ output = net(inp)
+
+ # Results
+ if hasattr(net, 'neurons'):
+ spk_count = jnp.sum(net.neurons.spike.value)
+ print(f"Spikes: {spk_count}")
+
+ print(f"Output shape: {output.shape}")
+ print(f"Output range: [{output.min():.2f}, {output.max():.2f}]")
+
+ # Checks
+ if jnp.any(jnp.isnan(output)):
+ print("โ NaN in output!")
+ if jnp.all(output == 0):
+ print("โ ๏ธ Output all zeros!")
+
+ print("=========================")
+ return output
+
+See Also
+--------
+
+- :doc:`../core-concepts/state-management` - Understanding state system
+- :doc:`../core-concepts/projections` - Projection architecture
+- :doc:`performance-optimization` - Optimization tips
diff --git a/docs_version3/how-to-guides/gpu-tpu-usage.rst b/docs_version3/how-to-guides/gpu-tpu-usage.rst
new file mode 100644
index 00000000..8de3cb5f
--- /dev/null
+++ b/docs_version3/how-to-guides/gpu-tpu-usage.rst
@@ -0,0 +1,826 @@
+How to Use GPU and TPU
+======================
+
+This guide shows you how to leverage GPU and TPU acceleration for faster simulations and training with BrainPy.
+
+.. contents:: Table of Contents
+ :local:
+ :depth: 2
+
+Quick Start
+-----------
+
+**Check available devices:**
+
+.. code-block:: python
+
+ import jax
+ print("Available devices:", jax.devices())
+ print("Default backend:", jax.default_backend())
+
+**BrainPy automatically uses available accelerators** - no code changes needed!
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+
+ # This automatically runs on GPU if available
+ net = bp.LIF(10000, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ brainstate.nn.init_all_states(net)
+
+ for _ in range(1000):
+ net(brainstate.random.rand(10000) * 2.0 * u.nA)
+
+Installation
+------------
+
+CPU-Only (Default)
+~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ pip install brainpy[cpu]
+
+GPU (CUDA 12)
+~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ # CUDA 12
+ pip install brainpy[cuda12]
+
+ # Or CUDA 11
+ pip install brainpy[cuda11]
+
+**Requirements:**
+
+- NVIDIA GPU (compute capability โฅ 3.5)
+- CUDA Toolkit installed
+- cuDNN libraries
+
+TPU (Google Cloud)
+~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ pip install brainpy[tpu]
+
+**Requirements:**
+
+- Google Cloud TPU instance
+- TPU runtime configured
+
+Verify Installation
+~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import jax
+ import jax.numpy as jnp
+
+ # Check JAX can see GPU/TPU
+ print("Devices:", jax.devices())
+
+ # Test computation
+ x = jnp.ones((1000, 1000))
+ y = jnp.dot(x, x)
+ print("โ
JAX computation works!")
+
+ # Check device placement
+ print("Result device:", y.device())
+
+Expected output (GPU):
+
+.. code-block:: text
+
+ Devices: [cuda(id=0)]
+ โ
JAX computation works!
+ Result device: cuda:0
+
+Understanding Device Placement
+-------------------------------
+
+Automatic Placement
+~~~~~~~~~~~~~~~~~~~
+
+**JAX automatically places computations on the best available device:**
+
+1. TPU (if available)
+2. GPU (if available)
+3. CPU (fallback)
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+
+ # Automatically uses GPU if available
+ net = bp.LIF(1000, ...)
+ brainstate.nn.init_all_states(net)
+
+ # All operations run on GPU
+ net(input_data)
+
+Manual Device Selection
+~~~~~~~~~~~~~~~~~~~~~~~
+
+Force computation on specific device:
+
+.. code-block:: python
+
+ import jax
+
+ # Run on specific GPU
+ with jax.default_device(jax.devices('gpu')[0]):
+ net = bp.LIF(1000, ...)
+ brainstate.nn.init_all_states(net)
+ result = net(input_data)
+
+ # Run on CPU
+ with jax.default_device(jax.devices('cpu')[0]):
+ net_cpu = bp.LIF(1000, ...)
+ brainstate.nn.init_all_states(net_cpu)
+ result_cpu = net_cpu(input_data)
+
+Check Data Location
+~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Check where data lives
+ neuron = bp.LIF(100, ...)
+ brainstate.nn.init_all_states(neuron)
+
+ print("Voltage device:", neuron.V.value.device())
+ # Output: cuda:0 (if on GPU)
+
+Optimizing for GPU
+-------------------
+
+Use JIT Compilation
+~~~~~~~~~~~~~~~~~~~
+
+**Essential for GPU performance!**
+
+.. code-block:: python
+
+ import brainstate
+
+ net = bp.LIF(10000, ...)
+ brainstate.nn.init_all_states(net)
+
+ # WITHOUT JIT (slow on GPU)
+ for _ in range(1000):
+ net(input_data) # Many small kernel launches
+
+ # WITH JIT (fast on GPU)
+ @brainstate.compile.jit
+ def simulate_step(net, inp):
+ return net(inp)
+
+ # Warmup (compilation)
+ _ = simulate_step(net, input_data)
+
+ # Fast execution
+ for _ in range(1000):
+ output = simulate_step(net, input_data)
+
+**Speedup:** 10-100ร with JIT on GPU
+
+Batch Operations
+~~~~~~~~~~~~~~~~
+
+**Process multiple trials in parallel:**
+
+.. code-block:: python
+
+ # Single trial (underutilizes GPU)
+ net = bp.LIF(1000, ...)
+ brainstate.nn.init_all_states(net) # Shape: (1000,)
+
+ # Multiple trials in parallel (efficient GPU usage)
+ net_batched = bp.LIF(1000, ...)
+ brainstate.nn.init_all_states(net_batched, batch_size=64) # Shape: (64, 1000)
+
+ # GPU processes all 64 trials simultaneously
+ inp = brainstate.random.rand(64, 1000) * 2.0 * u.nA
+ output = net_batched(inp)
+
+**GPU Utilization:**
+
+- Small batches (1-10): ~10-30% GPU usage
+- Medium batches (32-128): ~60-80% GPU usage
+- Large batches (256+): ~90-100% GPU usage
+
+Appropriate Problem Size
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+**GPU overhead is worth it for large problems:**
+
+.. list-table:: When to Use GPU
+ :header-rows: 1
+
+ * - Network Size
+ - GPU Speedup
+ - Recommendation
+ * - < 1,000 neurons
+ - 0.5-2ร
+ - Use CPU
+ * - 1,000-10,000
+ - 2-10ร
+ - GPU beneficial
+ * - 10,000-100,000
+ - 10-50ร
+ - GPU strongly recommended
+ * - > 100,000
+ - 50-100ร
+ - GPU essential
+
+Minimize Data Transfer
+~~~~~~~~~~~~~~~~~~~~~~
+
+**Avoid moving data between CPU and GPU:**
+
+.. code-block:: python
+
+ # BAD: Frequent CPU-GPU transfers
+ for i in range(1000):
+ inp_cpu = np.random.rand(1000) # On CPU
+ inp_gpu = jnp.array(inp_cpu) # Transfer to GPU
+ output_gpu = net(inp_gpu) # Compute on GPU
+ output_cpu = np.array(output_gpu) # Transfer to CPU
+ # CPU-GPU transfer dominates time!
+
+ # GOOD: Keep data on GPU
+ @brainstate.compile.jit
+ def simulate_step(net, key):
+ inp = brainstate.random.uniform(key, (1000,)) * 2.0 # Generated on GPU
+ return net(inp) # Stays on GPU
+
+ key = brainstate.random.split_key()
+ for i in range(1000):
+ output = simulate_step(net, key) # All on GPU
+
+Use Sparse Operations
+~~~~~~~~~~~~~~~~~~~~~
+
+**Sparse connectivity is crucial for large networks:**
+
+.. code-block:: python
+
+ # Dense (memory intensive on GPU)
+ dense_proj = bp.AlignPostProj(
+ comm=brainstate.nn.Linear(10000, 10000), # 400MB just for weights!
+ syn=bp.Expon.desc(10000, tau=5*u.ms),
+ out=bp.CUBA.desc(),
+ post=post_neurons
+ )
+
+ # Sparse (memory efficient)
+ sparse_proj = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(
+ pre_size=10000,
+ post_size=10000,
+ prob=0.01, # 1% connectivity
+ weight=0.5*u.mS
+ ), # Only 4MB for weights!
+ syn=bp.Expon.desc(10000, tau=5*u.ms),
+ out=bp.CUBA.desc(),
+ post=post_neurons
+ )
+
+Multi-GPU Usage
+---------------
+
+Data Parallelism
+~~~~~~~~~~~~~~~~
+
+**Run different trials on different GPUs:**
+
+.. code-block:: python
+
+ import jax
+
+ # Check available GPUs
+ gpus = jax.devices('gpu')
+ print(f"Found {len(gpus)} GPUs")
+
+ # Split work across GPUs
+ def run_on_gpu(gpu_id, n_trials):
+ with jax.default_device(gpus[gpu_id]):
+ net = bp.LIF(1000, ...)
+ brainstate.nn.init_all_states(net, batch_size=n_trials)
+
+ results = []
+ for _ in range(100):
+ output = net(input_data)
+ results.append(output)
+
+ return results
+
+ # Run on multiple GPUs in parallel
+ from concurrent.futures import ThreadPoolExecutor
+
+ with ThreadPoolExecutor(max_workers=len(gpus)) as executor:
+ futures = [
+ executor.submit(run_on_gpu, i, 32)
+ for i in range(len(gpus))
+ ]
+ all_results = [f.result() for f in futures]
+
+Using JAX pmap
+~~~~~~~~~~~~~~
+
+**Parallel map across devices:**
+
+.. code-block:: python
+
+ from jax import pmap
+ import jax.numpy as jnp
+
+ # Create model
+ net = bp.LIF(1000, ...)
+
+ @pmap
+ def parallel_simulate(inputs):
+ """Run on multiple devices in parallel."""
+ brainstate.nn.init_all_states(net)
+ return net(inputs)
+
+ # Split inputs across devices
+ n_devices = len(jax.devices())
+ inputs = jnp.ones((n_devices, 1000)) # One batch per device
+
+ # Run in parallel
+ outputs = parallel_simulate(inputs)
+ # outputs.shape = (n_devices, output_size)
+
+TPU-Specific Optimization
+--------------------------
+
+TPU Characteristics
+~~~~~~~~~~~~~~~~~~~
+
+**TPUs are optimized for:**
+
+โ
Large matrix multiplications (e.g., dense layers)
+
+โ
High batch sizes (128+)
+
+โ
Float32 operations (bf16 also good)
+
+โ Small operations (overhead dominates)
+
+โ Sparse operations (less optimized than GPU)
+
+โ Dynamic shapes (requires recompilation)
+
+Optimal TPU Usage
+~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Configure for TPU
+ import brainstate
+
+ # Large batches for TPU
+ batch_size = 256 # TPUs like large batches
+
+ net = bp.LIF(1000, ...)
+ brainstate.nn.init_all_states(net, batch_size=batch_size)
+
+ # JIT is essential
+ @brainstate.compile.jit
+ def train_step(net, inputs, labels):
+ # Dense operations work well
+ # Avoid sparse operations on TPU
+ return loss
+
+ # Static shapes (avoid dynamic)
+ inputs = jnp.ones((batch_size, 1000)) # Fixed shape
+
+ # Run
+ for batch in data_loader:
+ loss = train_step(net, batch_inputs, batch_labels)
+
+TPU Pods
+~~~~~~~~
+
+**Multi-TPU training:**
+
+.. code-block:: python
+
+ # TPU pods provide multiple TPU cores
+ devices = jax.devices('tpu')
+ print(f"TPU cores: {len(devices)}")
+
+ # Use pmap for data parallelism
+ @pmap
+ def parallel_step(inputs):
+ return net(inputs)
+
+ # Split across TPU cores
+ inputs_per_core = jnp.reshape(inputs, (len(devices), -1, 1000))
+ outputs = parallel_step(inputs_per_core)
+
+Performance Benchmarking
+------------------------
+
+Measure Speedup
+~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import time
+ import jax
+
+ def benchmark_device(device_type, n_neurons=10000, n_steps=1000):
+ """Benchmark simulation on specific device."""
+
+ # Select device
+ if device_type == 'cpu':
+ device = jax.devices('cpu')[0]
+ elif device_type == 'gpu':
+ device = jax.devices('gpu')[0]
+ else:
+ device = jax.devices('tpu')[0]
+
+ with jax.default_device(device):
+ # Create network
+ net = bp.LIF(n_neurons, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ brainstate.nn.init_all_states(net)
+
+ @brainstate.compile.jit
+ def step(net, inp):
+ return net(inp)
+
+ # Warmup
+ inp = brainstate.random.rand(n_neurons) * 2.0 * u.nA
+ _ = step(net, inp)
+
+ # Benchmark
+ start = time.time()
+ for _ in range(n_steps):
+ inp = brainstate.random.rand(n_neurons) * 2.0 * u.nA
+ output = step(net, inp)
+ elapsed = time.time() - start
+
+ return elapsed
+
+ # Compare devices
+ cpu_time = benchmark_device('cpu', n_neurons=10000, n_steps=1000)
+ gpu_time = benchmark_device('gpu', n_neurons=10000, n_steps=1000)
+
+ print(f"CPU time: {cpu_time:.2f}s")
+ print(f"GPU time: {gpu_time:.2f}s")
+ print(f"Speedup: {cpu_time/gpu_time:.1f}ร")
+
+Profile GPU Usage
+~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ # Monitor GPU memory
+ import jax
+
+ # Get memory info (NVIDIA GPUs)
+ try:
+ from jax.lib import xla_bridge
+ print("GPU memory allocated:", xla_bridge.get_backend().platform_memory_stats())
+ except:
+ print("Memory stats not available")
+
+ # Profile with TensorBoard (advanced)
+ with jax.profiler.trace("/tmp/tensorboard"):
+ for _ in range(100):
+ output = net(input_data)
+
+ # View with: tensorboard --logdir=/tmp/tensorboard
+
+Memory Management
+-----------------
+
+Check GPU Memory
+~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import jax
+
+ # Check total memory
+ for device in jax.devices('gpu'):
+ try:
+ # This may not work on all systems
+ print(f"Device: {device}")
+ print(f"Memory: {device.memory_stats()}")
+ except:
+ print("Memory stats not available")
+
+Estimate Memory Requirements
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ def estimate_memory_mb(n_neurons, n_synapses, batch_size=1, dtype_bytes=4):
+ """Estimate GPU memory needed.
+
+ Args:
+ n_neurons: Number of neurons
+ n_synapses: Number of synapses
+ batch_size: Batch size
+ dtype_bytes: 4 for float32, 2 for float16
+ """
+ # Neuron states (V, spike, etc.) ร batch
+ neuron_memory = n_neurons * 3 * batch_size * dtype_bytes
+
+ # Synapse states (g, x, etc.)
+ synapse_memory = n_synapses * 2 * dtype_bytes
+
+ # Weights
+ weight_memory = n_synapses * dtype_bytes
+
+ total_bytes = neuron_memory + synapse_memory + weight_memory
+ total_mb = total_bytes / (1024 * 1024)
+
+ return total_mb
+
+ # Example
+ mem_mb = estimate_memory_mb(
+ n_neurons=100000,
+ n_synapses=100000 * 100000 * 0.01, # 1% connectivity
+ batch_size=32
+ )
+ print(f"Estimated memory: {mem_mb:.1f} MB ({mem_mb/1024:.2f} GB)")
+
+Clear GPU Memory
+~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import jax
+
+ # JAX manages memory automatically
+ # But you can force garbage collection
+
+ import gc
+
+ # Delete large arrays
+ del large_array
+ del network
+
+ # Force garbage collection
+ gc.collect()
+
+ # Clear JAX compilation cache (if needed)
+ jax.clear_caches()
+
+Common Issues and Solutions
+----------------------------
+
+Issue: Out of Memory
+~~~~~~~~~~~~~~~~~~~~
+
+**Symptom:** `RESOURCE_EXHAUSTED: Out of memory`
+
+**Solutions:**
+
+1. **Reduce batch size:**
+
+ .. code-block:: python
+
+ # Try smaller batch
+ brainstate.nn.init_all_states(net, batch_size=16) # Instead of 64
+
+2. **Use sparse connectivity:**
+
+ .. code-block:: python
+
+ # Reduce connectivity
+ comm = brainstate.nn.EventFixedProb(..., prob=0.01) # Instead of 0.1
+
+3. **Use float16:**
+
+ .. code-block:: python
+
+ # Lower precision (experimental)
+ jax.config.update('jax_default_dtype_bits', '32') # Default
+ # Note: BrainPy primarily uses float32
+
+4. **Process in chunks:**
+
+ .. code-block:: python
+
+ # Split large population
+ for i in range(0, n_neurons, chunk_size):
+ chunk_output = process_chunk(neurons[i:i+chunk_size])
+
+Issue: Slow First Run
+~~~~~~~~~~~~~~~~~~~~~
+
+**Symptom:** First iteration very slow
+
+**Explanation:** JIT compilation happens on first call
+
+**Solution:** Warm up before timing
+
+.. code-block:: python
+
+ @brainstate.compile.jit
+ def step(net, inp):
+ return net(inp)
+
+ # Warmup (compile)
+ _ = step(net, dummy_input)
+
+ # Now fast
+ for real_input in data:
+ output = step(net, real_input)
+
+Issue: GPU Not Being Used
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptom:** Computation on CPU despite GPU available
+
+**Check:**
+
+.. code-block:: python
+
+ import jax
+ print("Devices:", jax.devices())
+ print("Default backend:", jax.default_backend())
+
+ # Should show GPU
+
+**Solutions:**
+
+1. Check installation: `pip list | grep jax`
+2. Reinstall with GPU support: `pip install brainpy[cuda12]`
+3. Check CUDA installation: `nvidia-smi`
+
+Issue: Version Mismatch
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptom:** `RuntimeError: CUDA error`
+
+**Check versions:**
+
+.. code-block:: bash
+
+ # Check CUDA version
+ nvcc --version
+
+ # Check JAX version
+ python -c "import jax; print(jax.__version__)"
+
+**Solution:** Match JAX CUDA version with system CUDA
+
+.. code-block:: bash
+
+ # For CUDA 12.x
+ pip install brainpy[cuda12]
+
+ # For CUDA 11.x
+ pip install brainpy[cuda11]
+
+Best Practices
+--------------
+
+โ
**Use JIT compilation** - Essential for GPU performance
+
+โ
**Batch operations** - Process multiple trials in parallel
+
+โ
**Keep data on device** - Avoid CPU-GPU transfers
+
+โ
**Use sparse connectivity** - For biological-scale networks
+
+โ
**Profile before optimizing** - Identify real bottlenecks
+
+โ
**Warm up JIT** - Compile before timing
+
+โ
**Monitor memory** - Estimate before running large models
+
+โ
**Static shapes** - Avoid dynamic shapes (causes recompilation)
+
+โ **Don't use GPU for small problems** - Overhead dominates
+
+โ **Don't transfer data unnecessarily** - Keep on GPU
+
+โ **Don't use dense connectivity for large networks** - Memory explosion
+
+Example: Complete GPU Workflow
+-------------------------------
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+ import braintools
+ import jax
+ import time
+
+ # 1. Check GPU availability
+ print("Devices:", jax.devices())
+ assert jax.default_backend() == 'gpu', "GPU not available!"
+
+ # 2. Create large network
+ class LargeNetwork(brainstate.nn.Module):
+ def __init__(self, n_exc=8000, n_inh=2000):
+ super().__init__()
+
+ self.E = bp.LIF(n_exc, V_rest=-65*u.mV, V_th=-50*u.mV, tau=15*u.ms)
+ self.I = bp.LIF(n_inh, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+
+ # Sparse connectivity (GPU efficient)
+ self.E2E = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(n_exc, n_exc, prob=0.02, weight=0.5*u.mS),
+ syn=bp.Expon.desc(n_exc, tau=5*u.ms),
+ out=bp.CUBA.desc(),
+ post=self.E
+ )
+ # ... more projections
+
+ def update(self, inp_e, inp_i):
+ spk_e = self.E.get_spike()
+ spk_i = self.I.get_spike()
+
+ self.E2E(spk_e)
+ # ... update all projections
+
+ self.E(inp_e)
+ self.I(inp_i)
+
+ return spk_e, spk_i
+
+ # 3. Initialize with large batch
+ net = LargeNetwork()
+ batch_size = 64 # Process 64 trials in parallel
+ brainstate.nn.init_all_states(net, batch_size=batch_size)
+
+ # 4. JIT compile
+ @brainstate.compile.jit
+ def simulate_step(net, inp_e, inp_i):
+ return net(inp_e, inp_i)
+
+ # 5. Warmup (compilation)
+ print("Compiling...")
+ inp_e = brainstate.random.rand(batch_size, 8000) * 1.0 * u.nA
+ inp_i = brainstate.random.rand(batch_size, 2000) * 1.0 * u.nA
+ _ = simulate_step(net, inp_e, inp_i)
+ print("โ
Compilation complete")
+
+ # 6. Run simulation
+ print("Running simulation...")
+ n_steps = 1000
+
+ start = time.time()
+ for _ in range(n_steps):
+ inp_e = brainstate.random.rand(batch_size, 8000) * 1.0 * u.nA
+ inp_i = brainstate.random.rand(batch_size, 2000) * 1.0 * u.nA
+ spk_e, spk_i = simulate_step(net, inp_e, inp_i)
+
+ elapsed = time.time() - start
+
+ print(f"โ
Simulation complete")
+ print(f" Time: {elapsed:.2f}s")
+ print(f" Throughput: {n_steps/elapsed:.1f} steps/s")
+ print(f" Speed: {batch_size * n_steps / elapsed:.1f} trials/s")
+
+Summary
+-------
+
+**Key Points:**
+
+- BrainPy automatically uses GPU/TPU when available
+- JIT compilation is essential for GPU performance
+- Batch operations maximize GPU utilization
+- Keep data on device to avoid transfer overhead
+- Use sparse connectivity for large networks
+- GPU beneficial for networks > 1,000 neurons
+
+**Quick Reference:**
+
+.. code-block:: python
+
+ # Check device
+ import jax
+ print(jax.devices())
+
+ # JIT for GPU
+ @brainstate.compile.jit
+ def step(net, inp):
+ return net(inp)
+
+ # Batch for GPU
+ brainstate.nn.init_all_states(net, batch_size=64)
+
+ # Sparse for memory
+ comm = brainstate.nn.EventFixedProb(..., prob=0.02)
+
+See Also
+--------
+
+- :doc:`../tutorials/advanced/07-large-scale-simulations` - Optimization techniques
+- :doc:`performance-optimization` - General performance tips
+- JAX documentation: https://jax.readthedocs.io/
diff --git a/docs_version3/how-to-guides/index.rst b/docs_version3/how-to-guides/index.rst
new file mode 100644
index 00000000..7526c311
--- /dev/null
+++ b/docs_version3/how-to-guides/index.rst
@@ -0,0 +1,41 @@
+How-to Guides
+=============
+
+Practical guides for common tasks in BrainPy.
+
+.. grid:: 1 2 2 2
+
+ .. grid-item-card:: :material-regular:`save;2em` Save and Load Models
+ :link: save-load-models.html
+
+ Learn how to checkpoint and restore your trained models
+
+ .. grid-item-card:: :material-regular:`speed;2em` GPU/TPU Usage
+ :link: gpu-tpu-usage.html
+
+ Accelerate simulations with GPU and TPU
+
+ .. grid-item-card:: :material-regular:`bug_report;2em` Debugging Networks
+ :link: debugging-networks.html
+
+ Troubleshoot common issues and debug effectively
+
+ .. grid-item-card:: :material-regular:`tune;2em` Performance Optimization
+ :link: performance-optimization.html
+
+ Make your simulations run faster
+
+ .. grid-item-card:: :material-regular:`extension;2em` Custom Components
+ :link: custom-components.html
+
+ Create custom neurons, synapses, and learning rules
+
+.. toctree::
+ :hidden:
+ :maxdepth: 1
+
+ save-load-models.rst
+ gpu-tpu-usage.rst
+ debugging-networks.rst
+ performance-optimization.rst
+ custom-components.rst
diff --git a/docs_version3/how-to-guides/performance-optimization.rst b/docs_version3/how-to-guides/performance-optimization.rst
new file mode 100644
index 00000000..0c1d4b8d
--- /dev/null
+++ b/docs_version3/how-to-guides/performance-optimization.rst
@@ -0,0 +1,340 @@
+How to Optimize Performance
+============================
+
+This guide shows you how to make your BrainPy simulations run faster.
+
+.. contents:: Table of Contents
+ :local:
+ :depth: 2
+
+Quick Wins
+----------
+
+**Top 5 optimizations (80% of speedup):**
+
+1. โ
**Use JIT compilation** - 10-100ร speedup
+2. โ
**Use sparse connectivity** - 10-100ร memory reduction
+3. โ
**Batch operations** - 2-10ร speedup on GPU
+4. โ
**Use GPU/TPU** - 10-100ร speedup for large networks
+5. โ
**Minimize Python loops** - Use JAX operations instead
+
+JIT Compilation
+---------------
+
+**Essential for performance!**
+
+.. code-block:: python
+
+ import brainstate
+
+ # Slow (no JIT)
+ def slow_step(net, inp):
+ return net(inp)
+
+ # Fast (with JIT)
+ @brainstate.compile.jit
+ def fast_step(net, inp):
+ return net(inp)
+
+ # Warmup (compilation)
+ _ = fast_step(net, inp)
+
+ # 10-100ร faster than slow_step
+ output = fast_step(net, inp)
+
+**Rules for JIT:**
+- Static shapes (no dynamic array sizes)
+- Pure functions (no side effects)
+- Avoid Python loops over data
+
+Sparse Connectivity
+-------------------
+
+**Biological networks are sparse (~1-10% connectivity)**
+
+.. code-block:: python
+
+ # Dense: 10,000 ร 10,000 = 100M connections (400MB)
+ comm_dense = brainstate.nn.Linear(10000, 10000)
+
+ # Sparse: 10,000 ร 10,000 ร 0.01 = 1M connections (4MB)
+ comm_sparse = brainstate.nn.EventFixedProb(
+ 10000, 10000,
+ prob=0.01, # 1% connectivity
+ weight=0.5*u.mS
+ )
+
+**Memory savings:** 100ร for 1% connectivity
+
+Batching
+--------
+
+**Process multiple trials in parallel:**
+
+.. code-block:: python
+
+ # Sequential: 10 trials one by one
+ for trial in range(10):
+ brainstate.nn.init_all_states(net)
+ run_trial(net)
+
+ # Parallel: 10 trials simultaneously
+ brainstate.nn.init_all_states(net, batch_size=10)
+ run_batched(net) # 5-10ร faster on GPU
+
+**Optimal batch sizes:**
+- CPU: 1-16
+- GPU: 32-256
+- TPU: 128-512
+
+GPU Usage
+---------
+
+**Automatic when available:**
+
+.. code-block:: python
+
+ import jax
+ print(jax.devices()) # Check for GPU
+
+ # BrainPy automatically uses GPU
+ net = bp.LIF(10000, ...)
+ # Runs on GPU if available
+
+**See:** :doc:`gpu-tpu-usage` for details
+
+Avoid Python Loops
+------------------
+
+**Replace Python loops with JAX operations:**
+
+.. code-block:: python
+
+ # SLOW: Python loop
+ result = []
+ for i in range(1000):
+ result.append(net(inp))
+
+ # FAST: JAX loop
+ def body_fun(i):
+ return net(inp)
+
+ results = brainstate.transform.for_loop(body_fun, jnp.arange(1000))
+
+Use Appropriate Precision
+--------------------------
+
+**Float32 is usually sufficient:**
+
+.. code-block:: python
+
+ # Default (float32) - fast
+ weights = jnp.ones((1000, 1000)) # 4 bytes/element
+
+ # Float64 - 2ร slower, 2ร memory
+ weights = jnp.ones((1000, 1000), dtype=jnp.float64) # 8 bytes/element
+
+Minimize State Storage
+----------------------
+
+**Don't accumulate history:**
+
+.. code-block:: python
+
+ # BAD: Stores all history in Python list
+ history = []
+ for t in range(10000):
+ output = net(inp)
+ history.append(output) # Memory leak!
+
+ # GOOD: Process on the fly
+ for t in range(10000):
+ output = net(inp)
+ metrics = compute_metrics(output) # Don't store raw data
+
+Optimize Network Architecture
+------------------------------
+
+**1. Use simpler neuron models when possible:**
+
+.. code-block:: python
+
+ # Complex (slow but realistic)
+ neuron = bp.HH(1000, ...) # Hodgkin-Huxley
+
+ # Simple (fast)
+ neuron = bp.LIF(1000, ...) # Leaky Integrate-and-Fire
+
+**2. Use CUBA instead of COBA when possible:**
+
+.. code-block:: python
+
+ # Slower (conductance-based)
+ out = bp.COBA.desc(E=0*u.mV)
+
+ # Faster (current-based)
+ out = bp.CUBA.desc()
+
+**3. Reduce connectivity:**
+
+.. code-block:: python
+
+ # Dense
+ prob = 0.1 # 10% connectivity
+
+ # Sparse
+ prob = 0.02 # 2% connectivity (5ร fewer connections)
+
+Profile Before Optimizing
+--------------------------
+
+**Identify actual bottlenecks:**
+
+.. code-block:: python
+
+ import time
+
+ # Time different components
+ start = time.time()
+ for _ in range(100):
+ net(inp)
+ print(f"Network update: {time.time() - start:.2f}s")
+
+ start = time.time()
+ for _ in range(100):
+ output = process_output(net.get_spike())
+ print(f"Output processing: {time.time() - start:.2f}s")
+
+**Don't optimize blindly - measure first!**
+
+Performance Checklist
+---------------------
+
+**For maximum performance:**
+
+.. code-block:: python
+
+ โ
JIT compiled (@brainstate.compile.jit)
+ โ
Sparse connectivity (EventFixedProb with prob < 0.1)
+ โ
Batched (batch_size โฅ 32 on GPU)
+ โ
GPU enabled (check jax.devices())
+ โ
Static shapes (no dynamic array sizes)
+ โ
Minimal history storage
+ โ
Appropriate neuron models (LIF vs HH)
+ โ
Float32 precision
+
+Common Bottlenecks
+------------------
+
+**Issue 1: First run very slow**
+ โ JIT compilation happens on first call (warmup)
+
+**Issue 2: CPU-GPU transfers**
+ โ Keep data on GPU between operations
+
+**Issue 3: Small batch sizes**
+ โ Increase batch_size for better GPU utilization
+
+**Issue 4: Python loops**
+ โ Replace with JAX operations (for_loop, vmap)
+
+**Issue 5: Dense connectivity**
+ โ Use sparse (EventFixedProb) for large networks
+
+Complete Optimization Example
+------------------------------
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+ import jax
+
+ # Optimized network
+ class OptimizedNetwork(brainstate.nn.Module):
+ def __init__(self, n_neurons=10000):
+ super().__init__()
+
+ # Simple neuron model
+ self.neurons = bp.LIF(n_neurons, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+
+ # Sparse connectivity
+ self.recurrent = bp.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(
+ n_neurons, n_neurons,
+ prob=0.01, # Sparse!
+ weight=0.5*u.mS
+ ),
+ syn=bp.Expon.desc(n_neurons, tau=5*u.ms),
+ out=bp.CUBA.desc(), # Simple output
+ post=self.neurons
+ )
+
+ def update(self, inp):
+ spk = self.neurons.get_spike()
+ self.recurrent(spk)
+ self.neurons(inp)
+ return spk
+
+ # Initialize
+ net = OptimizedNetwork()
+ brainstate.nn.init_all_states(net, batch_size=64) # Batched
+
+ # JIT compile
+ @brainstate.compile.jit
+ def simulate_step(net, inp):
+ return net(inp)
+
+ # Warmup
+ inp = brainstate.random.rand(64, 10000) * 2.0 * u.nA
+ _ = simulate_step(net, inp)
+
+ # Fast simulation
+ import time
+ start = time.time()
+ for _ in range(1000):
+ output = simulate_step(net, inp)
+ elapsed = time.time() - start
+
+ print(f"Optimized: {1000/elapsed:.1f} steps/s")
+ print(f"Throughput: {64*1000/elapsed:.1f} trials/s")
+
+Benchmark Results
+-----------------
+
+**Typical speedups from optimization:**
+
+.. list-table::
+ :header-rows: 1
+
+ * - Optimization
+ - Speedup
+ - Cumulative
+ * - Baseline (Python loops, dense)
+ - 1ร
+ - 1ร
+ * - + JIT compilation
+ - 10-50ร
+ - 10-50ร
+ * - + Sparse connectivity
+ - 2-10ร
+ - 20-500ร
+ * - + GPU
+ - 5-20ร
+ - 100-10,000ร
+ * - + Batching
+ - 2-5ร
+ - 200-50,000ร
+
+**Real example:** 10,000 neuron network
+- Baseline (CPU, no JIT): 0.5 steps/s
+- Optimized (GPU, JIT, sparse, batched): 5,000 steps/s
+- **Total speedup: 10,000ร**
+
+See Also
+--------
+
+- :doc:`../tutorials/advanced/07-large-scale-simulations`
+- :doc:`gpu-tpu-usage`
+- :doc:`debugging-networks`
diff --git a/docs_version3/how-to-guides/save-load-models.rst b/docs_version3/how-to-guides/save-load-models.rst
new file mode 100644
index 00000000..5666376c
--- /dev/null
+++ b/docs_version3/how-to-guides/save-load-models.rst
@@ -0,0 +1,890 @@
+How to Save and Load Models
+============================
+
+This guide shows you how to save and load BrainPy models for checkpointing, resuming training, and deployment.
+
+.. contents:: Table of Contents
+ :local:
+ :depth: 2
+
+Quick Start
+-----------
+
+**Save a trained model:**
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import pickle
+
+ # After training...
+ state_dict = {
+ 'params': net.states(brainstate.ParamState),
+ 'epoch': current_epoch,
+ }
+
+ with open('model.pkl', 'wb') as f:
+ pickle.dump(state_dict, f)
+
+**Load a model:**
+
+.. code-block:: python
+
+ # Create model with same architecture
+ net = MyNetwork()
+ brainstate.nn.init_all_states(net)
+
+ # Load saved state
+ with open('model.pkl', 'rb') as f:
+ state_dict = pickle.load(f)
+
+ # Restore parameters
+ for name, state in state_dict['params'].items():
+ net.states(brainstate.ParamState)[name].value = state.value
+
+Understanding What to Save
+---------------------------
+
+State Types
+~~~~~~~~~~~
+
+BrainPy has three state types with different persistence requirements:
+
+**ParamState (Always save)**
+ - Learnable weights and biases
+ - Required to restore trained model
+ - Examples: synaptic weights, neural biases
+
+**LongTermState (Usually save)**
+ - Persistent statistics and counters
+ - Not updated by gradients
+ - Examples: running averages, spike counts
+
+**ShortTermState (Never save)**
+ - Temporary dynamics that reset each trial
+ - Will be re-initialized anyway
+ - Examples: membrane potentials, synaptic conductances
+
+Recommended Approach
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ def save_checkpoint(net, optimizer, epoch, filepath):
+ """Save model checkpoint."""
+ state_dict = {
+ # Required: model parameters
+ 'params': net.states(brainstate.ParamState),
+
+ # Optional but recommended: long-term states
+ 'long_term': net.states(brainstate.LongTermState),
+
+ # Training metadata
+ 'epoch': epoch,
+ 'optimizer_state': optimizer.state_dict(), # If continuing training
+
+ # Model configuration (helpful for loading)
+ 'config': {
+ 'n_input': net.n_input,
+ 'n_hidden': net.n_hidden,
+ 'n_output': net.n_output,
+ # ... other hyperparameters
+ }
+ }
+
+ with open(filepath, 'wb') as f:
+ pickle.dump(state_dict, f)
+
+ print(f"โ
Saved checkpoint to {filepath}")
+
+Basic Save/Load
+---------------
+
+Using Pickle (Simple)
+~~~~~~~~~~~~~~~~~~~~~
+
+**Advantages:**
+- Simple and straightforward
+- Works with any Python object
+- Good for quick prototyping
+
+**Disadvantages:**
+- Python-specific format
+- Version compatibility issues
+- Not human-readable
+
+.. code-block:: python
+
+ import pickle
+ import brainpy as bp
+ import brainstate
+
+ # Define your model
+ class SimpleNet(brainstate.nn.Module):
+ def __init__(self, n_neurons=100):
+ super().__init__()
+ self.lif = bp.LIF(n_neurons, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
+ self.fc = brainstate.nn.Linear(n_neurons, 10)
+
+ def update(self, x):
+ self.lif(x)
+ return self.fc(self.lif.get_spike())
+
+ # Train model
+ net = SimpleNet()
+ brainstate.nn.init_all_states(net)
+ # ... training code ...
+
+ # Save
+ params = net.states(brainstate.ParamState)
+ with open('simple_net.pkl', 'wb') as f:
+ pickle.dump(params, f)
+
+ # Load
+ net_new = SimpleNet()
+ brainstate.nn.init_all_states(net_new)
+
+ with open('simple_net.pkl', 'rb') as f:
+ loaded_params = pickle.load(f)
+
+ # Restore parameters
+ for name, state in loaded_params.items():
+ net_new.states(brainstate.ParamState)[name].value = state.value
+
+Using NumPy (Arrays Only)
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Advantages:**
+- Language-agnostic
+- Efficient storage
+- Widely supported
+
+**Disadvantages:**
+- Only saves arrays (not structure)
+- Need to manually track parameter names
+
+.. code-block:: python
+
+ import numpy as np
+
+ # Save parameters as .npz
+ params = net.states(brainstate.ParamState)
+ param_dict = {name: np.array(state.value) for name, state in params.items()}
+ np.savez('model_params.npz', **param_dict)
+
+ # Load parameters
+ loaded = np.load('model_params.npz')
+ for name, array in loaded.items():
+ net.states(brainstate.ParamState)[name].value = jnp.array(array)
+
+Checkpointing During Training
+------------------------------
+
+Periodic Checkpoints
+~~~~~~~~~~~~~~~~~~~~
+
+Save at regular intervals during training.
+
+.. code-block:: python
+
+ import braintools
+
+ # Training setup
+ net = MyNetwork()
+ optimizer = braintools.optim.Adam(lr=1e-3)
+ optimizer.register_trainable_weights(net.states(brainstate.ParamState))
+
+ save_interval = 5 # Save every 5 epochs
+ checkpoint_dir = './checkpoints'
+ import os
+ os.makedirs(checkpoint_dir, exist_ok=True)
+
+ # Training loop
+ for epoch in range(num_epochs):
+ # Training step
+ for batch in train_loader:
+ loss = train_step(net, optimizer, batch)
+
+ # Periodic save
+ if (epoch + 1) % save_interval == 0:
+ checkpoint_path = f'{checkpoint_dir}/epoch_{epoch+1}.pkl'
+ save_checkpoint(net, optimizer, epoch, checkpoint_path)
+
+ print(f"Epoch {epoch+1}: Loss={loss:.4f}, Checkpoint saved")
+
+Best Model Checkpoint
+~~~~~~~~~~~~~~~~~~~~~
+
+Save only when validation performance improves.
+
+.. code-block:: python
+
+ best_val_loss = float('inf')
+ best_model_path = 'best_model.pkl'
+
+ for epoch in range(num_epochs):
+ # Training
+ train_loss = train_epoch(net, optimizer, train_loader)
+
+ # Validation
+ val_loss = validate(net, val_loader)
+
+ # Save if best
+ if val_loss < best_val_loss:
+ best_val_loss = val_loss
+ save_checkpoint(net, optimizer, epoch, best_model_path)
+ print(f"โ
New best model! Val loss: {val_loss:.4f}")
+
+ print(f"Epoch {epoch+1}: Train={train_loss:.4f}, Val={val_loss:.4f}")
+
+Resuming Training
+~~~~~~~~~~~~~~~~~
+
+Continue training from a checkpoint.
+
+.. code-block:: python
+
+ def load_checkpoint(filepath, net, optimizer=None):
+ """Load checkpoint and restore state."""
+ with open(filepath, 'rb') as f:
+ state_dict = pickle.load(f)
+
+ # Restore model parameters
+ params = net.states(brainstate.ParamState)
+ for name, state in state_dict['params'].items():
+ if name in params:
+ params[name].value = state.value
+
+ # Restore long-term states
+ if 'long_term' in state_dict:
+ long_term = net.states(brainstate.LongTermState)
+ for name, state in state_dict['long_term'].items():
+ if name in long_term:
+ long_term[name].value = state.value
+
+ # Restore optimizer state
+ if optimizer is not None and 'optimizer_state' in state_dict:
+ optimizer.load_state_dict(state_dict['optimizer_state'])
+
+ start_epoch = state_dict.get('epoch', 0) + 1
+ return start_epoch
+
+ # Resume training
+ net = MyNetwork()
+ brainstate.nn.init_all_states(net)
+ optimizer = braintools.optim.Adam(lr=1e-3)
+ optimizer.register_trainable_weights(net.states(brainstate.ParamState))
+
+ # Load checkpoint
+ start_epoch = load_checkpoint('checkpoint_epoch_50.pkl', net, optimizer)
+
+ # Continue training from where we left off
+ for epoch in range(start_epoch, num_epochs):
+ train_step(net, optimizer, train_loader)
+
+Advanced Saving Strategies
+---------------------------
+
+Versioned Checkpoints
+~~~~~~~~~~~~~~~~~~~~~
+
+Keep multiple checkpoints without overwriting.
+
+.. code-block:: python
+
+ from datetime import datetime
+
+ def save_versioned_checkpoint(net, epoch, base_dir='checkpoints'):
+ """Save checkpoint with timestamp."""
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
+ filename = f'model_epoch{epoch}_{timestamp}.pkl'
+ filepath = os.path.join(base_dir, filename)
+
+ state_dict = {
+ 'params': net.states(brainstate.ParamState),
+ 'epoch': epoch,
+ 'timestamp': timestamp,
+ }
+
+ with open(filepath, 'wb') as f:
+ pickle.dump(state_dict, f)
+
+ return filepath
+
+Keep Last N Checkpoints
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+Automatically delete old checkpoints to save disk space.
+
+.. code-block:: python
+
+ import glob
+
+ def save_with_cleanup(net, epoch, checkpoint_dir='checkpoints', keep_last=5):
+ """Save checkpoint and keep only last N."""
+
+ # Save new checkpoint
+ filepath = f'{checkpoint_dir}/epoch_{epoch:04d}.pkl'
+ save_checkpoint(net, None, epoch, filepath)
+
+ # Get all checkpoints
+ checkpoints = sorted(glob.glob(f'{checkpoint_dir}/epoch_*.pkl'))
+
+ # Delete old ones
+ if len(checkpoints) > keep_last:
+ for old_checkpoint in checkpoints[:-keep_last]:
+ os.remove(old_checkpoint)
+ print(f"Removed old checkpoint: {old_checkpoint}")
+
+Conditional Saving
+~~~~~~~~~~~~~~~~~~
+
+Save based on custom criteria.
+
+.. code-block:: python
+
+ class CheckpointManager:
+ """Manage model checkpoints with custom logic."""
+
+ def __init__(self, checkpoint_dir, keep_best=True, keep_last=3):
+ self.checkpoint_dir = checkpoint_dir
+ self.keep_best = keep_best
+ self.keep_last = keep_last
+ self.best_metric = float('inf')
+ os.makedirs(checkpoint_dir, exist_ok=True)
+
+ def save(self, net, epoch, metric, is_better=None):
+ """Save checkpoint based on metric.
+
+ Args:
+ net: Network to save
+ epoch: Current epoch
+ metric: Validation metric
+ is_better: Function to compare metrics (default: lower is better)
+ """
+ if is_better is None:
+ is_better = lambda new, old: new < old
+
+ # Save if best
+ if self.keep_best and is_better(metric, self.best_metric):
+ self.best_metric = metric
+ filepath = f'{self.checkpoint_dir}/best_model.pkl'
+ save_checkpoint(net, None, epoch, filepath)
+ print(f"๐พ Saved best model (metric: {metric:.4f})")
+
+ # Save periodic
+ filepath = f'{self.checkpoint_dir}/epoch_{epoch:04d}.pkl'
+ save_checkpoint(net, None, epoch, filepath)
+
+ # Cleanup old checkpoints
+ self._cleanup()
+
+ def _cleanup(self):
+ """Keep only last N checkpoints."""
+ checkpoints = sorted(glob.glob(f'{self.checkpoint_dir}/epoch_*.pkl'))
+ if len(checkpoints) > self.keep_last:
+ for old in checkpoints[:-self.keep_last]:
+ os.remove(old)
+
+ # Usage
+ manager = CheckpointManager('./checkpoints', keep_best=True, keep_last=3)
+
+ for epoch in range(num_epochs):
+ train_loss = train_epoch(net, optimizer, train_loader)
+ val_loss = validate(net, val_loader)
+
+ manager.save(net, epoch, metric=val_loss)
+
+Model Export for Deployment
+----------------------------
+
+Minimal Model File
+~~~~~~~~~~~~~~~~~~
+
+Save only what's needed for inference.
+
+.. code-block:: python
+
+ def export_for_inference(net, filepath, metadata=None):
+ """Export minimal model for inference."""
+
+ export_dict = {
+ 'params': net.states(brainstate.ParamState),
+ 'config': {
+ # Only architecture info, no training state
+ 'model_type': net.__class__.__name__,
+ # ... architecture hyperparameters
+ }
+ }
+
+ if metadata:
+ export_dict['metadata'] = metadata
+
+ with open(filepath, 'wb') as f:
+ pickle.dump(export_dict, f)
+
+ # Report size
+ size_mb = os.path.getsize(filepath) / (1024 * 1024)
+ print(f"๐ฆ Exported model: {size_mb:.2f} MB")
+
+ # Export trained model
+ export_for_inference(
+ net,
+ 'deployed_model.pkl',
+ metadata={
+ 'description': 'LIF network for digit classification',
+ 'accuracy': 0.95,
+ 'date': datetime.now().isoformat()
+ }
+ )
+
+Loading for Inference
+~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ def load_for_inference(filepath, model_class):
+ """Load model for inference only."""
+
+ with open(filepath, 'rb') as f:
+ export_dict = pickle.load(f)
+
+ # Create model from config
+ config = export_dict['config']
+ net = model_class(**config) # Must match saved config
+ brainstate.nn.init_all_states(net)
+
+ # Load parameters
+ params = net.states(brainstate.ParamState)
+ for name, state in export_dict['params'].items():
+ params[name].value = state.value
+
+ return net, export_dict.get('metadata')
+
+ # Load and use
+ net, metadata = load_for_inference('deployed_model.pkl', MyNetwork)
+ print(f"Loaded model: {metadata['description']}")
+
+ # Run inference
+ brainstate.nn.init_all_states(net)
+ output = net(input_data)
+
+Saving Model Architecture
+--------------------------
+
+Configuration-Based
+~~~~~~~~~~~~~~~~~~~
+
+Save hyperparameters to recreate model.
+
+.. code-block:: python
+
+ class ConfigurableNetwork(brainstate.nn.Module):
+ """Network that can be created from config."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ # Build from config
+ self.input_layer = brainstate.nn.Linear(
+ config['n_input'],
+ config['n_hidden']
+ )
+ self.hidden = bp.LIF(
+ config['n_hidden'],
+ V_rest=config['V_rest'],
+ V_th=config['V_th'],
+ tau=config['tau']
+ )
+ # ... more layers
+
+ @classmethod
+ def from_config(cls, config):
+ """Create model from config dict."""
+ return cls(config)
+
+ def get_config(self):
+ """Get configuration dict."""
+ return self.config.copy()
+
+ # Save with config
+ config = {
+ 'n_input': 784,
+ 'n_hidden': 128,
+ 'n_output': 10,
+ 'V_rest': -65.0,
+ 'V_th': -50.0,
+ 'tau': 10.0
+ }
+
+ net = ConfigurableNetwork(config)
+ # ... train ...
+
+ # Save both params and config
+ checkpoint = {
+ 'config': net.get_config(),
+ 'params': net.states(brainstate.ParamState)
+ }
+
+ with open('model_with_config.pkl', 'wb') as f:
+ pickle.dump(checkpoint, f)
+
+ # Load from config
+ with open('model_with_config.pkl', 'rb') as f:
+ checkpoint = pickle.load(f)
+
+ net_new = ConfigurableNetwork.from_config(checkpoint['config'])
+ brainstate.nn.init_all_states(net_new)
+
+ for name, state in checkpoint['params'].items():
+ net_new.states(brainstate.ParamState)[name].value = state.value
+
+Handling Model Updates
+----------------------
+
+Version Compatibility
+~~~~~~~~~~~~~~~~~~~~~
+
+Handle changes in model architecture.
+
+.. code-block:: python
+
+ VERSION = '2.0'
+
+ def save_with_version(net, filepath):
+ """Save model with version info."""
+ checkpoint = {
+ 'version': VERSION,
+ 'params': net.states(brainstate.ParamState),
+ 'config': net.get_config()
+ }
+
+ with open(filepath, 'wb') as f:
+ pickle.dump(checkpoint, f)
+
+ def load_with_migration(filepath, model_class):
+ """Load model with version migration."""
+ with open(filepath, 'rb') as f:
+ checkpoint = pickle.load(f)
+
+ version = checkpoint.get('version', '1.0')
+
+ # Migrate old versions
+ if version == '1.0':
+ print("Migrating from v1.0 to v2.0...")
+ checkpoint = migrate_v1_to_v2(checkpoint)
+
+ # Create model
+ net = model_class.from_config(checkpoint['config'])
+ brainstate.nn.init_all_states(net)
+
+ # Load parameters
+ for name, state in checkpoint['params'].items():
+ if name in net.states(brainstate.ParamState):
+ net.states(brainstate.ParamState)[name].value = state.value
+ else:
+ print(f"โ ๏ธ Skipping unknown parameter: {name}")
+
+ return net
+
+ def migrate_v1_to_v2(checkpoint):
+ """Migrate checkpoint from v1.0 to v2.0."""
+ # Example: rename parameter
+ if 'old_param_name' in checkpoint['params']:
+ checkpoint['params']['new_param_name'] = checkpoint['params'].pop('old_param_name')
+
+ checkpoint['version'] = '2.0'
+ return checkpoint
+
+Partial Loading
+~~~~~~~~~~~~~~~
+
+Load only some parameters (e.g., for transfer learning).
+
+.. code-block:: python
+
+ def load_partial(filepath, net, param_filter=None):
+ """Load only specified parameters.
+
+ Args:
+ filepath: Checkpoint file
+ net: Network to load into
+ param_filter: Function that takes param name and returns True to load
+ """
+ with open(filepath, 'rb') as f:
+ checkpoint = pickle.load(f)
+
+ if param_filter is None:
+ param_filter = lambda name: True
+
+ loaded_count = 0
+ skipped_count = 0
+
+ for name, state in checkpoint['params'].items():
+ if param_filter(name):
+ if name in net.states(brainstate.ParamState):
+ net.states(brainstate.ParamState)[name].value = state.value
+ loaded_count += 1
+ else:
+ print(f"โ ๏ธ Parameter not found in model: {name}")
+ skipped_count += 1
+ else:
+ skipped_count += 1
+
+ print(f"โ
Loaded {loaded_count} parameters, skipped {skipped_count}")
+
+ # Example: Load only encoder parameters
+ load_partial(
+ 'pretrained.pkl',
+ net,
+ param_filter=lambda name: name.startswith('encoder.')
+ )
+
+Common Patterns
+---------------
+
+Pattern 1: Training Session Manager
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class TrainingSession:
+ """Manage full training session with checkpointing."""
+
+ def __init__(self, net, optimizer, checkpoint_dir='./checkpoints'):
+ self.net = net
+ self.optimizer = optimizer
+ self.checkpoint_dir = checkpoint_dir
+ self.epoch = 0
+ self.best_metric = float('inf')
+
+ os.makedirs(checkpoint_dir, exist_ok=True)
+
+ def save(self, metric=None):
+ """Save current state."""
+ checkpoint = {
+ 'params': self.net.states(brainstate.ParamState),
+ 'optimizer': self.optimizer.state_dict(),
+ 'epoch': self.epoch,
+ 'best_metric': self.best_metric
+ }
+
+ # Regular checkpoint
+ filepath = f'{self.checkpoint_dir}/checkpoint_latest.pkl'
+ with open(filepath, 'wb') as f:
+ pickle.dump(checkpoint, f)
+
+ # Best checkpoint
+ if metric is not None and metric < self.best_metric:
+ self.best_metric = metric
+ best_path = f'{self.checkpoint_dir}/checkpoint_best.pkl'
+ with open(best_path, 'wb') as f:
+ pickle.dump(checkpoint, f)
+
+ def restore(self, filepath=None):
+ """Restore from checkpoint."""
+ if filepath is None:
+ filepath = f'{self.checkpoint_dir}/checkpoint_latest.pkl'
+
+ with open(filepath, 'rb') as f:
+ checkpoint = pickle.load(f)
+
+ # Restore state
+ for name, state in checkpoint['params'].items():
+ self.net.states(brainstate.ParamState)[name].value = state.value
+
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ self.epoch = checkpoint['epoch']
+ self.best_metric = checkpoint['best_metric']
+
+ print(f"โ
Restored from epoch {self.epoch}")
+
+Pattern 2: Model Zoo
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ class ModelZoo:
+ """Collection of pre-trained models."""
+
+ def __init__(self, zoo_dir='./model_zoo'):
+ self.zoo_dir = zoo_dir
+ os.makedirs(zoo_dir, exist_ok=True)
+
+ def save_model(self, net, name, metadata=None):
+ """Add model to zoo."""
+ model_path = f'{self.zoo_dir}/{name}.pkl'
+ export_dict = {
+ 'params': net.states(brainstate.ParamState),
+ 'config': net.get_config(),
+ 'metadata': metadata or {}
+ }
+
+ with open(model_path, 'wb') as f:
+ pickle.dump(export_dict, f)
+
+ print(f"๐ฆ Added {name} to model zoo")
+
+ def load_model(self, name, model_class):
+ """Load model from zoo."""
+ model_path = f'{self.zoo_dir}/{name}.pkl'
+
+ with open(model_path, 'rb') as f:
+ export_dict = pickle.load(f)
+
+ net = model_class.from_config(export_dict['config'])
+ brainstate.nn.init_all_states(net)
+
+ for param_name, state in export_dict['params'].items():
+ net.states(brainstate.ParamState)[param_name].value = state.value
+
+ return net, export_dict['metadata']
+
+ def list_models(self):
+ """List available models."""
+ models = glob.glob(f'{self.zoo_dir}/*.pkl')
+ return [os.path.basename(m).replace('.pkl', '') for m in models]
+
+ # Usage
+ zoo = ModelZoo()
+
+ # Save trained models
+ zoo.save_model(net1, 'mnist_classifier', {'accuracy': 0.98})
+ zoo.save_model(net2, 'fashion_classifier', {'accuracy': 0.92})
+
+ # List and load
+ print("Available models:", zoo.list_models())
+ net, metadata = zoo.load_model('mnist_classifier', MyNetwork)
+ print(f"Loaded model with accuracy: {metadata['accuracy']}")
+
+Troubleshooting
+---------------
+
+Issue: Pickle version mismatch
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptom:** `AttributeError` or `ModuleNotFoundError` when loading
+
+**Solution:** Use protocol version 4 or lower for compatibility
+
+.. code-block:: python
+
+ # Save with specific protocol
+ with open('model.pkl', 'wb') as f:
+ pickle.dump(state_dict, f, protocol=4)
+
+Issue: JAX array serialization
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptom:** Can't pickle JAX arrays directly
+
+**Solution:** Convert to NumPy before saving
+
+.. code-block:: python
+
+ import numpy as np
+
+ # Convert to NumPy for saving
+ params_np = {
+ name: np.array(state.value)
+ for name, state in net.states(brainstate.ParamState).items()
+ }
+
+ with open('model.pkl', 'wb') as f:
+ pickle.dump(params_np, f)
+
+ # Convert back when loading
+ for name, array in params_np.items():
+ net.states(brainstate.ParamState)[name].value = jnp.array(array)
+
+Issue: Model architecture changed
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**Symptom:** Parameter names don't match
+
+**Solution:** Use partial loading with error handling
+
+.. code-block:: python
+
+ def safe_load(checkpoint_path, net):
+ """Load with error handling."""
+ with open(checkpoint_path, 'rb') as f:
+ checkpoint = pickle.load(f)
+
+ current_params = net.states(brainstate.ParamState)
+ loaded_params = checkpoint['params']
+
+ # Check compatibility
+ missing = set(current_params.keys()) - set(loaded_params.keys())
+ unexpected = set(loaded_params.keys()) - set(current_params.keys())
+
+ if missing:
+ print(f"โ ๏ธ Missing parameters: {missing}")
+ if unexpected:
+ print(f"โ ๏ธ Unexpected parameters: {unexpected}")
+
+ # Load matching parameters
+ for name in current_params.keys() & loaded_params.keys():
+ current_params[name].value = loaded_params[name].value
+
+ print(f"โ
Loaded {len(current_params.keys() & loaded_params.keys())} parameters")
+
+Best Practices
+--------------
+
+โ
**Always save configuration** - Include hyperparameters for reproducibility
+
+โ
**Version your checkpoints** - Track model version for compatibility
+
+โ
**Save metadata** - Include training metrics, date, description
+
+โ
**Regular backups** - Save periodically during long training
+
+โ
**Keep best model** - Separate best and latest checkpoints
+
+โ
**Test loading** - Verify checkpoint can be loaded before continuing
+
+โ
**Use relative paths** - Make checkpoints portable
+
+โ
**Document format** - Comment what's in your checkpoint files
+
+โ **Don't save ShortTermState** - It resets anyway
+
+โ **Don't save everything** - Minimize checkpoint size
+
+โ **Don't overwrite** - Keep multiple checkpoints for safety
+
+Summary
+-------
+
+**Quick reference:**
+
+.. code-block:: python
+
+ # Save
+ checkpoint = {
+ 'params': net.states(brainstate.ParamState),
+ 'epoch': epoch,
+ 'config': net.get_config()
+ }
+ with open('checkpoint.pkl', 'wb') as f:
+ pickle.dump(checkpoint, f)
+
+ # Load
+ with open('checkpoint.pkl', 'rb') as f:
+ checkpoint = pickle.load(f)
+
+ net = MyNetwork.from_config(checkpoint['config'])
+ brainstate.nn.init_all_states(net)
+
+ for name, state in checkpoint['params'].items():
+ net.states(brainstate.ParamState)[name].value = state.value
+
+See Also
+--------
+
+- :doc:`../core-concepts/state-management` - Understanding states
+- :doc:`../tutorials/advanced/05-snn-training` - Training models
+- :doc:`gpu-tpu-usage` - Accelerated training
diff --git a/docs_version3/index.rst b/docs_version3/index.rst
index 58d3bb32..82245c4d 100644
--- a/docs_version3/index.rst
+++ b/docs_version3/index.rst
@@ -1,8 +1,18 @@
-``brainpy`` documentation
-=========================
+BrainPy documentation
+=====================
-`brainpy `_ provides a powerful and flexible framework
-for building, simulating, and training spiking neural networks.
+`BrainPy`_ is a flexible, efficient, and extensible framework for computational neuroscience
+and brain-inspired computation. It provides a powerful and flexible framework for building,
+simulating, and training spiking neural networks.
+
+
+.. _BrainPy: https://github.com/brainpy/BrainPy
+
+
+.. note::
+
+ ``BrainPy>=3.0.0`` is rewritten based on `brainstate `_ since August 2025.
+ This documentation is for the latest version 3.x.
@@ -22,6 +32,7 @@ Installation
.. code-block:: bash
pip install -U brainpy[cuda12]
+
pip install -U brainpy[cuda13]
.. tab-item:: TPU
@@ -30,13 +41,83 @@ Installation
pip install -U brainpy[tpu]
+ .. tab-item:: Ecosystem
+
+ .. code-block:: bash
+
+ pip install -U BrainX
+
+
+
----
+Learn more
+^^^^^^^^^^
+
+.. grid::
+
+ .. grid-item::
+ :columns: 6 6 6 4
+
+ .. card:: :material-regular:`rocket_launch;2em` 5-Minute Tutorial
+ :class-card: sd-text-black sd-bg-light
+ :link: quickstart/5min-tutorial.html
+
+ .. grid-item::
+ :columns: 6 6 6 4
+
+ .. card:: :material-regular:`library_books;2em` Core Concepts
+ :class-card: sd-text-black sd-bg-light
+ :link: quickstart/concepts-overview.html
+
+ .. grid-item::
+ :columns: 6 6 6 4
+
+ .. card:: :material-regular:`school;2em` Tutorials
+ :class-card: sd-text-black sd-bg-light
+ :link: tutorials/index.html
+
+ .. grid-item::
+ :columns: 6 6 6 4
+
+ .. card:: :material-regular:`explore;2em` Examples Gallery
+ :class-card: sd-text-black sd-bg-light
+ :link: examples/gallery.html
+
+ .. grid-item::
+ :columns: 6 6 6 4
+
+ .. card:: :material-regular:`data_exploration;2em` API Documentation
+ :class-card: sd-text-black sd-bg-light
+ :link: apis.html
+
+ .. grid-item::
+ :columns: 6 6 6 4
+
+ .. card:: :material-regular:`swap_horiz;2em` Migration from 2.x
+ :class-card: sd-text-black sd-bg-light
+ :link: migration/migration-guide.html
+
+ .. grid-item::
+ :columns: 6 6 6 4
+
+ .. card:: :material-regular:`settings;2em` Ecosystem
+ :class-card: sd-text-black sd-bg-light
+ :link: https://brainmodeling.readthedocs.io
+
+ .. grid-item::
+ :columns: 6 6 6 4
+
+ .. card:: :material-regular:`history;2em` Changelog
+ :class-card: sd-text-black sd-bg-light
+ :link: changelog.html
+
+
+----
See also the ecosystem
^^^^^^^^^^^^^^^^^^^^^^
-
``brainpy`` is one part of our `brain simulation ecosystem `_.
@@ -47,11 +128,53 @@ See also the ecosystem
:maxdepth: 1
:caption: Quickstart
- quickstart/concepts-en.ipynb
- quickstart/concepts-zh.ipynb
+ quickstart/installation.rst
+ quickstart/5min-tutorial.ipynb
+ quickstart/concepts-overview.rst
+
+
+.. toctree::
+ :hidden:
+ :maxdepth: 2
+ :caption: Core Concepts
+
+ core-concepts/architecture.rst
+ core-concepts/neurons.rst
+ core-concepts/synapses.rst
+ core-concepts/projections.rst
+ core-concepts/state-management.rst
+
+
+.. toctree::
+ :hidden:
+ :maxdepth: 2
+ :caption: Tutorials
+
+ tutorials/index.rst
+.. toctree::
+ :hidden:
+ :maxdepth: 2
+ :caption: How-to Guides
+
+ how-to-guides/index.rst
+
+
+.. toctree::
+ :hidden:
+ :maxdepth: 2
+ :caption: Examples
+
+ examples/gallery.rst
+
+
+.. toctree::
+ :hidden:
+ :maxdepth: 2
+ :caption: Migration
+ migration/migration-guide.rst
.. toctree::
@@ -59,6 +182,6 @@ See also the ecosystem
:maxdepth: 2
:caption: API Reference
+ api/index.rst
changelog.md
- apis.rst
diff --git a/docs_version3/migration/migration-guide.rst b/docs_version3/migration/migration-guide.rst
new file mode 100644
index 00000000..72c7abd3
--- /dev/null
+++ b/docs_version3/migration/migration-guide.rst
@@ -0,0 +1,567 @@
+Migration Guide: BrainPy 2.x to 3.0
+====================================
+
+This guide helps you migrate your code from BrainPy 2.x to BrainPy 3.0. BrainPy 3.0 represents a complete rewrite built on ``brainstate``, with significant architectural changes and API improvements.
+
+Overview of Changes
+-------------------
+
+BrainPy 3.0 introduces several major changes:
+
+**Architecture**
+ - Built on ``brainstate`` framework
+ - State-based programming model
+ - Integrated physical units (``brainunit``)
+ - Modular projection architecture
+
+**API Changes**
+ - New neuron and synapse interfaces
+ - Projection system redesign
+ - Updated simulation APIs
+ - Training framework changes
+
+**Performance**
+ - Improved JIT compilation
+ - Better memory efficiency
+ - Enhanced GPU/TPU support
+
+Compatibility Layer
+-------------------
+
+BrainPy 3.0 includes ``brainpy.version2`` for backward compatibility:
+
+.. code-block:: python
+
+ # Old code (BrainPy 2.x) - still works with deprecation warning
+ import brainpy as bp
+ # bp.math, bp.layers, etc. redirect to bp.version2
+
+ # Explicit version2 usage (recommended during migration)
+ import brainpy.version2 as bp2
+
+ # New BrainPy 3.0 API
+ import brainpy # Use new 3.0 features
+
+Migration Strategy
+------------------
+
+Recommended Approach
+~~~~~~~~~~~~~~~~~~~~
+
+1. **Gradual Migration**: Use ``brainpy.version2`` for old code while writing new code with 3.0 API
+2. **Test Thoroughly**: Ensure numerical equivalence between versions
+3. **Update Incrementally**: Migrate module by module, not all at once
+4. **Use Both**: Mix version2 and 3.0 code during transition
+
+.. code-block:: python
+
+ # During migration
+ import brainpy # New 3.0 API
+ import brainpy.version2 as bp2 # Old 2.x API
+
+ # Old model
+ old_network = bp2.dyn.Network(...)
+
+ # New model
+ new_network = brainpy.LIF(...)
+
+ # Can coexist in same codebase
+
+Key API Changes
+---------------
+
+Imports and Modules
+~~~~~~~~~~~~~~~~~~~
+
+**BrainPy 2.x:**
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainpy.math as bm
+ import brainpy.layers as layers
+ import brainpy.dyn as dyn
+ from brainpy import neurons, synapses
+
+**BrainPy 3.0:**
+
+.. code-block:: python
+
+ import brainpy as bp # Core neurons, synapses, projections
+ import brainstate # State management, modules
+ import brainunit as u # Physical units
+ import braintools # Utilities, optimizers, etc.
+
+Neuron Models
+~~~~~~~~~~~~~
+
+**BrainPy 2.x:**
+
+.. code-block:: python
+
+ # Old API
+ neurons = bp.neurons.LIF(
+ size=100,
+ V_rest=-65.,
+ V_th=-50.,
+ V_reset=-60.,
+ tau=10.,
+ V_initializer=bp.init.Normal(-60., 5.)
+ )
+
+**BrainPy 3.0:**
+
+.. code-block:: python
+
+ # New API - with units!
+ import brainunit as u
+ import braintools
+
+ neurons = brainpy.LIF(
+ size=100,
+ V_rest=-65. * u.mV, # Units required
+ V_th=-50. * u.mV,
+ V_reset=-60. * u.mV,
+ tau=10. * u.ms,
+ V_initializer=braintools.init.Normal(-60., 5., unit=u.mV)
+ )
+
+**Key Changes:**
+
+- Simpler import: ``brainpy.LIF`` instead of ``bp.neurons.LIF``
+- Physical units are mandatory
+- Initializers from ``braintools.init``
+- Must use ``brainstate.nn.init_all_states()`` before simulation
+
+Synapse Models
+~~~~~~~~~~~~~~
+
+**BrainPy 2.x:**
+
+.. code-block:: python
+
+ # Old API
+ syn = bp.synapses.Exponential(
+ pre=pre_neurons,
+ post=post_neurons,
+ conn=bp.connect.FixedProb(0.1),
+ tau=5.,
+ output=bp.synouts.CUBA()
+ )
+
+**BrainPy 3.0:**
+
+.. code-block:: python
+
+ # New API - using projection architecture
+ import brainstate
+
+ projection = brainpy.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(
+ pre_size, post_size, prob=0.1, weight=0.5*u.mS
+ ),
+ syn=brainpy.Expon.desc(post_size, tau=5.*u.ms),
+ out=brainpy.CUBA.desc(),
+ post=post_neurons
+ )
+
+**Key Changes:**
+
+- Synapse, connectivity, and output are separated
+- Use descriptor pattern (``.desc()``)
+- Projections handle the complete pathway
+- Physical units throughout
+
+Network Definition
+~~~~~~~~~~~~~~~~~~
+
+**BrainPy 2.x:**
+
+.. code-block:: python
+
+ # Old API
+ class EINet(bp.DynamicalSystem):
+ def __init__(self):
+ super().__init__()
+ self.E = bp.neurons.LIF(800)
+ self.I = bp.neurons.LIF(200)
+ self.E2E = bp.synapses.Exponential(...)
+ self.E2I = bp.synapses.Exponential(...)
+ # ...
+
+ def update(self, tdi, x):
+ self.E(x)
+ self.I(x)
+ self.E2E()
+ # ...
+
+**BrainPy 3.0:**
+
+.. code-block:: python
+
+ # New API
+ import brainstate
+
+ class EINet(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.E = brainpy.LIF(800, ...)
+ self.I = brainpy.LIF(200, ...)
+ self.E2E = brainpy.AlignPostProj(...)
+ self.E2I = brainpy.AlignPostProj(...)
+ # ...
+
+ def update(self, x):
+ spikes_e = self.E.get_spike()
+ spikes_i = self.I.get_spike()
+
+ self.E2E(spikes_e)
+ self.E2I(spikes_e)
+ # ...
+
+ self.E(x)
+ self.I(x)
+
+**Key Changes:**
+
+- Inherit from ``brainstate.nn.Module`` instead of ``bp.DynamicalSystem``
+- No ``tdi`` argument (time info from ``brainstate.environ``)
+- Explicit spike handling with ``get_spike()``
+- Update order: projections first, then neurons
+
+Running Simulations
+~~~~~~~~~~~~~~~~~~~
+
+**BrainPy 2.x:**
+
+.. code-block:: python
+
+ # Old API
+ runner = bp.DSRunner(network, monitors=['E.spike'])
+ runner.run(duration=1000.)
+
+ # Access results
+ spikes = runner.mon['E.spike']
+
+**BrainPy 3.0:**
+
+.. code-block:: python
+
+ # New API
+ import brainunit as u
+
+ # Set time step
+ brainstate.environ.set(dt=0.1 * u.ms)
+
+ # Initialize
+ brainstate.nn.init_all_states(network)
+
+ # Run simulation
+ times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())
+ results = brainstate.transform.for_loop(
+ network.update,
+ times,
+ pbar=brainstate.transform.ProgressBar(10)
+ )
+
+**Key Changes:**
+
+- No ``DSRunner`` class
+- Use ``brainstate.transform.for_loop`` for simulation
+- Must initialize states explicitly
+- Manual recording of variables
+- Physical units for time
+
+Training
+~~~~~~~~
+
+**BrainPy 2.x:**
+
+.. code-block:: python
+
+ # Old API
+ trainer = bp.BPTT(
+ network,
+ loss_fun=loss_fn,
+ optimizer=bp.optim.Adam(lr=1e-3)
+ )
+ trainer.fit(train_data, epochs=100)
+
+**BrainPy 3.0:**
+
+.. code-block:: python
+
+ # New API
+ import braintools
+
+ # Define optimizer
+ optimizer = braintools.optim.Adam(lr=1e-3)
+ optimizer.register_trainable_weights(
+ network.states(brainstate.ParamState)
+ )
+
+ # Training loop
+ @brainstate.compile.jit
+ def train_step(inputs, targets):
+ def loss_fn():
+ predictions = brainstate.compile.for_loop(network.update, inputs)
+ return compute_loss(predictions, targets)
+
+ grads, loss = brainstate.transform.grad(
+ loss_fn,
+ network.states(brainstate.ParamState),
+ return_value=True
+ )()
+ optimizer.update(grads)
+ return loss
+
+ # Train
+ for epoch in range(100):
+ loss = train_step(train_inputs, train_targets)
+
+**Key Changes:**
+
+- No ``BPTT`` or ``Trainer`` classes
+- Manual training loop implementation
+- Explicit gradient computation
+- More control, more flexibility
+
+Common Migration Patterns
+--------------------------
+
+Pattern 1: Simple Neuron Population
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+**2.x Code:**
+
+.. code-block:: python
+
+ neurons = bp.neurons.LIF(100, V_rest=-65., V_th=-50., tau=10.)
+ runner = bp.DSRunner(neurons)
+ runner.run(100., inputs=2.0)
+
+**3.0 Code:**
+
+.. code-block:: python
+
+ import brainunit as u
+ import brainstate
+
+ brainstate.environ.set(dt=0.1*u.ms)
+ neurons = brainpy.LIF(100, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)
+ brainstate.nn.init_all_states(neurons)
+
+ times = u.math.arange(0*u.ms, 100*u.ms, brainstate.environ.get_dt())
+ results = brainstate.transform.for_loop(
+ lambda t: neurons(2.0*u.nA),
+ times
+ )
+
+Pattern 2: E-I Network
+~~~~~~~~~~~~~~~~~~~~~~
+
+**2.x Code:**
+
+.. code-block:: python
+
+ E = bp.neurons.LIF(800)
+ I = bp.neurons.LIF(200)
+ E2E = bp.synapses.Exponential(E, E, bp.connect.FixedProb(0.02))
+ E2I = bp.synapses.Exponential(E, I, bp.connect.FixedProb(0.02))
+ I2E = bp.synapses.Exponential(I, E, bp.connect.FixedProb(0.02))
+ I2I = bp.synapses.Exponential(I, I, bp.connect.FixedProb(0.02))
+
+ net = bp.Network(E, I, E2E, E2I, I2E, I2I)
+ runner = bp.DSRunner(net)
+ runner.run(1000.)
+
+**3.0 Code:**
+
+.. code-block:: python
+
+ import brainpy as bp
+ import brainstate
+ import brainunit as u
+
+ class EINet(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.E = brainpy.LIF(800, V_th=-50.*u.mV, tau=10.*u.ms)
+ self.I = brainpy.LIF(200, V_th=-50.*u.mV, tau=10.*u.ms)
+
+ self.E2E = brainpy.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(800, 800, 0.02, 0.1*u.mS),
+ syn=brainpy.Expon.desc(800, tau=5.*u.ms),
+ out=brainpy.CUBA.desc(),
+ post=self.E
+ )
+ # ... similar for E2I, I2E, I2I
+
+ def update(self, inp):
+ e_spk = self.E.get_spike()
+ i_spk = self.I.get_spike()
+ self.E2E(e_spk)
+ # ... other projections
+ self.E(inp)
+ self.I(inp)
+
+ brainstate.environ.set(dt=0.1*u.ms)
+ net = EINet()
+ brainstate.nn.init_all_states(net)
+
+ times = u.math.arange(0*u.ms, 1000*u.ms, 0.1*u.ms)
+ results = brainstate.transform.for_loop(
+ lambda t: net.update(1.*u.nA),
+ times
+ )
+
+Troubleshooting
+---------------
+
+Common Issues
+~~~~~~~~~~~~~
+
+**Issue 1: ImportError**
+
+.. code-block:: python
+
+ # Error: ModuleNotFoundError: No module named 'brainpy.math'
+ import brainpy.math as bm # Old import
+
+ # Solution: Use version2 or update to new API
+ import brainpy.version2.math as bm # Temporary
+ # or
+ import brainunit as u # New API
+
+**Issue 2: Unit Errors**
+
+.. code-block:: python
+
+ # Error: Units required but not provided
+ neuron = bp.LIF(100, tau=10.) # Missing units
+
+ # Solution: Add units
+ import brainunit as u
+ neuron = bp.LIF(100, tau=10.*u.ms)
+
+**Issue 3: State Initialization**
+
+.. code-block:: python
+
+ # Error: States not initialized
+ neuron = brainpy.LIF(100, ...)
+ neuron(input) # May fail or give wrong results
+
+ # Solution: Initialize states
+ import brainstate
+ neuron = brainpy.LIF(100, ...)
+ brainstate.nn.init_all_states(neuron)
+ neuron(input) # Now works correctly
+
+**Issue 4: Projection Update Order**
+
+.. code-block:: python
+
+ # Wrong: Neurons before projections
+ def update(self, inp):
+ self.neurons(inp)
+ self.projection(self.neurons.get_spike()) # Uses current spikes
+
+ # Correct: Projections before neurons
+ def update(self, inp):
+ spikes = self.neurons.get_spike() # Get previous spikes
+ self.projection(spikes) # Update synapses
+ self.neurons(inp) # Update neurons
+
+Testing Migration
+-----------------
+
+Numerical Equivalence
+~~~~~~~~~~~~~~~~~~~~~
+
+When migrating, verify that new code produces equivalent results:
+
+.. code-block:: python
+
+ # Old code results
+ import brainpy.version2 as bp2
+ old_network = bp2.neurons.LIF(100, ...)
+ old_runner = bp2.DSRunner(old_network)
+ old_runner.run(100.)
+ old_voltages = old_runner.mon['V']
+
+ # New code results
+ import brainpy as bp
+ import brainstate
+ new_network = brainpy.LIF(100, ...)
+ brainstate.nn.init_all_states(new_network)
+ # ... run simulation ...
+ # new_voltages = ...
+
+ # Compare
+ import numpy as np
+ np.allclose(old_voltages, new_voltages, rtol=1e-5)
+
+Feature Parity Checklist
+-------------------------
+
+Before completing migration, verify:
+
+โ All neuron models migrated
+โ All synapse models migrated
+โ Network structure preserved
+โ Simulation produces equivalent results
+โ Training works (if applicable)
+โ Visualization updated
+โ Unit tests pass
+โ Documentation updated
+
+Getting Help
+------------
+
+If you encounter issues during migration:
+
+- Check the `API documentation <../api/index.html>`_
+- Review `examples <../examples/gallery.html>`_
+- Search `GitHub issues `_
+- Ask on GitHub Discussions
+- Read the `brainstate documentation `_
+
+Benefits of Migration
+---------------------
+
+Migrating to BrainPy 3.0 provides:
+
+โ
**Better Performance**: Optimized compilation and execution
+
+โ
**Physical Units**: Automatic unit checking prevents errors
+
+โ
**Cleaner API**: More intuitive and consistent interfaces
+
+โ
**Modularity**: Easier to compose and reuse components
+
+โ
**Modern Architecture**: Built on proven frameworks
+
+โ
**Better Tooling**: Improved ecosystem integration
+
+Summary
+-------
+
+Migration from BrainPy 2.x to 3.0 requires:
+
+1. Understanding new architecture (state-based, modular)
+2. Adding physical units to all parameters
+3. Updating import statements
+4. Refactoring network definitions
+5. Changing simulation and training code
+6. Testing for numerical equivalence
+
+The ``brainpy.version2`` compatibility layer enables gradual migration, allowing you to update your codebase incrementally.
+
+Next Steps
+----------
+
+- Start with the :doc:`../quickstart/5min-tutorial` to learn 3.0 basics
+- Review :doc:`../core-concepts/architecture` for design understanding
+- Follow :doc:`../tutorials/basic/01-lif-neuron` for hands-on practice
+- Study :doc:`../examples/gallery` for complete migration examples
diff --git a/docs_version3/quickstart/5min-tutorial.ipynb b/docs_version3/quickstart/5min-tutorial.ipynb
new file mode 100644
index 00000000..debf25fe
--- /dev/null
+++ b/docs_version3/quickstart/5min-tutorial.ipynb
@@ -0,0 +1,353 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 5-Minute Tutorial: Getting Started with BrainPy 3.0\n",
+ "\n",
+ "Welcome to BrainPy 3.0! This quick tutorial will get you up and running with your first neural simulation in just a few minutes.\n",
+ "\n",
+ "## What You'll Learn\n",
+ "\n",
+ "- How to create neurons\n",
+ "- How to build simple networks\n",
+ "- How to run simulations\n",
+ "- How to visualize results"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 1: Import Libraries\n",
+ "\n",
+ "First, let's import the necessary libraries:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import brainpy\n",
+ "import brainstate\n",
+ "import brainunit as u\n",
+ "import braintools\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "# Check version\n",
+ "print(f\"BrainPy version: {brainpy.__version__}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 2: Create Your First Neuron\n",
+ "\n",
+ "Let's create a simple Leaky Integrate-and-Fire (LIF) neuron:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set simulation time step\n",
+ "brainstate.environ.set(dt=0.1 * u.ms)\n",
+ "\n",
+ "# Create a single LIF neuron\n",
+ "neuron = brainpy.LIF(\n",
+ " size=1,\n",
+ " V_rest=-65. * u.mV, # Resting potential\n",
+ " V_th=-50. * u.mV, # Spike threshold\n",
+ " V_reset=-65. * u.mV, # Reset potential\n",
+ " tau=10. * u.ms, # Membrane time constant\n",
+ ")\n",
+ "\n",
+ "print(\"Created a LIF neuron!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 3: Simulate the Neuron\n",
+ "\n",
+ "Now let's inject a constant current and see how the neuron responds:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize neuron state\n",
+ "brainstate.nn.init_all_states(neuron)\n",
+ "\n",
+ "# Define simulation parameters\n",
+ "duration = 200. * u.ms\n",
+ "dt = brainstate.environ.get_dt()\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "\n",
+ "# Input current (constant)\n",
+ "I_input = 2.0 * u.nA\n",
+ "\n",
+ "# Run simulation and record membrane potential\n",
+ "voltages = []\n",
+ "spikes = []\n",
+ "\n",
+ "for t in times:\n",
+ " neuron(I_input)\n",
+ " voltages.append(neuron.V.value)\n",
+ " spikes.append(neuron.get_spike())\n",
+ "\n",
+ "voltages = u.math.asarray(voltages)\n",
+ "spikes = u.math.asarray(spikes)\n",
+ "\n",
+ "print(f\"Simulation complete! Recorded {len(times)} time steps.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 4: Visualize the Results\n",
+ "\n",
+ "Let's plot the membrane potential over time:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Convert to appropriate units for plotting\n",
+ "times_plot = times.to_decimal(u.ms)\n",
+ "voltages_plot = voltages.to_decimal(u.mV)\n",
+ "\n",
+ "# Create plot\n",
+ "plt.figure(figsize=(10, 4))\n",
+ "plt.plot(times_plot, voltages_plot, linewidth=2)\n",
+ "plt.xlabel('Time (ms)')\n",
+ "plt.ylabel('Membrane Potential (mV)')\n",
+ "plt.title('LIF Neuron Response to Constant Input')\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "# Count spikes\n",
+ "n_spikes = int(u.math.sum(spikes != 0))\n",
+ "firing_rate = n_spikes / (duration.to_decimal(u.second))\n",
+ "print(f\"Number of spikes: {n_spikes}\")\n",
+ "print(f\"Average firing rate: {firing_rate:.2f} Hz\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 5: Create a Network of Neurons\n",
+ "\n",
+ "Now let's create a small network with excitatory and inhibitory populations:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class SimpleEINet(brainstate.nn.Module):\n",
+ " def __init__(self, n_exc=80, n_inh=20):\n",
+ " super().__init__()\n",
+ " self.n_exc = n_exc\n",
+ " self.n_inh = n_inh\n",
+ " self.num = n_exc + n_inh\n",
+ " \n",
+ " # Create neurons\n",
+ " self.neurons = brainpy.LIF(\n",
+ " self.num,\n",
+ " V_rest=-65. * u.mV,\n",
+ " V_th=-50. * u.mV,\n",
+ " V_reset=-65. * u.mV,\n",
+ " tau=10. * u.ms,\n",
+ " V_initializer=braintools.init.Normal(-65., 5., unit=u.mV)\n",
+ " )\n",
+ " \n",
+ " # Excitatory to all projection\n",
+ " self.E2all = brainpy.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_exc, self.num, prob=0.1, weight=0.6*u.mS),\n",
+ " syn=brainpy.Expon.desc(self.num, tau=2. * u.ms),\n",
+ " out=brainpy.CUBA.desc(),\n",
+ " post=self.neurons,\n",
+ " )\n",
+ " \n",
+ " # Inhibitory to all projection\n",
+ " self.I2all = brainpy.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_inh, self.num, prob=0.1, weight=-5.0*u.mS),\n",
+ " syn=brainpy.Expon.desc(self.num, tau=2. * u.ms),\n",
+ " out=brainpy.CUBA.desc(),\n",
+ " post=self.neurons,\n",
+ " )\n",
+ " \n",
+ " def update(self, input_current):\n",
+ " # Get spikes from previous time step\n",
+ " spikes = self.neurons.get_spike() != 0.\n",
+ " \n",
+ " # Update projections\n",
+ " self.E2all(spikes[:self.n_exc]) # Excitatory spikes\n",
+ " self.I2all(spikes[self.n_exc:]) # Inhibitory spikes\n",
+ " \n",
+ " # Update neurons\n",
+ " self.neurons(input_current)\n",
+ " \n",
+ " return self.neurons.get_spike()\n",
+ "\n",
+ "# Create network\n",
+ "net = SimpleEINet(n_exc=80, n_inh=20)\n",
+ "print(f\"Created network with {net.num} neurons\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 6: Simulate the Network\n",
+ "\n",
+ "Let's run the network and visualize its activity:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize network states\n",
+ "brainstate.nn.init_all_states(net)\n",
+ "\n",
+ "# Simulation parameters\n",
+ "duration = 500. * u.ms\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "I_ext = 1.5 * u.nA # External input current\n",
+ "\n",
+ "# Run simulation\n",
+ "spike_history = brainstate.transform.for_loop(\n",
+ " lambda t: net.update(I_ext),\n",
+ " times,\n",
+ " pbar=brainstate.transform.ProgressBar(10)\n",
+ ")\n",
+ "\n",
+ "print(\"Network simulation complete!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 7: Visualize Network Activity (Raster Plot)\n",
+ "\n",
+ "Create a raster plot showing when each neuron fired:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Find spike times and neuron indices\n",
+ "t_indices, n_indices = u.math.where(spike_history != 0)\n",
+ "\n",
+ "# Convert to plottable format\n",
+ "spike_times = times[t_indices].to_decimal(u.ms)\n",
+ "\n",
+ "# Create raster plot\n",
+ "plt.figure(figsize=(12, 6))\n",
+ "plt.scatter(spike_times, n_indices, s=1, c='black', alpha=0.5)\n",
+ "\n",
+ "# Mark excitatory and inhibitory populations\n",
+ "plt.axhline(y=net.n_exc, color='red', linestyle='--', alpha=0.5, label='E/I boundary')\n",
+ "\n",
+ "plt.xlabel('Time (ms)', fontsize=12)\n",
+ "plt.ylabel('Neuron Index', fontsize=12)\n",
+ "plt.title('Network Activity (Raster Plot)', fontsize=14)\n",
+ "plt.legend()\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "\n",
+ "# Add text annotations\n",
+ "plt.text(10, net.n_exc/2, 'Excitatory', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))\n",
+ "plt.text(10, net.n_exc + net.n_inh/2, 'Inhibitory', fontsize=10, bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "# Calculate statistics\n",
+ "total_spikes = len(t_indices)\n",
+ "avg_rate = total_spikes / (net.num * duration.to_decimal(u.second))\n",
+ "print(f\"Total spikes: {total_spikes}\")\n",
+ "print(f\"Average firing rate: {avg_rate:.2f} Hz\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "Congratulations! ๐ You've just:\n",
+ "\n",
+ "1. โ
Created individual neurons with physical units\n",
+ "2. โ
Simulated neuron dynamics with input currents\n",
+ "3. โ
Built a network with excitatory and inhibitory populations\n",
+ "4. โ
Connected neurons with synaptic projections\n",
+ "5. โ
Visualized network activity\n",
+ "\n",
+ "## Next Steps\n",
+ "\n",
+ "Now that you've completed your first simulation, you can:\n",
+ "\n",
+ "- **Learn more concepts**: Read the [Core Concepts](../core-concepts/architecture.rst) guide\n",
+ "- **Follow tutorials**: Try the [Basic Tutorials](../tutorials/basic/01-lif-neuron.ipynb) for deeper understanding\n",
+ "- **Explore examples**: Check out the [Examples Gallery](../examples/gallery.rst) for real-world models\n",
+ "- **Experiment**: Modify the network parameters and see what happens!\n",
+ "\n",
+ "### Try These Experiments\n",
+ "\n",
+ "1. Change the connection probability in the network\n",
+ "2. Adjust the excitatory/inhibitory balance\n",
+ "3. Add more neuron populations\n",
+ "4. Try different input currents or patterns\n",
+ "\n",
+ "Happy modeling! ๐ง "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs_version3/quickstart/concepts-overview.rst b/docs_version3/quickstart/concepts-overview.rst
new file mode 100644
index 00000000..8bfe4593
--- /dev/null
+++ b/docs_version3/quickstart/concepts-overview.rst
@@ -0,0 +1,300 @@
+Core Concepts Overview
+======================
+
+BrainPy 3.0 introduces a modern, state-based architecture built on top of ``brainstate``. This overview will help you understand the key concepts and design philosophy.
+
+What's New in BrainPy 3.0
+-------------------------
+
+BrainPy 3.0 has been completely rewritten to provide:
+
+- **State-based programming**: Built on ``brainstate`` for efficient state management
+- **Modular architecture**: Clear separation of concerns (communication, dynamics, outputs)
+- **Physical units**: Integration with ``brainunit`` for scientifically accurate simulations
+- **Modern API**: Cleaner, more intuitive interfaces
+- **Better performance**: Optimized JIT compilation and memory management
+
+Key Architectural Components
+-----------------------------
+
+BrainPy 3.0 is organized around several core concepts:
+
+1. State Management
+~~~~~~~~~~~~~~~~~~~
+
+Everything in BrainPy 3.0 revolves around **states**. States are variables that persist across time steps:
+
+- ``brainstate.State``: Base state container
+- ``brainstate.ParamState``: Trainable parameters
+- ``brainstate.ShortTermState``: Temporary variables
+
+States enable:
+
+- Automatic differentiation for training
+- Efficient memory management
+- Batching and parallelization
+
+2. Neurons
+~~~~~~~~~~
+
+Neurons are the fundamental computational units:
+
+.. code-block:: python
+
+ import brainpy
+ import brainunit as u
+
+ # Create a population of 100 LIF neurons
+ neurons = brainpy.LIF(100, tau=10*u.ms, V_th=-50*u.mV)
+
+Key neuron models:
+
+- ``brainpy.IF``: Integrate-and-Fire
+- ``brainpy.LIF``: Leaky Integrate-and-Fire
+- ``brainpy.LIFRef``: LIF with refractory period
+- ``brainpy.ALIF``: Adaptive LIF
+
+3. Synapses
+~~~~~~~~~~~
+
+Synapses model the dynamics of neural connections:
+
+.. code-block:: python
+
+ # Exponential synapse
+ synapse = brainpy.Expon(100, tau=5*u.ms)
+
+ # Alpha synapse (more realistic)
+ synapse = brainpy.Alpha(100, tau=5*u.ms)
+
+Synapse models:
+
+- ``brainpy.Expon``: Single exponential decay
+- ``brainpy.Alpha``: Double exponential (alpha function)
+- ``brainpy.AMPA``: Excitatory receptor dynamics
+- ``brainpy.GABAa``: Inhibitory receptor dynamics
+
+4. Projections
+~~~~~~~~~~~~~~
+
+Projections connect neural populations:
+
+.. code-block:: python
+
+ projection = brainpy.AlignPostProj(
+ comm=brainstate.nn.EventFixedProb(N_pre, N_post, prob=0.1, weight=0.5),
+ syn=brainpy.Expon.desc(N_post, tau=5*u.ms),
+ out=brainpy.CUBA.desc(),
+ post=neurons
+ )
+
+The projection architecture separates:
+
+- **Communication**: How spikes are transmitted (connectivity, weights)
+- **Synaptic dynamics**: How synapses respond (temporal filtering)
+- **Output mechanism**: How synaptic currents affect neurons (CUBA/COBA)
+
+5. Networks
+~~~~~~~~~~~
+
+Networks combine neurons and projections:
+
+.. code-block:: python
+
+ import brainstate
+
+ class EINet(brainstate.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.E = brainpy.LIF(800)
+ self.I = brainpy.LIF(200)
+ self.E2E = brainpy.AlignPostProj(...)
+ self.E2I = brainpy.AlignPostProj(...)
+ # ... more projections
+
+ def update(self, input):
+ # Define network dynamics
+ pass
+
+Computational Model
+-------------------
+
+Time-Stepped Simulation
+~~~~~~~~~~~~~~~~~~~~~~~
+
+BrainPy uses discrete time steps for simulation:
+
+.. code-block:: python
+
+ import brainstate
+ import brainunit as u
+
+ # Set simulation time step
+ brainstate.environ.set(dt=0.1 * u.ms)
+
+ # Run simulation
+ times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())
+ results = brainstate.transform.for_loop(network.update, times)
+
+JIT Compilation
+~~~~~~~~~~~~~~~
+
+BrainPy leverages JAX for Just-In-Time compilation:
+
+.. code-block:: python
+
+ @brainstate.compile.jit
+ def simulate():
+ return network.update(input)
+
+ # First call compiles, subsequent calls are fast
+ result = simulate()
+
+Benefits:
+
+- Near-C performance
+- Automatic GPU/TPU dispatch
+- Optimized memory usage
+
+Physical Units
+~~~~~~~~~~~~~~
+
+BrainPy 3.0 integrates ``brainunit`` for scientific accuracy:
+
+.. code-block:: python
+
+ import brainunit as u
+
+ # Define parameters with units
+ tau = 10 * u.ms
+ V_threshold = -50 * u.mV
+ current = 5 * u.nA
+
+ # Units are checked automatically
+ neurons = brainpy.LIF(100, tau=tau, V_th=V_threshold)
+
+This prevents unit-related bugs and makes code self-documenting.
+
+Training and Learning
+---------------------
+
+BrainPy 3.0 supports gradient-based training:
+
+.. code-block:: python
+
+ import braintools
+
+ # Define optimizer
+ optimizer = braintools.optim.Adam(lr=1e-3)
+ optimizer.register_trainable_weights(net.states(brainstate.ParamState))
+
+ # Define loss function
+ def loss_fn():
+ predictions = brainstate.compile.for_loop(net.update, inputs)
+ return loss(predictions, targets)
+
+ # Training step
+ @brainstate.compile.jit
+ def train_step():
+ grads, loss = brainstate.transform.grad(
+ loss_fn,
+ net.states(brainstate.ParamState),
+ return_value=True
+ )()
+ optimizer.update(grads)
+ return loss
+
+Key features:
+
+- Surrogate gradients for spiking neurons
+- Automatic differentiation
+- Various optimizers (Adam, SGD, etc.)
+
+Ecosystem Components
+--------------------
+
+BrainPy 3.0 is part of a larger ecosystem:
+
+brainstate
+~~~~~~~~~~
+
+The foundation for state management and compilation:
+
+- State-based IR construction
+- JIT compilation
+- Program augmentation (batching, etc.)
+
+brainunit
+~~~~~~~~~
+
+Physical units system:
+
+- SI units support
+- Automatic unit checking
+- Unit conversions
+
+braintools
+~~~~~~~~~~
+
+Utilities and tools:
+
+- Optimizers (``braintools.optim``)
+- Initialization (``braintools.init``)
+- Metrics and losses (``braintools.metric``)
+- Surrogate gradients (``braintools.surrogate``)
+- Visualization (``braintools.visualize``)
+
+Design Philosophy
+-----------------
+
+BrainPy 3.0 follows these principles:
+
+1. **Explicit over implicit**: Clear, readable code
+2. **Modular composition**: Build complex models from simple components
+3. **Performance by default**: JIT compilation and optimization built-in
+4. **Scientific accuracy**: Physical units and biologically realistic models
+5. **Extensibility**: Easy to add custom components
+
+Comparison with BrainPy 2.x
+----------------------------
+
+Key differences:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 30 35 35
+
+ * - Aspect
+ - BrainPy 2.x
+ - BrainPy 3.0
+ * - Architecture
+ - Custom backend
+ - Built on ``brainstate``
+ * - State management
+ - Manual
+ - Automatic with ``State``
+ * - Units
+ - Optional
+ - Integrated with ``brainunit``
+ * - API style
+ - Object-oriented
+ - Functional + OOP hybrid
+ * - Performance
+ - Good
+ - Better (optimized compilation)
+ * - Projection model
+ - Monolithic
+ - Comm-Syn-Out separation
+
+.. note::
+ BrainPy 3.0 includes a compatibility layer (``brainpy.version2``) for gradual migration.
+
+Next Steps
+----------
+
+Now that you understand the core concepts:
+
+- Try the :doc:`5-minute tutorial <5min-tutorial>` to get hands-on experience
+- Read the :doc:`detailed core concepts <../core-concepts/architecture>` documentation
+- Explore :doc:`basic tutorials <../tutorials/basic/01-lif-neuron>` to learn each component
+- Check out the :doc:`examples gallery <../examples/gallery>` for real-world models
diff --git a/docs_version3/quickstart/installation.rst b/docs_version3/quickstart/installation.rst
new file mode 100644
index 00000000..341ae78a
--- /dev/null
+++ b/docs_version3/quickstart/installation.rst
@@ -0,0 +1,173 @@
+Installation Guide
+==================
+
+BrainPy 3.0 is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation. This guide will help you install BrainPy on your system.
+
+Requirements
+------------
+
+- Python 3.10 or later
+- pip package manager
+- Supported platforms: Linux (Ubuntu 16.04+), macOS (10.12+), Windows
+
+Basic Installation
+------------------
+
+Install the latest version of BrainPy:
+
+.. code-block:: bash
+
+ pip install brainpy -U
+
+This will install BrainPy with CPU support by default.
+
+Hardware-Specific Installation
+-------------------------------
+
+Depending on your hardware, you can install BrainPy with optimized support:
+
+CPU Only
+~~~~~~~~
+
+For CPU-only installations:
+
+.. code-block:: bash
+
+ pip install brainpy[cpu] -U
+
+This is suitable for development, testing, and small-scale simulations.
+
+GPU Support (CUDA)
+~~~~~~~~~~~~~~~~~~
+
+For NVIDIA GPU acceleration:
+
+**CUDA 12.x:**
+
+.. code-block:: bash
+
+ pip install brainpy[cuda12] -U
+
+**CUDA 13.x:**
+
+.. code-block:: bash
+
+ pip install brainpy[cuda13] -U
+
+.. note::
+ Make sure you have the appropriate CUDA toolkit installed on your system before installing the GPU version.
+
+TPU Support
+~~~~~~~~~~~
+
+For Google Cloud TPU support:
+
+.. code-block:: bash
+
+ pip install brainpy[tpu] -U
+
+This is typically used when running on Google Cloud Platform or Colab with TPU runtime.
+
+Ecosystem Installation
+----------------------
+
+To install BrainPy along with the entire ecosystem of tools:
+
+.. code-block:: bash
+
+ pip install BrainX -U
+
+This includes:
+
+- ``brainpy``: Main framework
+- ``brainstate``: State management and compilation backend
+- ``brainunit``: Physical units system
+- ``braintools``: Utilities and tools
+- Additional ecosystem packages
+
+Verifying Installation
+----------------------
+
+To verify that BrainPy is installed correctly:
+
+.. code-block:: python
+
+ import brainpy
+ import brainstate
+ import brainunit as u
+
+ print(f"BrainPy version: {brainpy.__version__}")
+ print(f"BrainState version: {brainstate.__version__}")
+
+ # Test basic functionality
+ neuron = brainpy.LIF(10)
+ print("Installation successful!")
+
+Development Installation
+------------------------
+
+If you want to install BrainPy from source for development:
+
+.. code-block:: bash
+
+ git clone https://github.com/brainpy/BrainPy.git
+ cd BrainPy
+ pip install -e .
+
+This creates an editable installation that reflects your local changes.
+
+Troubleshooting
+---------------
+
+Common Issues
+~~~~~~~~~~~~~
+
+**ImportError: No module named 'brainpy'**
+
+Make sure you've activated the correct Python environment and that the installation completed successfully.
+
+**CUDA not found**
+
+If you installed the GPU version but get CUDA errors, ensure that:
+
+1. Your NVIDIA drivers are up to date
+2. CUDA toolkit is installed and matches the version (12.x or 13.x)
+3. Your GPU is CUDA-capable
+
+**Version Conflicts**
+
+If you're upgrading from BrainPy 2.x, you might need to uninstall the old version first:
+
+.. code-block:: bash
+
+ pip uninstall brainpy
+ pip install brainpy -U
+
+Getting Help
+~~~~~~~~~~~~
+
+If you encounter issues:
+
+- Check the `GitHub Issues `_
+- Read the documentation at `https://brainpy.readthedocs.io/ `_
+- Join our community discussions
+
+Next Steps
+----------
+
+Now that you have BrainPy installed, you can:
+
+- Follow the :doc:`5-minute tutorial <5min-tutorial>` for a quick introduction
+- Read about :doc:`core concepts ` to understand BrainPy's architecture
+- Explore the :doc:`tutorials <../tutorials/index>` for detailed guides
+
+Using BrainPy with Binder
+--------------------------
+
+If you want to try BrainPy without installing it locally, you can use our Binder environment:
+
+.. image:: https://mybinder.org/badge_logo.svg
+ :target: https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main
+ :alt: Binder
+
+This provides a pre-configured Jupyter notebook environment in your browser.
diff --git a/docs_version3/spiking_neural_networks-en.ipynb b/docs_version3/spiking_neural_networks-en.ipynb
deleted file mode 100644
index c548e521..00000000
--- a/docs_version3/spiking_neural_networks-en.ipynb
+++ /dev/null
@@ -1,35 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "source": [
- "# Building Spiking Neural Networks"
- ],
- "metadata": {
- "collapsed": false
- },
- "id": "a39cf07d62caa659"
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "2.7.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
\ No newline at end of file
diff --git a/docs_version3/spiking_neural_networks-zh.ipynb b/docs_version3/spiking_neural_networks-zh.ipynb
deleted file mode 100644
index 41b5854c..00000000
--- a/docs_version3/spiking_neural_networks-zh.ipynb
+++ /dev/null
@@ -1,35 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "source": [
- "# ๆๅปบ่ๅฒ็ฅ็ป็ฝ็ป"
- ],
- "metadata": {
- "collapsed": false
- },
- "id": "540ea47d24c27831"
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "2.7.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
\ No newline at end of file
diff --git a/docs_version3/tutorials/advanced/05-snn-training.ipynb b/docs_version3/tutorials/advanced/05-snn-training.ipynb
new file mode 100644
index 00000000..fdd21cc5
--- /dev/null
+++ b/docs_version3/tutorials/advanced/05-snn-training.ipynb
@@ -0,0 +1,925 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Tutorial 5: Training Spiking Neural Networks\n",
+ "\n",
+ "**Duration:** ~45 minutes | **Prerequisites:** Basic Tutorials 1-4\n",
+ "\n",
+ "## Learning Objectives\n",
+ "\n",
+ "By the end of this tutorial, you will:\n",
+ "\n",
+ "- โ
Understand surrogate gradient methods for training SNNs\n",
+ "- โ
Implement backpropagation through time (BPTT) for SNNs\n",
+ "- โ
Use appropriate loss functions for spike-based learning\n",
+ "- โ
Configure optimizers and learning rates\n",
+ "- โ
Train an SNN classifier on real datasets\n",
+ "- โ
Evaluate and visualize training progress\n",
+ "\n",
+ "## Overview\n",
+ "\n",
+ "Training spiking neural networks is challenging because spike generation is a discrete, non-differentiable operation. In this tutorial, we'll learn how to overcome this using **surrogate gradient methods**, which allow us to train SNNs using standard gradient-based optimization.\n",
+ "\n",
+ "**Key Concepts:**\n",
+ "- **The gradient problem**: Spike generation has zero gradient almost everywhere\n",
+ "- **Surrogate gradients**: Use smooth approximations during backpropagation\n",
+ "- **BPTT for SNNs**: Unroll network dynamics through time\n",
+ "- **Rate-based losses**: Train on spike rates or membrane potentials\n",
+ "- **Temporal credit assignment**: Learn when to spike\n",
+ "\n",
+ "Let's start by understanding why training SNNs is difficult and how we can solve it!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import brainpy as bp\n",
+ "import brainstate\n",
+ "import brainunit as u\n",
+ "import braintools\n",
+ "import jax.numpy as jnp\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "# Set random seed for reproducibility\n",
+ "brainstate.random.seed(42)\n",
+ "\n",
+ "# Configure environment\n",
+ "brainstate.environ.set(dt=1.0 * u.ms)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 1: The Gradient Problem\n",
+ "\n",
+ "Let's visualize why training SNNs is challenging. The spike generation function is a Heaviside step function:\n",
+ "\n",
+ "$$\n",
+ "S(V) = \\begin{cases}\n",
+ "1 & \\text{if } V \\geq V_{th} \\\\\n",
+ "0 & \\text{if } V < V_{th}\n",
+ "\\end{cases}\n",
+ "$$\n",
+ "\n",
+ "The gradient of this function is:\n",
+ "\n",
+ "$$\n",
+ "\\frac{dS}{dV} = \\begin{cases}\n",
+ "\\infty & \\text{at } V = V_{th} \\\\\n",
+ "0 & \\text{everywhere else}\n",
+ "\\end{cases}\n",
+ "$$\n",
+ "\n",
+ "This makes gradient-based learning impossible! Let's see this visually."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Heaviside step function\n",
+ "def heaviside(x, threshold=0.0):\n",
+ " return (x >= threshold).astype(float)\n",
+ "\n",
+ "# Voltage values\n",
+ "V = np.linspace(-2, 2, 1000)\n",
+ "V_th = 0.0\n",
+ "\n",
+ "# Spike function and its \"gradient\"\n",
+ "spikes = heaviside(V, V_th)\n",
+ "# Numerical gradient (will be mostly zeros)\n",
+ "grad_spike = np.gradient(spikes, V)\n",
+ "\n",
+ "# Plot\n",
+ "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n",
+ "\n",
+ "# Spike function\n",
+ "axes[0].plot(V, spikes, 'b-', linewidth=2)\n",
+ "axes[0].axvline(V_th, color='r', linestyle='--', label='Threshold')\n",
+ "axes[0].set_xlabel('Membrane Potential (V)', fontsize=12)\n",
+ "axes[0].set_ylabel('Spike Output', fontsize=12)\n",
+ "axes[0].set_title('Spike Generation Function', fontsize=14, fontweight='bold')\n",
+ "axes[0].legend()\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Gradient (problematic!)\n",
+ "axes[1].plot(V, grad_spike, 'r-', linewidth=2)\n",
+ "axes[1].axvline(V_th, color='r', linestyle='--', label='Threshold')\n",
+ "axes[1].set_xlabel('Membrane Potential (V)', fontsize=12)\n",
+ "axes[1].set_ylabel('Gradient dS/dV', fontsize=12)\n",
+ "axes[1].set_title('Gradient (Problematic!)', fontsize=14, fontweight='bold')\n",
+ "axes[1].legend()\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "axes[1].set_ylim(-0.1, 0.6)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"โ Problem: Gradient is zero almost everywhere!\")\n",
+ "print(\" This prevents gradient descent from working.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 2: Surrogate Gradient Solution\n",
+ "\n",
+ "The solution is **surrogate gradients**: Use the true spike function in the forward pass, but use a smooth approximation during backpropagation.\n",
+ "\n",
+ "**Common surrogate gradient functions:**\n",
+ "\n",
+ "1. **Sigmoid**: $\\sigma'(\\beta(V - V_{th}))$\n",
+ "2. **ReLU**: $\\max(0, 1 - |V - V_{th}|)$\n",
+ "3. **SuperSpike**: $\\frac{1}{(1 + |\\beta(V - V_{th})|)^2}$\n",
+ "\n",
+ "BrainPy provides these in `braintools.surrogate`. Let's visualize them!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create surrogate gradient functions\n",
+ "sigmoid_surrogate = braintools.surrogate.sigmoid(alpha=4.0)\n",
+ "relu_surrogate = braintools.surrogate.relu_grad(alpha=1.0)\n",
+ "superspike_surrogate = braintools.surrogate.slayer_grad(alpha=4.0)\n",
+ "\n",
+ "# Voltage range\n",
+ "V_range = np.linspace(-2, 2, 1000)\n",
+ "V_th = 0.0\n",
+ "\n",
+ "# Compute surrogate gradients\n",
+ "grad_sigmoid = sigmoid_surrogate(V_range - V_th)\n",
+ "grad_relu = relu_surrogate(V_range - V_th)\n",
+ "grad_superspike = superspike_surrogate(V_range - V_th)\n",
+ "\n",
+ "# Plot\n",
+ "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
+ "\n",
+ "# Sigmoid surrogate\n",
+ "axes[0].plot(V_range, grad_sigmoid, 'g-', linewidth=2, label='Sigmoid surrogate')\n",
+ "axes[0].axvline(V_th, color='r', linestyle='--', alpha=0.5)\n",
+ "axes[0].set_xlabel('V - V_th', fontsize=12)\n",
+ "axes[0].set_ylabel('Surrogate Gradient', fontsize=12)\n",
+ "axes[0].set_title('Sigmoid Surrogate', fontsize=14, fontweight='bold')\n",
+ "axes[0].legend()\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# ReLU surrogate\n",
+ "axes[1].plot(V_range, grad_relu, 'b-', linewidth=2, label='ReLU surrogate')\n",
+ "axes[1].axvline(V_th, color='r', linestyle='--', alpha=0.5)\n",
+ "axes[1].set_xlabel('V - V_th', fontsize=12)\n",
+ "axes[1].set_ylabel('Surrogate Gradient', fontsize=12)\n",
+ "axes[1].set_title('ReLU Surrogate', fontsize=14, fontweight='bold')\n",
+ "axes[1].legend()\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "# SuperSpike surrogate\n",
+ "axes[2].plot(V_range, grad_superspike, 'm-', linewidth=2, label='SuperSpike surrogate')\n",
+ "axes[2].axvline(V_th, color='r', linestyle='--', alpha=0.5)\n",
+ "axes[2].set_xlabel('V - V_th', fontsize=12)\n",
+ "axes[2].set_ylabel('Surrogate Gradient', fontsize=12)\n",
+ "axes[2].set_title('SuperSpike Surrogate', fontsize=14, fontweight='bold')\n",
+ "axes[2].legend()\n",
+ "axes[2].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"โ
Solution: Smooth surrogate gradients enable learning!\")\n",
+ "print(\" Forward pass: Use real spikes\")\n",
+ "print(\" Backward pass: Use smooth gradient approximation\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 3: Creating a Trainable SNN\n",
+ "\n",
+ "Now let's create an SNN classifier. We'll build a simple network:\n",
+ "\n",
+ "**Architecture:**\n",
+ "- Input layer: 784 neurons (28ร28 image)\n",
+ "- Hidden layer: 128 LIF neurons\n",
+ "- Output layer: 10 LIF neurons (digits 0-9)\n",
+ "\n",
+ "**Key for training:**\n",
+ "- Use LIF neurons with surrogate gradient spike functions\n",
+ "- Use `bp.Readout` to convert spikes to logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class TrainableSNN(brainstate.nn.Module):\n",
+ " \"\"\"Simple feedforward SNN for classification.\"\"\"\n",
+ " \n",
+ " def __init__(self, n_input=784, n_hidden=128, n_output=10):\n",
+ " super().__init__()\n",
+ " \n",
+ " # Input to hidden projection\n",
+ " self.fc1 = brainstate.nn.Linear(n_input, n_hidden, w_init=brainstate.init.KaimingNormal())\n",
+ " \n",
+ " # Hidden LIF neurons with surrogate gradient\n",
+ " self.lif1 = bp.LIF(\n",
+ " n_hidden,\n",
+ " V_rest=-65.0 * u.mV,\n",
+ " V_th=-50.0 * u.mV,\n",
+ " V_reset=-65.0 * u.mV,\n",
+ " tau=10.0 * u.ms,\n",
+ " spike_fun=braintools.surrogate.ReluGrad() # Surrogate gradient!\n",
+ " )\n",
+ " \n",
+ " # Hidden to output projection\n",
+ " self.fc2 = brainstate.nn.Linear(n_hidden, n_output, w_init=brainstate.init.KaimingNormal())\n",
+ " \n",
+ " # Output LIF neurons with surrogate gradient\n",
+ " self.lif2 = bp.LIF(\n",
+ " n_output,\n",
+ " V_rest=-65.0 * u.mV,\n",
+ " V_th=-50.0 * u.mV,\n",
+ " V_reset=-65.0 * u.mV,\n",
+ " tau=10.0 * u.ms,\n",
+ " spike_fun=braintools.surrogate.ReluGrad() # Surrogate gradient!\n",
+ " )\n",
+ " \n",
+ " # Readout layer to convert spikes to logits\n",
+ " self.readout = bp.Readout(n_output, n_output)\n",
+ " \n",
+ " def update(self, x):\n",
+ " \"\"\"Forward pass for one time step.\n",
+ " \n",
+ " Args:\n",
+ " x: Input current (batch_size, n_input) with physical units\n",
+ " \n",
+ " Returns:\n",
+ " logits: Output logits (batch_size, n_output)\n",
+ " \"\"\"\n",
+ " # Input to hidden\n",
+ " current1 = self.fc1(x)\n",
+ " self.lif1(current1)\n",
+ " hidden_spikes = self.lif1.get_spike()\n",
+ " \n",
+ " # Hidden to output\n",
+ " current2 = self.fc2(hidden_spikes)\n",
+ " self.lif2(current2)\n",
+ " output_spikes = self.lif2.get_spike()\n",
+ " \n",
+ " # Convert spikes to logits\n",
+ " logits = self.readout(output_spikes)\n",
+ " \n",
+ " return logits\n",
+ "\n",
+ "# Create network\n",
+ "net = TrainableSNN(n_input=784, n_hidden=128, n_output=10)\n",
+ "brainstate.nn.init_all_states(net, batch_size=32)\n",
+ "\n",
+ "print(\"โ
Created trainable SNN with surrogate gradients\")\n",
+ "print(f\" Input: 784 neurons\")\n",
+ "print(f\" Hidden: 128 LIF neurons\")\n",
+ "print(f\" Output: 10 LIF neurons\")\n",
+ "print(f\" Total parameters: {sum(p.size for p in net.states(brainstate.ParamState).values())}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 4: Loss Functions for SNNs\n",
+ "\n",
+ "For classification, we typically use **cross-entropy loss** on the output logits. The logits are computed by integrating spikes over time.\n",
+ "\n",
+ "**Loss computation:**\n",
+ "1. Run the network for `T` time steps\n",
+ "2. Accumulate output logits over time\n",
+ "3. Compute cross-entropy loss: $L = -\\sum_i y_i \\log(\\text{softmax}(\\text{logits}_i))$\n",
+ "\n",
+ "Let's implement the training step!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def loss_fn(network, inputs, labels, n_steps=25):\n",
+ " \"\"\"Compute loss for SNN classification.\n",
+ " \n",
+ " Args:\n",
+ " network: SNN model\n",
+ " inputs: Input data (batch_size, n_features)\n",
+ " labels: True labels (batch_size,)\n",
+ " n_steps: Number of simulation time steps\n",
+ " \n",
+ " Returns:\n",
+ " loss: Cross-entropy loss\n",
+ " \"\"\"\n",
+ " # Reset network state\n",
+ " brainstate.nn.init_all_states(network)\n",
+ " \n",
+ " # Add physical units to input (convert to current)\n",
+ " inputs_with_units = inputs * u.nA\n",
+ " \n",
+ " # Simulate for n_steps and accumulate output\n",
+ " def run_step(i):\n",
+ " return network(inputs_with_units)\n",
+ " \n",
+ " # Run simulation and accumulate logits\n",
+ " logits_sum = brainstate.transform.for_loop(run_step, jnp.arange(n_steps))\n",
+ " logits_sum = jnp.sum(logits_sum, axis=0) # Sum over time\n",
+ " \n",
+ " # Compute cross-entropy loss\n",
+ " loss = braintools.metric.softmax_cross_entropy_with_integer_labels(\n",
+ " logits_sum, labels\n",
+ " ).mean()\n",
+ " \n",
+ " return loss\n",
+ "\n",
+ "def accuracy_fn(network, inputs, labels, n_steps=25):\n",
+ " \"\"\"Compute accuracy for SNN classification.\"\"\"\n",
+ " # Reset network state\n",
+ " brainstate.nn.init_all_states(network)\n",
+ " \n",
+ " # Add physical units\n",
+ " inputs_with_units = inputs * u.nA\n",
+ " \n",
+ " # Simulate and accumulate logits\n",
+ " def run_step(i):\n",
+ " return network(inputs_with_units)\n",
+ " \n",
+ " logits_sum = brainstate.transform.for_loop(run_step, jnp.arange(n_steps))\n",
+ " logits_sum = jnp.sum(logits_sum, axis=0)\n",
+ " \n",
+ " # Compute accuracy\n",
+ " predictions = jnp.argmax(logits_sum, axis=1)\n",
+ " accuracy = jnp.mean(predictions == labels)\n",
+ " \n",
+ " return accuracy\n",
+ "\n",
+ "print(\"โ
Defined loss and accuracy functions\")\n",
+ "print(\" Loss: Cross-entropy on accumulated logits\")\n",
+ "print(\" Accuracy: Argmax of accumulated logits\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 5: Optimizers and Training Loop\n",
+ "\n",
+ "Now we'll set up the optimizer and training loop. BrainPy uses `braintools.optim` which provides standard optimizers like Adam, SGD, etc.\n",
+ "\n",
+ "**Training loop:**\n",
+ "1. Get batch of data\n",
+ "2. Compute gradients using `brainstate.transform.grad()`\n",
+ "3. Update parameters using optimizer\n",
+ "4. Track loss and accuracy\n",
+ "\n",
+ "We'll use synthetic data for this demo (in practice, you'd use MNIST)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create synthetic dataset (in practice, use real data like MNIST)\n",
+ "def create_synthetic_data(n_samples=1000, n_features=784, n_classes=10):\n",
+ " \"\"\"Create synthetic classification data.\"\"\"\n",
+ " X = np.random.randn(n_samples, n_features).astype(np.float32) * 0.5\n",
+ " y = np.random.randint(0, n_classes, size=n_samples)\n",
+ " return X, y\n",
+ "\n",
+ "# Generate data\n",
+ "X_train, y_train = create_synthetic_data(n_samples=1000)\n",
+ "X_test, y_test = create_synthetic_data(n_samples=200)\n",
+ "\n",
+ "print(\"โ
Created synthetic dataset\")\n",
+ "print(f\" Training: {X_train.shape[0]} samples\")\n",
+ "print(f\" Test: {X_test.shape[0]} samples\")\n",
+ "print(f\" Features: {X_train.shape[1]}\")\n",
+ "print(f\" Classes: {len(np.unique(y_train))}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Reset network and create optimizer\n",
+ "net = TrainableSNN(n_input=784, n_hidden=128, n_output=10)\n",
+ "brainstate.nn.init_all_states(net, batch_size=32)\n",
+ "\n",
+ "# Create Adam optimizer\n",
+ "optimizer = braintools.optim.Adam(learning_rate=1e-3)\n",
+ "optimizer.register_trainable_weights(net.states(brainstate.ParamState))\n",
+ "\n",
+ "print(\"โ
Created optimizer\")\n",
+ "print(f\" Type: Adam\")\n",
+ "print(f\" Learning rate: 1e-3\")\n",
+ "print(f\" Trainable parameters: {len(net.states(brainstate.ParamState))}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Training loop\n",
+ "n_epochs = 5\n",
+ "batch_size = 32\n",
+ "n_steps = 25 # Simulation steps per sample\n",
+ "\n",
+ "train_losses = []\n",
+ "train_accs = []\n",
+ "test_accs = []\n",
+ "\n",
+ "print(\"๐ Starting training...\\n\")\n",
+ "\n",
+ "for epoch in range(n_epochs):\n",
+ " # Shuffle training data\n",
+ " indices = np.random.permutation(len(X_train))\n",
+ " X_shuffled = X_train[indices]\n",
+ " y_shuffled = y_train[indices]\n",
+ " \n",
+ " epoch_losses = []\n",
+ " epoch_accs = []\n",
+ " \n",
+ " # Mini-batch training\n",
+ " n_batches = len(X_train) // batch_size\n",
+ " for i in range(n_batches):\n",
+ " # Get batch\n",
+ " start_idx = i * batch_size\n",
+ " end_idx = start_idx + batch_size\n",
+ " X_batch = X_shuffled[start_idx:end_idx]\n",
+ " y_batch = y_shuffled[start_idx:end_idx]\n",
+ " \n",
+ " # Compute gradients\n",
+ " grads, loss = brainstate.transform.grad(\n",
+ " loss_fn,\n",
+ " net.states(brainstate.ParamState),\n",
+ " return_value=True\n",
+ " )(net, X_batch, y_batch, n_steps)\n",
+ " \n",
+ " # Update parameters\n",
+ " optimizer.update(grads)\n",
+ " \n",
+ " # Track metrics\n",
+ " epoch_losses.append(float(loss))\n",
+ " \n",
+ " # Compute accuracy every 10 batches\n",
+ " if i % 10 == 0:\n",
+ " acc = accuracy_fn(net, X_batch, y_batch, n_steps)\n",
+ " epoch_accs.append(float(acc))\n",
+ " \n",
+ " # Epoch statistics\n",
+ " avg_loss = np.mean(epoch_losses)\n",
+ " avg_train_acc = np.mean(epoch_accs) if epoch_accs else 0.0\n",
+ " \n",
+ " # Test accuracy\n",
+ " test_acc = float(accuracy_fn(net, X_test, y_test, n_steps))\n",
+ " \n",
+ " train_losses.append(avg_loss)\n",
+ " train_accs.append(avg_train_acc)\n",
+ " test_accs.append(test_acc)\n",
+ " \n",
+ " print(f\"Epoch {epoch+1}/{n_epochs}:\")\n",
+ " print(f\" Loss: {avg_loss:.4f}\")\n",
+ " print(f\" Train Acc: {avg_train_acc:.2%}\")\n",
+ " print(f\" Test Acc: {test_acc:.2%}\\n\")\n",
+ "\n",
+ "print(\"โ
Training complete!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 6: Visualizing Training Progress\n",
+ "\n",
+ "Let's visualize how the loss and accuracy evolved during training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
+ "\n",
+ "epochs_range = np.arange(1, n_epochs + 1)\n",
+ "\n",
+ "# Plot loss\n",
+ "axes[0].plot(epochs_range, train_losses, 'b-o', linewidth=2, markersize=8, label='Training Loss')\n",
+ "axes[0].set_xlabel('Epoch', fontsize=12)\n",
+ "axes[0].set_ylabel('Loss', fontsize=12)\n",
+ "axes[0].set_title('Training Loss', fontsize=14, fontweight='bold')\n",
+ "axes[0].legend()\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Plot accuracy\n",
+ "axes[1].plot(epochs_range, train_accs, 'g-o', linewidth=2, markersize=8, label='Train Accuracy')\n",
+ "axes[1].plot(epochs_range, test_accs, 'r-s', linewidth=2, markersize=8, label='Test Accuracy')\n",
+ "axes[1].set_xlabel('Epoch', fontsize=12)\n",
+ "axes[1].set_ylabel('Accuracy', fontsize=12)\n",
+ "axes[1].set_title('Classification Accuracy', fontsize=14, fontweight='bold')\n",
+ "axes[1].legend()\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "axes[1].set_ylim(0, 1)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(f\"๐ Final Results:\")\n",
+ "print(f\" Final train accuracy: {train_accs[-1]:.2%}\")\n",
+ "print(f\" Final test accuracy: {test_accs[-1]:.2%}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 7: Understanding BPTT for SNNs\n",
+ "\n",
+ "Let's visualize what happens during backpropagation through time (BPTT). The network processes input over multiple time steps, and gradients flow backward through time.\n",
+ "\n",
+ "**BPTT process:**\n",
+ "1. **Forward pass**: Simulate network for T steps, accumulate outputs\n",
+ "2. **Backward pass**: Compute gradients backward through all T steps\n",
+ "3. **Surrogate gradients**: Used at spike generation points\n",
+ "\n",
+ "Let's examine the gradient flow!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Analyze gradient magnitudes during training\n",
+ "def analyze_gradients(network, inputs, labels, n_steps=25):\n",
+ " \"\"\"Compute and analyze gradient magnitudes.\"\"\"\n",
+ " grads = brainstate.transform.grad(\n",
+ " loss_fn,\n",
+ " network.states(brainstate.ParamState)\n",
+ " )(network, inputs, labels, n_steps)\n",
+ " \n",
+ " # Compute gradient norms for each layer\n",
+ " grad_norms = {}\n",
+ " for name, grad in grads.items():\n",
+ " grad_norm = float(jnp.linalg.norm(grad.value.flatten()))\n",
+ " grad_norms[name] = grad_norm\n",
+ " \n",
+ " return grad_norms\n",
+ "\n",
+ "# Analyze gradients on a batch\n",
+ "sample_X = X_train[:32]\n",
+ "sample_y = y_train[:32]\n",
+ "grad_norms = analyze_gradients(net, sample_X, sample_y)\n",
+ "\n",
+ "# Plot gradient magnitudes\n",
+ "fig, ax = plt.subplots(figsize=(10, 6))\n",
+ "\n",
+ "layer_names = list(grad_norms.keys())\n",
+ "grad_values = list(grad_norms.values())\n",
+ "\n",
+ "colors = ['blue' if 'fc1' in name else 'green' if 'fc2' in name else 'red' for name in layer_names]\n",
+ "\n",
+ "bars = ax.bar(range(len(layer_names)), grad_values, color=colors, alpha=0.7)\n",
+ "ax.set_xticks(range(len(layer_names)))\n",
+ "ax.set_xticklabels(layer_names, rotation=45, ha='right')\n",
+ "ax.set_ylabel('Gradient Norm', fontsize=12)\n",
+ "ax.set_title('Gradient Magnitudes Across Layers', fontsize=14, fontweight='bold')\n",
+ "ax.grid(True, alpha=0.3, axis='y')\n",
+ "\n",
+ "# Add legend\n",
+ "from matplotlib.patches import Patch\n",
+ "legend_elements = [\n",
+ " Patch(facecolor='blue', alpha=0.7, label='Input Layer'),\n",
+ " Patch(facecolor='green', alpha=0.7, label='Hidden Layer'),\n",
+ " Patch(facecolor='red', alpha=0.7, label='Readout Layer')\n",
+ "]\n",
+ "ax.legend(handles=legend_elements, loc='upper right')\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"๐ Gradient Analysis:\")\n",
+ "for name, norm in grad_norms.items():\n",
+ " print(f\" {name}: {norm:.6f}\")\n",
+ "print(\"\\nโ
Surrogate gradients enable backpropagation through spike generation!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 8: Real-World Example - MNIST Classification\n",
+ "\n",
+ "Now let's see how to train on real data. Here's the complete workflow for MNIST (or Fashion-MNIST):\n",
+ "\n",
+ "**Steps:**\n",
+ "1. Load and preprocess MNIST data\n",
+ "2. Convert images to rate-coded spike trains (or use pixel intensities as currents)\n",
+ "3. Train SNN classifier\n",
+ "4. Evaluate on test set\n",
+ "\n",
+ "Below is a template you can use with real MNIST data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Template for MNIST training (requires torchvision or tensorflow)\n",
+ "\n",
+ "def load_mnist_data():\n",
+ " \"\"\"Load and preprocess MNIST data.\n",
+ " \n",
+ " In practice, use:\n",
+ " from torchvision import datasets, transforms\n",
+ " \n",
+ " train_dataset = datasets.MNIST(\n",
+ " './data', train=True, download=True,\n",
+ " transform=transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.1307,), (0.3081,))\n",
+ " ])\n",
+ " )\n",
+ " \"\"\"\n",
+ " pass\n",
+ "\n",
+ "def train_on_mnist():\n",
+ " \"\"\"Complete MNIST training workflow.\"\"\"\n",
+ " \n",
+ " # 1. Load data\n",
+ " # X_train, y_train, X_test, y_test = load_mnist_data()\n",
+ " \n",
+ " # 2. Create network\n",
+ " net = TrainableSNN(n_input=784, n_hidden=256, n_output=10)\n",
+ " brainstate.nn.init_all_states(net, batch_size=128)\n",
+ " \n",
+ " # 3. Create optimizer\n",
+ " optimizer = braintools.optim.Adam(learning_rate=1e-3)\n",
+ " optimizer.register_trainable_weights(net.states(brainstate.ParamState))\n",
+ " \n",
+ " # 4. Training loop (epochs, batches, gradient updates)\n",
+ " # for epoch in range(n_epochs):\n",
+ " # for batch in data_loader:\n",
+ " # grads, loss = compute_gradients(...)\n",
+ " # optimizer.update(grads)\n",
+ " \n",
+ " # 5. Evaluation\n",
+ " # test_acc = evaluate(net, X_test, y_test)\n",
+ " \n",
+ " return net\n",
+ "\n",
+ "print(\"๐ MNIST Training Template:\")\n",
+ "print(\"\"\"\\n1. Load MNIST: Use torchvision.datasets.MNIST or tensorflow.keras.datasets.mnist\n",
+ "2. Preprocess: Flatten images (28ร28 โ 784), normalize to [0,1]\n",
+ "3. Convert to currents: Multiply by scaling factor (e.g., 5 nA)\n",
+ "4. Train: Use same loss_fn and training loop as above\n",
+ "5. Expected accuracy: 95-98% on MNIST with proper hyperparameters\n",
+ "\n",
+ "Key hyperparameters to tune:\n",
+ "- Learning rate: Try 1e-3, 5e-4, 1e-4\n",
+ "- Hidden size: Try 128, 256, 512\n",
+ "- Simulation steps: Try 25, 50, 100\n",
+ "- Batch size: Try 32, 64, 128\n",
+ "\"\"\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 9: Advanced Training Techniques\n",
+ "\n",
+ "Here are some advanced techniques to improve SNN training:\n",
+ "\n",
+ "### 1. Learning Rate Scheduling\n",
+ "\n",
+ "Reduce learning rate during training for better convergence."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Example: Exponential decay learning rate schedule\n",
+ "def create_lr_schedule(initial_lr=1e-3, decay_rate=0.95, decay_steps=1000):\n",
+ " \"\"\"Create exponential decay learning rate schedule.\"\"\"\n",
+ " def lr_schedule(step):\n",
+ " return initial_lr * (decay_rate ** (step / decay_steps))\n",
+ " return lr_schedule\n",
+ "\n",
+ "# Usage:\n",
+ "# lr_schedule = create_lr_schedule()\n",
+ "# optimizer = braintools.optim.Adam(learning_rate=lr_schedule)\n",
+ "\n",
+ "print(\"โ
Learning rate scheduling helps with convergence\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. Gradient Clipping\n",
+ "\n",
+ "Prevent gradient explosion by clipping large gradients."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def clip_gradients(grads, max_norm=1.0):\n",
+ " \"\"\"Clip gradients by global norm.\"\"\"\n",
+ " # Compute global norm\n",
+ " global_norm = jnp.sqrt(\n",
+ " sum(jnp.sum(g.value ** 2) for g in grads.values())\n",
+ " )\n",
+ " \n",
+ " # Clip if necessary\n",
+ " clip_coef = max_norm / (global_norm + 1e-6)\n",
+ " clip_coef = jnp.minimum(1.0, clip_coef)\n",
+ " \n",
+ " # Apply clipping\n",
+ " clipped_grads = {}\n",
+ " for name, grad in grads.items():\n",
+ " clipped_grads[name] = brainstate.ParamState(grad.value * clip_coef)\n",
+ " \n",
+ " return clipped_grads\n",
+ "\n",
+ "# Usage in training loop:\n",
+ "# grads = compute_gradients(...)\n",
+ "# grads = clip_gradients(grads, max_norm=1.0)\n",
+ "# optimizer.update(grads)\n",
+ "\n",
+ "print(\"โ
Gradient clipping prevents training instabilities\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. Regularization\n",
+ "\n",
+ "Add L2 regularization to prevent overfitting."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def loss_with_regularization(network, inputs, labels, n_steps=25, l2_weight=1e-4):\n",
+ " \"\"\"Loss function with L2 regularization.\"\"\"\n",
+ " # Standard loss\n",
+ " ce_loss = loss_fn(network, inputs, labels, n_steps)\n",
+ " \n",
+ " # L2 regularization\n",
+ " l2_loss = 0.0\n",
+ " for param in network.states(brainstate.ParamState).values():\n",
+ " l2_loss += jnp.sum(param.value ** 2)\n",
+ " \n",
+ " total_loss = ce_loss + l2_weight * l2_loss\n",
+ " return total_loss\n",
+ "\n",
+ "print(\"โ
L2 regularization improves generalization\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "In this tutorial, you learned:\n",
+ "\n",
+ "โ
**The gradient problem**: Spike generation is non-differentiable\n",
+ "\n",
+ "โ
**Surrogate gradients**: Use smooth approximations during backprop\n",
+ " - Forward: Real spikes\n",
+ " - Backward: Smooth surrogate\n",
+ "\n",
+ "โ
**SNN architecture**: Create trainable networks with LIF neurons\n",
+ "\n",
+ "โ
**Loss functions**: Cross-entropy on accumulated spike outputs\n",
+ "\n",
+ "โ
**Training loop**: BPTT with gradient descent\n",
+ " ```python\n",
+ " grads, loss = brainstate.transform.grad(loss_fn, params)(net, X, y)\n",
+ " optimizer.update(grads)\n",
+ " ```\n",
+ "\n",
+ "โ
**Advanced techniques**: LR scheduling, gradient clipping, regularization\n",
+ "\n",
+ "**Key code pattern:**\n",
+ "```python\n",
+ "# 1. Create network with surrogate gradients\n",
+ "lif = bp.LIF(..., spike_fun=braintools.surrogate.ReluGrad())\n",
+ "\n",
+ "# 2. Define loss over time\n",
+ "def loss_fn(net, X, y, n_steps):\n",
+ " logits = simulate_for_n_steps(net, X, n_steps)\n",
+ " return cross_entropy(logits, y)\n",
+ "\n",
+ "# 3. Compute gradients and update\n",
+ "grads = brainstate.transform.grad(loss_fn, params)(...)\n",
+ "optimizer.update(grads)\n",
+ "```\n",
+ "\n",
+ "**Next steps:**\n",
+ "- Try training on real MNIST/Fashion-MNIST\n",
+ "- Experiment with different surrogate functions\n",
+ "- Tune hyperparameters (learning rate, hidden size, simulation steps)\n",
+ "- Add recurrent connections for temporal tasks\n",
+ "- See Tutorial 6 for incorporating synaptic plasticity\n",
+ "\n",
+ "**References:**\n",
+ "- Neftci et al. (2019): \"Surrogate Gradient Learning in Spiking Neural Networks\"\n",
+ "- Zenke & Ganguli (2018): \"SuperSpike: Supervised learning in multilayer spiking neural networks\"\n",
+ "- Shrestha & Orchard (2018): \"SLAYER: Spike Layer Error Reassignment in Time\"\n",
+ "- Wu et al. (2018): \"Spatio-Temporal Backpropagation for Training High-Performance Spiking Neural Networks\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Exercises\n",
+ "\n",
+ "Test your understanding:\n",
+ "\n",
+ "### Exercise 1: Surrogate Function Comparison\n",
+ "Compare training with different surrogate gradient functions (Sigmoid, ReLU, SuperSpike). Which works best?\n",
+ "\n",
+ "### Exercise 2: Simulation Steps\n",
+ "How does the number of simulation steps (n_steps) affect accuracy and training time? Plot the trade-off.\n",
+ "\n",
+ "### Exercise 3: Network Architecture\n",
+ "Add a second hidden layer. Does deeper architecture improve performance?\n",
+ "\n",
+ "### Exercise 4: Learning Rate Tuning\n",
+ "Implement learning rate scheduling and compare convergence with fixed learning rate.\n",
+ "\n",
+ "### Exercise 5: Real MNIST\n",
+ "Load real MNIST data and train a classifier. Aim for >95% test accuracy!\n",
+ "\n",
+ "**Bonus Challenge:** Implement online learning where the network is trained on streaming data one sample at a time (no batches)."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs_version3/tutorials/advanced/06-synaptic-plasticity.ipynb b/docs_version3/tutorials/advanced/06-synaptic-plasticity.ipynb
new file mode 100644
index 00000000..59cec8d7
--- /dev/null
+++ b/docs_version3/tutorials/advanced/06-synaptic-plasticity.ipynb
@@ -0,0 +1,943 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Tutorial 6: Synaptic Plasticity\n",
+ "\n",
+ "**Duration:** ~40 minutes | **Prerequisites:** Basic Tutorials, Tutorial 5\n",
+ "\n",
+ "## Learning Objectives\n",
+ "\n",
+ "By the end of this tutorial, you will:\n",
+ "\n",
+ "- โ
Understand short-term plasticity (STP) mechanisms\n",
+ "- โ
Implement synaptic depression and facilitation\n",
+ "- โ
Learn spike-timing-dependent plasticity (STDP) principles\n",
+ "- โ
Create adaptive synapses with learning rules\n",
+ "- โ
Build networks with plastic connections\n",
+ "- โ
Combine plasticity with network training\n",
+ "\n",
+ "## Overview\n",
+ "\n",
+ "Synaptic plasticity is the ability of synapses to change their strength over time. This is fundamental to learning and memory in biological brains. BrainPy supports multiple forms of plasticity:\n",
+ "\n",
+ "**Types of plasticity:**\n",
+ "- **Short-term plasticity (STP)**: Temporary changes on timescales of milliseconds to seconds\n",
+ " - Depression (STD): Synaptic strength decreases with repeated use\n",
+ " - Facilitation (STF): Synaptic strength increases with repeated use\n",
+ "- **Long-term plasticity**: Persistent changes\n",
+ " - STDP: Depends on relative timing of pre and post spikes\n",
+ " - Rate-based: Depends on firing rates\n",
+ "\n",
+ "Let's explore these mechanisms!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import brainpy as bp\n",
+ "import brainstate\n",
+ "import brainunit as u\n",
+ "import braintools\n",
+ "import jax.numpy as jnp\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "# Set random seed for reproducibility\n",
+ "brainstate.random.seed(42)\n",
+ "\n",
+ "# Configure environment\n",
+ "brainstate.environ.set(dt=0.1 * u.ms)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 1: Short-Term Depression (STD)\n",
+ "\n",
+ "Short-term depression models the depletion of neurotransmitter resources. Each spike consumes some fraction of available resources, which recover over time.\n",
+ "\n",
+ "**STD dynamics:**\n",
+ "$$\n",
+ "\\frac{dx}{dt} = \\frac{1 - x}{\\tau_d} - u \\cdot x \\cdot \\delta(t - t_{spike})\n",
+ "$$\n",
+ "\n",
+ "Where:\n",
+ "- $x$: Fraction of available resources (0 to 1)\n",
+ "- $\\tau_d$: Recovery time constant\n",
+ "- $u$: Utilization fraction per spike\n",
+ "\n",
+ "**Effect:** Repeated spikes deplete resources โ synaptic current decreases"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create synapse with short-term depression\n",
+ "class STDSynapse(bp.Synapse):\n",
+ " \"\"\"Synapse with short-term depression.\"\"\"\n",
+ " \n",
+ " def __init__(self, size, tau=5.0*u.ms, tau_d=200.0*u.ms, U=0.5, **kwargs):\n",
+ " super().__init__(size, **kwargs)\n",
+ " \n",
+ " # Synapse parameters\n",
+ " self.tau = tau # Synaptic time constant\n",
+ " self.tau_d = tau_d # Depression time constant\n",
+ " self.U = U # Utilization fraction\n",
+ " \n",
+ " # States\n",
+ " self.g = brainstate.ShortTermState(jnp.zeros(size)) # Conductance\n",
+ " self.x = brainstate.ShortTermState(jnp.ones(size)) # Available resources\n",
+ " \n",
+ " def reset_state(self, batch_size=None):\n",
+ " self.g.value = jnp.zeros(self.size if batch_size is None else (batch_size, self.size))\n",
+ " self.x.value = jnp.ones(self.size if batch_size is None else (batch_size, self.size))\n",
+ " \n",
+ " def update(self, pre_spike):\n",
+ " # Get time step\n",
+ " dt = brainstate.environ.get_dt()\n",
+ " \n",
+ " # Depression: reduce available resources on spike\n",
+ " x_new = self.x.value + pre_spike * (-self.U * self.x.value)\n",
+ " \n",
+ " # Recovery: exponential recovery of resources\n",
+ " dx = (1.0 - x_new) / self.tau_d.to_decimal(u.ms) * dt.to_decimal(u.ms)\n",
+ " self.x.value = x_new + dx\n",
+ " \n",
+ " # Synaptic current: modulated by available resources\n",
+ " dg = -self.g.value / self.tau.to_decimal(u.ms) * dt.to_decimal(u.ms)\n",
+ " self.g.value += dg + pre_spike * self.U * self.x.value\n",
+ " \n",
+ " return self.g.value\n",
+ "\n",
+ "# Test with spike train\n",
+ "std_syn = STDSynapse(size=1, tau=5.0*u.ms, tau_d=200.0*u.ms, U=0.5)\n",
+ "brainstate.nn.init_all_states(std_syn)\n",
+ "\n",
+ "# Generate spike train: 10 spikes at 20 Hz\n",
+ "duration = 1000 * u.ms\n",
+ "n_steps = int(duration / brainstate.environ.get_dt())\n",
+ "spike_times = [50, 100, 150, 200, 250, 300, 350, 400, 450, 500] # in ms\n",
+ "spike_indices = [int(t / 0.1) for t in spike_times]\n",
+ "\n",
+ "# Simulate\n",
+ "g_history = []\n",
+ "x_history = []\n",
+ "\n",
+ "for i in range(n_steps):\n",
+ " spike = 1.0 if i in spike_indices else 0.0\n",
+ " g = std_syn(spike)\n",
+ " g_history.append(float(g))\n",
+ " x_history.append(float(std_syn.x.value))\n",
+ "\n",
+ "# Plot\n",
+ "times = np.arange(n_steps) * 0.1\n",
+ "\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)\n",
+ "\n",
+ "# Synaptic conductance\n",
+ "axes[0].plot(times, g_history, 'b-', linewidth=2)\n",
+ "for st in spike_times:\n",
+ " axes[0].axvline(st, color='r', linestyle='--', alpha=0.3)\n",
+ "axes[0].set_ylabel('Synaptic Conductance g', fontsize=12)\n",
+ "axes[0].set_title('Short-Term Depression', fontsize=14, fontweight='bold')\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Available resources\n",
+ "axes[1].plot(times, x_history, 'g-', linewidth=2)\n",
+ "for st in spike_times:\n",
+ " axes[1].axvline(st, color='r', linestyle='--', alpha=0.3, label='Spike' if st == spike_times[0] else '')\n",
+ "axes[1].set_xlabel('Time (ms)', fontsize=12)\n",
+ "axes[1].set_ylabel('Available Resources x', fontsize=12)\n",
+ "axes[1].set_title('Resource Depletion and Recovery', fontsize=14, fontweight='bold')\n",
+ "axes[1].legend()\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"๐ STD Observations:\")\n",
+ "print(\" โข Each spike depletes available resources\")\n",
+ "print(\" โข Synaptic conductance decreases with repeated spikes\")\n",
+ "print(\" โข Resources recover exponentially between spikes\")\n",
+ "print(\" โข Implements synaptic fatigue\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 2: Short-Term Facilitation (STF)\n",
+ "\n",
+ "Short-term facilitation models the buildup of calcium in the presynaptic terminal, which increases neurotransmitter release probability.\n",
+ "\n",
+ "**STF dynamics:**\n",
+ "$$\n",
+ "\\frac{du}{dt} = \\frac{U - u}{\\tau_f} + U(1 - u) \\cdot \\delta(t - t_{spike})\n",
+ "$$\n",
+ "\n",
+ "Where:\n",
+ "- $u$: Utilization parameter (increases with spikes)\n",
+ "- $\\tau_f$: Facilitation time constant\n",
+ "- $U$: Baseline utilization\n",
+ "\n",
+ "**Effect:** Repeated spikes increase utilization โ synaptic current increases"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class STFSynapse(bp.Synapse):\n",
+ " \"\"\"Synapse with short-term facilitation.\"\"\"\n",
+ " \n",
+ " def __init__(self, size, tau=5.0*u.ms, tau_f=200.0*u.ms, U=0.15, **kwargs):\n",
+ " super().__init__(size, **kwargs)\n",
+ " \n",
+ " self.tau = tau\n",
+ " self.tau_f = tau_f # Facilitation time constant\n",
+ " self.U = U # Baseline utilization\n",
+ " \n",
+ " # States\n",
+ " self.g = brainstate.ShortTermState(jnp.zeros(size))\n",
+ " self.u = brainstate.ShortTermState(jnp.ones(size) * U) # Current utilization\n",
+ " \n",
+ " def reset_state(self, batch_size=None):\n",
+ " self.g.value = jnp.zeros(self.size if batch_size is None else (batch_size, self.size))\n",
+ " self.u.value = jnp.ones(self.size if batch_size is None else (batch_size, self.size)) * self.U\n",
+ " \n",
+ " def update(self, pre_spike):\n",
+ " dt = brainstate.environ.get_dt()\n",
+ " \n",
+ " # Facilitation: increase utilization on spike\n",
+ " u_new = self.u.value + pre_spike * (self.U * (1.0 - self.u.value))\n",
+ " \n",
+ " # Decay: exponential decay of facilitation\n",
+ " du = -(u_new - self.U) / self.tau_f.to_decimal(u.ms) * dt.to_decimal(u.ms)\n",
+ " self.u.value = u_new + du\n",
+ " \n",
+ " # Synaptic current: modulated by current utilization\n",
+ " dg = -self.g.value / self.tau.to_decimal(u.ms) * dt.to_decimal(u.ms)\n",
+ " self.g.value += dg + pre_spike * self.u.value\n",
+ " \n",
+ " return self.g.value\n",
+ "\n",
+ "# Test facilitation\n",
+ "stf_syn = STFSynapse(size=1, tau=5.0*u.ms, tau_f=200.0*u.ms, U=0.15)\n",
+ "brainstate.nn.init_all_states(stf_syn)\n",
+ "\n",
+ "# Same spike train as STD\n",
+ "g_history_f = []\n",
+ "u_history = []\n",
+ "\n",
+ "for i in range(n_steps):\n",
+ " spike = 1.0 if i in spike_indices else 0.0\n",
+ " g = stf_syn(spike)\n",
+ " g_history_f.append(float(g))\n",
+ " u_history.append(float(stf_syn.u.value))\n",
+ "\n",
+ "# Plot comparison\n",
+ "fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)\n",
+ "\n",
+ "# STD conductance\n",
+ "axes[0].plot(times, g_history, 'b-', linewidth=2, label='STD')\n",
+ "for st in spike_times:\n",
+ " axes[0].axvline(st, color='r', linestyle='--', alpha=0.2)\n",
+ "axes[0].set_ylabel('Conductance', fontsize=12)\n",
+ "axes[0].set_title('Depression: Decreasing Response', fontsize=14, fontweight='bold')\n",
+ "axes[0].legend()\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# STF conductance\n",
+ "axes[1].plot(times, g_history_f, 'g-', linewidth=2, label='STF')\n",
+ "for st in spike_times:\n",
+ " axes[1].axvline(st, color='r', linestyle='--', alpha=0.2)\n",
+ "axes[1].set_ylabel('Conductance', fontsize=12)\n",
+ "axes[1].set_title('Facilitation: Increasing Response', fontsize=14, fontweight='bold')\n",
+ "axes[1].legend()\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Utilization parameter\n",
+ "axes[2].plot(times, u_history, 'm-', linewidth=2)\n",
+ "for st in spike_times:\n",
+ " axes[2].axvline(st, color='r', linestyle='--', alpha=0.2, label='Spike' if st == spike_times[0] else '')\n",
+ "axes[2].set_xlabel('Time (ms)', fontsize=12)\n",
+ "axes[2].set_ylabel('Utilization u', fontsize=12)\n",
+ "axes[2].set_title('Facilitation Buildup', fontsize=14, fontweight='bold')\n",
+ "axes[2].legend()\n",
+ "axes[2].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"๐ STF vs STD:\")\n",
+ "print(\" STD: Synaptic strength DECREASES with repeated use\")\n",
+ "print(\" STF: Synaptic strength INCREASES with repeated use\")\n",
+ "print(\" Both effects are temporary (100s of ms)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 3: Combined STP (Depression + Facilitation)\n",
+ "\n",
+ "Real synapses often exhibit both depression and facilitation. BrainPy provides a combined STP model.\n",
+ "\n",
+ "**Combined dynamics:**\n",
+ "- Depression: Resource depletion with time constant $\\tau_d$\n",
+ "- Facilitation: Utilization increase with time constant $\\tau_f$\n",
+ "- Effective synaptic current: $g_{eff} = u \\cdot x \\cdot g$\n",
+ "\n",
+ "Depending on relative values of $\\tau_d$, $\\tau_f$, and $U$, synapses can be:\n",
+ "- **Depressing**: $\\tau_d \\gg \\tau_f$, large $U$\n",
+ "- **Facilitating**: $\\tau_f \\gg \\tau_d$, small $U$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Use BrainPy's built-in STP model\n",
+ "# For demonstration, we'll test different parameter regimes\n",
+ "\n",
+ "def simulate_stp(tau_f, tau_d, U, spike_indices, n_steps, label):\n",
+ " \"\"\"Simulate STP synapse and return conductance history.\"\"\"\n",
+ " \n",
+ " class STPSynapse(bp.Synapse):\n",
+ " def __init__(self, size, **kwargs):\n",
+ " super().__init__(size, **kwargs)\n",
+ " self.tau = 5.0 * u.ms\n",
+ " self.tau_f = tau_f\n",
+ " self.tau_d = tau_d\n",
+ " self.U = U\n",
+ " self.g = brainstate.ShortTermState(jnp.zeros(size))\n",
+ " self.x = brainstate.ShortTermState(jnp.ones(size))\n",
+ " self.u = brainstate.ShortTermState(jnp.ones(size) * U)\n",
+ " \n",
+ " def reset_state(self, batch_size=None):\n",
+ " self.g.value = jnp.zeros(self.size if batch_size is None else (batch_size, self.size))\n",
+ " self.x.value = jnp.ones(self.size if batch_size is None else (batch_size, self.size))\n",
+ " self.u.value = jnp.ones(self.size if batch_size is None else (batch_size, self.size)) * self.U\n",
+ " \n",
+ " def update(self, pre_spike):\n",
+ " dt = brainstate.environ.get_dt()\n",
+ " \n",
+ " # Facilitation\n",
+ " u_new = self.u.value + pre_spike * (self.U * (1.0 - self.u.value))\n",
+ " du = -(u_new - self.U) / self.tau_f.to_decimal(u.ms) * dt.to_decimal(u.ms)\n",
+ " self.u.value = u_new + du\n",
+ " \n",
+ " # Depression\n",
+ " x_new = self.x.value + pre_spike * (-self.u.value * self.x.value)\n",
+ " dx = (1.0 - x_new) / self.tau_d.to_decimal(u.ms) * dt.to_decimal(u.ms)\n",
+ " self.x.value = x_new + dx\n",
+ " \n",
+ " # Conductance\n",
+ " dg = -self.g.value / self.tau.to_decimal(u.ms) * dt.to_decimal(u.ms)\n",
+ " self.g.value += dg + pre_spike * self.u.value * self.x.value\n",
+ " \n",
+ " return self.g.value\n",
+ " \n",
+ " syn = STPSynapse(size=1)\n",
+ " brainstate.nn.init_all_states(syn)\n",
+ " \n",
+ " g_hist = []\n",
+ " for i in range(n_steps):\n",
+ " spike = 1.0 if i in spike_indices else 0.0\n",
+ " g = syn(spike)\n",
+ " g_hist.append(float(g))\n",
+ " \n",
+ " return g_hist\n",
+ "\n",
+ "# Three parameter regimes\n",
+ "g_depressing = simulate_stp(\n",
+ " tau_f=50.0*u.ms, tau_d=400.0*u.ms, U=0.6,\n",
+ " spike_indices=spike_indices, n_steps=n_steps, label='Depressing'\n",
+ ")\n",
+ "\n",
+ "g_facilitating = simulate_stp(\n",
+ " tau_f=400.0*u.ms, tau_d=50.0*u.ms, U=0.1,\n",
+ " spike_indices=spike_indices, n_steps=n_steps, label='Facilitating'\n",
+ ")\n",
+ "\n",
+ "g_mixed = simulate_stp(\n",
+ " tau_f=200.0*u.ms, tau_d=200.0*u.ms, U=0.3,\n",
+ " spike_indices=spike_indices, n_steps=n_steps, label='Mixed'\n",
+ ")\n",
+ "\n",
+ "# Plot all three\n",
+ "fig, ax = plt.subplots(figsize=(14, 6))\n",
+ "\n",
+ "ax.plot(times, g_depressing, 'b-', linewidth=2, label='Depressing (large U, slow recovery)', alpha=0.8)\n",
+ "ax.plot(times, g_facilitating, 'g-', linewidth=2, label='Facilitating (small U, fast recovery)', alpha=0.8)\n",
+ "ax.plot(times, g_mixed, 'm-', linewidth=2, label='Mixed (balanced)', alpha=0.8)\n",
+ "\n",
+ "for st in spike_times:\n",
+ " ax.axvline(st, color='r', linestyle='--', alpha=0.2)\n",
+ "\n",
+ "ax.set_xlabel('Time (ms)', fontsize=12)\n",
+ "ax.set_ylabel('Synaptic Conductance', fontsize=12)\n",
+ "ax.set_title('Short-Term Plasticity: Parameter Regimes', fontsize=14, fontweight='bold')\n",
+ "ax.legend(fontsize=11)\n",
+ "ax.grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"๐ STP Parameter Regimes:\")\n",
+ "print(\" Blue (Depressing): High U, slow depression recovery\")\n",
+ "print(\" Green (Facilitating): Low U, fast depression recovery\")\n",
+ "print(\" Magenta (Mixed): Balanced parameters\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 4: Spike-Timing-Dependent Plasticity (STDP)\n",
+ "\n",
+ "STDP is a form of long-term plasticity where synaptic strength changes depend on the relative timing of pre- and postsynaptic spikes.\n",
+ "\n",
+ "**STDP rule:**\n",
+ "- **Potentiation**: If pre-spike occurs before post-spike ($\\Delta t > 0$), strengthen synapse\n",
+ "- **Depression**: If post-spike occurs before pre-spike ($\\Delta t < 0$), weaken synapse\n",
+ "\n",
+ "**Weight update:**\n",
+ "$$\n",
+ "\\Delta w = \\begin{cases}\n",
+ "A_+ e^{-\\Delta t / \\tau_+} & \\text{if } \\Delta t > 0 \\\\\n",
+ "-A_- e^{\\Delta t / \\tau_-} & \\text{if } \\Delta t < 0\n",
+ "\\end{cases}\n",
+ "$$\n",
+ "\n",
+ "Where $\\Delta t = t_{post} - t_{pre}$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# STDP learning window\n",
+ "def stdp_window(dt_values, A_plus=0.01, A_minus=0.01, tau_plus=20.0, tau_minus=20.0):\n",
+ " \"\"\"Compute STDP weight change as a function of spike timing difference.\"\"\"\n",
+ " dw = np.zeros_like(dt_values)\n",
+ " \n",
+ " # Potentiation (pre before post)\n",
+ " pos_mask = dt_values > 0\n",
+ " dw[pos_mask] = A_plus * np.exp(-dt_values[pos_mask] / tau_plus)\n",
+ " \n",
+ " # Depression (post before pre)\n",
+ " neg_mask = dt_values < 0\n",
+ " dw[neg_mask] = -A_minus * np.exp(dt_values[neg_mask] / tau_minus)\n",
+ " \n",
+ " return dw\n",
+ "\n",
+ "# Plot STDP window\n",
+ "dt_range = np.linspace(-100, 100, 1000)\n",
+ "dw_values = stdp_window(dt_range)\n",
+ "\n",
+ "fig, ax = plt.subplots(figsize=(12, 6))\n",
+ "\n",
+ "# Plot STDP curve\n",
+ "ax.plot(dt_range, dw_values, 'b-', linewidth=3)\n",
+ "ax.axhline(0, color='k', linestyle='--', alpha=0.3)\n",
+ "ax.axvline(0, color='k', linestyle='--', alpha=0.3)\n",
+ "\n",
+ "# Annotate regions\n",
+ "ax.fill_between(dt_range[dt_range > 0], 0, dw_values[dt_range > 0], \n",
+ " alpha=0.3, color='green', label='LTP (potentiation)')\n",
+ "ax.fill_between(dt_range[dt_range < 0], 0, dw_values[dt_range < 0], \n",
+ " alpha=0.3, color='red', label='LTD (depression)')\n",
+ "\n",
+ "ax.set_xlabel('ฮt = t_post - t_pre (ms)', fontsize=12)\n",
+ "ax.set_ylabel('Weight Change ฮw', fontsize=12)\n",
+ "ax.set_title('STDP Learning Window', fontsize=14, fontweight='bold')\n",
+ "ax.legend(fontsize=11)\n",
+ "ax.grid(True, alpha=0.3)\n",
+ "\n",
+ "# Add annotations\n",
+ "ax.annotate('Pre โ Post\\nStrengthen', xy=(30, 0.006), fontsize=10,\n",
+ " ha='center', bbox=dict(boxstyle='round', facecolor='green', alpha=0.2))\n",
+ "ax.annotate('Post โ Pre\\nWeaken', xy=(-30, -0.006), fontsize=10,\n",
+ " ha='center', bbox=dict(boxstyle='round', facecolor='red', alpha=0.2))\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"๐ STDP Principle:\")\n",
+ "print(\" 'Neurons that fire together, wire together'\")\n",
+ "print(\" Positive ฮt (preโpost): Potentiation (LTP)\")\n",
+ "print(\" Negative ฮt (postโpre): Depression (LTD)\")\n",
+ "print(\" Exponential decay with distance from ฮt=0\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 5: Implementing STDP in Networks\n",
+ "\n",
+ "Let's implement a simple STDP learning rule in a small network. We'll track spike times and update weights according to the STDP rule."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class STDPSynapse(bp.Synapse):\n",
+ " \"\"\"Synapse with STDP learning.\"\"\"\n",
+ " \n",
+ " def __init__(self, size, tau=5.0*u.ms, A_plus=0.01, A_minus=0.01, \n",
+ " tau_plus=20.0*u.ms, tau_minus=20.0*u.ms, w_max=1.0, **kwargs):\n",
+ " super().__init__(size, **kwargs)\n",
+ " \n",
+ " self.tau = tau\n",
+ " self.A_plus = A_plus\n",
+ " self.A_minus = A_minus\n",
+ " self.tau_plus = tau_plus\n",
+ " self.tau_minus = tau_minus\n",
+ " self.w_max = w_max\n",
+ " \n",
+ " # States\n",
+ " self.g = brainstate.ShortTermState(jnp.zeros(size))\n",
+ " self.w = brainstate.ParamState(jnp.ones(size) * 0.5) # Learnable weights\n",
+ " self.pre_trace = brainstate.ShortTermState(jnp.zeros(size)) # Pre-synaptic trace\n",
+ " self.post_trace = brainstate.ShortTermState(jnp.zeros(size)) # Post-synaptic trace\n",
+ " \n",
+ " def reset_state(self, batch_size=None):\n",
+ " self.g.value = jnp.zeros(self.size if batch_size is None else (batch_size, self.size))\n",
+ " self.pre_trace.value = jnp.zeros(self.size if batch_size is None else (batch_size, self.size))\n",
+ " self.post_trace.value = jnp.zeros(self.size if batch_size is None else (batch_size, self.size))\n",
+ " \n",
+ " def update(self, pre_spike, post_spike=None):\n",
+ " dt = brainstate.environ.get_dt()\n",
+ " \n",
+ " # Update pre-synaptic trace\n",
+ " self.pre_trace.value += -self.pre_trace.value / self.tau_plus.to_decimal(u.ms) * dt.to_decimal(u.ms)\n",
+ " self.pre_trace.value += pre_spike\n",
+ " \n",
+ " # Update conductance\n",
+ " dg = -self.g.value / self.tau.to_decimal(u.ms) * dt.to_decimal(u.ms)\n",
+ " self.g.value += dg + pre_spike * self.w.value\n",
+ " \n",
+ " # STDP learning (if post spike provided)\n",
+ " if post_spike is not None:\n",
+ " # Update post-synaptic trace\n",
+ " self.post_trace.value += -self.post_trace.value / self.tau_minus.to_decimal(u.ms) * dt.to_decimal(u.ms)\n",
+ " self.post_trace.value += post_spike\n",
+ " \n",
+ " # Weight updates\n",
+ " # LTP: pre spike causes weight increase proportional to post trace\n",
+ " dw_ltp = self.A_plus * pre_spike * self.post_trace.value\n",
+ " # LTD: post spike causes weight decrease proportional to pre trace\n",
+ " dw_ltd = -self.A_minus * post_spike * self.pre_trace.value\n",
+ " \n",
+ " # Update weights with bounds\n",
+ " self.w.value = jnp.clip(self.w.value + dw_ltp + dw_ltd, 0.0, self.w_max)\n",
+ " \n",
+ " return self.g.value\n",
+ "\n",
+ "# Test STDP learning\n",
+ "stdp_syn = STDPSynapse(size=1, A_plus=0.005, A_minus=0.005)\n",
+ "brainstate.nn.init_all_states(stdp_syn)\n",
+ "\n",
+ "# Simulate with correlated pre-post spikes\n",
+ "duration = 1000 * u.ms\n",
+ "n_steps = int(duration / brainstate.environ.get_dt())\n",
+ "\n",
+ "# Pre spikes followed by post spikes (should cause LTP)\n",
+ "pre_spike_times = [100, 300, 500, 700, 900] # ms\n",
+ "post_spike_times = [105, 305, 505, 705, 905] # 5ms after pre (potentiation)\n",
+ "\n",
+ "pre_indices = [int(t / 0.1) for t in pre_spike_times]\n",
+ "post_indices = [int(t / 0.1) for t in post_spike_times]\n",
+ "\n",
+ "w_history = []\n",
+ "for i in range(n_steps):\n",
+ " pre_spike = 1.0 if i in pre_indices else 0.0\n",
+ " post_spike = 1.0 if i in post_indices else 0.0\n",
+ " stdp_syn(pre_spike, post_spike)\n",
+ " w_history.append(float(stdp_syn.w.value))\n",
+ "\n",
+ "# Plot weight evolution\n",
+ "times_plot = np.arange(n_steps) * 0.1\n",
+ "\n",
+ "fig, ax = plt.subplots(figsize=(12, 6))\n",
+ "\n",
+ "ax.plot(times_plot, w_history, 'b-', linewidth=2, label='Synaptic Weight')\n",
+ "\n",
+ "for pt, pst in zip(pre_spike_times, post_spike_times):\n",
+ " ax.axvline(pt, color='g', linestyle='--', alpha=0.3, linewidth=1.5)\n",
+ " ax.axvline(pst, color='r', linestyle='--', alpha=0.3, linewidth=1.5)\n",
+ "\n",
+ "# Add legend entries\n",
+ "from matplotlib.lines import Line2D\n",
+ "legend_elements = [\n",
+ " Line2D([0], [0], color='b', linewidth=2, label='Synaptic Weight'),\n",
+ " Line2D([0], [0], color='g', linestyle='--', label='Pre-spike'),\n",
+ " Line2D([0], [0], color='r', linestyle='--', label='Post-spike (5ms later)')\n",
+ "]\n",
+ "ax.legend(handles=legend_elements, fontsize=11)\n",
+ "\n",
+ "ax.set_xlabel('Time (ms)', fontsize=12)\n",
+ "ax.set_ylabel('Weight', fontsize=12)\n",
+ "ax.set_title('STDP Learning: Weight Potentiation', fontsize=14, fontweight='bold')\n",
+ "ax.grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(f\"๐ STDP Learning Result:\")\n",
+ "print(f\" Initial weight: {w_history[0]:.3f}\")\n",
+ "print(f\" Final weight: {w_history[-1]:.3f}\")\n",
+ "print(f\" Change: {w_history[-1] - w_history[0]:+.3f}\")\n",
+ "print(f\" โ
Weight increased due to consistent preโpost timing!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 6: Network with Plastic Synapses\n",
+ "\n",
+ "Let's build a small recurrent network with STDP to see how plasticity affects network dynamics."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class PlasticNetwork(brainstate.nn.Module):\n",
+ " \"\"\"Recurrent network with STDP.\"\"\"\n",
+ " \n",
+ " def __init__(self, n_neurons=10, connectivity=0.3):\n",
+ " super().__init__()\n",
+ " \n",
+ " self.n_neurons = n_neurons\n",
+ " \n",
+ " # LIF neurons\n",
+ " self.neurons = bp.LIF(\n",
+ " n_neurons,\n",
+ " V_rest=-65.0 * u.mV,\n",
+ " V_th=-50.0 * u.mV,\n",
+ " V_reset=-65.0 * u.mV,\n",
+ " tau=10.0 * u.ms\n",
+ " )\n",
+ " \n",
+ " # Recurrent connections with STDP (simplified)\n",
+ " # In practice, use projection structure\n",
+ " self.connectivity = connectivity\n",
+ " mask = (np.random.rand(n_neurons, n_neurons) < connectivity).astype(float)\n",
+ " np.fill_diagonal(mask, 0) # No self-connections\n",
+ " \n",
+ " self.conn_matrix = brainstate.ParamState(jnp.array(mask))\n",
+ " self.weights = brainstate.ParamState(\n",
+ " jnp.array(mask * 0.5) # Initial weights\n",
+ " )\n",
+ " \n",
+ " def update(self, inp):\n",
+ " # Get current spikes\n",
+ " spikes = self.neurons.get_spike()\n",
+ " \n",
+ " # Compute recurrent input\n",
+ " recurrent_input = jnp.dot(spikes, self.weights.value) * u.nA\n",
+ " \n",
+ " # Update neurons\n",
+ " self.neurons(inp + recurrent_input)\n",
+ " \n",
+ " return spikes\n",
+ " \n",
+ " def apply_stdp(self, pre_spikes, post_spikes, learning_rate=0.001):\n",
+ " \"\"\"Apply STDP update to weights.\"\"\"\n",
+ " # Simple STDP: strengthen connections where both fire\n",
+ " # (This is simplified; real STDP uses spike timing)\n",
+ " dw = learning_rate * jnp.outer(post_spikes, pre_spikes)\n",
+ " \n",
+ " # Update weights with connectivity mask\n",
+ " new_weights = self.weights.value + dw * self.conn_matrix.value\n",
+ " self.weights.value = jnp.clip(new_weights, 0.0, 1.0)\n",
+ "\n",
+ "# Create network\n",
+ "net = PlasticNetwork(n_neurons=20, connectivity=0.2)\n",
+ "brainstate.nn.init_all_states(net)\n",
+ "\n",
+ "# Simulate with external input\n",
+ "duration = 500 * u.ms\n",
+ "n_steps = int(duration / brainstate.environ.get_dt())\n",
+ "\n",
+ "spike_records = []\n",
+ "weight_norms = []\n",
+ "\n",
+ "for i in range(n_steps):\n",
+ " # Random external input\n",
+ " inp = brainstate.random.rand(net.n_neurons) * 2.0 * u.nA\n",
+ " \n",
+ " # Get spikes before update\n",
+ " pre_spikes = net.neurons.get_spike()\n",
+ " \n",
+ " # Update network\n",
+ " post_spikes = net(inp)\n",
+ " \n",
+ " # Apply STDP\n",
+ " if i % 10 == 0: # Update every 10 steps\n",
+ " net.apply_stdp(pre_spikes, post_spikes)\n",
+ " \n",
+ " spike_records.append(post_spikes)\n",
+ " weight_norms.append(float(jnp.linalg.norm(net.weights.value)))\n",
+ "\n",
+ "spike_records = jnp.array(spike_records)\n",
+ "\n",
+ "# Visualize\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(14, 8))\n",
+ "\n",
+ "# Spike raster\n",
+ "times_ms = np.arange(n_steps) * 0.1\n",
+ "for neuron_idx in range(net.n_neurons):\n",
+ " spike_times = times_ms[spike_records[:, neuron_idx] > 0]\n",
+ " axes[0].scatter(spike_times, [neuron_idx] * len(spike_times), \n",
+ " s=1, c='black', alpha=0.5)\n",
+ "\n",
+ "axes[0].set_ylabel('Neuron Index', fontsize=12)\n",
+ "axes[0].set_title('Network Activity with STDP', fontsize=14, fontweight='bold')\n",
+ "axes[0].set_xlim(0, float(duration.to_decimal(u.ms)))\n",
+ "axes[0].grid(True, alpha=0.3, axis='x')\n",
+ "\n",
+ "# Weight evolution\n",
+ "axes[1].plot(times_ms, weight_norms, 'b-', linewidth=2)\n",
+ "axes[1].set_xlabel('Time (ms)', fontsize=12)\n",
+ "axes[1].set_ylabel('Weight Norm', fontsize=12)\n",
+ "axes[1].set_title('Evolution of Synaptic Weights', fontsize=14, fontweight='bold')\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"๐ Network with Plasticity:\")\n",
+ "print(f\" Initial weight norm: {weight_norms[0]:.3f}\")\n",
+ "print(f\" Final weight norm: {weight_norms[-1]:.3f}\")\n",
+ "print(f\" Change: {weight_norms[-1] - weight_norms[0]:+.3f}\")\n",
+ "print(\" Weights adapt based on network activity!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 7: Combining Plasticity with Training\n",
+ "\n",
+ "Plasticity can be combined with gradient-based training. This creates networks that:\n",
+ "1. Learn through backpropagation (supervised)\n",
+ "2. Adapt through plasticity (unsupervised)\n",
+ "\n",
+ "**Hybrid approach:**\n",
+ "- Use gradient descent to train feedforward weights\n",
+ "- Use STDP/STP for recurrent weights\n",
+ "- Combine benefits of both learning paradigms"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Template for hybrid learning\n",
+ "class HybridNetwork(brainstate.nn.Module):\n",
+ " \"\"\"Network combining gradient-based and plasticity-based learning.\"\"\"\n",
+ " \n",
+ " def __init__(self, n_input, n_hidden, n_output):\n",
+ " super().__init__()\n",
+ " \n",
+ " # Feedforward layers (trained with gradients)\n",
+ " self.fc1 = brainstate.nn.Linear(n_input, n_hidden)\n",
+ " self.hidden = bp.LIF(\n",
+ " n_hidden,\n",
+ " V_rest=-65.0*u.mV, V_th=-50.0*u.mV, tau=10.0*u.ms,\n",
+ " spike_fun=braintools.surrogate.ReluGrad()\n",
+ " )\n",
+ " \n",
+ " # Recurrent connections (updated with STDP)\n",
+ " # Would use STDPSynapse in practice\n",
+ " \n",
+ " self.fc2 = brainstate.nn.Linear(n_hidden, n_output)\n",
+ " self.output = bp.LIF(\n",
+ " n_output,\n",
+ " V_rest=-65.0*u.mV, V_th=-50.0*u.mV, tau=10.0*u.ms,\n",
+ " spike_fun=braintools.surrogate.ReluGrad()\n",
+ " )\n",
+ " \n",
+ " self.readout = bp.Readout(n_output, n_output)\n",
+ " \n",
+ " def update(self, x):\n",
+ " # Feedforward path (gradient-trained)\n",
+ " current1 = self.fc1(x)\n",
+ " self.hidden(current1)\n",
+ " h_spikes = self.hidden.get_spike()\n",
+ " \n",
+ " # Add recurrent dynamics here (STDP-updated)\n",
+ " # ...\n",
+ " \n",
+ " current2 = self.fc2(h_spikes)\n",
+ " self.output(current2)\n",
+ " o_spikes = self.output.get_spike()\n",
+ " \n",
+ " return self.readout(o_spikes)\n",
+ "\n",
+ "print(\"๐ก Hybrid Learning Strategy:\")\n",
+ "print(\"\"\"\\n1. Feedforward weights: Trained with gradient descent (supervised)\n",
+ " - Fast convergence\n",
+ " - Optimized for task objective\n",
+ "\n",
+ "2. Recurrent weights: Updated with STDP (unsupervised)\n",
+ " - Biologically plausible\n",
+ " - Adapts to input statistics\n",
+ " - Provides temporal dynamics\n",
+ "\n",
+ "3. Benefits:\n",
+ " - Best of both worlds\n",
+ " - Robust to distribution shift\n",
+ " - Continual adaptation\n",
+ "\n",
+ "Implementation:\n",
+ " - Train feedforward with brainstate.transform.grad()\n",
+ " - Update recurrent with STDP rule\n",
+ " - Alternate or interleave both updates\n",
+ "\"\"\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "In this tutorial, you learned:\n",
+ "\n",
+ "โ
**Short-term plasticity (STP)**\n",
+ " - Depression: Resource depletion, decreasing response\n",
+ " - Facilitation: Calcium buildup, increasing response\n",
+ " - Combined dynamics for realistic synapses\n",
+ "\n",
+ "โ
**STDP principles**\n",
+ " - Spike timing matters: preโpost strengthens, postโpre weakens\n",
+ " - Exponential learning window\n",
+ " - \"Fire together, wire together\"\n",
+ "\n",
+ "โ
**Implementation**\n",
+ " - Create custom synapse classes with plasticity\n",
+ " - Track spike traces for STDP\n",
+ " - Update weights based on activity\n",
+ "\n",
+ "โ
**Network plasticity**\n",
+ " - Embed plastic synapses in networks\n",
+ " - Observe weight evolution\n",
+ " - Combine with gradient-based training\n",
+ "\n",
+ "**Key code patterns:**\n",
+ "\n",
+ "```python\n",
+ "# Short-term depression\n",
+ "class STDSynapse(bp.Synapse):\n",
+ " def update(self, pre_spike):\n",
+ " # Deplete resources on spike\n",
+ " self.x.value -= pre_spike * U * self.x.value\n",
+ " # Exponential recovery\n",
+ " self.x.value += (1 - self.x.value) / tau_d * dt\n",
+ " # Modulated conductance\n",
+ " self.g.value += pre_spike * U * self.x.value\n",
+ "\n",
+ "# STDP learning\n",
+ "class STDPSynapse(bp.Synapse):\n",
+ " def update(self, pre_spike, post_spike):\n",
+ " # Update traces\n",
+ " self.pre_trace.value += pre_spike\n",
+ " self.post_trace.value += post_spike\n",
+ " # Weight updates\n",
+ " dw_ltp = A_plus * pre_spike * self.post_trace.value\n",
+ " dw_ltd = -A_minus * post_spike * self.pre_trace.value\n",
+ " self.w.value += dw_ltp + dw_ltd\n",
+ "```\n",
+ "\n",
+ "**Next steps:**\n",
+ "- Implement full STDP in recurrent networks\n",
+ "- Explore homeostatic plasticity (weight normalization)\n",
+ "- Combine plasticity with network training (Tutorial 5)\n",
+ "- Study biological learning rules (BCM, Oja's rule)\n",
+ "- See Tutorial 7 for scaling plastic networks\n",
+ "\n",
+ "**References:**\n",
+ "- Markram et al. (1998): \"Redistribution of synaptic efficacy between neocortical pyramidal neurons\" (STP)\n",
+ "- Bi & Poo (1998): \"Synaptic modifications in cultured hippocampal neurons\" (STDP)\n",
+ "- Song et al. (2000): \"Competitive Hebbian learning through spike-timing-dependent synaptic plasticity\" (STDP theory)\n",
+ "- Tsodyks & Markram (1997): \"The neural code between neocortical pyramidal neurons depends on neurotransmitter release probability\" (STP model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Exercises\n",
+ "\n",
+ "Test your understanding:\n",
+ "\n",
+ "### Exercise 1: Parameter Exploration\n",
+ "Vary STD/STF time constants and observe how they affect frequency filtering. Which regimes amplify or attenuate high-frequency inputs?\n",
+ "\n",
+ "### Exercise 2: STDP Pattern Learning\n",
+ "Create a network that learns to respond to specific temporal patterns using STDP. Test with repeated spike sequences.\n",
+ "\n",
+ "### Exercise 3: Homeostatic Plasticity\n",
+ "Implement weight normalization to prevent runaway potentiation/depression. Keep total synaptic weight constant.\n",
+ "\n",
+ "### Exercise 4: Recurrent STDP\n",
+ "Build a recurrent network where all connections use STDP. Observe emergence of structured connectivity.\n",
+ "\n",
+ "### Exercise 5: Hybrid Training\n",
+ "Combine gradient-based training (Tutorial 5) with STDP in recurrent connections. Compare performance with pure gradient descent.\n",
+ "\n",
+ "**Bonus Challenge:** Implement triplet STDP, which considers triplets of spikes for more accurate learning rules."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs_version3/tutorials/advanced/07-large-scale-simulations.ipynb b/docs_version3/tutorials/advanced/07-large-scale-simulations.ipynb
new file mode 100644
index 00000000..b7d0f478
--- /dev/null
+++ b/docs_version3/tutorials/advanced/07-large-scale-simulations.ipynb
@@ -0,0 +1,1003 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Tutorial 7: Large-Scale Simulations\n",
+ "\n",
+ "**Duration:** ~35 minutes | **Prerequisites:** All Basic Tutorials\n",
+ "\n",
+ "## Learning Objectives\n",
+ "\n",
+ "By the end of this tutorial, you will:\n",
+ "\n",
+ "- โ
Optimize memory usage for large networks\n",
+ "- โ
Apply JIT compilation best practices\n",
+ "- โ
Use batching strategies effectively\n",
+ "- โ
Leverage GPU/TPU acceleration\n",
+ "- โ
Profile and optimize performance\n",
+ "- โ
Implement sparse connectivity\n",
+ "\n",
+ "## Overview\n",
+ "\n",
+ "Scaling neural simulations to thousands or millions of neurons requires careful optimization. BrainPy leverages JAX for high-performance computing on CPUs, GPUs, and TPUs.\n",
+ "\n",
+ "**Key concepts:**\n",
+ "- **JIT compilation**: Compile Python code to optimized machine code\n",
+ "- **Memory efficiency**: Minimize state storage and intermediate computations\n",
+ "- **Sparse operations**: Only compute where connections exist\n",
+ "- **Batching**: Process multiple trials simultaneously\n",
+ "- **Device acceleration**: Utilize GPU/TPU parallelism\n",
+ "\n",
+ "Let's learn how to build efficient large-scale simulations!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import brainpy as bp\n",
+ "import brainstate\n",
+ "import brainunit as u\n",
+ "import braintools\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import time\n",
+ "\n",
+ "# Set random seed\n",
+ "brainstate.random.seed(42)\n",
+ "\n",
+ "# Configure environment\n",
+ "brainstate.environ.set(dt=0.1 * u.ms)\n",
+ "\n",
+ "# Check available devices\n",
+ "print(\"๐ฅ๏ธ Available devices:\")\n",
+ "print(f\" {jax.devices()}\")\n",
+ "print(f\" Default backend: {jax.default_backend()}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 1: JIT Compilation Basics\n",
+ "\n",
+ "Just-In-Time (JIT) compilation converts Python code to optimized machine code. This can provide 10-100ร speedups!\n",
+ "\n",
+ "**Benefits of JIT:**\n",
+ "- Eliminates Python interpreter overhead\n",
+ "- Enables compiler optimizations (loop fusion, vectorization)\n",
+ "- Required for GPU/TPU execution\n",
+ "\n",
+ "**Rules for JIT:**\n",
+ "- Functions must be pure (no side effects)\n",
+ "- Array shapes must be static (known at compile time)\n",
+ "- Avoid Python loops over dynamic ranges\n",
+ "\n",
+ "Let's compare JIT vs non-JIT performance!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Simple network without JIT\n",
+ "class SimpleNetwork(brainstate.nn.Module):\n",
+ " def __init__(self, n_neurons=1000):\n",
+ " super().__init__()\n",
+ " self.neurons = bp.LIF(\n",
+ " n_neurons,\n",
+ " V_rest=-65.0*u.mV, V_th=-50.0*u.mV, tau=10.0*u.ms\n",
+ " )\n",
+ " \n",
+ " def update(self, inp):\n",
+ " self.neurons(inp)\n",
+ " return self.neurons.get_spike()\n",
+ "\n",
+ "# Test without JIT\n",
+ "net_no_jit = SimpleNetwork(n_neurons=1000)\n",
+ "brainstate.nn.init_all_states(net_no_jit)\n",
+ "\n",
+ "# Warmup\n",
+ "inp = brainstate.random.rand(1000) * 2.0 * u.nA\n",
+ "_ = net_no_jit(inp)\n",
+ "\n",
+ "# Time execution\n",
+ "n_steps = 1000\n",
+ "start = time.time()\n",
+ "for _ in range(n_steps):\n",
+ " inp = brainstate.random.rand(1000) * 2.0 * u.nA\n",
+ " _ = net_no_jit(inp)\n",
+ "time_no_jit = time.time() - start\n",
+ "\n",
+ "print(f\"โฑ๏ธ Without JIT: {time_no_jit:.3f} seconds for {n_steps} steps\")\n",
+ "print(f\" ({time_no_jit/n_steps*1000:.2f} ms/step)\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Same network WITH JIT\n",
+ "net_jit = SimpleNetwork(n_neurons=1000)\n",
+ "brainstate.nn.init_all_states(net_jit)\n",
+ "\n",
+ "# Apply JIT compilation\n",
+ "@brainstate.compile.jit\n",
+ "def run_step_jit(net, inp):\n",
+ " return net(inp)\n",
+ "\n",
+ "# Warmup (compilation happens here)\n",
+ "inp = brainstate.random.rand(1000) * 2.0 * u.nA\n",
+ "_ = run_step_jit(net_jit, inp)\n",
+ "\n",
+ "# Time execution\n",
+ "start = time.time()\n",
+ "for _ in range(n_steps):\n",
+ " inp = brainstate.random.rand(1000) * 2.0 * u.nA\n",
+ " _ = run_step_jit(net_jit, inp)\n",
+ "time_jit = time.time() - start\n",
+ "\n",
+ "print(f\"โฑ๏ธ With JIT: {time_jit:.3f} seconds for {n_steps} steps\")\n",
+ "print(f\" ({time_jit/n_steps*1000:.2f} ms/step)\")\n",
+ "print(f\"\\n๐ Speedup: {time_no_jit/time_jit:.1f}ร\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 2: Memory Optimization\n",
+ "\n",
+ "Large networks require careful memory management. Key strategies:\n",
+ "\n",
+ "1. **Use appropriate data types**: Float32 instead of Float64\n",
+ "2. **Minimize state storage**: Only keep necessary variables\n",
+ "3. **Avoid unnecessary copies**: Use in-place updates where possible\n",
+ "4. **Clear intermediate results**: Don't accumulate large histories\n",
+ "\n",
+ "Let's compare memory usage for different approaches."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Estimate memory usage\n",
+ "def estimate_memory_mb(n_neurons, n_synapses, dtype_bytes=4):\n",
+ " \"\"\"Estimate memory requirements.\n",
+ " \n",
+ " Args:\n",
+ " n_neurons: Number of neurons\n",
+ " n_synapses: Number of synaptic connections\n",
+ " dtype_bytes: Bytes per element (4 for float32, 8 for float64)\n",
+ " \"\"\"\n",
+ " # Neuron states (V, spike)\n",
+ " neuron_memory = n_neurons * 2 * dtype_bytes\n",
+ " \n",
+ " # Synapse states (g, x for plasticity)\n",
+ " synapse_memory = n_synapses * 2 * dtype_bytes\n",
+ " \n",
+ " # Connection weights\n",
+ " weight_memory = n_synapses * dtype_bytes\n",
+ " \n",
+ " total_bytes = neuron_memory + synapse_memory + weight_memory\n",
+ " total_mb = total_bytes / (1024 * 1024)\n",
+ " \n",
+ " return total_mb\n",
+ "\n",
+ "# Compare different network sizes\n",
+ "sizes = [100, 1000, 10000, 100000, 1000000]\n",
+ "connectivity = 0.1\n",
+ "\n",
+ "print(\"๐ Memory Requirements (Float32):\")\n",
+ "print(\"=\"*60)\n",
+ "print(f\"{'Neurons':<12} {'Synapses':<15} {'Memory (MB)':<15} {'Memory (GB)'}\")\n",
+ "print(\"=\"*60)\n",
+ "\n",
+ "for n in sizes:\n",
+ " n_syn = int(n * n * connectivity)\n",
+ " mem_mb = estimate_memory_mb(n, n_syn, dtype_bytes=4)\n",
+ " mem_gb = mem_mb / 1024\n",
+ " print(f\"{n:<12,} {n_syn:<15,} {mem_mb:<15.2f} {mem_gb:<.3f}\")\n",
+ "\n",
+ "print(\"\\n๐ก Optimization tips:\")\n",
+ "print(\" โข Use sparse connectivity to reduce synapse count\")\n",
+ "print(\" โข Use float32 instead of float64 (2ร memory savings)\")\n",
+ "print(\" โข Don't store full spike history (record only what you need)\")\n",
+ "print(\" โข Process in batches if memory-constrained\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 3: Sparse Connectivity\n",
+ "\n",
+ "Biological networks are sparsely connected (~1-10% connectivity). Using sparse matrices dramatically reduces memory and computation.\n",
+ "\n",
+ "**Dense vs Sparse:**\n",
+ "- Dense: Store all $N \\times N$ connections (even zeros)\n",
+ "- Sparse: Store only non-zero connections\n",
+ "\n",
+ "**Memory savings:**\n",
+ "- 10% connectivity โ 90% memory reduction\n",
+ "- 1% connectivity โ 99% memory reduction\n",
+ "\n",
+ "BrainPy's `EventFixedProb` connection automatically uses sparse representations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Compare dense vs sparse connectivity\n",
+ "n_pre = 1000\n",
+ "n_post = 1000\n",
+ "prob = 0.05 # 5% connectivity\n",
+ "\n",
+ "# Dense connection matrix\n",
+ "dense_matrix = (np.random.rand(n_post, n_pre) < prob).astype(np.float32)\n",
+ "dense_size_mb = dense_matrix.nbytes / (1024 * 1024)\n",
+ "\n",
+ "# Sparse representation (only store indices and values)\n",
+ "indices = np.argwhere(dense_matrix > 0)\n",
+ "values = dense_matrix[dense_matrix > 0]\n",
+ "sparse_size_mb = (indices.nbytes + values.nbytes) / (1024 * 1024)\n",
+ "\n",
+ "print(\"๐ Dense vs Sparse Comparison:\")\n",
+ "print(f\" Network size: {n_pre} โ {n_post} neurons\")\n",
+ "print(f\" Connectivity: {prob*100}%\")\n",
+ "print(f\" Actual connections: {len(values):,}\")\n",
+ "print()\n",
+ "print(f\" Dense storage: {dense_size_mb:.2f} MB\")\n",
+ "print(f\" Sparse storage: {sparse_size_mb:.2f} MB\")\n",
+ "print(f\" Memory savings: {(1 - sparse_size_mb/dense_size_mb)*100:.1f}%\")\n",
+ "print(f\" Space ratio: {dense_size_mb/sparse_size_mb:.1f}ร\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Build large sparse network\n",
+ "class LargeSparseNetwork(brainstate.nn.Module):\n",
+ " \"\"\"Large network with sparse connectivity.\"\"\"\n",
+ " \n",
+ " def __init__(self, n_exc=4000, n_inh=1000, p_conn=0.02):\n",
+ " super().__init__()\n",
+ " \n",
+ " # Neurons\n",
+ " self.E = bp.LIF(n_exc, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=15.*u.ms)\n",
+ " self.I = bp.LIF(n_inh, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)\n",
+ " \n",
+ " # Sparse projections with EventFixedProb\n",
+ " self.E2E = bp.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_exc, n_exc, prob=p_conn, weight=0.6*u.mS),\n",
+ " syn=bp.Expon.desc(n_exc, tau=5.*u.ms),\n",
+ " out=bp.COBA.desc(E=0.*u.mV),\n",
+ " post=self.E\n",
+ " )\n",
+ " \n",
+ " self.E2I = bp.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_exc, n_inh, prob=p_conn, weight=0.6*u.mS),\n",
+ " syn=bp.Expon.desc(n_inh, tau=5.*u.ms),\n",
+ " out=bp.COBA.desc(E=0.*u.mV),\n",
+ " post=self.I\n",
+ " )\n",
+ " \n",
+ " self.I2E = bp.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_inh, n_exc, prob=p_conn, weight=6.7*u.mS),\n",
+ " syn=bp.Expon.desc(n_exc, tau=10.*u.ms),\n",
+ " out=bp.COBA.desc(E=-80.*u.mV),\n",
+ " post=self.E\n",
+ " )\n",
+ " \n",
+ " self.I2I = bp.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_inh, n_inh, prob=p_conn, weight=6.7*u.mS),\n",
+ " syn=bp.Expon.desc(n_inh, tau=10.*u.ms),\n",
+ " out=bp.COBA.desc(E=-80.*u.mV),\n",
+ " post=self.I\n",
+ " )\n",
+ " \n",
+ " def update(self, inp_e, inp_i):\n",
+ " spk_e = self.E.get_spike()\n",
+ " spk_i = self.I.get_spike()\n",
+ " \n",
+ " self.E2E(spk_e)\n",
+ " self.E2I(spk_e)\n",
+ " self.I2E(spk_i)\n",
+ " self.I2I(spk_i)\n",
+ " \n",
+ " self.E(inp_e)\n",
+ " self.I(inp_i)\n",
+ " \n",
+ " return spk_e, spk_i\n",
+ "\n",
+ "# Create large network\n",
+ "large_net = LargeSparseNetwork(n_exc=4000, n_inh=1000, p_conn=0.02)\n",
+ "brainstate.nn.init_all_states(large_net)\n",
+ "\n",
+ "print(\"โ
Created large sparse network:\")\n",
+ "print(f\" Excitatory neurons: 4,000\")\n",
+ "print(f\" Inhibitory neurons: 1,000\")\n",
+ "print(f\" Total neurons: 5,000\")\n",
+ "print(f\" Connectivity: 2%\")\n",
+ "print(f\" Approximate connections: {5000*5000*0.02:,.0f}\")\n",
+ "print(f\" Estimated memory: ~20 MB (sparse) vs ~400 MB (dense)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 4: Batching for Parallelism\n",
+ "\n",
+ "Running multiple independent simulations (trials) can be done in parallel using batching. This is especially efficient on GPUs.\n",
+ "\n",
+ "**Batching benefits:**\n",
+ "- Run multiple trials simultaneously\n",
+ "- Amortize compilation cost\n",
+ "- Better GPU utilization\n",
+ "- Faster parameter sweeps\n",
+ "\n",
+ "**How it works:**\n",
+ "- Add batch dimension: `(batch_size, n_neurons)`\n",
+ "- Operations automatically vectorized\n",
+ "- Each trial independent"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Single trial simulation\n",
+ "def simulate_single_trial(n_steps=1000):\n",
+ " net = SimpleNetwork(n_neurons=1000)\n",
+ " brainstate.nn.init_all_states(net)\n",
+ " \n",
+ " @brainstate.compile.jit\n",
+ " def step(net, inp):\n",
+ " return net(inp)\n",
+ " \n",
+ " for _ in range(n_steps):\n",
+ " inp = brainstate.random.rand(1000) * 2.0 * u.nA\n",
+ " _ = step(net, inp)\n",
+ "\n",
+ "# Batched simulation\n",
+ "def simulate_batched_trials(n_trials=10, n_steps=1000):\n",
+ " net = SimpleNetwork(n_neurons=1000)\n",
+ " brainstate.nn.init_all_states(net, batch_size=n_trials)\n",
+ " \n",
+ " @brainstate.compile.jit\n",
+ " def step(net, inp):\n",
+ " return net(inp)\n",
+ " \n",
+ " for _ in range(n_steps):\n",
+ " inp = brainstate.random.rand(n_trials, 1000) * 2.0 * u.nA\n",
+ " _ = step(net, inp)\n",
+ "\n",
+ "# Compare timing\n",
+ "n_trials = 10\n",
+ "\n",
+ "# Sequential trials\n",
+ "start = time.time()\n",
+ "for _ in range(n_trials):\n",
+ " simulate_single_trial(n_steps=100)\n",
+ "time_sequential = time.time() - start\n",
+ "\n",
+ "# Batched trials\n",
+ "start = time.time()\n",
+ "simulate_batched_trials(n_trials=n_trials, n_steps=100)\n",
+ "time_batched = time.time() - start\n",
+ "\n",
+ "print(f\"โฑ๏ธ Sequential (10 trials): {time_sequential:.3f} seconds\")\n",
+ "print(f\"โฑ๏ธ Batched (10 trials): {time_batched:.3f} seconds\")\n",
+ "print(f\"\\n๐ Batching speedup: {time_sequential/time_batched:.1f}ร\")\n",
+ "print(f\"\\n๐ก Batching is especially effective on GPUs!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 5: GPU Acceleration\n",
+ "\n",
+ "GPUs excel at parallel operations on large arrays. BrainPy automatically uses GPUs when available via JAX.\n",
+ "\n",
+ "**GPU benefits:**\n",
+ "- Massive parallelism (1000s of cores)\n",
+ "- High memory bandwidth\n",
+ "- Fast matrix operations\n",
+ "- 10-100ร speedup for large networks\n",
+ "\n",
+ "**Best practices:**\n",
+ "- Use large batch sizes\n",
+ "- Minimize CPU-GPU data transfer\n",
+ "- Keep data on GPU between operations\n",
+ "- Use JIT compilation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Check if GPU is available\n",
+ "try:\n",
+ " gpu_device = jax.devices('gpu')[0]\n",
+ " has_gpu = True\n",
+ " print(\"โ
GPU detected:\", gpu_device)\n",
+ "except:\n",
+ " has_gpu = False\n",
+ " print(\"โน๏ธ No GPU detected, using CPU\")\n",
+ "\n",
+ "if has_gpu:\n",
+ " # Compare CPU vs GPU for large operation\n",
+ " n = 10000\n",
+ " \n",
+ " # CPU\n",
+ " with jax.default_device(jax.devices('cpu')[0]):\n",
+ " x = jax.random.normal(jax.random.PRNGKey(0), (n, n))\n",
+ " \n",
+ " start = time.time()\n",
+ " y = jnp.dot(x, x)\n",
+ " y.block_until_ready() # Wait for computation\n",
+ " time_cpu = time.time() - start\n",
+ " \n",
+ " # GPU\n",
+ " with jax.default_device(gpu_device):\n",
+ " x = jax.random.normal(jax.random.PRNGKey(0), (n, n))\n",
+ " \n",
+ " start = time.time()\n",
+ " y = jnp.dot(x, x)\n",
+ " y.block_until_ready()\n",
+ " time_gpu = time.time() - start\n",
+ " \n",
+ " print(f\"\\n๐ฅ๏ธ CPU time: {time_cpu:.4f} seconds\")\n",
+ " print(f\"๐ฎ GPU time: {time_gpu:.4f} seconds\")\n",
+ " print(f\"๐ GPU speedup: {time_cpu/time_gpu:.1f}ร\")\n",
+ "else:\n",
+ " print(\"\\n๐ก To use GPU:\")\n",
+ " print(\" 1. Install JAX with GPU support\")\n",
+ " print(\" 2. Install CUDA drivers\")\n",
+ " print(\" 3. BrainPy will automatically use GPU\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 6: Performance Profiling\n",
+ "\n",
+ "To optimize performance, you need to identify bottlenecks. Use profiling to find where time is spent.\n",
+ "\n",
+ "**Profiling strategies:**\n",
+ "1. **Time individual operations**: Find slow components\n",
+ "2. **Use JAX profiler**: Detailed GPU/TPU profiling\n",
+ "3. **Monitor memory**: Detect memory leaks\n",
+ "4. **Check compilation**: Ensure JIT is working"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Simple profiling example\n",
+ "class ProfilingNetwork(brainstate.nn.Module):\n",
+ " def __init__(self, n_neurons=5000):\n",
+ " super().__init__()\n",
+ " self.lif = bp.LIF(n_neurons, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)\n",
+ " self.proj = bp.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_neurons, n_neurons, prob=0.01, weight=0.5*u.mS),\n",
+ " syn=bp.Expon.desc(n_neurons, tau=5.*u.ms),\n",
+ " out=bp.CUBA.desc(),\n",
+ " post=self.lif\n",
+ " )\n",
+ " \n",
+ " def update(self, inp):\n",
+ " spk = self.lif.get_spike()\n",
+ " self.proj(spk)\n",
+ " self.lif(inp)\n",
+ " return spk\n",
+ "\n",
+ "# Profile simulation\n",
+ "net = ProfilingNetwork(n_neurons=5000)\n",
+ "brainstate.nn.init_all_states(net)\n",
+ "\n",
+ "@brainstate.compile.jit\n",
+ "def run_step(net, inp):\n",
+ " return net(inp)\n",
+ "\n",
+ "# Warmup\n",
+ "inp = brainstate.random.rand(5000) * 2.0 * u.nA\n",
+ "_ = run_step(net, inp)\n",
+ "\n",
+ "# Profile multiple steps\n",
+ "n_steps = 100\n",
+ "step_times = []\n",
+ "\n",
+ "for _ in range(n_steps):\n",
+ " inp = brainstate.random.rand(5000) * 2.0 * u.nA\n",
+ " \n",
+ " start = time.time()\n",
+ " _ = run_step(net, inp)\n",
+ " step_times.append(time.time() - start)\n",
+ "\n",
+ "step_times = np.array(step_times) * 1000 # Convert to ms\n",
+ "\n",
+ "# Plot timing distribution\n",
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
+ "\n",
+ "# Time series\n",
+ "axes[0].plot(step_times, 'b-', linewidth=1, alpha=0.7)\n",
+ "axes[0].axhline(np.mean(step_times), color='r', linestyle='--', \n",
+ " label=f'Mean: {np.mean(step_times):.2f} ms')\n",
+ "axes[0].set_xlabel('Step', fontsize=12)\n",
+ "axes[0].set_ylabel('Time (ms)', fontsize=12)\n",
+ "axes[0].set_title('Step-by-Step Timing', fontsize=14, fontweight='bold')\n",
+ "axes[0].legend()\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Histogram\n",
+ "axes[1].hist(step_times, bins=30, color='blue', alpha=0.7, edgecolor='black')\n",
+ "axes[1].axvline(np.mean(step_times), color='r', linestyle='--', linewidth=2,\n",
+ " label=f'Mean: {np.mean(step_times):.2f} ms')\n",
+ "axes[1].set_xlabel('Time (ms)', fontsize=12)\n",
+ "axes[1].set_ylabel('Frequency', fontsize=12)\n",
+ "axes[1].set_title('Timing Distribution', fontsize=14, fontweight='bold')\n",
+ "axes[1].legend()\n",
+ "axes[1].grid(True, alpha=0.3, axis='y')\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(f\"๐ Performance Statistics:\")\n",
+ "print(f\" Mean time/step: {np.mean(step_times):.2f} ms\")\n",
+ "print(f\" Std deviation: {np.std(step_times):.2f} ms\")\n",
+ "print(f\" Min time: {np.min(step_times):.2f} ms\")\n",
+ "print(f\" Max time: {np.max(step_times):.2f} ms\")\n",
+ "print(f\" Throughput: {1000/np.mean(step_times):.1f} steps/second\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 7: Optimization Checklist\n",
+ "\n",
+ "Here's a comprehensive checklist for optimizing large-scale simulations.\n",
+ "\n",
+ "### Before Optimization\n",
+ "1. **Profile first**: Identify actual bottlenecks\n",
+ "2. **Set target**: Define performance goals\n",
+ "3. **Baseline**: Measure current performance\n",
+ "\n",
+ "### Code Optimizations\n",
+ "- โ
Use JIT compilation (`@brainstate.compile.jit`)\n",
+ "- โ
Use sparse connectivity (`EventFixedProb`)\n",
+ "- โ
Use float32 instead of float64\n",
+ "- โ
Batch multiple trials together\n",
+ "- โ
Avoid Python loops (use `for_loop` or `scan`)\n",
+ "- โ
Minimize state storage\n",
+ "- โ
Use appropriate time steps (larger = faster)\n",
+ "\n",
+ "### Hardware Optimizations\n",
+ "- โ
Use GPU/TPU when available\n",
+ "- โ
Increase batch size for better GPU utilization\n",
+ "- โ
Monitor GPU memory usage\n",
+ "- โ
Keep data on accelerator (avoid CPU-GPU transfers)\n",
+ "\n",
+ "### Algorithm Optimizations\n",
+ "- โ
Simplify neuron models if possible\n",
+ "- โ
Use event-driven dynamics where appropriate\n",
+ "- โ
Reduce synaptic computations (sparse updates)\n",
+ "- โ
Cache frequently computed values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Demonstrate optimization impact\n",
+ "def benchmark_configurations():\n",
+ " \"\"\"Benchmark different optimization strategies.\"\"\"\n",
+ " \n",
+ " n_neurons = 2000\n",
+ " n_steps = 100\n",
+ " results = {}\n",
+ " \n",
+ " # 1. Baseline (no optimizations)\n",
+ " print(\"Testing: Baseline (no optimizations)...\")\n",
+ " net1 = SimpleNetwork(n_neurons)\n",
+ " brainstate.nn.init_all_states(net1)\n",
+ " \n",
+ " start = time.time()\n",
+ " for _ in range(n_steps):\n",
+ " inp = brainstate.random.rand(n_neurons) * 2.0 * u.nA\n",
+ " _ = net1(inp)\n",
+ " results['Baseline'] = time.time() - start\n",
+ " \n",
+ " # 2. With JIT\n",
+ " print(\"Testing: With JIT...\")\n",
+ " net2 = SimpleNetwork(n_neurons)\n",
+ " brainstate.nn.init_all_states(net2)\n",
+ " \n",
+ " @brainstate.compile.jit\n",
+ " def step_jit(net, inp):\n",
+ " return net(inp)\n",
+ " \n",
+ " # Warmup\n",
+ " inp = brainstate.random.rand(n_neurons) * 2.0 * u.nA\n",
+ " _ = step_jit(net2, inp)\n",
+ " \n",
+ " start = time.time()\n",
+ " for _ in range(n_steps):\n",
+ " inp = brainstate.random.rand(n_neurons) * 2.0 * u.nA\n",
+ " _ = step_jit(net2, inp)\n",
+ " results['JIT'] = time.time() - start\n",
+ " \n",
+ " # 3. With JIT + Batching\n",
+ " print(\"Testing: JIT + Batching...\")\n",
+ " batch_size = 10\n",
+ " net3 = SimpleNetwork(n_neurons)\n",
+ " brainstate.nn.init_all_states(net3, batch_size=batch_size)\n",
+ " \n",
+ " # Warmup\n",
+ " inp = brainstate.random.rand(batch_size, n_neurons) * 2.0 * u.nA\n",
+ " _ = step_jit(net3, inp)\n",
+ " \n",
+ " start = time.time()\n",
+ " for _ in range(n_steps):\n",
+ " inp = brainstate.random.rand(batch_size, n_neurons) * 2.0 * u.nA\n",
+ " _ = step_jit(net3, inp)\n",
+ " results['JIT+Batch'] = time.time() - start\n",
+ " \n",
+ " return results\n",
+ "\n",
+ "# Run benchmark\n",
+ "results = benchmark_configurations()\n",
+ "\n",
+ "# Visualize results\n",
+ "fig, ax = plt.subplots(figsize=(10, 6))\n",
+ "\n",
+ "configs = list(results.keys())\n",
+ "times = list(results.values())\n",
+ "speedups = [times[0] / t for t in times]\n",
+ "\n",
+ "bars = ax.bar(configs, times, color=['red', 'orange', 'green'], alpha=0.7)\n",
+ "\n",
+ "# Add speedup labels\n",
+ "for i, (bar, speedup) in enumerate(zip(bars, speedups)):\n",
+ " height = bar.get_height()\n",
+ " ax.text(bar.get_x() + bar.get_width()/2., height,\n",
+ " f'{speedup:.1f}ร faster\\n{times[i]:.2f}s',\n",
+ " ha='center', va='bottom', fontsize=11, fontweight='bold')\n",
+ "\n",
+ "ax.set_ylabel('Time (seconds)', fontsize=12)\n",
+ "ax.set_title('Optimization Impact (2000 neurons, 100 steps)', fontsize=14, fontweight='bold')\n",
+ "ax.grid(True, alpha=0.3, axis='y')\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"\\n๐ Optimization Results:\")\n",
+ "for config, t in results.items():\n",
+ " speedup = times[0] / t\n",
+ " print(f\" {config:15s}: {t:.3f}s ({speedup:.1f}ร faster)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 8: Complete Large-Scale Example\n",
+ "\n",
+ "Let's put it all together with a fully optimized large-scale simulation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Optimized large-scale network\n",
+ "class OptimizedLargeNetwork(brainstate.nn.Module):\n",
+ " \"\"\"Fully optimized large-scale E-I network.\"\"\"\n",
+ " \n",
+ " def __init__(self, n_exc=8000, n_inh=2000, p_conn=0.02):\n",
+ " super().__init__()\n",
+ " \n",
+ " self.n_exc = n_exc\n",
+ " self.n_inh = n_inh\n",
+ " \n",
+ " # LIF neurons (using default float32)\n",
+ " self.E = bp.LIF(n_exc, V_rest=-65.*u.mV, V_th=-50.*u.mV, \n",
+ " V_reset=-65.*u.mV, tau=15.*u.ms)\n",
+ " self.I = bp.LIF(n_inh, V_rest=-65.*u.mV, V_th=-50.*u.mV,\n",
+ " V_reset=-65.*u.mV, tau=10.*u.ms)\n",
+ " \n",
+ " # Sparse connectivity\n",
+ " self.E2E = bp.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_exc, n_exc, prob=p_conn, weight=0.05*u.mS),\n",
+ " syn=bp.Expon.desc(n_exc, tau=5.*u.ms),\n",
+ " out=bp.CUBA.desc(),\n",
+ " post=self.E\n",
+ " )\n",
+ " \n",
+ " self.E2I = bp.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_exc, n_inh, prob=p_conn, weight=0.05*u.mS),\n",
+ " syn=bp.Expon.desc(n_inh, tau=5.*u.ms),\n",
+ " out=bp.CUBA.desc(),\n",
+ " post=self.I\n",
+ " )\n",
+ " \n",
+ " self.I2E = bp.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_inh, n_exc, prob=p_conn, weight=0.4*u.mS),\n",
+ " syn=bp.Expon.desc(n_exc, tau=10.*u.ms),\n",
+ " out=bp.CUBA.desc(),\n",
+ " post=self.E\n",
+ " )\n",
+ " \n",
+ " self.I2I = bp.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_inh, n_inh, prob=p_conn, weight=0.4*u.mS),\n",
+ " syn=bp.Expon.desc(n_inh, tau=10.*u.ms),\n",
+ " out=bp.CUBA.desc(),\n",
+ " post=self.I\n",
+ " )\n",
+ " \n",
+ " def update(self, inp_e, inp_i):\n",
+ " # Get spikes\n",
+ " spk_e = self.E.get_spike()\n",
+ " spk_i = self.I.get_spike()\n",
+ " \n",
+ " # Update projections\n",
+ " self.E2E(spk_e)\n",
+ " self.E2I(spk_e)\n",
+ " self.I2E(spk_i)\n",
+ " self.I2I(spk_i)\n",
+ " \n",
+ " # Update neurons\n",
+ " self.E(inp_e)\n",
+ " self.I(inp_i)\n",
+ " \n",
+ " return spk_e, spk_i\n",
+ "\n",
+ "# Create and simulate\n",
+ "print(\"Creating large-scale network...\")\n",
+ "large_net = OptimizedLargeNetwork(n_exc=8000, n_inh=2000, p_conn=0.02)\n",
+ "brainstate.nn.init_all_states(large_net)\n",
+ "\n",
+ "print(\"\\n๐ Network Statistics:\")\n",
+ "print(f\" Total neurons: {large_net.n_exc + large_net.n_inh:,}\")\n",
+ "print(f\" Excitatory: {large_net.n_exc:,} (80%)\")\n",
+ "print(f\" Inhibitory: {large_net.n_inh:,} (20%)\")\n",
+ "print(f\" Connectivity: 2%\")\n",
+ "print(f\" Estimated connections: {10000*10000*0.02:,.0f}\")\n",
+ "print(f\" Estimated memory: ~50 MB\")\n",
+ "\n",
+ "# JIT-compiled simulation\n",
+ "@brainstate.compile.jit\n",
+ "def simulate_step(net, inp_e, inp_i):\n",
+ " return net(inp_e, inp_i)\n",
+ "\n",
+ "# Warmup\n",
+ "print(\"\\nCompiling (this takes a moment)...\")\n",
+ "inp_e = brainstate.random.rand(large_net.n_exc) * 1.0 * u.nA\n",
+ "inp_i = brainstate.random.rand(large_net.n_inh) * 1.0 * u.nA\n",
+ "_ = simulate_step(large_net, inp_e, inp_i)\n",
+ "print(\"โ
Compilation complete!\")\n",
+ "\n",
+ "# Run simulation\n",
+ "print(\"\\nRunning simulation...\")\n",
+ "n_steps = 500\n",
+ "spike_history_e = []\n",
+ "spike_history_i = []\n",
+ "\n",
+ "start = time.time()\n",
+ "for i in range(n_steps):\n",
+ " inp_e = brainstate.random.rand(large_net.n_exc) * 1.0 * u.nA\n",
+ " inp_i = brainstate.random.rand(large_net.n_inh) * 1.0 * u.nA\n",
+ " spk_e, spk_i = simulate_step(large_net, inp_e, inp_i)\n",
+ " \n",
+ " # Downsample recording (save memory)\n",
+ " if i % 5 == 0:\n",
+ " spike_history_e.append(spk_e)\n",
+ " spike_history_i.append(spk_i)\n",
+ "\n",
+ "sim_time = time.time() - start\n",
+ "\n",
+ "print(f\"\\nโฑ๏ธ Simulation complete:\")\n",
+ "print(f\" Real time: {sim_time:.2f} seconds\")\n",
+ "print(f\" Simulated time: {n_steps * 0.1} ms\")\n",
+ "print(f\" Speedup: {(n_steps * 0.1 / 1000) / sim_time:.1f}ร real-time\")\n",
+ "print(f\" Throughput: {n_steps / sim_time:.1f} steps/second\")\n",
+ "\n",
+ "# Visualize downsampled activity\n",
+ "spike_history_e = jnp.array(spike_history_e)\n",
+ "spike_history_i = jnp.array(spike_history_i)\n",
+ "\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(14, 8), sharex=True)\n",
+ "\n",
+ "# Excitatory raster (subsample neurons for visibility)\n",
+ "n_show = 500\n",
+ "times_ms = np.arange(len(spike_history_e)) * 5 * 0.1 # Downsampled times\n",
+ "\n",
+ "for neuron_idx in range(min(n_show, large_net.n_exc)):\n",
+ " spike_times = times_ms[spike_history_e[:, neuron_idx] > 0]\n",
+ " axes[0].scatter(spike_times, [neuron_idx] * len(spike_times),\n",
+ " s=0.5, c='blue', alpha=0.5)\n",
+ "\n",
+ "axes[0].set_ylabel('Excitatory Neuron', fontsize=12)\n",
+ "axes[0].set_title(f'Large-Scale Network Activity ({large_net.n_exc + large_net.n_inh:,} neurons)', \n",
+ " fontsize=14, fontweight='bold')\n",
+ "axes[0].set_ylim(0, n_show)\n",
+ "\n",
+ "# Inhibitory raster\n",
+ "for neuron_idx in range(large_net.n_inh):\n",
+ " spike_times = times_ms[spike_history_i[:, neuron_idx] > 0]\n",
+ " axes[1].scatter(spike_times, [neuron_idx] * len(spike_times),\n",
+ " s=0.5, c='red', alpha=0.5)\n",
+ "\n",
+ "axes[1].set_xlabel('Time (ms)', fontsize=12)\n",
+ "axes[1].set_ylabel('Inhibitory Neuron', fontsize=12)\n",
+ "axes[1].set_title('Inhibitory Population', fontsize=14, fontweight='bold')\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"\\nโ
Successfully simulated 10,000 neuron network!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "In this tutorial, you learned:\n",
+ "\n",
+ "โ
**JIT compilation**\n",
+ " - Use `@brainstate.compile.jit` for 10-100ร speedup\n",
+ " - Functions must be pure and have static shapes\n",
+ " - Essential for large-scale simulations\n",
+ "\n",
+ "โ
**Memory optimization**\n",
+ " - Use float32 instead of float64 (2ร savings)\n",
+ " - Minimize state storage\n",
+ " - Don't accumulate full histories\n",
+ "\n",
+ "โ
**Sparse connectivity**\n",
+ " - Use `EventFixedProb` for automatic sparse operations\n",
+ " - 90-99% memory reduction for biological connectivity\n",
+ " - Faster computation (skip zero connections)\n",
+ "\n",
+ "โ
**Batching**\n",
+ " - Run multiple trials simultaneously\n",
+ " - Better hardware utilization\n",
+ " - Faster parameter sweeps\n",
+ "\n",
+ "โ
**GPU/TPU acceleration**\n",
+ " - Automatic via JAX when available\n",
+ " - 10-100ร speedup for large networks\n",
+ " - Keep data on device\n",
+ "\n",
+ "โ
**Performance profiling**\n",
+ " - Identify bottlenecks before optimizing\n",
+ " - Monitor memory usage\n",
+ " - Track throughput metrics\n",
+ "\n",
+ "**Optimization workflow:**\n",
+ "\n",
+ "```python\n",
+ "# 1. Create network with sparse connectivity\n",
+ "net = OptimizedNetwork(\n",
+ " n_neurons=10000,\n",
+ " connectivity=0.02 # Sparse!\n",
+ ")\n",
+ "\n",
+ "# 2. Initialize with batching\n",
+ "brainstate.nn.init_all_states(net, batch_size=10)\n",
+ "\n",
+ "# 3. JIT compile simulation loop\n",
+ "@brainstate.compile.jit\n",
+ "def simulate_step(net, inp):\n",
+ " return net(inp)\n",
+ "\n",
+ "# 4. Run on GPU (automatic if available)\n",
+ "for i in range(n_steps):\n",
+ " inp = get_input()\n",
+ " output = simulate_step(net, inp)\n",
+ "```\n",
+ "\n",
+ "**Scale achieved:**\n",
+ "- โ
10,000 neurons: Easy on CPU\n",
+ "- โ
100,000 neurons: Needs GPU\n",
+ "- โ
1,000,000+ neurons: Multi-GPU or TPU\n",
+ "\n",
+ "**Next steps:**\n",
+ "- Try your own large-scale models\n",
+ "- Experiment with different connectivity patterns\n",
+ "- Profile and optimize your specific use case\n",
+ "- Use specialized tutorials for specific applications\n",
+ "- Explore multi-GPU scaling (advanced)\n",
+ "\n",
+ "**References:**\n",
+ "- JAX documentation: https://jax.readthedocs.io/\n",
+ "- BrainPy optimization guide: https://brainpy.readthedocs.io/\n",
+ "- Neuromorphic computing benchmarks\n",
+ "- Large-scale brain simulation papers (Spaun, Blue Brain Project)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Exercises\n",
+ "\n",
+ "Test your understanding:\n",
+ "\n",
+ "### Exercise 1: JIT Compilation\n",
+ "Take a non-JIT network and apply JIT compilation. Measure the speedup. What happens if you violate JIT rules (e.g., use Python loops)?\n",
+ "\n",
+ "### Exercise 2: Memory Analysis\n",
+ "Estimate memory requirements for a 100,000 neuron network with 1% connectivity. Will it fit in 16GB RAM?\n",
+ "\n",
+ "### Exercise 3: Sparse vs Dense\n",
+ "Implement the same network with dense and sparse connectivity. Compare memory usage and runtime.\n",
+ "\n",
+ "### Exercise 4: Batching Strategy\n",
+ "Run 100 independent trials. Compare: (a) sequential, (b) batched 10ร10, (c) batched 100ร1. Which is fastest?\n",
+ "\n",
+ "### Exercise 5: Profiling\n",
+ "Profile a large network and identify the slowest operation. Optimize it and measure improvement.\n",
+ "\n",
+ "**Bonus Challenge:** Scale up to the largest network your hardware can handle. How many neurons can you simulate in real-time?"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs_version3/tutorials/basic/01-lif-neuron.ipynb b/docs_version3/tutorials/basic/01-lif-neuron.ipynb
new file mode 100644
index 00000000..ec1edf2c
--- /dev/null
+++ b/docs_version3/tutorials/basic/01-lif-neuron.ipynb
@@ -0,0 +1,545 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Tutorial 1: LIF Neuron Basics\n",
+ "\n",
+ "In this tutorial, you'll learn how to:\n",
+ "\n",
+ "- Create and configure LIF neurons\n",
+ "- Simulate neuron dynamics\n",
+ "- Analyze neuron behavior\n",
+ "- Understand different reset modes\n",
+ "- Work with physical units"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import brainpy\n",
+ "import brainstate\n",
+ "import brainunit as u\n",
+ "import braintools\n",
+ "import matplotlib.pyplot as plt\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "print(f\"BrainPy version: {brainpy.__version__}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 1: Understanding the LIF Model\n",
+ "\n",
+ "The Leaky Integrate-and-Fire (LIF) neuron is described by:\n",
+ "\n",
+ "$$\\tau \\frac{dV}{dt} = -(V - V_{rest}) + R \\cdot I(t)$$\n",
+ "\n",
+ "Where:\n",
+ "- $V$ is the membrane potential\n",
+ "- $\\tau$ is the membrane time constant\n",
+ "- $V_{rest}$ is the resting potential\n",
+ "- $R$ is the input resistance\n",
+ "- $I(t)$ is the input current\n",
+ "\n",
+ "When $V \\geq V_{th}$ (threshold), the neuron spikes and resets."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 2: Creating Your First LIF Neuron"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set simulation time step\n",
+ "brainstate.environ.set(dt=0.1 * u.ms)\n",
+ "\n",
+ "# Create a single LIF neuron\n",
+ "neuron = brainpy.LIF(\n",
+ " size=1,\n",
+ " V_rest=-65. * u.mV, # Resting potential\n",
+ " V_th=-50. * u.mV, # Spike threshold\n",
+ " V_reset=-65. * u.mV, # Reset potential\n",
+ " tau=10. * u.ms, # Membrane time constant\n",
+ " R=1. * u.ohm, # Input resistance\n",
+ " spk_reset='hard' # Reset mode\n",
+ ")\n",
+ "\n",
+ "# Initialize neuron state\n",
+ "brainstate.nn.init_all_states(neuron)\n",
+ "\n",
+ "print(\"Neuron created successfully!\")\n",
+ "print(f\"Initial membrane potential: {neuron.V.value}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 3: Response to Constant Input\n",
+ "\n",
+ "Let's see how the neuron responds to a constant input current."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Reset neuron\n",
+ "brainstate.nn.init_all_states(neuron)\n",
+ "\n",
+ "# Simulation parameters\n",
+ "duration = 200. * u.ms\n",
+ "dt = brainstate.environ.get_dt()\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "\n",
+ "# Constant input current\n",
+ "I_input = 2.0 * u.nA\n",
+ "\n",
+ "# Run simulation\n",
+ "voltages = []\n",
+ "spikes = []\n",
+ "\n",
+ "for t in times:\n",
+ " neuron(I_input)\n",
+ " voltages.append(neuron.V.value)\n",
+ " spikes.append(neuron.get_spike()[0]) # Single neuron\n",
+ "\n",
+ "voltages = u.math.asarray(voltages)\n",
+ "spikes = u.math.asarray(spikes)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot results\n",
+ "times_plot = times.to_decimal(u.ms)\n",
+ "voltages_plot = voltages.to_decimal(u.mV)\n",
+ "\n",
+ "plt.figure(figsize=(12, 5))\n",
+ "\n",
+ "# Membrane potential\n",
+ "plt.subplot(2, 1, 1)\n",
+ "plt.plot(times_plot, voltages_plot, linewidth=2)\n",
+ "plt.axhline(y=-50, color='r', linestyle='--', alpha=0.7, label='Threshold')\n",
+ "plt.axhline(y=-65, color='g', linestyle='--', alpha=0.7, label='Rest/Reset')\n",
+ "plt.ylabel('Voltage (mV)')\n",
+ "plt.title(f'LIF Neuron Response to Constant Input (I = {I_input})')\n",
+ "plt.legend()\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "\n",
+ "# Spike raster\n",
+ "plt.subplot(2, 1, 2)\n",
+ "spike_times = times_plot[spikes > 0]\n",
+ "plt.scatter(spike_times, [0]*len(spike_times), marker='|', s=1000, c='black')\n",
+ "plt.ylabel('Spikes')\n",
+ "plt.xlabel('Time (ms)')\n",
+ "plt.ylim([-0.5, 0.5])\n",
+ "plt.yticks([])\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "# Statistics\n",
+ "n_spikes = int(u.math.sum(spikes > 0))\n",
+ "firing_rate = n_spikes / (duration.to_decimal(u.second))\n",
+ "print(f\"Number of spikes: {n_spikes}\")\n",
+ "print(f\"Firing rate: {firing_rate:.2f} Hz\")\n",
+ "\n",
+ "if n_spikes > 1:\n",
+ " isis = jnp.diff(times_plot[spikes > 0]) # Inter-spike intervals\n",
+ " print(f\"Mean ISI: {jnp.mean(isis):.2f} ms\")\n",
+ " print(f\"ISI CV: {jnp.std(isis)/jnp.mean(isis):.3f}\") # Coefficient of variation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 4: F-I Curve (Frequency-Current Relationship)\n",
+ "\n",
+ "Let's explore how firing rate changes with input current."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Range of input currents\n",
+ "currents = u.math.linspace(0 * u.nA, 5 * u.nA, 20)\n",
+ "firing_rates = []\n",
+ "\n",
+ "duration = 500. * u.ms\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "\n",
+ "for I in currents:\n",
+ " # Reset neuron\n",
+ " brainstate.nn.init_all_states(neuron)\n",
+ " \n",
+ " # Simulate\n",
+ " spike_count = 0\n",
+ " for t in times:\n",
+ " neuron(I)\n",
+ " if neuron.get_spike()[0] > 0:\n",
+ " spike_count += 1\n",
+ " \n",
+ " # Calculate firing rate\n",
+ " rate = spike_count / (duration.to_decimal(u.second))\n",
+ " firing_rates.append(rate)\n",
+ "\n",
+ "firing_rates = jnp.array(firing_rates)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot F-I curve\n",
+ "plt.figure(figsize=(8, 5))\n",
+ "plt.plot(currents.to_decimal(u.nA), firing_rates, 'o-', linewidth=2, markersize=6)\n",
+ "plt.xlabel('Input Current (nA)')\n",
+ "plt.ylabel('Firing Rate (Hz)')\n",
+ "plt.title('F-I Curve: Firing Rate vs Input Current')\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "# Find rheobase (minimum current for spiking)\n",
+ "spiking_currents = currents[firing_rates > 0]\n",
+ "if len(spiking_currents) > 0:\n",
+ " rheobase = spiking_currents[0]\n",
+ " print(f\"Rheobase (minimum spiking current): {rheobase}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 5: Soft vs Hard Reset\n",
+ "\n",
+ "LIF neurons can use different reset mechanisms:\n",
+ "\n",
+ "- **Hard reset**: $V \\leftarrow V_{reset}$ (discards extra charge)\n",
+ "- **Soft reset**: $V \\leftarrow V - V_{th}$ (preserves extra charge)\n",
+ "\n",
+ "Let's compare their behavior."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create two neurons with different reset modes\n",
+ "neuron_hard = brainpy.LIF(\n",
+ " size=1,\n",
+ " V_rest=-65. * u.mV,\n",
+ " V_th=-50. * u.mV,\n",
+ " V_reset=-65. * u.mV,\n",
+ " tau=10. * u.ms,\n",
+ " spk_reset='hard'\n",
+ ")\n",
+ "\n",
+ "neuron_soft = brainpy.LIF(\n",
+ " size=1,\n",
+ " V_rest=-65. * u.mV,\n",
+ " V_th=-50. * u.mV,\n",
+ " V_reset=-65. * u.mV, # Not used in soft reset\n",
+ " tau=10. * u.ms,\n",
+ " spk_reset='soft'\n",
+ ")\n",
+ "\n",
+ "# Initialize\n",
+ "brainstate.nn.init_all_states(neuron_hard)\n",
+ "brainstate.nn.init_all_states(neuron_soft)\n",
+ "\n",
+ "# Simulate both with same input\n",
+ "duration = 200. * u.ms\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "I_input = 3.0 * u.nA # Strong input\n",
+ "\n",
+ "voltages_hard = []\n",
+ "voltages_soft = []\n",
+ "spikes_hard = []\n",
+ "spikes_soft = []\n",
+ "\n",
+ "for t in times:\n",
+ " neuron_hard(I_input)\n",
+ " neuron_soft(I_input)\n",
+ " \n",
+ " voltages_hard.append(neuron_hard.V.value)\n",
+ " voltages_soft.append(neuron_soft.V.value)\n",
+ " spikes_hard.append(neuron_hard.get_spike()[0])\n",
+ " spikes_soft.append(neuron_soft.get_spike()[0])\n",
+ "\n",
+ "voltages_hard = u.math.asarray(voltages_hard)\n",
+ "voltages_soft = u.math.asarray(voltages_soft)\n",
+ "spikes_hard = u.math.asarray(spikes_hard)\n",
+ "spikes_soft = u.math.asarray(spikes_soft)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot comparison\n",
+ "times_plot = times.to_decimal(u.ms)\n",
+ "\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)\n",
+ "\n",
+ "# Hard reset\n",
+ "axes[0].plot(times_plot, voltages_hard.to_decimal(u.mV), linewidth=2, label='Hard Reset')\n",
+ "axes[0].axhline(y=-50, color='r', linestyle='--', alpha=0.5, label='Threshold')\n",
+ "axes[0].set_ylabel('Voltage (mV)')\n",
+ "axes[0].set_title('Hard Reset: V โ V_reset (discards extra charge)')\n",
+ "axes[0].legend()\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Soft reset\n",
+ "axes[1].plot(times_plot, voltages_soft.to_decimal(u.mV), linewidth=2, label='Soft Reset', color='orange')\n",
+ "axes[1].axhline(y=-50, color='r', linestyle='--', alpha=0.5, label='Threshold')\n",
+ "axes[1].set_ylabel('Voltage (mV)')\n",
+ "axes[1].set_xlabel('Time (ms)')\n",
+ "axes[1].set_title('Soft Reset: V โ V - V_th (preserves extra charge)')\n",
+ "axes[1].legend()\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "# Compare firing rates\n",
+ "n_spikes_hard = int(u.math.sum(spikes_hard > 0))\n",
+ "n_spikes_soft = int(u.math.sum(spikes_soft > 0))\n",
+ "rate_hard = n_spikes_hard / (duration.to_decimal(u.second))\n",
+ "rate_soft = n_spikes_soft / (duration.to_decimal(u.second))\n",
+ "\n",
+ "print(f\"Hard reset: {n_spikes_hard} spikes, {rate_hard:.2f} Hz\")\n",
+ "print(f\"Soft reset: {n_spikes_soft} spikes, {rate_soft:.2f} Hz\")\n",
+ "print(f\"\\nSoft reset fires {(rate_soft/rate_hard - 1)*100:.1f}% more frequently\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 6: Population of LIF Neurons\n",
+ "\n",
+ "Now let's create a population of neurons with heterogeneous properties."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create population with varied initial conditions\n",
+ "pop_size = 50\n",
+ "neuron_pop = brainpy.LIF(\n",
+ " size=pop_size,\n",
+ " V_rest=-65. * u.mV,\n",
+ " V_th=-50. * u.mV,\n",
+ " V_reset=-65. * u.mV,\n",
+ " tau=10. * u.ms,\n",
+ " V_initializer=braintools.init.Normal(-65., 5., unit=u.mV), # Random initial V\n",
+ " spk_reset='hard'\n",
+ ")\n",
+ "\n",
+ "# Initialize\n",
+ "brainstate.nn.init_all_states(neuron_pop)\n",
+ "\n",
+ "# Simulate with step current\n",
+ "duration = 300. * u.ms\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "\n",
+ "spike_history = []\n",
+ "for t in times:\n",
+ " # Step current: 0 โ 2.5 nA at t=50ms\n",
+ " I = 2.5 * u.nA if t > 50*u.ms else 0 * u.nA\n",
+ " neuron_pop(I)\n",
+ " spike_history.append(neuron_pop.get_spike())\n",
+ "\n",
+ "spike_history = u.math.asarray(spike_history)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Raster plot\n",
+ "t_indices, n_indices = u.math.where(spike_history > 0)\n",
+ "spike_times = times[t_indices].to_decimal(u.ms)\n",
+ "\n",
+ "plt.figure(figsize=(12, 6))\n",
+ "plt.scatter(spike_times, n_indices, s=2, c='black', alpha=0.6)\n",
+ "plt.axvline(x=50, color='r', linestyle='--', alpha=0.5, label='Input onset')\n",
+ "plt.xlabel('Time (ms)', fontsize=12)\n",
+ "plt.ylabel('Neuron Index', fontsize=12)\n",
+ "plt.title('Population Activity Raster Plot (50 LIF Neurons)', fontsize=14)\n",
+ "plt.legend()\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "# Population firing rate over time\n",
+ "bin_size = 10 * u.ms\n",
+ "bins = u.math.arange(0*u.ms, duration, bin_size)\n",
+ "pop_rate, _ = u.math.histogram(times[t_indices], bins=bins.to_decimal(u.ms))\n",
+ "pop_rate = pop_rate / (pop_size * bin_size.to_decimal(u.second)) # Convert to Hz\n",
+ "\n",
+ "plt.figure(figsize=(12, 4))\n",
+ "bin_centers = bins[:-1] + bin_size/2\n",
+ "plt.plot(bin_centers.to_decimal(u.ms), pop_rate, linewidth=2)\n",
+ "plt.axvline(x=50, color='r', linestyle='--', alpha=0.5, label='Input onset')\n",
+ "plt.xlabel('Time (ms)', fontsize=12)\n",
+ "plt.ylabel('Population Rate (Hz)', fontsize=12)\n",
+ "plt.title('Population Firing Rate', fontsize=14)\n",
+ "plt.legend()\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 7: Effects of Different Parameters\n",
+ "\n",
+ "Let's explore how changing parameters affects neuron behavior."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Different time constants\n",
+ "taus = [5*u.ms, 10*u.ms, 20*u.ms]\n",
+ "neurons = [brainpy.LIF(1, V_rest=-65.*u.mV, V_th=-50.*u.mV, \n",
+ " V_reset=-65.*u.mV, tau=tau, spk_reset='hard') \n",
+ " for tau in taus]\n",
+ "\n",
+ "# Initialize all\n",
+ "for n in neurons:\n",
+ " brainstate.nn.init_all_states(n)\n",
+ "\n",
+ "# Simulate\n",
+ "duration = 150. * u.ms\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "I_input = 2.0 * u.nA\n",
+ "\n",
+ "results = {}\n",
+ "for tau, neuron in zip(taus, neurons):\n",
+ " voltages = []\n",
+ " for t in times:\n",
+ " neuron(I_input)\n",
+ " voltages.append(neuron.V.value)\n",
+ " results[tau] = u.math.asarray(voltages)\n",
+ "\n",
+ "# Plot\n",
+ "plt.figure(figsize=(12, 5))\n",
+ "for tau, voltages in results.items():\n",
+ " plt.plot(times.to_decimal(u.ms), voltages.to_decimal(u.mV), \n",
+ " linewidth=2, label=f'ฯ = {tau}')\n",
+ "\n",
+ "plt.axhline(y=-50, color='r', linestyle='--', alpha=0.5, label='Threshold')\n",
+ "plt.xlabel('Time (ms)', fontsize=12)\n",
+ "plt.ylabel('Membrane Potential (mV)', fontsize=12)\n",
+ "plt.title('Effect of Time Constant on LIF Dynamics', fontsize=14)\n",
+ "plt.legend(fontsize=10)\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"Observation: Larger ฯ โ slower integration, lower firing rate\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "In this tutorial, you learned:\n",
+ "\n",
+ "โ
How to create and configure LIF neurons with physical units\n",
+ "\n",
+ "โ
How to simulate and visualize neuron dynamics\n",
+ "\n",
+ "โ
How to compute F-I curves (frequency-current relationships)\n",
+ "\n",
+ "โ
The difference between hard and soft reset modes\n",
+ "\n",
+ "โ
How to work with populations of neurons\n",
+ "\n",
+ "โ
How parameters affect neuron behavior\n",
+ "\n",
+ "## Next Steps\n",
+ "\n",
+ "- **Tutorial 2**: Learn about [synapse models](02-synapse-models.ipynb)\n",
+ "- **Tutorial 3**: Build [connected networks](03-network-connection.ipynb)\n",
+ "- **Advanced**: Explore [other neuron models](01-other-neurons.ipynb) (LIFRef, ALIF)\n",
+ "\n",
+ "## Exercises\n",
+ "\n",
+ "Try these on your own:\n",
+ "\n",
+ "1. Create a neuron with refractory period using `brainpy.LIFRef`\n",
+ "2. Implement adaptive neuron using `brainpy.ALIF` and observe spike-frequency adaptation\n",
+ "3. Generate an F-I curve for different values of ฯ\n",
+ "4. Create a population with heterogeneous time constants (use different ฯ for each neuron)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs_version3/tutorials/basic/02-synapse-models.ipynb b/docs_version3/tutorials/basic/02-synapse-models.ipynb
new file mode 100644
index 00000000..4002f4d5
--- /dev/null
+++ b/docs_version3/tutorials/basic/02-synapse-models.ipynb
@@ -0,0 +1,651 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Tutorial 2: Synapse Models\n",
+ "\n",
+ "In this tutorial, you'll learn:\n",
+ "\n",
+ "- What synapses do in neural networks\n",
+ "- Different synapse models (Expon, Alpha, AMPA, GABAa)\n",
+ "- How to compare synapse dynamics\n",
+ "- When to use each synapse type\n",
+ "- How to create custom synapses"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import brainpy\n",
+ "import brainstate\n",
+ "import brainunit as u\n",
+ "import braintools\n",
+ "import matplotlib.pyplot as plt\n",
+ "import jax.numpy as jnp\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 1: Understanding Synapses\n",
+ "\n",
+ "Synapses perform **temporal filtering** of spike trains:\n",
+ "\n",
+ "```\n",
+ "Discrete Spikes โ [Synapse] โ Continuous Signal\n",
+ "```\n",
+ "\n",
+ "They model:\n",
+ "- Postsynaptic potentials (PSPs)\n",
+ "- Rise and decay kinetics\n",
+ "- Neurotransmitter dynamics"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 2: Exponential Synapse (Expon)\n",
+ "\n",
+ "The simplest model: single exponential decay.\n",
+ "\n",
+ "$$\\tau \\frac{dg}{dt} = -g$$\n",
+ "\n",
+ "When spike arrives: $g \\leftarrow g + 1$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set time step\n",
+ "brainstate.environ.set(dt=0.1 * u.ms)\n",
+ "\n",
+ "# Create exponential synapse\n",
+ "expon_syn = brainpy.Expon(\n",
+ " size=1,\n",
+ " tau=5. * u.ms,\n",
+ " g_initializer=braintools.init.Constant(0. * u.mS)\n",
+ ")\n",
+ "\n",
+ "# Initialize\n",
+ "brainstate.nn.init_all_states(expon_syn)\n",
+ "\n",
+ "print(f\"Created Expon synapse with tau={expon_syn.tau}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Simulate response to single spike\n",
+ "brainstate.nn.init_all_states(expon_syn)\n",
+ "\n",
+ "duration = 50. * u.ms\n",
+ "dt = brainstate.environ.get_dt()\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "\n",
+ "responses = []\n",
+ "for i, t in enumerate(times):\n",
+ " # Spike at t=0\n",
+ " spike = 1.0 if i == 0 else 0.0\n",
+ " expon_syn(jnp.array([spike]))\n",
+ " responses.append(expon_syn.g.value[0])\n",
+ "\n",
+ "responses = u.math.asarray(responses)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot response\n",
+ "plt.figure(figsize=(10, 4))\n",
+ "plt.plot(times.to_decimal(u.ms), responses.to_decimal(u.mS), linewidth=2, label='Expon')\n",
+ "plt.axvline(x=0, color='r', linestyle='--', alpha=0.5, label='Spike time')\n",
+ "plt.xlabel('Time (ms)')\n",
+ "plt.ylabel('Synaptic Variable g (mS)')\n",
+ "plt.title('Exponential Synapse Response to Single Spike')\n",
+ "plt.legend()\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"Observation: Instantaneous rise, exponential decay\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 3: Alpha Synapse\n",
+ "\n",
+ "More realistic: gradual rise and decay.\n",
+ "\n",
+ "$$\\tau \\frac{dh}{dt} = -h$$\n",
+ "$$\\tau \\frac{dg}{dt} = -g + h$$\n",
+ "\n",
+ "Response: $g(t) = \\frac{t}{\\tau} e^{-t/\\tau}$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create alpha synapse\n",
+ "alpha_syn = brainpy.Alpha(\n",
+ " size=1,\n",
+ " tau=5. * u.ms,\n",
+ " g_initializer=braintools.init.Constant(0. * u.mS)\n",
+ ")\n",
+ "\n",
+ "brainstate.nn.init_all_states(alpha_syn)\n",
+ "\n",
+ "# Simulate\n",
+ "alpha_responses = []\n",
+ "for i, t in enumerate(times):\n",
+ " spike = 1.0 if i == 0 else 0.0\n",
+ " alpha_syn(jnp.array([spike]))\n",
+ " alpha_responses.append(alpha_syn.g.value[0])\n",
+ "\n",
+ "alpha_responses = u.math.asarray(alpha_responses)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Compare Expon vs Alpha\n",
+ "plt.figure(figsize=(10, 5))\n",
+ "plt.plot(times.to_decimal(u.ms), responses.to_decimal(u.mS), \n",
+ " linewidth=2, label='Expon (instantaneous rise)', color='blue')\n",
+ "plt.plot(times.to_decimal(u.ms), alpha_responses.to_decimal(u.mS), \n",
+ " linewidth=2, label='Alpha (gradual rise)', color='orange')\n",
+ "plt.axvline(x=0, color='r', linestyle='--', alpha=0.5, label='Spike time')\n",
+ "plt.xlabel('Time (ms)')\n",
+ "plt.ylabel('Synaptic Variable g (mS)')\n",
+ "plt.title('Comparison: Exponential vs Alpha Synapse')\n",
+ "plt.legend()\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"Key difference: Alpha has realistic rise time, peak at t=ฯ\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 4: AMPA and GABAa Synapses\n",
+ "\n",
+ "Biologically parameterized models:\n",
+ "\n",
+ "- **AMPA**: Fast excitatory (ฯ โ 2 ms)\n",
+ "- **GABAa**: Slower inhibitory (ฯ โ 10 ms)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create AMPA synapse (fast excitatory)\n",
+ "ampa_syn = brainpy.AMPA(\n",
+ " size=1,\n",
+ " tau=2. * u.ms,\n",
+ " g_initializer=braintools.init.Constant(0. * u.mS)\n",
+ ")\n",
+ "\n",
+ "# Create GABAa synapse (slower inhibitory)\n",
+ "gaba_syn = brainpy.GABAa(\n",
+ " size=1,\n",
+ " tau=10. * u.ms,\n",
+ " g_initializer=braintools.init.Constant(0. * u.mS)\n",
+ ")\n",
+ "\n",
+ "brainstate.nn.init_all_states(ampa_syn)\n",
+ "brainstate.nn.init_all_states(gaba_syn)\n",
+ "\n",
+ "print(f\"AMPA tau: {ampa_syn.tau} (fast)\")\n",
+ "print(f\"GABAa tau: {gaba_syn.tau} (slow)\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Simulate both\n",
+ "ampa_responses = []\n",
+ "gaba_responses = []\n",
+ "\n",
+ "for i, t in enumerate(times):\n",
+ " spike = 1.0 if i == 0 else 0.0\n",
+ " ampa_syn(jnp.array([spike]))\n",
+ " gaba_syn(jnp.array([spike]))\n",
+ " ampa_responses.append(ampa_syn.g.value[0])\n",
+ " gaba_responses.append(gaba_syn.g.value[0])\n",
+ "\n",
+ "ampa_responses = u.math.asarray(ampa_responses)\n",
+ "gaba_responses = u.math.asarray(gaba_responses)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot all four synapse types\n",
+ "plt.figure(figsize=(12, 6))\n",
+ "\n",
+ "plt.plot(times.to_decimal(u.ms), responses.to_decimal(u.mS), \n",
+ " linewidth=2, label='Expon (ฯ=5ms)', alpha=0.7)\n",
+ "plt.plot(times.to_decimal(u.ms), alpha_responses.to_decimal(u.mS), \n",
+ " linewidth=2, label='Alpha (ฯ=5ms)', alpha=0.7)\n",
+ "plt.plot(times.to_decimal(u.ms), ampa_responses.to_decimal(u.mS), \n",
+ " linewidth=2, label='AMPA (ฯ=2ms, fast excitatory)', linestyle='--')\n",
+ "plt.plot(times.to_decimal(u.ms), gaba_responses.to_decimal(u.mS), \n",
+ " linewidth=2, label='GABAa (ฯ=10ms, slow inhibitory)', linestyle='--')\n",
+ "\n",
+ "plt.axvline(x=0, color='r', linestyle=':', alpha=0.5, label='Spike time')\n",
+ "plt.xlabel('Time (ms)', fontsize=12)\n",
+ "plt.ylabel('Synaptic Variable g (mS)', fontsize=12)\n",
+ "plt.title('Comparison of All Synapse Models', fontsize=14)\n",
+ "plt.legend(loc='upper right')\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"\\nObservations:\")\n",
+ "print(\"- AMPA: Fastest decay (excitatory transmission)\")\n",
+ "print(\"- GABAa: Slowest decay (prolonged inhibition)\")\n",
+ "print(\"- Alpha models have realistic rise time\")\n",
+ "print(\"- Expon models are computationally faster\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 5: Response to Spike Trains\n",
+ "\n",
+ "How do synapses integrate multiple spikes?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create spike train: 5 spikes at 50 Hz (20ms intervals)\n",
+ "brainstate.nn.init_all_states(expon_syn)\n",
+ "brainstate.nn.init_all_states(alpha_syn)\n",
+ "\n",
+ "duration = 150. * u.ms\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "spike_times = [0, 20, 40, 60, 80] # ms\n",
+ "\n",
+ "expon_train_resp = []\n",
+ "alpha_train_resp = []\n",
+ "\n",
+ "for i, t in enumerate(times):\n",
+ " t_ms = t.to_decimal(u.ms)\n",
+ " spike = 1.0 if any(abs(t_ms - st) < 0.1 for st in spike_times) else 0.0\n",
+ " \n",
+ " expon_syn(jnp.array([spike]))\n",
+ " alpha_syn(jnp.array([spike]))\n",
+ " \n",
+ " expon_train_resp.append(expon_syn.g.value[0])\n",
+ " alpha_train_resp.append(alpha_syn.g.value[0])\n",
+ "\n",
+ "expon_train_resp = u.math.asarray(expon_train_resp)\n",
+ "alpha_train_resp = u.math.asarray(alpha_train_resp)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot spike train responses\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)\n",
+ "\n",
+ "# Expon response\n",
+ "axes[0].plot(times.to_decimal(u.ms), expon_train_resp.to_decimal(u.mS), \n",
+ " linewidth=2, color='blue')\n",
+ "for st in spike_times:\n",
+ " axes[0].axvline(x=st, color='r', linestyle='--', alpha=0.3)\n",
+ "axes[0].set_ylabel('g (mS)')\n",
+ "axes[0].set_title('Exponential Synapse Response to Spike Train (50 Hz)')\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Alpha response\n",
+ "axes[1].plot(times.to_decimal(u.ms), alpha_train_resp.to_decimal(u.mS), \n",
+ " linewidth=2, color='orange')\n",
+ "for st in spike_times:\n",
+ " axes[1].axvline(x=st, color='r', linestyle='--', alpha=0.3, label='Spike' if st == 0 else '')\n",
+ "axes[1].set_ylabel('g (mS)')\n",
+ "axes[1].set_xlabel('Time (ms)')\n",
+ "axes[1].set_title('Alpha Synapse Response to Spike Train (50 Hz)')\n",
+ "axes[1].legend()\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"Temporal summation: synaptic responses accumulate over time\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 6: Effect of Time Constant\n",
+ "\n",
+ "How does ฯ affect synapse dynamics?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Compare different time constants\n",
+ "taus = [2*u.ms, 5*u.ms, 10*u.ms, 20*u.ms]\n",
+ "synapses = [brainpy.Expon(1, tau=tau) for tau in taus]\n",
+ "\n",
+ "for syn in synapses:\n",
+ " brainstate.nn.init_all_states(syn)\n",
+ "\n",
+ "# Simulate\n",
+ "duration = 100. * u.ms\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "\n",
+ "responses_by_tau = {}\n",
+ "for tau, syn in zip(taus, synapses):\n",
+ " resp = []\n",
+ " for i, t in enumerate(times):\n",
+ " spike = 1.0 if i == 0 else 0.0\n",
+ " syn(jnp.array([spike]))\n",
+ " resp.append(syn.g.value[0])\n",
+ " responses_by_tau[tau] = u.math.asarray(resp)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot effect of tau\n",
+ "plt.figure(figsize=(10, 6))\n",
+ "\n",
+ "colors = plt.cm.viridis(np.linspace(0, 1, len(taus)))\n",
+ "for (tau, resp), color in zip(responses_by_tau.items(), colors):\n",
+ " plt.plot(times.to_decimal(u.ms), resp.to_decimal(u.mS), \n",
+ " linewidth=2, label=f'ฯ = {tau}', color=color)\n",
+ "\n",
+ "plt.axvline(x=0, color='r', linestyle='--', alpha=0.3, label='Spike')\n",
+ "plt.xlabel('Time (ms)', fontsize=12)\n",
+ "plt.ylabel('Synaptic Variable g (mS)', fontsize=12)\n",
+ "plt.title('Effect of Time Constant on Synapse Decay', fontsize=14)\n",
+ "plt.legend(fontsize=10)\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"\\nEffect of ฯ:\")\n",
+ "print(\"- Smaller ฯ โ faster decay, less temporal summation\")\n",
+ "print(\"- Larger ฯ โ slower decay, more temporal summation\")\n",
+ "print(\"- Choose ฯ based on biological constraints and computational needs\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 7: Synapses in Networks (Preview)\n",
+ "\n",
+ "Synapses are used within projections. Here's a preview:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create neurons\n",
+ "pre_neurons = brainpy.LIF(10, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)\n",
+ "post_neurons = brainpy.LIF(5, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)\n",
+ "\n",
+ "# Create projection with exponential synapse\n",
+ "projection = brainpy.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(\n",
+ " 10, 5, prob=0.5, weight=0.5*u.mS\n",
+ " ),\n",
+ " syn=brainpy.Expon.desc(5, tau=5.*u.ms), # Synapse descriptor\n",
+ " out=brainpy.CUBA.desc(),\n",
+ " post=post_neurons\n",
+ ")\n",
+ "\n",
+ "print(\"Synapse integrated into projection!\")\n",
+ "print(f\"Synapse type: {type(projection.syn).__name__}\")\n",
+ "print(f\"Synapse tau: {projection.syn.tau}\")\n",
+ "print(\"\\nSee Tutorial 3 for complete network building!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 8: When to Use Each Synapse Type\n",
+ "\n",
+ "### Decision Guide\n",
+ "\n",
+ "**Use Expon when:**\n",
+ "- Speed is critical\n",
+ "- Training SNNs\n",
+ "- Large-scale simulations\n",
+ "- Precise kinetics not needed\n",
+ "\n",
+ "**Use Alpha when:**\n",
+ "- Biological realism matters\n",
+ "- Detailed cortical models\n",
+ "- Comparing to experimental data\n",
+ "- Rise time is important\n",
+ "\n",
+ "**Use AMPA when:**\n",
+ "- Excitatory synapses\n",
+ "- Fast glutamatergic transmission\n",
+ "- Cortical excitatory neurons\n",
+ "\n",
+ "**Use GABAa when:**\n",
+ "- Inhibitory synapses\n",
+ "- GABAergic interneurons\n",
+ "- Slower inhibition needed"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 9: Creating Custom Synapses\n",
+ "\n",
+ "You can create custom synapse models:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from brainpy._base import Synapse\n",
+ "\n",
+ "class DoubleExpSynapse(Synapse):\n",
+ " \"\"\"Synapse with different rise and decay time constants.\"\"\"\n",
+ " \n",
+ " def __init__(self, size, tau_rise, tau_decay, **kwargs):\n",
+ " super().__init__(size, **kwargs)\n",
+ " \n",
+ " self.tau_rise = tau_rise\n",
+ " self.tau_decay = tau_decay\n",
+ " \n",
+ " # Two state variables\n",
+ " self.h = brainstate.ShortTermState(\n",
+ " braintools.init.Constant(0., unit=u.mS)(size)\n",
+ " )\n",
+ " self.g = brainstate.ShortTermState(\n",
+ " braintools.init.Constant(0., unit=u.mS)(size)\n",
+ " )\n",
+ " \n",
+ " def update(self, spike_input):\n",
+ " dt = brainstate.environ.get_dt()\n",
+ " \n",
+ " # Rise dynamics\n",
+ " dh = -self.h.value / self.tau_rise\n",
+ " self.h.value = self.h.value + dh * dt + spike_input * u.mS\n",
+ " \n",
+ " # Decay dynamics\n",
+ " dg = -self.g.value / self.tau_decay + self.h.value / self.tau_rise\n",
+ " self.g.value = self.g.value + dg * dt\n",
+ " \n",
+ " return self.g.value\n",
+ " \n",
+ " @classmethod\n",
+ " def desc(cls, size, tau_rise, tau_decay, **kwargs):\n",
+ " def create():\n",
+ " return cls(size, tau_rise, tau_decay, **kwargs)\n",
+ " return create\n",
+ "\n",
+ "# Test custom synapse\n",
+ "custom_syn = DoubleExpSynapse(\n",
+ " size=1, \n",
+ " tau_rise=1.*u.ms, \n",
+ " tau_decay=10.*u.ms\n",
+ ")\n",
+ "brainstate.nn.init_all_states(custom_syn)\n",
+ "\n",
+ "print(\"Custom synapse created!\")\n",
+ "print(f\"Rise time: {custom_syn.tau_rise}\")\n",
+ "print(f\"Decay time: {custom_syn.tau_decay}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Test custom synapse\n",
+ "custom_resp = []\n",
+ "duration = 50. * u.ms\n",
+ "times = u.math.arange(0. * u.ms, duration, dt)\n",
+ "\n",
+ "for i, t in enumerate(times):\n",
+ " spike = 1.0 if i == 0 else 0.0\n",
+ " custom_syn(jnp.array([spike]))\n",
+ " custom_resp.append(custom_syn.g.value[0])\n",
+ "\n",
+ "custom_resp = u.math.asarray(custom_resp)\n",
+ "\n",
+ "plt.figure(figsize=(10, 5))\n",
+ "plt.plot(times.to_decimal(u.ms), custom_resp.to_decimal(u.mS), linewidth=2)\n",
+ "plt.axvline(x=0, color='r', linestyle='--', alpha=0.5, label='Spike')\n",
+ "plt.xlabel('Time (ms)')\n",
+ "plt.ylabel('g (mS)')\n",
+ "plt.title('Custom Double-Exponential Synapse (ฯ_rise=1ms, ฯ_decay=10ms)')\n",
+ "plt.legend()\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"Custom synapse shows fast rise (1ms) and slow decay (10ms)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "In this tutorial, you learned:\n",
+ "\n",
+ "โ
How synapses perform temporal filtering\n",
+ "\n",
+ "โ
Four synapse models: Expon, Alpha, AMPA, GABAa\n",
+ "\n",
+ "โ
Differences in rise time and decay kinetics\n",
+ "\n",
+ "โ
Response to single spikes and spike trains\n",
+ "\n",
+ "โ
Effect of time constant ฯ\n",
+ "\n",
+ "โ
When to use each synapse type\n",
+ "\n",
+ "โ
How to create custom synapses\n",
+ "\n",
+ "## Next Steps\n",
+ "\n",
+ "- **Tutorial 3**: Learn to [build connected networks](03-network-connections.ipynb)\n",
+ "- **Core Concepts**: Read detailed [synapse documentation](../../core-concepts/synapses.rst)\n",
+ "- **Examples**: See synapses in action in the [examples gallery](../../examples/gallery.rst)\n",
+ "\n",
+ "## Exercises\n",
+ "\n",
+ "Try these on your own:\n",
+ "\n",
+ "1. Compare AMPA vs GABAa responses to a 100 Hz spike train\n",
+ "2. Find the ฯ value where peak response to a 50 Hz train is maximized\n",
+ "3. Implement an NMDA synapse (voltage-dependent)\n",
+ "4. Create a synapse with adaptation (decrease response over time)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs_version3/tutorials/basic/03-network-connections.ipynb b/docs_version3/tutorials/basic/03-network-connections.ipynb
new file mode 100644
index 00000000..dbb93eae
--- /dev/null
+++ b/docs_version3/tutorials/basic/03-network-connections.ipynb
@@ -0,0 +1,649 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Tutorial 3: Network Connections\n",
+ "\n",
+ "In this tutorial, you'll learn:\n",
+ "\n",
+ "- The projection architecture (Comm-Syn-Out)\n",
+ "- Connectivity patterns\n",
+ "- CUBA vs COBA output mechanisms\n",
+ "- Building excitatory-inhibitory networks\n",
+ "- Network simulation and analysis"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import brainpy\n",
+ "import brainstate\n",
+ "import brainunit as u\n",
+ "import braintools\n",
+ "import matplotlib.pyplot as plt\n",
+ "import jax.numpy as jnp"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 1: Understanding Projections\n",
+ "\n",
+ "BrainPy 3.0 uses a **three-stage projection architecture**:\n",
+ "\n",
+ "```\n",
+ "Presynaptic [Communication] [Synapse] [Output] Postsynaptic\n",
+ "Spikes โ Connectivity โ Dynamics โ Injection โ Neurons\n",
+ " & Weights Filtering Mechanism\n",
+ "```\n",
+ "\n",
+ "### Why This Design?\n",
+ "\n",
+ "1. **Modularity**: Each stage is independent and swappable\n",
+ "2. **Clarity**: Clear separation of concerns\n",
+ "3. **Reusability**: Mix and match components\n",
+ "4. **Flexibility**: Easy to customize any stage"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 2: Simple Projection Example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set time step\n",
+ "brainstate.environ.set(dt=0.1 * u.ms)\n",
+ "\n",
+ "# Create pre and post neurons\n",
+ "pre_neurons = brainpy.LIF(20, V_rest=-65.*u.mV, V_th=-50.*u.mV, V_reset=-65.*u.mV, tau=10.*u.ms)\n",
+ "post_neurons = brainpy.LIF(10, V_rest=-65.*u.mV, V_th=-50.*u.mV, V_reset=-65.*u.mV, tau=10.*u.ms)\n",
+ "\n",
+ "# Create projection with all three stages\n",
+ "projection = brainpy.AlignPostProj(\n",
+ " # Stage 1: Communication (connectivity + weights)\n",
+ " comm=brainstate.nn.EventFixedProb(\n",
+ " pre_num=20, \n",
+ " post_num=10, \n",
+ " prob=0.3, # 30% connection probability\n",
+ " weight=0.5*u.mS # Synaptic weight\n",
+ " ),\n",
+ " \n",
+ " # Stage 2: Synapse (temporal dynamics)\n",
+ " syn=brainpy.Expon.desc(10, tau=5.*u.ms),\n",
+ " \n",
+ " # Stage 3: Output (how to affect post neurons)\n",
+ " out=brainpy.CUBA.desc(),\n",
+ " \n",
+ " # Target neurons\n",
+ " post=post_neurons\n",
+ ")\n",
+ "\n",
+ "print(f\"Created projection: {20} โ {10} neurons\")\n",
+ "print(f\"Connectivity: {0.3*100:.0f}% probability\")\n",
+ "print(f\"Expected connections: ~{20*10*0.3:.0f} synapses\")\n",
+ "print(f\"Synapse type: Exponential (ฯ=5ms)\")\n",
+ "print(f\"Output mechanism: CUBA (current-based)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 3: Testing the Projection"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize all states\n",
+ "brainstate.nn.init_all_states(pre_neurons)\n",
+ "brainstate.nn.init_all_states(post_neurons)\n",
+ "\n",
+ "# Simulate one step\n",
+ "# 1. Activate pre neurons with strong input\n",
+ "pre_neurons(5.0 * u.nA)\n",
+ "\n",
+ "# 2. Get spikes from pre neurons\n",
+ "pre_spikes = pre_neurons.get_spike()\n",
+ "print(f\"Pre spikes: {u.math.sum(pre_spikes != 0)} neurons fired\")\n",
+ "\n",
+ "# 3. Propagate through projection\n",
+ "projection(pre_spikes)\n",
+ "\n",
+ "# 4. Check synaptic variable\n",
+ "print(f\"Synaptic conductance (g): {projection.syn.g.value[:5]}\")\n",
+ "\n",
+ "# 5. Update post neurons (they receive synaptic input)\n",
+ "post_neurons(0. * u.nA)\n",
+ "post_spikes = post_neurons.get_spike()\n",
+ "print(f\"Post spikes: {u.math.sum(post_spikes != 0)} neurons fired\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 4: CUBA vs COBA Output\n",
+ "\n",
+ "Two ways synapses can affect postsynaptic neurons:\n",
+ "\n",
+ "### CUBA (Current-Based)\n",
+ "$$I_{syn} = g$$\n",
+ "\n",
+ "- Simple: synaptic conductance directly becomes current\n",
+ "- Faster computation\n",
+ "- Less biologically realistic\n",
+ "\n",
+ "### COBA (Conductance-Based)\n",
+ "$$I_{syn} = g \\cdot (V - E_{rev})$$\n",
+ "\n",
+ "- Realistic: current depends on driving force\n",
+ "- Voltage-dependent\n",
+ "- More biologically accurate"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create neurons\n",
+ "neurons_cuba = brainpy.LIF(5, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)\n",
+ "neurons_coba = brainpy.LIF(5, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)\n",
+ "pre = brainpy.LIF(10, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)\n",
+ "\n",
+ "# CUBA projection\n",
+ "proj_cuba = brainpy.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(10, 5, prob=0.5, weight=1.0*u.mS),\n",
+ " syn=brainpy.Expon.desc(5, tau=5.*u.ms),\n",
+ " out=brainpy.CUBA.desc(), # Current-based\n",
+ " post=neurons_cuba\n",
+ ")\n",
+ "\n",
+ "# COBA projection (excitatory reversal potential at 0 mV)\n",
+ "proj_coba = brainpy.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(10, 5, prob=0.5, weight=1.0*u.mS),\n",
+ " syn=brainpy.Expon.desc(5, tau=5.*u.ms),\n",
+ " out=brainpy.COBA.desc(E=0.*u.mV), # Conductance-based\n",
+ " post=neurons_coba\n",
+ ")\n",
+ "\n",
+ "print(\"CUBA: I_syn = g\")\n",
+ "print(\"COBA: I_syn = g * (V - E_rev)\")\n",
+ "print(\"\\nWith V=-65mV and E=0mV:\")\n",
+ "print(\"COBA current will be ~65mV larger (more driving force)\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Compare CUBA vs COBA\n",
+ "brainstate.nn.init_all_states(pre)\n",
+ "brainstate.nn.init_all_states(neurons_cuba)\n",
+ "brainstate.nn.init_all_states(neurons_coba)\n",
+ "\n",
+ "duration = 100. * u.ms\n",
+ "dt = brainstate.environ.get_dt()\n",
+ "times = u.math.arange(0.*u.ms, duration, dt)\n",
+ "\n",
+ "V_cuba_hist = []\n",
+ "V_coba_hist = []\n",
+ "\n",
+ "for t in times:\n",
+ " # Strong input to pre neurons\n",
+ " pre(3.0 * u.nA)\n",
+ " spikes = pre.get_spike()\n",
+ " \n",
+ " # Propagate through both projections\n",
+ " proj_cuba(spikes)\n",
+ " proj_coba(spikes)\n",
+ " \n",
+ " # Update post neurons\n",
+ " neurons_cuba(0. * u.nA)\n",
+ " neurons_coba(0. * u.nA)\n",
+ " \n",
+ " # Record voltages\n",
+ " V_cuba_hist.append(neurons_cuba.V.value[0])\n",
+ " V_coba_hist.append(neurons_coba.V.value[0])\n",
+ "\n",
+ "V_cuba_hist = u.math.asarray(V_cuba_hist)\n",
+ "V_coba_hist = u.math.asarray(V_coba_hist)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot comparison\n",
+ "plt.figure(figsize=(12, 5))\n",
+ "plt.plot(times.to_decimal(u.ms), V_cuba_hist.to_decimal(u.mV), \n",
+ " linewidth=2, label='CUBA (current-based)', alpha=0.8)\n",
+ "plt.plot(times.to_decimal(u.ms), V_coba_hist.to_decimal(u.mV), \n",
+ " linewidth=2, label='COBA (conductance-based)', alpha=0.8)\n",
+ "plt.axhline(y=-50, color='r', linestyle='--', alpha=0.5, label='Threshold')\n",
+ "plt.xlabel('Time (ms)', fontsize=12)\n",
+ "plt.ylabel('Membrane Potential (mV)', fontsize=12)\n",
+ "plt.title('CUBA vs COBA: Postsynaptic Response', fontsize=14)\n",
+ "plt.legend()\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"Observation: COBA produces stronger depolarization due to larger driving force\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 5: Building an Excitatory-Inhibitory Network\n",
+ "\n",
+ "Let's build a classic E-I balanced network with:\n",
+ "- 80% excitatory neurons\n",
+ "- 20% inhibitory neurons\n",
+ "- All-to-all connectivity patterns"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class EINetwork(brainstate.nn.Module):\n",
+ " def __init__(self, n_exc=80, n_inh=20, prob=0.1):\n",
+ " super().__init__()\n",
+ " self.n_exc = n_exc\n",
+ " self.n_inh = n_inh\n",
+ " \n",
+ " # Create neuron populations\n",
+ " self.E = brainpy.LIF(\n",
+ " n_exc, \n",
+ " V_rest=-65.*u.mV, \n",
+ " V_th=-50.*u.mV, \n",
+ " V_reset=-65.*u.mV,\n",
+ " tau=10.*u.ms,\n",
+ " V_initializer=braintools.init.Normal(-65., 5., unit=u.mV)\n",
+ " )\n",
+ " \n",
+ " self.I = brainpy.LIF(\n",
+ " n_inh,\n",
+ " V_rest=-65.*u.mV,\n",
+ " V_th=-50.*u.mV,\n",
+ " V_reset=-65.*u.mV,\n",
+ " tau=10.*u.ms,\n",
+ " V_initializer=braintools.init.Normal(-65., 5., unit=u.mV)\n",
+ " )\n",
+ " \n",
+ " # Excitatory projections (fast, AMPA-like)\n",
+ " self.E2E = brainpy.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_exc, n_exc, prob=prob, weight=0.3*u.mS),\n",
+ " syn=brainpy.Expon.desc(n_exc, tau=2.*u.ms), # Fast excitation\n",
+ " out=brainpy.CUBA.desc(),\n",
+ " post=self.E\n",
+ " )\n",
+ " \n",
+ " self.E2I = brainpy.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_exc, n_inh, prob=prob, weight=0.3*u.mS),\n",
+ " syn=brainpy.Expon.desc(n_inh, tau=2.*u.ms),\n",
+ " out=brainpy.CUBA.desc(),\n",
+ " post=self.I\n",
+ " )\n",
+ " \n",
+ " # Inhibitory projections (slower, GABAa-like)\n",
+ " self.I2E = brainpy.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_inh, n_exc, prob=prob, weight=-2.0*u.mS),\n",
+ " syn=brainpy.Expon.desc(n_exc, tau=10.*u.ms), # Slower inhibition\n",
+ " out=brainpy.CUBA.desc(),\n",
+ " post=self.E\n",
+ " )\n",
+ " \n",
+ " self.I2I = brainpy.AlignPostProj(\n",
+ " comm=brainstate.nn.EventFixedProb(n_inh, n_inh, prob=prob, weight=-2.0*u.mS),\n",
+ " syn=brainpy.Expon.desc(n_inh, tau=10.*u.ms),\n",
+ " out=brainpy.CUBA.desc(),\n",
+ " post=self.I\n",
+ " )\n",
+ " \n",
+ " def update(self, inp_e, inp_i):\n",
+ " \"\"\"Update network for one time step.\n",
+ " \n",
+ " Key: Get spikes BEFORE updating neurons!\n",
+ " \"\"\"\n",
+ " # Get spikes from previous timestep\n",
+ " spk_e = self.E.get_spike()\n",
+ " spk_i = self.I.get_spike()\n",
+ " \n",
+ " # Update projections (uses previous spikes)\n",
+ " self.E2E(spk_e)\n",
+ " self.E2I(spk_e)\n",
+ " self.I2E(spk_i)\n",
+ " self.I2I(spk_i)\n",
+ " \n",
+ " # Update neurons (receives synaptic input)\n",
+ " self.E(inp_e)\n",
+ " self.I(inp_i)\n",
+ " \n",
+ " return spk_e, spk_i\n",
+ "\n",
+ "# Create network\n",
+ "net = EINetwork(n_exc=80, n_inh=20, prob=0.1)\n",
+ "print(f\"Created E-I network:\")\n",
+ "print(f\" - {net.n_exc} excitatory neurons\")\n",
+ "print(f\" - {net.n_inh} inhibitory neurons\")\n",
+ "print(f\" - {net.n_exc + net.n_inh} total neurons\")\n",
+ "print(f\" - 10% connectivity\")\n",
+ "print(f\" - 4 projection types (E2E, E2I, I2E, I2I)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 6: Simulate the Network"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize all states\n",
+ "brainstate.nn.init_all_states(net)\n",
+ "\n",
+ "# Simulation parameters\n",
+ "duration = 500. * u.ms\n",
+ "dt = brainstate.environ.get_dt()\n",
+ "times = u.math.arange(0.*u.ms, duration, dt)\n",
+ "\n",
+ "# External input currents\n",
+ "I_ext_e = 1.5 * u.nA # To excitatory neurons\n",
+ "I_ext_i = 1.0 * u.nA # To inhibitory neurons\n",
+ "\n",
+ "# Define simulation step\n",
+ "def sim_step(t):\n",
+ " return net.update(I_ext_e, I_ext_i)\n",
+ "\n",
+ "print(\"Running simulation...\")\n",
+ "print(f\"Duration: {duration}\")\n",
+ "print(f\"Time step: {dt}\")\n",
+ "print(f\"Steps: {len(times)}\")\n",
+ "\n",
+ "# Run simulation with progress bar\n",
+ "spikes = brainstate.transform.for_loop(\n",
+ " sim_step, \n",
+ " times,\n",
+ " pbar=brainstate.transform.ProgressBar(10)\n",
+ ")\n",
+ "\n",
+ "spk_e, spk_i = spikes\n",
+ "print(\"\\nSimulation complete!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 7: Analyze Network Activity"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Calculate statistics\n",
+ "n_spikes_e = int(u.math.sum(spk_e != 0))\n",
+ "n_spikes_i = int(u.math.sum(spk_i != 0))\n",
+ "n_spikes_total = n_spikes_e + n_spikes_i\n",
+ "\n",
+ "# Firing rates\n",
+ "duration_s = duration.to_decimal(u.second)\n",
+ "rate_e = n_spikes_e / (net.n_exc * duration_s)\n",
+ "rate_i = n_spikes_i / (net.n_inh * duration_s)\n",
+ "rate_total = n_spikes_total / ((net.n_exc + net.n_inh) * duration_s)\n",
+ "\n",
+ "print(\"Network Statistics:\")\n",
+ "print(\"=\"*50)\n",
+ "print(f\"Total spikes: {n_spikes_total}\")\n",
+ "print(f\" - Excitatory: {n_spikes_e} ({n_spikes_e/n_spikes_total*100:.1f}%)\")\n",
+ "print(f\" - Inhibitory: {n_spikes_i} ({n_spikes_i/n_spikes_total*100:.1f}%)\")\n",
+ "print(f\"\\nFiring Rates:\")\n",
+ "print(f\" - Excitatory population: {rate_e:.2f} Hz\")\n",
+ "print(f\" - Inhibitory population: {rate_i:.2f} Hz\")\n",
+ "print(f\" - Overall network: {rate_total:.2f} Hz\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 8: Visualize Network Activity"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Combine E and I spikes\n",
+ "all_spikes = u.math.concatenate([spk_e, spk_i], axis=1)\n",
+ "\n",
+ "# Find spike times and neuron indices\n",
+ "t_idx, n_idx = u.math.where(all_spikes != 0)\n",
+ "spike_times = times[t_idx].to_decimal(u.ms)\n",
+ "\n",
+ "# Create raster plot\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(14, 8), \n",
+ " gridspec_kw={'height_ratios': [3, 1]})\n",
+ "\n",
+ "# Raster plot\n",
+ "colors = ['blue' if i < net.n_exc else 'red' for i in n_idx]\n",
+ "axes[0].scatter(spike_times, n_idx, s=1, c=colors, alpha=0.6)\n",
+ "axes[0].axhline(y=net.n_exc, color='black', linestyle='--', \n",
+ " alpha=0.5, linewidth=2, label='E/I boundary')\n",
+ "axes[0].set_ylabel('Neuron Index', fontsize=12)\n",
+ "axes[0].set_title('E-I Network Activity Raster Plot', fontsize=14, fontweight='bold')\n",
+ "axes[0].text(10, net.n_exc/2, 'Excitatory', fontsize=10, \n",
+ " bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))\n",
+ "axes[0].text(10, net.n_exc + net.n_inh/2, 'Inhibitory', fontsize=10,\n",
+ " bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.7))\n",
+ "axes[0].legend(loc='upper right')\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Population firing rate over time\n",
+ "bin_size = 10 * u.ms\n",
+ "bins = u.math.arange(0.*u.ms, duration, bin_size)\n",
+ "bins_decimal = bins.to_decimal(u.ms)\n",
+ "\n",
+ "# Calculate rates for E and I\n",
+ "t_idx_e, _ = u.math.where(spk_e != 0)\n",
+ "t_idx_i, _ = u.math.where(spk_i != 0)\n",
+ "\n",
+ "hist_e, _ = u.math.histogram(times[t_idx_e].to_decimal(u.ms), bins=bins_decimal)\n",
+ "hist_i, _ = u.math.histogram(times[t_idx_i].to_decimal(u.ms), bins=bins_decimal)\n",
+ "\n",
+ "rate_e = hist_e / (net.n_exc * bin_size.to_decimal(u.second))\n",
+ "rate_i = hist_i / (net.n_inh * bin_size.to_decimal(u.second))\n",
+ "\n",
+ "bin_centers = bins[:-1] + bin_size/2\n",
+ "axes[1].plot(bin_centers.to_decimal(u.ms), rate_e, linewidth=2, \n",
+ " label='Excitatory', color='blue', alpha=0.7)\n",
+ "axes[1].plot(bin_centers.to_decimal(u.ms), rate_i, linewidth=2, \n",
+ " label='Inhibitory', color='red', alpha=0.7)\n",
+ "axes[1].set_xlabel('Time (ms)', fontsize=12)\n",
+ "axes[1].set_ylabel('Firing Rate (Hz)', fontsize=12)\n",
+ "axes[1].set_title('Population Firing Rates', fontsize=12)\n",
+ "axes[1].legend()\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 9: Connectivity Patterns\n",
+ "\n",
+ "Different connectivity strategies:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Example: Different connectivity patterns\n",
+ "\n",
+ "# 1. Fixed probability (what we used)\n",
+ "conn_fixed_prob = brainstate.nn.EventFixedProb(\n",
+ " 20, 10, prob=0.3, weight=0.5*u.mS\n",
+ ")\n",
+ "\n",
+ "# 2. All-to-all (full connectivity)\n",
+ "# conn_all2all = brainstate.nn.EventFixedProb(\n",
+ "# 20, 10, prob=1.0, weight=0.5*u.mS\n",
+ "# )\n",
+ "\n",
+ "# 3. One-to-one (for same-sized populations)\n",
+ "# conn_one2one = brainstate.nn.EventOne2One(\n",
+ "# 10, 10, weight=0.5*u.mS\n",
+ "# )\n",
+ "\n",
+ "print(\"Connectivity patterns:\")\n",
+ "print(\"1. FixedProb: Random connections with fixed probability\")\n",
+ "print(\"2. All2All: Every neuron connects to every other\")\n",
+ "print(\"3. One2One: Neuron i connects only to neuron i\")\n",
+ "print(\"\\nYou can also create custom connectivity!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 10: Network Variations\n",
+ "\n",
+ "Experiment with different parameters:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Try these experiments:\n",
+ "\n",
+ "print(\"Experiments to try:\")\n",
+ "print(\"\\n1. Change E/I balance:\")\n",
+ "print(\" - Increase inhibitory weight โ more synchrony\")\n",
+ "print(\" - Decrease inhibitory weight โ more irregular\")\n",
+ "\n",
+ "print(\"\\n2. Change connectivity:\")\n",
+ "print(\" - Higher prob โ more correlated activity\")\n",
+ "print(\" - Lower prob โ more independent neurons\")\n",
+ "\n",
+ "print(\"\\n3. Change time constants:\")\n",
+ "print(\" - Faster inhibition โ sharper oscillations\")\n",
+ "print(\" - Slower inhibition โ smoother dynamics\")\n",
+ "\n",
+ "print(\"\\n4. Change external input:\")\n",
+ "print(\" - Stronger input โ higher firing rates\")\n",
+ "print(\" - Unbalanced input โ population bias\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "In this tutorial, you learned:\n",
+ "\n",
+ "โ
**Projection Architecture**: Comm-Syn-Out three-stage design\n",
+ "\n",
+ "โ
**Connectivity Patterns**: Fixed probability, all-to-all, one-to-one\n",
+ "\n",
+ "โ
**Output Mechanisms**: CUBA (current-based) vs COBA (conductance-based)\n",
+ "\n",
+ "โ
**E-I Networks**: Building balanced excitatory-inhibitory networks\n",
+ "\n",
+ "โ
**Network Simulation**: Running and analyzing network dynamics\n",
+ "\n",
+ "โ
**Visualization**: Raster plots and population firing rates\n",
+ "\n",
+ "## Key Concepts\n",
+ "\n",
+ "1. **Update Order**: Always get spikes BEFORE updating projections\n",
+ "2. **Modular Design**: Each projection component is independent\n",
+ "3. **E-I Balance**: Inhibition counteracts excitation for stable dynamics\n",
+ "4. **Time Constants**: Excitation fast (2ms), inhibition slow (10ms)\n",
+ "\n",
+ "## Next Steps\n",
+ "\n",
+ "- **Tutorial 4**: Learn about [Input and Output](04-input-output.ipynb)\n",
+ "- **Examples**: See [E-I networks in action](../../examples/gallery.rst)\n",
+ "- **Advanced**: Explore [oscillation mechanisms](../../examples/gallery.rst#oscillations-and-rhythms)\n",
+ "\n",
+ "## Exercises\n",
+ "\n",
+ "Try these on your own:\n",
+ "\n",
+ "1. Create a network with 3 populations (E1, E2, I)\n",
+ "2. Implement distance-dependent connectivity\n",
+ "3. Add COBA synapses with different reversal potentials\n",
+ "4. Implement a feedforward network (no recurrence)\n",
+ "5. Analyze inter-spike intervals (ISI) distribution"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs_version3/tutorials/basic/04-input-output.ipynb b/docs_version3/tutorials/basic/04-input-output.ipynb
new file mode 100644
index 00000000..edc510a8
--- /dev/null
+++ b/docs_version3/tutorials/basic/04-input-output.ipynb
@@ -0,0 +1,758 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Tutorial 4: Input and Output\n",
+ "\n",
+ "In this tutorial, you'll learn:\n",
+ "\n",
+ "- Generating input patterns (Poisson, periodic, custom)\n",
+ "- Input encoding strategies\n",
+ "- Using readout layers\n",
+ "- Population coding and decoding\n",
+ "- Recording and analyzing network outputs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import brainpy as bp\n",
+ "import brainstate\n",
+ "import brainunit as u\n",
+ "import braintools\n",
+ "import matplotlib.pyplot as plt\n",
+ "import jax.numpy as jnp\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 1: Understanding Inputs and Outputs\n",
+ "\n",
+ "Neural networks need:\n",
+ "\n",
+ "**Inputs** โ Convert external signals to neural activity\n",
+ "- Current injection\n",
+ "- Spike trains (Poisson, regular)\n",
+ "- Temporal patterns\n",
+ "\n",
+ "**Outputs** โ Extract information from network\n",
+ "- Spike counts\n",
+ "- Population vectors\n",
+ "- Readout layers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 2: Constant Current Input\n",
+ "\n",
+ "The simplest input: constant current."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "brainstate.environ.set(dt=0.1 * u.ms)\n",
+ "\n",
+ "# Create neuron\n",
+ "neuron = bp.LIF(10, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)\n",
+ "brainstate.nn.init_all_states(neuron)\n",
+ "\n",
+ "# Simulate with constant input\n",
+ "duration = 200. * u.ms\n",
+ "times = u.math.arange(0.*u.ms, duration, brainstate.environ.get_dt())\n",
+ "\n",
+ "I_constant = 2.0 * u.nA\n",
+ "spikes = brainstate.transform.for_loop(\n",
+ " lambda t: neuron(I_constant),\n",
+ " times\n",
+ ")\n",
+ "\n",
+ "# Plot\n",
+ "t_idx, n_idx = u.math.where(spikes != 0)\n",
+ "plt.figure(figsize=(10, 4))\n",
+ "plt.scatter(times[t_idx].to_decimal(u.ms), n_idx, s=5, c='black')\n",
+ "plt.xlabel('Time (ms)')\n",
+ "plt.ylabel('Neuron Index')\n",
+ "plt.title('Response to Constant Current Input')\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.show()\n",
+ "\n",
+ "print(f\"Total spikes: {len(t_idx)}\")\n",
+ "print(f\"Average rate: {len(t_idx) / (10 * duration.to_decimal(u.second)):.2f} Hz\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 3: Poisson Spike Trains\n",
+ "\n",
+ "Realistic input: random Poisson spike trains."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def poisson_input(size, rate, dt):\n",
+ " \"\"\"Generate Poisson spike train.\n",
+ " \n",
+ " Args:\n",
+ " size: Number of neurons\n",
+ " rate: Firing rate (Hz)\n",
+ " dt: Time step\n",
+ " \n",
+ " Returns:\n",
+ " Binary spike array\n",
+ " \"\"\"\n",
+ " prob = rate * dt.to_decimal(u.second)\n",
+ " return (brainstate.random.rand(size) < prob).astype(float)\n",
+ "\n",
+ "# Test Poisson input\n",
+ "brainstate.nn.init_all_states(neuron)\n",
+ "rate = 50 * u.Hz\n",
+ "dt = brainstate.environ.get_dt()\n",
+ "\n",
+ "input_spikes_hist = []\n",
+ "output_spikes_hist = []\n",
+ "\n",
+ "for t in times:\n",
+ " # Generate Poisson input\n",
+ " input_spikes = poisson_input(10, rate, dt)\n",
+ " input_spikes_hist.append(input_spikes)\n",
+ " \n",
+ " # Convert spikes to current (simple model)\n",
+ " I_poisson = input_spikes * 5.0 * u.nA\n",
+ " neuron(I_poisson)\n",
+ " output_spikes_hist.append(neuron.get_spike())\n",
+ "\n",
+ "input_spikes_hist = jnp.array(input_spikes_hist)\n",
+ "output_spikes_hist = u.math.asarray(output_spikes_hist)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize input and output\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)\n",
+ "\n",
+ "# Input spikes\n",
+ "t_in, n_in = jnp.where(input_spikes_hist > 0)\n",
+ "axes[0].scatter(times[t_in].to_decimal(u.ms), n_in, s=2, c='blue', alpha=0.5)\n",
+ "axes[0].set_ylabel('Neuron Index')\n",
+ "axes[0].set_title(f'Input: Poisson Spike Train ({rate.to_decimal(u.Hz):.0f} Hz)', fontweight='bold')\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Output spikes\n",
+ "t_out, n_out = u.math.where(output_spikes_hist != 0)\n",
+ "axes[1].scatter(times[t_out].to_decimal(u.ms), n_out, s=2, c='red', alpha=0.5)\n",
+ "axes[1].set_xlabel('Time (ms)')\n",
+ "axes[1].set_ylabel('Neuron Index')\n",
+ "axes[1].set_title('Output: Neuron Response', fontweight='bold')\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(f\"Input spikes: {len(t_in)}\")\n",
+ "print(f\"Output spikes: {len(t_out)}\")\n",
+ "print(f\"Gain: {len(t_out) / len(t_in):.2f}x\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 4: Periodic Input Patterns\n",
+ "\n",
+ "Regular, rhythmic inputs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def periodic_input(t, frequency, amplitude, phase=0):\n",
+ " \"\"\"Generate sinusoidal input current.\n",
+ " \n",
+ " Args:\n",
+ " t: Time\n",
+ " frequency: Oscillation frequency\n",
+ " amplitude: Current amplitude\n",
+ " phase: Phase offset\n",
+ " \"\"\"\n",
+ " omega = 2 * jnp.pi * frequency.to_decimal(u.Hz)\n",
+ " t_sec = t.to_decimal(u.second)\n",
+ " return amplitude * (0.5 + 0.5 * jnp.sin(omega * t_sec + phase))\n",
+ "\n",
+ "# Test periodic input\n",
+ "brainstate.nn.init_all_states(neuron)\n",
+ "freq = 10 * u.Hz\n",
+ "amp = 3.0 * u.nA\n",
+ "\n",
+ "currents_hist = []\n",
+ "spikes_hist = []\n",
+ "\n",
+ "for t in times:\n",
+ " I_periodic = periodic_input(t, freq, amp)\n",
+ " currents_hist.append(I_periodic)\n",
+ " neuron(I_periodic)\n",
+ " spikes_hist.append(neuron.get_spike())\n",
+ "\n",
+ "currents_hist = u.math.asarray(currents_hist)\n",
+ "spikes_hist = u.math.asarray(spikes_hist)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot periodic input and response\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)\n",
+ "\n",
+ "# Input current\n",
+ "axes[0].plot(times.to_decimal(u.ms), currents_hist.to_decimal(u.nA), \n",
+ " linewidth=2, color='blue')\n",
+ "axes[0].set_ylabel('Current (nA)')\n",
+ "axes[0].set_title(f'Input: Periodic Current ({freq})', fontweight='bold')\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Output spikes\n",
+ "t_idx, n_idx = u.math.where(spikes_hist != 0)\n",
+ "axes[1].scatter(times[t_idx].to_decimal(u.ms), n_idx, s=5, c='red', alpha=0.7)\n",
+ "axes[1].set_xlabel('Time (ms)')\n",
+ "axes[1].set_ylabel('Neuron Index')\n",
+ "axes[1].set_title('Output: Phase-Locked Spiking', fontweight='bold')\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"Observation: Neurons fire preferentially during high-current phases\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 5: Rate Coding\n",
+ "\n",
+ "Encode information in firing rates."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def rate_encode(values, max_rate, dt):\n",
+ " \"\"\"Encode values as Poisson spike trains.\n",
+ " \n",
+ " Args:\n",
+ " values: Array of values to encode (0 to 1)\n",
+ " max_rate: Maximum firing rate\n",
+ " dt: Time step\n",
+ " \n",
+ " Returns:\n",
+ " Binary spike array\n",
+ " \"\"\"\n",
+ " rates = values * max_rate.to_decimal(u.Hz)\n",
+ " probs = rates * dt.to_decimal(u.second)\n",
+ " return (brainstate.random.rand(len(values)) < probs).astype(float)\n",
+ "\n",
+ "# Example: encode a sine wave\n",
+ "n_neurons = 10\n",
+ "max_rate = 100 * u.Hz\n",
+ "duration = 500. * u.ms\n",
+ "times = u.math.arange(0.*u.ms, duration, brainstate.environ.get_dt())\n",
+ "\n",
+ "encoded_spikes = []\n",
+ "signal_values = []\n",
+ "\n",
+ "for i, t in enumerate(times):\n",
+ " # Signal to encode (sine wave)\n",
+ " signal = 0.5 + 0.5 * jnp.sin(2 * jnp.pi * 5 * t.to_decimal(u.second))\n",
+ " signal_values.append(signal)\n",
+ " \n",
+ " # Encode as spikes for each neuron\n",
+ " values = jnp.ones(n_neurons) * signal # Same value for all neurons\n",
+ " spikes = rate_encode(values, max_rate, brainstate.environ.get_dt())\n",
+ " encoded_spikes.append(spikes)\n",
+ "\n",
+ "encoded_spikes = jnp.array(encoded_spikes)\n",
+ "signal_values = jnp.array(signal_values)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize rate coding\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)\n",
+ "\n",
+ "# Original signal\n",
+ "axes[0].plot(times.to_decimal(u.ms), signal_values, linewidth=2, color='blue')\n",
+ "axes[0].set_ylabel('Signal Value')\n",
+ "axes[0].set_title('Original Signal (to be encoded)', fontweight='bold')\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Encoded spikes\n",
+ "t_idx, n_idx = jnp.where(encoded_spikes > 0)\n",
+ "axes[1].scatter(times[t_idx].to_decimal(u.ms), n_idx, s=1, c='red', alpha=0.5)\n",
+ "axes[1].set_xlabel('Time (ms)')\n",
+ "axes[1].set_ylabel('Neuron Index')\n",
+ "axes[1].set_title('Rate-Coded Spike Train', fontweight='bold')\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"Higher signal โ higher spike density\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 6: Population Coding\n",
+ "\n",
+ "Multiple neurons encode a single value."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def population_encode(value, n_neurons, pref_values, sigma, max_rate, dt):\n",
+ " \"\"\"Encode value using population code with tuning curves.\n",
+ " \n",
+ " Args:\n",
+ " value: Value to encode (0 to 1)\n",
+ " n_neurons: Number of neurons\n",
+ " pref_values: Preferred values for each neuron\n",
+ " sigma: Tuning width\n",
+ " max_rate: Maximum firing rate\n",
+ " dt: Time step\n",
+ " \"\"\"\n",
+ " # Tuning curves: Gaussian around preferred value\n",
+ " responses = jnp.exp(-0.5 * ((value - pref_values) / sigma)**2)\n",
+ " rates = responses * max_rate.to_decimal(u.Hz)\n",
+ " probs = rates * dt.to_decimal(u.second)\n",
+ " return (brainstate.random.rand(n_neurons) < probs).astype(float)\n",
+ "\n",
+ "# Setup population\n",
+ "n_pop = 20\n",
+ "pref_values = jnp.linspace(0, 1, n_pop) # Evenly spaced preferences\n",
+ "sigma = 0.2\n",
+ "max_rate = 100 * u.Hz\n",
+ "\n",
+ "# Encode a slowly changing value\n",
+ "duration = 500. * u.ms\n",
+ "times = u.math.arange(0.*u.ms, duration, brainstate.environ.get_dt())\n",
+ "\n",
+ "pop_spikes = []\n",
+ "true_values = []\n",
+ "\n",
+ "for i, t in enumerate(times):\n",
+ " # Value changes over time\n",
+ " value = 0.5 + 0.3 * jnp.sin(2 * jnp.pi * 2 * t.to_decimal(u.second))\n",
+ " true_values.append(value)\n",
+ " \n",
+ " # Population encoding\n",
+ " spikes = population_encode(value, n_pop, pref_values, sigma, max_rate, \n",
+ " brainstate.environ.get_dt())\n",
+ " pop_spikes.append(spikes)\n",
+ "\n",
+ "pop_spikes = jnp.array(pop_spikes)\n",
+ "true_values = jnp.array(true_values)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize population coding\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)\n",
+ "\n",
+ "# True value\n",
+ "axes[0].plot(times.to_decimal(u.ms), true_values, linewidth=2, color='blue')\n",
+ "axes[0].set_ylabel('Encoded Value')\n",
+ "axes[0].set_title('True Value (to be encoded)', fontweight='bold')\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Population spikes\n",
+ "t_idx, n_idx = jnp.where(pop_spikes > 0)\n",
+ "axes[1].scatter(times[t_idx].to_decimal(u.ms), n_idx, s=2, c='red', alpha=0.5)\n",
+ "axes[1].set_xlabel('Time (ms)')\n",
+ "axes[1].set_ylabel('Neuron Index (Preference)')\n",
+ "axes[1].set_title('Population Code: Activity Follows Value', fontweight='bold')\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"Peak activity shifts with encoded value\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 7: Population Decoding\n",
+ "\n",
+ "Extract the encoded value from population activity."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def population_decode(spike_counts, pref_values):\n",
+ " \"\"\"Decode value from population activity.\n",
+ " \n",
+ " Args:\n",
+ " spike_counts: Number of spikes per neuron\n",
+ " pref_values: Preferred values of neurons\n",
+ " \n",
+ " Returns:\n",
+ " Decoded value (population vector)\n",
+ " \"\"\"\n",
+ " # Population vector: weighted average\n",
+ " total_activity = jnp.sum(spike_counts)\n",
+ " if total_activity > 0:\n",
+ " decoded = jnp.sum(spike_counts * pref_values) / total_activity\n",
+ " return decoded\n",
+ " else:\n",
+ " return 0.5 # Default\n",
+ "\n",
+ "# Decode the population activity\n",
+ "window_size = 50 # ms\n",
+ "window_steps = int(window_size / brainstate.environ.get_dt().to_decimal(u.ms))\n",
+ "\n",
+ "decoded_values = []\n",
+ "decode_times = []\n",
+ "\n",
+ "for i in range(0, len(times) - window_steps, window_steps // 2):\n",
+ " # Count spikes in window\n",
+ " window_spikes = pop_spikes[i:i+window_steps]\n",
+ " spike_counts = jnp.sum(window_spikes, axis=0)\n",
+ " \n",
+ " # Decode\n",
+ " decoded = population_decode(spike_counts, pref_values)\n",
+ " decoded_values.append(decoded)\n",
+ " decode_times.append(times[i + window_steps//2])\n",
+ "\n",
+ "decoded_values = jnp.array(decoded_values)\n",
+ "decode_times = u.math.asarray(decode_times)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Compare true and decoded values\n",
+ "plt.figure(figsize=(12, 5))\n",
+ "plt.plot(times.to_decimal(u.ms), true_values, linewidth=2, \n",
+ " label='True Value', color='blue', alpha=0.7)\n",
+ "plt.plot(decode_times.to_decimal(u.ms), decoded_values, linewidth=2, \n",
+ " label='Decoded Value', color='red', linestyle='--', alpha=0.7)\n",
+ "plt.xlabel('Time (ms)', fontsize=12)\n",
+ "plt.ylabel('Value', fontsize=12)\n",
+ "plt.title('Population Decoding: True vs Decoded Values', fontsize=14, fontweight='bold')\n",
+ "plt.legend(fontsize=11)\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "# Calculate decoding error\n",
+ "# Interpolate true values at decode times\n",
+ "true_at_decode = jnp.interp(\n",
+ " decode_times.to_decimal(u.ms),\n",
+ " times.to_decimal(u.ms),\n",
+ " true_values\n",
+ ")\n",
+ "error = jnp.abs(decoded_values - true_at_decode)\n",
+ "print(f\"Mean decoding error: {jnp.mean(error):.4f}\")\n",
+ "print(f\"Max decoding error: {jnp.max(error):.4f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 8: Readout Layers\n",
+ "\n",
+ "Use a readout layer to extract output."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create network with readout\n",
+ "class NetworkWithReadout(brainstate.nn.Module):\n",
+ " def __init__(self, n_input, n_hidden, n_output):\n",
+ " super().__init__()\n",
+ " \n",
+ " # Hidden layer (recurrent LIF neurons)\n",
+ " self.hidden = bp.LIF(n_hidden, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)\n",
+ " \n",
+ " # Readout layer\n",
+ " self.readout = bp.Readout(\n",
+ " n_hidden, n_output,\n",
+ " weight_initializer=braintools.init.KaimingNormal()\n",
+ " )\n",
+ " \n",
+ " def update(self, spike_input):\n",
+ " # Convert input spikes to current\n",
+ " I_input = spike_input * 5.0 * u.nA\n",
+ " \n",
+ " # Update hidden neurons\n",
+ " self.hidden(I_input)\n",
+ " spikes = self.hidden.get_spike()\n",
+ " \n",
+ " # Readout\n",
+ " output = self.readout(spikes)\n",
+ " \n",
+ " return output, spikes\n",
+ "\n",
+ "# Create network\n",
+ "net = NetworkWithReadout(n_input=10, n_hidden=50, n_output=2)\n",
+ "brainstate.nn.init_all_states(net)\n",
+ "\n",
+ "print(\"Network with readout layer created\")\n",
+ "print(f\"Hidden neurons: {50}\")\n",
+ "print(f\"Output dimensions: {2}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Test readout\n",
+ "duration = 200. * u.ms\n",
+ "times = u.math.arange(0.*u.ms, duration, brainstate.environ.get_dt())\n",
+ "\n",
+ "outputs_hist = []\n",
+ "spikes_hist = []\n",
+ "\n",
+ "for t in times:\n",
+ " # Generate Poisson input\n",
+ " input_spikes = poisson_input(10, 50*u.Hz, brainstate.environ.get_dt())\n",
+ " \n",
+ " # Network update\n",
+ " output, spikes = net.update(input_spikes)\n",
+ " outputs_hist.append(output)\n",
+ " spikes_hist.append(spikes)\n",
+ "\n",
+ "outputs_hist = jnp.array(outputs_hist)\n",
+ "spikes_hist = u.math.asarray(spikes_hist)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize readout\n",
+ "fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)\n",
+ "\n",
+ "# Hidden layer activity\n",
+ "t_idx, n_idx = u.math.where(spikes_hist != 0)\n",
+ "axes[0].scatter(times[t_idx].to_decimal(u.ms), n_idx, s=1, c='blue', alpha=0.5)\n",
+ "axes[0].set_ylabel('Neuron Index')\n",
+ "axes[0].set_title('Hidden Layer Spikes', fontweight='bold')\n",
+ "axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ "# Readout outputs\n",
+ "axes[1].plot(times.to_decimal(u.ms), outputs_hist[:, 0], \n",
+ " linewidth=2, label='Output 1', alpha=0.7)\n",
+ "axes[1].plot(times.to_decimal(u.ms), outputs_hist[:, 1], \n",
+ " linewidth=2, label='Output 2', alpha=0.7)\n",
+ "axes[1].set_xlabel('Time (ms)')\n",
+ "axes[1].set_ylabel('Readout Value')\n",
+ "axes[1].set_title('Readout Layer Outputs', fontweight='bold')\n",
+ "axes[1].legend()\n",
+ "axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n",
+ "\n",
+ "print(\"Readout layer converts spikes to continuous values\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 9: Recording Network States\n",
+ "\n",
+ "Record variables during simulation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Manual recording example\n",
+ "neuron = bp.LIF(5, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)\n",
+ "brainstate.nn.init_all_states(neuron)\n",
+ "\n",
+ "duration = 100. * u.ms\n",
+ "times = u.math.arange(0.*u.ms, duration, brainstate.environ.get_dt())\n",
+ "\n",
+ "# Preallocate recording arrays\n",
+ "n_steps = len(times)\n",
+ "V_hist = []\n",
+ "spike_hist = []\n",
+ "\n",
+ "for t in times:\n",
+ " neuron(2.0 * u.nA)\n",
+ " \n",
+ " # Record states\n",
+ " V_hist.append(neuron.V.value.copy())\n",
+ " spike_hist.append(neuron.get_spike().copy())\n",
+ "\n",
+ "V_hist = u.math.asarray(V_hist) # Shape: (time, neurons)\n",
+ "spike_hist = u.math.asarray(spike_hist)\n",
+ "\n",
+ "print(f\"Recorded {n_steps} time steps\")\n",
+ "print(f\"Voltage history shape: {V_hist.shape}\")\n",
+ "print(f\"Spike history shape: {spike_hist.shape}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot recorded states\n",
+ "plt.figure(figsize=(12, 6))\n",
+ "\n",
+ "# Plot voltage traces\n",
+ "for i in range(5):\n",
+ " V_trace = V_hist[:, i]\n",
+ " # Mark spikes\n",
+ " spike_times = times[spike_hist[:, i] > 0]\n",
+ " V_with_spikes = V_trace.copy()\n",
+ " V_with_spikes = V_with_spikes.to_decimal(u.mV)\n",
+ " \n",
+ " plt.plot(times.to_decimal(u.ms), V_with_spikes, \n",
+ " linewidth=1.5, alpha=0.7, label=f'Neuron {i}')\n",
+ "\n",
+ "plt.axhline(y=-50, color='r', linestyle='--', alpha=0.5, label='Threshold')\n",
+ "plt.xlabel('Time (ms)', fontsize=12)\n",
+ "plt.ylabel('Membrane Potential (mV)', fontsize=12)\n",
+ "plt.title('Recorded Voltage Traces', fontsize=14, fontweight='bold')\n",
+ "plt.legend(loc='upper right', fontsize=9)\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "In this tutorial, you learned:\n",
+ "\n",
+ "โ
**Input Generation**: Constant, Poisson, periodic patterns\n",
+ "\n",
+ "โ
**Rate Coding**: Encoding values as firing rates\n",
+ "\n",
+ "โ
**Population Coding**: Multiple neurons encode single values\n",
+ "\n",
+ "โ
**Population Decoding**: Extract values from spike trains\n",
+ "\n",
+ "โ
**Readout Layers**: Convert spikes to continuous outputs\n",
+ "\n",
+ "โ
**Recording States**: Track network variables over time\n",
+ "\n",
+ "## Key Concepts\n",
+ "\n",
+ "1. **Input Encoding**: Convert signals โ spike patterns\n",
+ "2. **Population Codes**: Distributed representation across neurons\n",
+ "3. **Decoding**: Extract information from population activity\n",
+ "4. **Readout**: Linear combination of spike counts\n",
+ "\n",
+ "## Next Steps\n",
+ "\n",
+ "- **Tutorial 5**: Learn [SNN training](../advanced/05-snn-training.ipynb)\n",
+ "- **Examples**: See [trained networks](../../examples/gallery.rst#snn-training)\n",
+ "- **Advanced**: Explore [reservoir computing](../../examples/gallery.rst)\n",
+ "\n",
+ "## Exercises\n",
+ "\n",
+ "1. Implement temporal coding (first-spike latency)\n",
+ "2. Create a 2D population code (e.g., for position)\n",
+ "3. Build a classifier using readout layer\n",
+ "4. Compare different decoding methods (vector, maximum likelihood)\n",
+ "5. Implement sparse coding with inhibition"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs_version3/tutorials/index.rst b/docs_version3/tutorials/index.rst
new file mode 100644
index 00000000..3c7e9013
--- /dev/null
+++ b/docs_version3/tutorials/index.rst
@@ -0,0 +1,362 @@
+Tutorials
+=========
+
+Welcome to the BrainPy 3.0 tutorials! These step-by-step guides will help you master computational neuroscience modeling with BrainPy.
+
+Learning Path
+-------------
+
+We recommend following the tutorials in order:
+
+1. **Basic Tutorials**: Learn core components (neurons, synapses, networks)
+2. **Advanced Tutorials**: Master complex topics (training, plasticity, large-scale simulations)
+3. **Specialized Topics**: Explore specific applications and techniques
+
+Basic Tutorials
+---------------
+
+Start here to learn the fundamentals of BrainPy 3.0.
+
+Tutorial 1: LIF Neuron Basics
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Learn the most important spiking neuron model.
+
+**Topics covered:**
+
+- Creating and configuring LIF neurons
+- Simulating neuron dynamics
+- Computing F-I curves
+- Hard vs soft reset modes
+- Working with neuron populations
+- Parameter effects on behavior
+
+:doc:`Go to Tutorial 1 `
+
+**Prerequisites:** None (start here!)
+
+**Duration:** ~30 minutes
+
+---
+
+Tutorial 2: Synapse Models
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Understand temporal filtering and synaptic dynamics.
+
+**Topics covered:**
+
+- Exponential synapses
+- Alpha synapses
+- AMPA and GABA receptors
+- Comparing synapse models
+- Custom synapse creation
+
+:doc:`Go to Tutorial 2 `
+
+**Prerequisites:** Tutorial 1
+
+**Duration:** ~25 minutes
+
+---
+
+Tutorial 3: Network Connections
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Build connected neural networks.
+
+**Topics covered:**
+
+- Projection architecture (Comm-Syn-Out)
+- Fixed probability connectivity
+- CUBA vs COBA synapses
+- E-I balanced networks
+- Network visualization
+
+:doc:`Go to Tutorial 3 `
+
+**Prerequisites:** Tutorials 1-2
+
+**Duration:** ~35 minutes
+
+---
+
+Tutorial 4: Input and Output
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Generate inputs and process network outputs.
+
+**Topics covered:**
+
+- Poisson spike trains
+- Periodic inputs
+- Custom input patterns
+- Readout layers
+- Population coding
+
+:doc:`Go to Tutorial 4 `
+
+**Prerequisites:** Tutorials 1-3
+
+**Duration:** ~20 minutes
+
+Advanced Tutorials
+------------------
+
+Dive deeper into sophisticated modeling techniques.
+
+Tutorial 5: Training Spiking Neural Networks
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Learn gradient-based training for SNNs.
+
+**Topics covered:**
+
+- Surrogate gradient methods
+- BPTT for SNNs
+- Loss functions for spikes
+- Optimizers and learning rates
+- Classification tasks
+- Training loops
+
+:doc:`Go to Tutorial 5 `
+
+**Prerequisites:** Basic Tutorials 1-4
+
+**Duration:** ~45 minutes
+
+---
+
+Tutorial 6: Synaptic Plasticity
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Implement learning rules and adaptation.
+
+**Topics covered:**
+
+- Short-term plasticity (STP)
+- Depression and facilitation
+- STDP principles
+- Homeostatic mechanisms
+- Network learning
+
+:doc:`Go to Tutorial 6 `
+
+**Prerequisites:** Basic Tutorials, Tutorial 5
+
+**Duration:** ~40 minutes
+
+---
+
+Tutorial 7: Large-Scale Simulations
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Scale up your models efficiently.
+
+**Topics covered:**
+
+- Memory optimization
+- JIT compilation best practices
+- Batching strategies
+- GPU/TPU acceleration
+- Performance profiling
+- Sparse connectivity
+
+:doc:`Go to Tutorial 7 `
+
+**Prerequisites:** All Basic Tutorials
+
+**Duration:** ~35 minutes
+
+Specialized Topics
+------------------
+
+Application-specific tutorials for advanced users.
+
+Brain Oscillations (Coming Soon)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Model and analyze network rhythms.
+
+**Topics:** Gamma oscillations, synchrony, oscillation mechanisms
+
+**Prerequisites:** Advanced
+
+---
+
+Decision-Making Networks (Coming Soon)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Build cognitive computation models.
+
+**Topics:** Attractor dynamics, competition, working memory
+
+**Prerequisites:** Advanced
+
+---
+
+Reservoir Computing (Coming Soon)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Use untrained recurrent networks for computation.
+
+**Topics:** Echo state networks, liquid state machines, readout training
+
+**Prerequisites:** Advanced
+
+Tutorial Format
+---------------
+
+Each tutorial includes:
+
+โ
**Clear learning objectives**: Know what you'll learn
+
+โ
**Runnable code**: All examples work out of the box
+
+โ
**Visualizations**: See your models in action
+
+โ
**Explanations**: Understand the "why" behind the code
+
+โ
**Exercises**: Practice what you've learned
+
+โ
**References**: Links to papers and further reading
+
+How to Use These Tutorials
+---------------------------
+
+Interactive (Recommended)
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Download and run the Jupyter notebooks:
+
+.. code-block:: bash
+
+ git clone https://github.com/brainpy/BrainPy.git
+ cd BrainPy/docs_version3/tutorials/basic
+ jupyter notebook 01-lif-neuron.ipynb
+
+Read-Only
+~~~~~~~~~
+
+Browse the tutorials online in the documentation.
+
+Binder
+~~~~~~
+
+Run tutorials in your browser without installation:
+
+.. image:: https://mybinder.org/badge_logo.svg
+ :target: https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main
+ :alt: Binder
+
+Prerequisites
+-------------
+
+Before starting the tutorials, ensure you have:
+
+โ
Python 3.10 or later
+
+โ
BrainPy 3.0 installed (see :doc:`../quickstart/installation`)
+
+โ
Basic Python knowledge (functions, classes, NumPy)
+
+โ
Basic neuroscience concepts (optional but helpful)
+
+Recommended setup:
+
+.. code-block:: bash
+
+ pip install brainpy[cpu] matplotlib jupyter -U
+
+Additional Resources
+--------------------
+
+**For Quick Start**
+ See the :doc:`../quickstart/5min-tutorial` for a rapid introduction
+
+**For Concepts**
+ Read :doc:`../core-concepts/architecture` for architectural understanding
+
+**For Examples**
+ Browse :doc:`../examples/gallery` for complete, real-world models
+
+**For Reference**
+ Consult the :doc:`../apis` for detailed API documentation
+
+Getting Help
+------------
+
+If you get stuck:
+
+- Check the :doc:`FAQ <../migration/migration-guide>` (Migration Guide has troubleshooting)
+- Search `GitHub Issues `_
+- Ask on GitHub Discussions
+- Review the `brainstate documentation `_
+
+Tutorial Roadmap
+----------------
+
+**Currently Available:**
+
+**Basic Tutorials:**
+- โ
Tutorial 1: LIF Neuron Basics
+- โ
Tutorial 2: Synapse Models
+- โ
Tutorial 3: Network Connections
+- โ
Tutorial 4: Input and Output
+
+**Advanced Tutorials:**
+- โ
Tutorial 5: Training SNNs
+- โ
Tutorial 6: Synaptic Plasticity
+- โ
Tutorial 7: Large-Scale Simulations
+
+**Future Plans:**
+
+- Brain Oscillations
+- Decision-Making Networks
+- Reservoir Computing
+- Custom Components
+- Advanced Training Techniques
+
+We're actively developing new tutorials. Star the repository to stay updated!
+
+Contributing
+------------
+
+Want to contribute a tutorial? We'd love your help!
+
+1. Check the `contribution guidelines `_
+2. Open an issue to discuss your tutorial idea
+3. Submit a pull request
+
+Good tutorial topics:
+
+- Specific neuron models (Izhikevich, AdEx, etc.)
+- Network architectures (attractor networks, etc.)
+- Analysis techniques (spike train analysis, etc.)
+- Applications (sensory processing, motor control, etc.)
+
+Let's Start!
+------------
+
+Ready to begin? Start with :doc:`Tutorial 1: LIF Neuron Basics `!
+
+
+.. toctree::
+ :hidden:
+ :maxdepth: 1
+ :caption: Basic Tutorials
+
+ basic/01-lif-neuron.ipynb
+ basic/02-synapse-models.ipynb
+ basic/03-network-connections.ipynb
+ basic/04-input-output.ipynb
+
+
+.. toctree::
+ :hidden:
+ :maxdepth: 1
+ :caption: Advanced Tutorials
+
+ advanced/05-snn-training.ipynb
+ advanced/06-synaptic-plasticity.ipynb
+ advanced/07-large-scale-simulations.ipynb
diff --git a/examples/102_EI_net_1996.py b/examples_version3/102_EI_net_1996.py
similarity index 100%
rename from examples/102_EI_net_1996.py
rename to examples_version3/102_EI_net_1996.py
diff --git a/examples/103_COBA_2005.py b/examples_version3/103_COBA_2005.py
similarity index 100%
rename from examples/103_COBA_2005.py
rename to examples_version3/103_COBA_2005.py
diff --git a/examples/104_CUBA_2005.py b/examples_version3/104_CUBA_2005.py
similarity index 100%
rename from examples/104_CUBA_2005.py
rename to examples_version3/104_CUBA_2005.py
diff --git a/examples/104_CUBA_2005_version2.py b/examples_version3/104_CUBA_2005_version2.py
similarity index 100%
rename from examples/104_CUBA_2005_version2.py
rename to examples_version3/104_CUBA_2005_version2.py
diff --git a/examples/106_COBA_HH_2007.py b/examples_version3/106_COBA_HH_2007.py
similarity index 100%
rename from examples/106_COBA_HH_2007.py
rename to examples_version3/106_COBA_HH_2007.py
diff --git a/examples/107_gamma_oscillation_1996.py b/examples_version3/107_gamma_oscillation_1996.py
similarity index 100%
rename from examples/107_gamma_oscillation_1996.py
rename to examples_version3/107_gamma_oscillation_1996.py
diff --git a/examples/108_synfire_chains_199.py b/examples_version3/108_synfire_chains_199.py
similarity index 100%
rename from examples/108_synfire_chains_199.py
rename to examples_version3/108_synfire_chains_199.py
diff --git a/examples/109_fast_global_oscillation.py b/examples_version3/109_fast_global_oscillation.py
similarity index 100%
rename from examples/109_fast_global_oscillation.py
rename to examples_version3/109_fast_global_oscillation.py
diff --git a/examples/110_Susin_Destexhe_2021_gamma_oscillation_AI.py b/examples_version3/110_Susin_Destexhe_2021_gamma_oscillation_AI.py
similarity index 100%
rename from examples/110_Susin_Destexhe_2021_gamma_oscillation_AI.py
rename to examples_version3/110_Susin_Destexhe_2021_gamma_oscillation_AI.py
diff --git a/examples/111_Susin_Destexhe_2021_gamma_oscillation_CHING.py b/examples_version3/111_Susin_Destexhe_2021_gamma_oscillation_CHING.py
similarity index 100%
rename from examples/111_Susin_Destexhe_2021_gamma_oscillation_CHING.py
rename to examples_version3/111_Susin_Destexhe_2021_gamma_oscillation_CHING.py
diff --git a/examples/112_Susin_Destexhe_2021_gamma_oscillation_ING.py b/examples_version3/112_Susin_Destexhe_2021_gamma_oscillation_ING.py
similarity index 100%
rename from examples/112_Susin_Destexhe_2021_gamma_oscillation_ING.py
rename to examples_version3/112_Susin_Destexhe_2021_gamma_oscillation_ING.py
diff --git a/examples/113_Susin_Destexhe_2021_gamma_oscillation_PING.py b/examples_version3/113_Susin_Destexhe_2021_gamma_oscillation_PING.py
similarity index 100%
rename from examples/113_Susin_Destexhe_2021_gamma_oscillation_PING.py
rename to examples_version3/113_Susin_Destexhe_2021_gamma_oscillation_PING.py
diff --git a/examples/200_surrogate_grad_lif.py b/examples_version3/200_surrogate_grad_lif.py
similarity index 100%
rename from examples/200_surrogate_grad_lif.py
rename to examples_version3/200_surrogate_grad_lif.py
diff --git a/examples/201_surrogate_grad_lif_fashion_mnist.py b/examples_version3/201_surrogate_grad_lif_fashion_mnist.py
similarity index 100%
rename from examples/201_surrogate_grad_lif_fashion_mnist.py
rename to examples_version3/201_surrogate_grad_lif_fashion_mnist.py
diff --git a/examples/202_mnist_lif_readout.py b/examples_version3/202_mnist_lif_readout.py
similarity index 100%
rename from examples/202_mnist_lif_readout.py
rename to examples_version3/202_mnist_lif_readout.py
diff --git a/examples/Susin_Destexhe_2021_gamma_oscillation.py b/examples_version3/Susin_Destexhe_2021_gamma_oscillation.py
similarity index 100%
rename from examples/Susin_Destexhe_2021_gamma_oscillation.py
rename to examples_version3/Susin_Destexhe_2021_gamma_oscillation.py