# Brainstate Comprehensive Tutorial

This notebook curates every script under `examples/` into a single, end-to-end guide. Each section summarises the original program, explains the key Brainstate features it showcases, and provides runnable notebook snippets or references back to the full script for deeper study.

## How to Use This Notebook

- Install the optional dependencies that individual sections require (see each notes block).
- Toggle the runtime flags in the next cell to control whether expensive simulations or training loops execute.
- When a section only references the original script, open the file path locally for the full code listing.

In [None]:
from pathlib import Path

RUN_TRAINING = False       # Enable to execute longer optimisation loops (001, 002, 003, 300)
RUN_NEURO_SIM = False      # Enable biophysical simulations (100-series, braincell, brainpy examples)
RUN_BRAINSCALE = False     # Enable large brainscale training demos (203, 301)
DATA_DIR = Path('notebook_data')
DATA_DIR.mkdir(exist_ok=True)
print('Flags set. Toggle to run heavier workloads as needed.')


## 0xx — Deep Neural Network Foundations

The 0xx scripts demonstrate core neural-network workflows on top of Brainstate's stateful module system. They read like Flax or Haiku examples but lean on `brainstate.graph` and `brainstate.transform` utilities.

### Example 001 — Functional API Regression (`examples/001_functional_api.py`)

**Highlights**
- Splitting a stateful graph into parameter and auxiliary collections with `treefy_split`.
- Re-merging the graph inside JAX-transformed training/evaluation steps.
- Managing non-trainable counters via custom `brainstate.State` subclasses.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import brainstate

rng = np.random.default_rng(0)
X = np.linspace(0.0, 1.0, 100)[:, None]
Y = 0.8 * X ** 2 + 0.1 + rng.normal(0.0, 0.1, size=X.shape)

def dataset(batch_size: int):
    while True:
        idx = rng.choice(len(X), size=batch_size, replace=True)
        yield X[idx], Y[idx]

class Linear(brainstate.nn.Module):
    def __init__(self, din: int, dout: int):
        super().__init__()
        self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
        self.b = brainstate.ParamState(jnp.zeros((dout,)))

    def __call__(self, x):
        return x @ self.w.value + self.b.value

class CallCount(brainstate.State):
    pass

class RegressionMLP(brainstate.graph.Node):
    def __init__(self, din: int, dhidden: int, dout: int):
        self.count = CallCount(jnp.array(0))
        self.linear1 = Linear(din, dhidden)
        self.linear2 = Linear(dhidden, dout)

    def __call__(self, x):
        self.count.value += 1
        x = self.linear1(x)
        x = jax.nn.relu(x)
        return self.linear2(x)

graphdef, params, counts = brainstate.graph.treefy_split(
    RegressionMLP(1, 32, 1), brainstate.ParamState, CallCount
)

@jax.jit
def train_step(param_states, count_states, batch):
    x, y = batch

    def loss_fn(pstates):
        model = brainstate.graph.treefy_merge(graphdef, pstates, count_states)
        y_pred = model(x)
        new_counts = brainstate.graph.treefy_states(model, CallCount)
        loss = jnp.mean((y - y_pred) ** 2)
        return loss, new_counts

    grads, count_states = jax.grad(loss_fn, has_aux=True)(param_states)
    param_states = jax.tree.map(lambda w, g: w - 0.1 * g, param_states, grads)
    return param_states, count_states

@jax.jit
def eval_step(param_states, count_states, batch):
    x, y = batch
    model = brainstate.graph.treefy_merge(graphdef, param_states, count_states)
    y_pred = model(x)
    return {'loss': jnp.mean((y - y_pred) ** 2)}

if RUN_TRAINING:
    for step, batch in zip(range(600), dataset(32)):
        params, counts = train_step(params, counts, batch)
        if step % 200 == 0:
            logs = eval_step(params, counts, (X, Y))
            print(f"step {step}: loss={logs['loss']:.4f}")
else:
    print('Training skipped (set RUN_TRAINING=True to run SGD).')

model = brainstate.graph.treefy_merge(graphdef, params, counts)
print(f"call count state: {model.count.value}")
y_pred = model(X)
plt.scatter(X, Y, color='steelblue', label='data')
plt.plot(X, y_pred, color='black', label='fit')
plt.legend()
plt.show()


### Example 002 — Lifted Transforms (`examples/002_lifted_transforms.py`)

**Highlights**
- Decorating training/evaluation with `@brainstate.transform.jit`.
- Using `brainstate.transform.grad` to obtain gradients that respect state collections.
- Managing optimiser state with `braintools.optim.SGD`.

In [None]:
import braintools

class LiftedMLP(brainstate.nn.Module):
    def __init__(self, din: int, dhidden: int, dout: int):
        super().__init__()
        self.count = CallCount(jnp.array(0))
        self.linear1 = Linear(din, dhidden)
        self.linear2 = Linear(dhidden, dout)

    def __call__(self, x):
        self.count.value += 1
        x = self.linear1(x)
        x = jax.nn.relu(x)
        return self.linear2(x)

model2 = LiftedMLP(1, 32, 1)
optimizer = braintools.optim.SGD(1e-3)
optimizer.register_trainable_weights(model2.states(brainstate.ParamState))

@brainstate.transform.jit
def train_step_lifted(batch):
    x, y = batch

    def loss_fn():
        preds = model2(x)
        return jnp.mean((y - preds) ** 2)

    grads = brainstate.transform.grad(
        loss_fn, optimizer.param_states.to_pytree()
    )()
    optimizer.update(grads)

@brainstate.transform.jit
def eval_step_lifted(batch):
    x, y = batch
    preds = model2(x)
    return {'loss': jnp.mean((y - preds) ** 2)}

if RUN_TRAINING:
    for step, batch in zip(range(600), dataset(32)):
        train_step_lifted(batch)
        if step % 200 == 0:
            logs = eval_step_lifted((X, Y))
            print(f"[lifted] step {step}: loss={logs['loss']:.4f}")
else:
    print('Lifted training skipped (set RUN_TRAINING=True).')

print(f"lifted forward count: {model2.count.value}")
plt.scatter(X, Y, color='steelblue', label='data')
plt.plot(X, model2(X), color='darkorange', label='lifted fit')
plt.legend()
plt.show()


### Example 003 — Variational Auto-Encoder (`examples/003_vae.py`)

**Highlights**
- Combining multiple `brainstate.nn.Module` components to build an encoder/decoder pair.
- Storing auxiliary losses (KL divergence) in custom `brainstate.State` nodes.
- Mixing Brainstate transforms with Optax losses.

> **Note:** Requires `datasets`, `optax`, `matplotlib`, and several minutes of training. Leave `RUN_TRAINING=False` to skip execution.

In [None]:
import typing as tp
try:
    import optax
    from datasets import load_dataset
except ModuleNotFoundError:
    optax = None
    load_dataset = None

latent_size = 32
image_shape: tp.Sequence[int] = (28, 28)

class Loss(brainstate.State):
    pass

class Encoder(brainstate.nn.Module):
    def __init__(self, din: int, dmid: int, dout: int):
        super().__init__()
        self.linear1 = brainstate.nn.Linear(din, dmid)
        self.linear_mean = brainstate.nn.Linear(dmid, dout)
        self.linear_std = brainstate.nn.Linear(dmid, dout)

    def __call__(self, x: jax.Array) -> jax.Array:
        x = x.reshape((x.shape[0], -1))
        x = self.linear1(x)
        x = jax.nn.relu(x)
        mean = self.linear_mean(x)
        std = jnp.exp(self.linear_std(x))
        loss = jnp.mean(0.5 * jnp.mean(-jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
        self.kl_loss = Loss(loss)
        z = mean + std * brainstate.random.normal(size=mean.shape)
        return z

class Decoder(brainstate.nn.Module):
    def __init__(self, din: int, dmid: int, dout: int):
        super().__init__()
        self.linear1 = brainstate.nn.Linear(din, dmid)
        self.linear2 = brainstate.nn.Linear(dmid, dout)

    def __call__(self, z: jax.Array) -> jax.Array:
        z = self.linear1(z)
        z = jax.nn.relu(z)
        return self.linear2(z)

class VAE(brainstate.nn.Module):
    def __init__(self, din: int, hidden_size: int, latent_size: int, output_shape: tp.Sequence[int]):
        super().__init__()
        self.output_shape = output_shape
        self.encoder = Encoder(din, hidden_size, latent_size)
        self.decoder = Decoder(latent_size, hidden_size, int(np.prod(output_shape)))

    def __call__(self, x: jax.Array) -> jax.Array:
        logits = self.decoder(self.encoder(x))
        return jnp.reshape(logits, (-1, *self.output_shape))

    def generate(self, z):
        logits = self.decoder(z)
        return jax.nn.sigmoid(logits.reshape((-1, *self.output_shape)))

if RUN_TRAINING and optax is not None and load_dataset is not None:
    dataset = load_dataset('mnist')
    X_train = (np.stack(dataset['train']['image']) > 0).astype(jnp.float32)
    optimizer = braintools.optim.Adam(1e-3)
    model = VAE(np.prod(image_shape), 256, latent_size, image_shape)
    optimizer.register_trainable_weights(model.states(brainstate.ParamState))

    @brainstate.transform.jit
    def train_step(x: jax.Array):
        def loss_fn():
            logits = model(x)
            losses = brainstate.graph.treefy_states(model, Loss)
            kl_loss = sum(jax.tree_util.tree_leaves(losses), 0.0)
            recon_loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, x))
            return recon_loss + 0.1 * kl_loss
        grads, loss = brainstate.transform.grad(
            loss_fn, optimizer.param_states.to_pytree(), return_value=True)()
        optimizer.update(grads)
        return loss

    for epoch in range(3):
        batch_idx = np.random.randint(0, len(X_train), size=(64,))
        loss = train_step(X_train[batch_idx])
        print(f"epoch {epoch}: loss={float(loss):.4f}")
else:
    print('VAE training skipped (install optax/datasets and set RUN_TRAINING=True).')


### Example 004 — Scan Over Layers (`examples/004_scan_over_layers.py`)

**Highlights**
- Building repeated submodules and iterating through them manually to mimic `jax.lax.scan`.
- Combining dense, batch-norm, and dropout primitives within a single block.

This example is fast to run; no additional dependencies are required.

In [None]:
class Block(brainstate.nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.linear = brainstate.nn.Linear(dim, dim)
        self.bn = brainstate.nn.BatchNorm0d([dim])
        self.dropout = brainstate.nn.Dropout(0.5)

    def __call__(self, x: jax.Array):
        return jax.nn.gelu(self.dropout(self.bn(self.linear(x))))

class ScanMLP(brainstate.nn.Module):
    def __init__(self, dim: int, *, n_layers: int):
        super().__init__()
        self.layers = [Block(dim) for _ in range(n_layers)]

    def __call__(self, x: jax.Array) -> jax.Array:
        for layer in self.layers:
            x = layer(x)
        return x

with brainstate.environ.context(fit=True):
    model = ScanMLP(10, n_layers=5)
    y = model(jnp.ones((3, 10)))

print(jax.tree.map(jnp.shape, brainstate.graph.treefy_states(model)))
print('Output shape:', y.shape)


### Example 005 — Saving and Loading Checkpoints (`examples/005_save_load_checkpoints.py`)

**Highlights**
- Capturing the complete state tree of a model.
- Restoring into an abstractly initialised model with Orbax.

In [None]:
import os
from tempfile import TemporaryDirectory
import orbax.checkpoint as orbax

class CheckpointMLP(brainstate.nn.Module):
    def __init__(self, din: int, dmid: int, dout: int):
        super().__init__()
        self.dense1 = brainstate.nn.Linear(din, dmid)
        self.dense2 = brainstate.nn.Linear(dmid, dout)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.dense2(jax.nn.relu(self.dense1(x)))

def create_model(seed: int):
    brainstate.random.seed(seed)
    return CheckpointMLP(10, 20, 30)

def create_and_save(seed: int, path: str):
    model = create_model(seed)
    state_tree = brainstate.graph.treefy_states(model)
    orbax.PyTreeCheckpointer().save(os.path.join(path, 'state'), state_tree)

def load_model(path: str) -> CheckpointMLP:
    model = brainstate.transform.abstract_init(lambda: create_model(0))
    state_tree = brainstate.graph.treefy_states(model)
    state_tree = orbax.PyTreeCheckpointer().restore(os.path.join(path, 'state'), item=state_tree)
    brainstate.graph.update_states(model, state_tree)
    return model

with TemporaryDirectory() as tmpdir:
    create_and_save(42, tmpdir)
    restored = load_model(tmpdir)
    y = restored(jnp.ones((1, 10)))
    print(restored)
    print('Sample output shape:', y.shape)


### Example 007 — Parameter Surgery (`examples/007_parameter_surgery.py`)

**Highlights**
- Swapping submodules with pretrained replacements.
- Filtering parameter trees to separate frozen and trainable weights.

In [None]:
def load_pretrained():
    return brainstate.nn.Linear(784, 128)

class Classifier(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = brainstate.nn.Linear(784, 128)
        self.head = brainstate.nn.Linear(128, 10)

    def __call__(self, x):
        x = jax.nn.relu(self.backbone(x))
        return self.head(x)

model = Classifier()
model.backbone = load_pretrained()

def is_trainable(path, node):
    return 'backbone' not in path and issubclass(node.type, brainstate.ParamState)

graphdef, trainable_params, frozen_params = brainstate.graph.treefy_split(model, is_trainable, ...)
print('Trainable shapes:', jax.tree.map(jnp.shape, trainable_params))
print('Frozen shapes:', jax.tree.map(jnp.shape, frozen_params))


## 1xx — Biophysical Dynamics and Continuous Attractors

These examples move from abstract modules to conductance-based neurons and population dynamics.

### Example 100 — Hodgkin–Huxley Neuron (`examples/100_hh_neuron_model.py`)

**Highlights**
- Implementing the classic Hodgkin–Huxley system inside a `brainstate.nn.Dynamics` subclass.
- Using `brainstate.nn.exp_euler_step` with units from `brainunit`.
- Iterating the model with `brainstate.transform.for_loop`.

> Enable `RUN_NEURO_SIM=True` to execute the full 100 ms simulation.

In [None]:
import brainunit as u

class HH(brainstate.nn.Dynamics):
    def __init__(self, in_size):
        super().__init__(in_size)
        self.ENa = 50. * u.mV
        self.EK = -77. * u.mV
        self.EL = -54.387 * u.mV
        self.gNa = 120. * u.mS / u.cm ** 2
        self.gK = 36. * u.mS / u.cm ** 2
        self.gL = 0.03 * u.mS / u.cm ** 2
        self.C = 1.0 * u.uF / u.cm ** 2
        self.V_th = 20. * u.mV

    m_alpha = lambda self, V: 1. / u.math.exprel(-(V / u.mV + 40) / 10)
    m_beta = lambda self, V: 4.0 * jnp.exp(-(V / u.mV + 65) / 18)
    h_alpha = lambda self, V: 0.07 * jnp.exp(-(V / u.mV + 65) / 20.)
    h_beta = lambda self, V: 1 / (1 + jnp.exp(-(V / u.mV + 35) / 10))
    n_alpha = lambda self, V: 0.1 / u.math.exprel(-(V / u.mV + 55) / 10)
    n_beta = lambda self, V: 0.125 * jnp.exp(-(V / u.mV + 65) / 80)

    def init_state(self, batch_size=None):
        V0 = jnp.ones(self.varshape, brainstate.environ.dftype()) * -65. * u.mV
        self.V = brainstate.HiddenState(V0)
        self.m = brainstate.HiddenState(self.m_alpha(V0) / (self.m_alpha(V0) + self.m_beta(V0)))
        self.h = brainstate.HiddenState(self.h_alpha(V0) / (self.h_alpha(V0) + self.h_beta(V0)))
        self.n = brainstate.HiddenState(self.n_alpha(V0) / (self.n_alpha(V0) + self.n_beta(V0)))

    def dV(self, V, t, m, h, n, I):
        I = self.sum_current_inputs(I, V)
        I_Na = (self.gNa * m ** 3 * h) * (V - self.ENa)
        I_K = (self.gK * n ** 4) * (V - self.EK)
        I_leak = self.gL * (V - self.EL)
        return (- I_Na - I_K - I_leak + I) / self.C

    def update(self, x=0. * u.mA / u.cm ** 2):
        t = brainstate.environ.get('t')
        V = brainstate.nn.exp_euler_step(self.dV, self.V.value, t, self.m.value, self.h.value, self.n.value, x)
        m = brainstate.nn.exp_euler_step(lambda m, t, V: (self.m_alpha(V) * (1 - m) - self.m_beta(V) * m) / u.ms, self.m.value, t, self.V.value)
        h = brainstate.nn.exp_euler_step(lambda h, t, V: (self.h_alpha(V) * (1 - h) - self.h_beta(V) * h) / u.ms, self.h.value, t, self.V.value)
        n = brainstate.nn.exp_euler_step(lambda n, t, V: (self.n_alpha(V) * (1 - n) - self.n_beta(V) * n) / u.ms, self.n.value, t, self.V.value)
        V = self.sum_delta_inputs(init=V)
        spike = jnp.logical_and(self.V.value < self.V_th, V >= self.V_th)
        self.V.value, self.m.value, self.h.value, self.n.value = V, m, h, n
        return spike

if RUN_NEURO_SIM:
    hh = HH(10)
    brainstate.nn.init_all_states(hh)
    dt = 0.01 * u.ms

    def run(t, inp):
        with brainstate.environ.context(t=t, dt=dt):
            hh(inp)
        return hh.V.value

    times = u.math.arange(0. * u.ms, 10. * u.ms, dt)
    vs = brainstate.transform.for_loop(run, times, brainstate.random.uniform(1., 5., times.shape) * u.uA / u.cm ** 2)
    plt.plot(times.to_decimal(u.ms), vs.to_decimal(u.mV))
    plt.xlabel('Time (ms)')
    plt.ylabel('Membrane potential (mV)')
    plt.show()
else:
    print('HH simulation skipped (set RUN_NEURO_SIM=True).')


### Example 101 — Continuous Attractor Neural Network (`examples/101_cann_1d_oscillatory_tracking.py`)

**Highlights**
- Building a 1D CANN with spike-frequency adaptation.
- Using `brainstate.transform.for_loop` for iterative simulation and animation.
- Integrating external inputs with Brainstate's environment context.

> This script depends on `matplotlib.animation` and is computationally heavy. Execute only when `RUN_NEURO_SIM=True`.

In [None]:
if not RUN_NEURO_SIM:
    print('CANN simulation skipped (set RUN_NEURO_SIM=True).')
else:
    print('Open `examples/101_cann_1d_oscillatory_tracking.py` for the full animated simulation.')


### Example 105 — COBA HH Network with BrainCell (`examples/105_COBA_HH_2007_braincell.py`)

**Highlights**
- Combining Brainstate modules with the BrainCell biophysical library.
- Driving an excitatory/inhibitory network via event-based projections.

> Requires `braincell`, `brainpy`, and `brainunit`. Set `RUN_NEURO_SIM=True` to run.

In [None]:
if RUN_NEURO_SIM:
    print('Run `python examples/105_COBA_HH_2007_braincell.py` for the full raster plot (heavy dependencies).')
else:
    print('BrainCell COBA network skipped (enable RUN_NEURO_SIM to execute).')


## 2xx — Brain-Inspired Computing and Brainscale Training

The 203 example introduces Brainscale tools for training surrogate-gradient SNNs. These scripts are long-running and rely on `brainscale`, `torch`, `numba`, and other packages.

### Example 203 — Brainscale for SNNs (`examples/203_brainscale_for_snns.py`)

**Highlights**
- Command-line flag parsing for different eligibility-trace methods (diag, expsm_diag, bptt).
- Custom GIF neuron with eligibility traces (`brainscale.ETraceState`).
- Integration with PyTorch dataloaders and NumPy-based preprocessing.

> The notebook does not replicate the full training loop. Use the original script with `python examples/203_brainscale_for_snns.py --help` to explore all options.

## 3xx — Rate-Based RNNs and Brainscale RNN Training

The 3xx series covers classic RNN training and a Brainscale-enhanced workflow.

### Example 300 — Integrator RNN (`examples/300_integrator_rnn.py`)

**Highlights**
- Training a simple integrator to reproduce cumulative sums of noisy inputs.
- Demonstrating stateful RNN cells with trainable initial states.
- Combining Brainstate transforms with `braintools.optim.Adam`.

> Set `RUN_TRAINING=True` to execute a short training loop.

In [None]:
dt = 0.04
num_step = int(1.0 / dt)
num_batch = 128

@brainstate.transform.jit(static_argnums=2)
def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10):
    sample = brainstate.random.normal(size=(1, batch_size, 1))
    bias = mean * 2.0 * (sample - 0.5)
    noise_t = scale / dt ** 0.5 * brainstate.random.normal(size=(num_step, batch_size, 1))
    inputs = bias + noise_t
    targets = jnp.cumsum(inputs, axis=0)
    return inputs, targets

def train_data():
    for _ in range(100):
        yield build_inputs_and_targets(0.025, 0.01, num_batch)

class RNNCell(brainstate.nn.Module):
    def __init__(self, num_in: int, num_out: int):
        super().__init__()
        self.num_out = num_out
        self.W = brainstate.ParamState(braintools.init.XavierNormal()((num_in + num_out, num_out)))
        self.b = brainstate.ParamState(braintools.init.ZeroInit()((num_out,)))
        self.state_param = brainstate.ParamState(braintools.init.ZeroInit()((num_out,)))

    def init_state(self, batch_size=None, **kwargs):
        base = self.state_param.value
        if batch_size is None:
            self.state = brainstate.HiddenState(base)
        else:
            self.state = brainstate.HiddenState(jnp.repeat(base[None, :], batch_size, axis=0))

    def update(self, x):
        x = jnp.concatenate([x, self.state.value], axis=-1)
        h = jax.nn.tanh(x @ self.W.value + self.b.value)
        self.state.value = h
        return h

class IntegratorRNN(brainstate.nn.Module):
    def __init__(self, num_in, num_hidden):
        super().__init__()
        self.cell = RNNCell(num_in, num_hidden)
        self.out = brainstate.nn.Linear(num_hidden, 1)

    def update(self, x):
        return x >> self.cell >> self.out

model = IntegratorRNN(1, 64)
weights = model.states(brainstate.ParamState)

@brainstate.transform.jit
def f_predict(inputs):
    brainstate.nn.init_all_states(model, batch_size=inputs.shape[1])
    return brainstate.transform.for_loop(model.update, inputs)

def f_loss(inputs, targets, l2_reg=2e-4):
    preds = f_predict(inputs)
    mse = braintools.metric.squared_error(preds, targets).mean()
    l2 = sum(jnp.sum(leaf ** 2) for leaf in jax.tree.leaves(weights))
    return mse + l2_reg * l2

opt = braintools.optim.Adam(lr=braintools.optim.ExponentialDecayLR(0.01, 1, 0.999))
opt.register_trainable_weights(weights)

@brainstate.transform.jit
def f_train(inputs, targets):
    grads, loss = brainstate.transform.grad(f_loss, weights, return_value=True)(inputs, targets)
    opt.update(grads)
    return loss

if RUN_TRAINING:
    for step, (inp, tar) in zip(range(10), train_data()):
        loss = f_train(inp, tar)
        print(f"batch {step}: loss={float(loss):.5f}")
else:
    print('Integrator training skipped (set RUN_TRAINING=True).')


### Example 301 — Brainscale for RNNs (`examples/301_brainscale_for_rnns.py`)

**Highlights**
- Brainscale eligibility-trace training for recurrent models.
- Dataset streaming with PyTorch `IterableDataset`.
- Checkpointing via Orbax `CheckpointManager`.

> Requires `brainscale`, `torch`, and `orbax`. Execute the original script directly for the full experiment.

## BrainCell Interoperation (`examples/braincell-interoperation/`)

Seven scripts illustrate how Brainstate talks to the BrainCell ecosystem for detailed cellular modelling:

- `SC01_fitting_a_hh_neuron.py`: Parameter fitting of a Hodgkin–Huxley neuron using experimental traces.
- `SC03_COBA_HH_2007_braincell.py`: COBA network with HH neurons (BrainCell primitives).
- `SC04_hh_neuron.py`: Standalone BrainCell HH neuron driven by Brainstate.
- `SC05_thalamus_single_compartment_neurons.py`: Building thalamic cell types (RTC, TRN, HTC, IN).
- `SC06_unified_thalamus_model.py`: Multi-population thalamus network (Li et al., 2017).
- `SC07_Straital_beta_oscillation_2011.py`: Striatal beta oscillation model.

> These programmes depend on `braincell`, `brainunit`, `braintools`, `numba`, and dataset files (for SC01). Run them directly from the `examples/braincell-interoperation/` folder.

## BrainMass Interoperation (`examples/brainmass-interoperation/`)

The BrainMass demos model meso-scale neural mass dynamics with Brainstate integration:

- `00-hopf-osillator.py`: Hopf oscillator bifurcation and noise-driven dynamics.
- `01-wilsonwowan-osillator.py`: Wilson–Cowan oscillations and nullclines.
- `02-fhn-osillator.py`: FitzHugh–Nagumo excitability.
- `03-jansenrit_single_node_simulation.py`: Jansen–Rit cortical column.
- `Modeling_resting_state_MEG_data.py`: Resting-state MEG fitting with multi-region coupling.

> All scripts rely on `brainmass`, `brainunit`, `matplotlib`, and sometimes `pandas`. They are organised as rich notebooks in script form (`#%%` markers) for Jupyter conversion.

## BrainPy Interoperation (`examples/brainpy-interoperation/`)

BrainPy 3.x is rebuilt on top of Brainstate. This directory contains numerous examples spanning network simulations, gamma oscillation regimes, and surrogate-gradient training. Refer to `examples/brainpy-interoperation/README.md` for a curated list:

- 100-series: balanced E/I networks, synfire chains, gamma oscillations.
- 110-series: Susin & Destexhe gamma models (AI, CHING, ING, PING).
- 200-series: Surrogate gradient training on MNIST/Fashion-MNIST.

> Install `brainpy` with the appropriate extras (`brainpy[cpu]` or `brainpy[cuda12]`). Each script can be run individually via `python <filename>`.

## Benchmarks (`benchmark/`)

Although not part of the `examples/` folder, the `benchmark/` directory contains the CUBA and COBA reference benchmarks used across multiple demos. Consult these scripts when you need baseline parameters for large-scale simulations.

## Next Steps

1. Clone specific sections of this notebook into new experiments and customise the models.
2. Explore the `docs/tutorials/` notebooks for guided lessons on modules, transforms, and advanced APIs.
3. Combine Brainstate with the interoperation libraries to build hybrid simulations that mix detailed neurons with learning systems.