# Tutorial 24: BrainPy Integration

In this tutorial, we'll explore how BrainState integrates with BrainPy and when to use each framework.

## Learning Objectives

By the end of this tutorial, you will:
- Understand the relationship between BrainState and BrainPy
- Know when to use BrainState vs BrainPy
- Learn how to combine both frameworks
- Understand API interoperability
- Build hybrid models using both libraries
- Leverage the strengths of each framework

## Introduction

**BrainState** and **BrainPy** are complementary frameworks from the same ecosystem:

### BrainState
- **Focus**: Low-level state management and transformations
- **Purpose**: Foundation library for building neural models
- **Key Features**:
  - Explicit state management (ParamState, ShortTermState, etc.)
  - JAX transformations (JIT, grad, vmap)
  - Basic neural network layers
  - Graph operations and utilities

### BrainPy
- **Focus**: Brain dynamics and computational neuroscience
- **Purpose**: High-level brain modeling and simulation
- **Key Features**:
  - Spiking neural networks (SNNs)
  - Neuronal models (HH, LIF, Izhikevich, etc.)
  - Synaptic plasticity (STDP, BCM, etc.)
  - Network connectivity and dynamics
  - Differential equation solvers

**Relationship**: BrainPy is built on top of BrainState, using its state management system and transformations.

In [None]:
import brainstate as bst
# Note: BrainPy would be imported as:
# import brainpy as bp
# For this tutorial, we'll demonstrate the concepts with BrainState

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional

bst.random.seed(42)

## 1. Framework Comparison

### 1.1 When to Use BrainState

Use **BrainState** when you need:

1. **Custom State Management**: Fine-grained control over state types
2. **Deep Learning Models**: Standard neural networks (MLPs, CNNs, Transformers)
3. **JAX Transformations**: Direct access to JIT, vmap, grad
4. **Lightweight Solution**: Minimal dependencies for simple models
5. **Building Blocks**: Foundation for custom frameworks

In [None]:
# Example: BrainState is ideal for standard deep learning
class StandardMLP(bst.graph.Node):
    """Simple MLP using BrainState."""
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = bst.nn.Linear(input_dim, hidden_dim)
        self.fc2 = bst.nn.Linear(hidden_dim, output_dim)
    
    def __call__(self, x):
        x = jax.nn.relu(self.fc1(x))
        return self.fc2(x)

# Use case: Standard classification task
model = StandardMLP(784, 128, 10)
x = bst.random.randn(32, 784)
output = model(x)
print(f"BrainState MLP output shape: {output.shape}")

### 1.2 When to Use BrainPy

Use **BrainPy** when you need:

1. **Brain Modeling**: Biologically realistic neural simulations
2. **Spiking Networks**: Event-driven computation with spikes
3. **Neuronal Dynamics**: Complex differential equations (HH, Izhikevich)
4. **Synaptic Plasticity**: STDP, BCM, and other learning rules
5. **Neuroscience Research**: Tools for computational neuroscience

In [None]:
# Example: BrainPy-style neuron model (implemented with BrainState)
# This demonstrates the concepts that BrainPy provides

class LIFNeuron(bst.nn.Dynamics):
    """
    Leaky Integrate-and-Fire neuron.
    This is the type of model you'd use BrainPy for.
    """
    
    def __init__(self, size: int, tau: float = 10.0, 
                 V_rest: float = -70.0, V_th: float = -50.0, 
                 V_reset: float = -70.0, dt: float = 0.1):
        super().__init__()
        self.size = size
        self.tau = tau
        self.V_rest = V_rest
        self.V_th = V_th
        self.V_reset = V_reset
        self.dt = dt
        self.R = 1.0
        
        # State variable
        self.V = bst.ShortTermState(jnp.ones(size) * V_rest)
    
    def __call__(self, I: jnp.ndarray) -> jnp.ndarray:
        """
        Update neuron state given input current.
        
        Args:
            I: Input current
            
        Returns:
            Spike output (1 if spike, 0 otherwise)
        """
        # Differential equation: tau * dV/dt = -(V - V_rest) + R*I
        dV = (-(self.V.value - self.V_rest) + self.R * I) / self.tau
        V_new = self.V.value + dV * self.dt
        
        # Spike and reset
        spike = V_new >= self.V_th
        V_new = jnp.where(spike, self.V_reset, V_new)
        
        # Update state
        self.V.value = V_new
        
        return spike.astype(jnp.float32)

# Use case: Spiking neural network
lif = LIFNeuron(size=100, tau=10.0)
current_input = bst.random.randn(100) * 5 + 10  # Noisy input current
spikes = lif(current_input)
print(f"Number of spikes: {jnp.sum(spikes)}")

### 1.3 Feature Comparison Table

In [None]:
# Create comparison table
comparison = [
    ("Feature", "BrainState", "BrainPy"),
    ("-" * 30, "-" * 25, "-" * 25),
    ("State Management", "✓ Core feature", "✓ Built on BrainState"),
    ("Neural Network Layers", "✓ Basic layers", "✓ + Specialized layers"),
    ("JAX Transformations", "✓ Direct access", "✓ Via BrainState"),
    ("Spiking Neural Networks", "○ Manual impl.", "✓ Built-in models"),
    ("Neuronal Models", "○ Manual impl.", "✓ HH, Izhikevich, etc."),
    ("Synaptic Plasticity", "○ Manual impl.", "✓ STDP, BCM, etc."),
    ("Network Connectivity", "○ Manual impl.", "✓ Built-in patterns"),
    ("ODE Solvers", "○ Basic", "✓ Advanced solvers"),
    ("Deep Learning Focus", "✓ Primary", "○ Secondary"),
    ("Neuroscience Focus", "○ Secondary", "✓ Primary"),
    ("Package Size", "Lightweight", "Full-featured"),
    ("Learning Curve", "Moderate", "Steeper"),
]

print("BrainState vs BrainPy Feature Comparison")
print("=" * 80)
for row in comparison:
    print(f"{row[0]:<30} {row[1]:<25} {row[2]:<25}")

print("\nLegend: ✓ = Full support, ○ = Partial/Manual implementation")

## 2. Shared State Management

Both frameworks use the same state management system.

In [None]:
print("Shared State Types")
print("=" * 60)

# State types used by both frameworks
state_types = [
    ("ParamState", "Trainable parameters (weights, biases)", "Both"),
    ("ShortTermState", "Temporary state (hidden states, membrane potentials)", "Both"),
    ("LongTermState", "Accumulated state (running stats, counters)", "Both"),
    ("HiddenState", "Internal hidden variables", "Both"),
]

print(f"{'State Type':<20} {'Description':<40} {'Used By':<10}")
print("-" * 70)
for state_type, desc, used in state_types:
    print(f"{state_type:<20} {desc:<40} {used:<10}")

# Example: Same state system
class HybridModel(bst.graph.Node):
    """Model using shared state types."""
    
    def __init__(self, size):
        super().__init__()
        # Parameters (trainable)
        self.W = bst.ParamState(bst.random.randn(size, size) * 0.1)
        
        # Short-term state (reset per episode)
        self.activity = bst.ShortTermState(jnp.zeros(size))
        
        # Long-term state (accumulated)
        self.total_spikes = bst.LongTermState(jnp.array(0))
    
    def __call__(self, x):
        # Compute new activity
        new_activity = jax.nn.tanh(self.W.value @ x + self.activity.value)
        
        # Update states
        self.activity.value = new_activity
        self.total_spikes.value = self.total_spikes.value + jnp.sum(new_activity > 0.5)
        
        return new_activity

model = HybridModel(10)
x = bst.random.randn(10)
y = model(x)

print(f"\nState counts:")
print(f"  ParamState: {len(model.states(bst.ParamState))}")
print(f"  ShortTermState: {len(model.states(bst.ShortTermState))}")
print(f"  LongTermState: {len(model.states(bst.LongTermState))}")

## 3. API Interoperability

### 3.1 Common Base: bst.graph.Node

Both frameworks use `bst.graph.Node` as the base class.

In [None]:
# BrainState layer
class BrainStateLayer(bst.graph.Node):
    def __init__(self, dim):
        super().__init__()
        self.linear = bst.nn.Linear(dim, dim)
    
    def __call__(self, x):
        return jax.nn.relu(self.linear(x))

# BrainPy-style dynamics layer
class BrainPyStyleLayer(bst.nn.Dynamics):
    def __init__(self, dim, tau=10.0):
        super().__init__()
        self.linear = bst.nn.Linear(dim, dim)
        self.tau = tau
        self.state = bst.ShortTermState(jnp.zeros(dim))
    
    def __call__(self, x):
        # Dynamics integration
        dx = (-self.state.value + self.linear(x)) / self.tau
        self.state.value = self.state.value + dx * 0.1
        return self.state.value

# Both can be combined!
class CombinedModel(bst.graph.Node):
    def __init__(self, dim):
        super().__init__()
        self.bst_layer = BrainStateLayer(dim)
        self.bp_layer = BrainPyStyleLayer(dim)
    
    def __call__(self, x):
        x = self.bst_layer(x)  # BrainState processing
        x = self.bp_layer(x)   # BrainPy-style dynamics
        return x

# Test combined model
combined = CombinedModel(20)
x = bst.random.randn(5, 20)
output = combined(x)
print(f"Combined model output shape: {output.shape}")
print(f"Total parameters: {sum(p.value.size for p in combined.states(bst.ParamState).values())}")

### 3.2 Shared Transformations

Both frameworks use the same JAX transformations.

In [None]:
print("Shared JAX Transformations")
print("=" * 60)

# JIT compilation works for both
model = CombinedModel(10)
x_test = bst.random.randn(3, 10)

# JIT compile
@bst.transform.jit
def forward_jit(x):
    return model(x)

output_jit = forward_jit(x_test)
print(f"JIT output shape: {output_jit.shape}")

# Gradient computation
def loss_fn(x):
    return jnp.sum(model(x) ** 2)

params = model.states(bst.ParamState)
loss, grads = bst.transform.grad(loss_fn, grad_states=params, return_value=True)(x_test)
print(f"\nLoss: {loss:.4f}")
print(f"Gradient keys: {list(grads.keys())[:3]}...")  # Show first 3

# Vectorization
def process_single(x_single):
    return jnp.sum(model(x_single.reshape(1, -1)))

batch_vmap = jax.vmap(process_single)
results = batch_vmap(x_test)
print(f"\nvmap results shape: {results.shape}")

## 4. Building Hybrid Models

Combine deep learning (BrainState) with brain dynamics (BrainPy-style).

In [None]:
class SpikingCNN(bst.graph.Node):
    """
    Hybrid model: CNN feature extraction + Spiking output layer.
    Demonstrates combining traditional DL with neuromorphic computing.
    """
    
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Standard CNN layers (BrainState)
        self.conv1 = bst.nn.Conv2d(1, 32, kernel_size=(3, 3), padding='SAME')
        self.conv2 = bst.nn.Conv2d(32, 64, kernel_size=(3, 3), padding='SAME')
        self.fc = bst.nn.Linear(64 * 7 * 7, 256)
        
        # Spiking output layer (BrainPy-style)
        self.spiking_out = LIFNeuron(
            size=num_classes,
            tau=10.0,
            V_th=-50.0,
            V_reset=-70.0
        )
        
        # Readout weights
        self.readout = bst.nn.Linear(256, num_classes)
    
    def __call__(self, x: jnp.ndarray, return_spikes: bool = False) -> jnp.ndarray:
        """
        Forward pass.
        
        Args:
            x: Input image (batch, channels, height, width)
            return_spikes: If True, return spike trains
            
        Returns:
            Output logits or spike counts
        """
        # CNN feature extraction
        x = jax.nn.relu(self.conv1(x))
        x = bst.functional.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = jax.nn.relu(self.conv2(x))
        x = bst.functional.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        # Flatten and FC
        x = x.reshape(x.shape[0], -1)
        x = jax.nn.relu(self.fc(x))
        
        if return_spikes:
            # Convert to current and simulate spiking
            current = self.readout(x)
            spike_counts = jnp.zeros_like(current)
            
            # Simulate for multiple time steps
            for t in range(10):
                spikes = self.spiking_out(current[0])  # Process first sample
                spike_counts = spike_counts.at[0].add(spikes)
            
            return spike_counts
        else:
            # Standard output
            return self.readout(x)

# Create and test hybrid model
hybrid_model = SpikingCNN(num_classes=10)
test_image = bst.random.randn(1, 1, 28, 28)

# Standard forward pass
logits = hybrid_model(test_image, return_spikes=False)
print(f"Logits shape: {logits.shape}")
print(f"Logits: {logits[0]}")

# Spiking output
spike_counts = hybrid_model(test_image, return_spikes=True)
print(f"\nSpike counts shape: {spike_counts.shape}")
print(f"Spike counts: {spike_counts[0]}")

## 5. Use Case Scenarios

### 5.1 Scenario 1: Image Classification (BrainState)

In [None]:
print("Use Case 1: Standard Image Classification")
print("=" * 60)
print("Framework: BrainState")
print("Reason: Standard deep learning task, no brain dynamics needed\n")

class ImageClassifier(bst.graph.Node):
    """Standard CNN for image classification."""
    
    def __init__(self):
        super().__init__()
        self.conv1 = bst.nn.Conv2d(3, 64, kernel_size=(3, 3), padding='SAME')
        self.conv2 = bst.nn.Conv2d(64, 128, kernel_size=(3, 3), padding='SAME')
        self.fc1 = bst.nn.Linear(128 * 8 * 8, 256)
        self.fc2 = bst.nn.Linear(256, 10)
    
    def __call__(self, x):
        x = jax.nn.relu(self.conv1(x))
        x = bst.functional.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = jax.nn.relu(self.conv2(x))
        x = bst.functional.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape(x.shape[0], -1)
        x = jax.nn.relu(self.fc1(x))
        return self.fc2(x)

classifier = ImageClassifier()
sample = bst.random.randn(4, 3, 32, 32)  # CIFAR-like
preds = classifier(sample)
print(f"Predictions shape: {preds.shape}")
print(f"Parameters: {sum(p.value.size for p in classifier.states(bst.ParamState).values()):,}")

### 5.2 Scenario 2: Spiking Network Simulation (BrainPy-style)

In [None]:
print("\nUse Case 2: Spiking Neural Network Simulation")
print("=" * 60)
print("Framework: BrainPy (demonstrated with BrainState)")
print("Reason: Need biological neuron models and spike dynamics\n")

class SpikingNetwork(bst.graph.Node):
    """Simple spiking neural network."""
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        
        # Synaptic weights
        self.W_in = bst.ParamState(bst.random.randn(input_size, hidden_size) * 0.1)
        self.W_out = bst.ParamState(bst.random.randn(hidden_size, output_size) * 0.1)
        
        # Neuron populations
        self.hidden_neurons = LIFNeuron(hidden_size, tau=10.0)
        self.output_neurons = LIFNeuron(output_size, tau=20.0)
    
    def __call__(self, spike_input: jnp.ndarray, num_steps: int = 100) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Simulate network for multiple time steps.
        
        Args:
            spike_input: Input spike train (time, input_size)
            num_steps: Number of simulation steps
            
        Returns:
            Hidden and output spike trains
        """
        hidden_spikes = []
        output_spikes = []
        
        for t in range(num_steps):
            # Get input for this timestep
            if t < len(spike_input):
                inp = spike_input[t]
            else:
                inp = jnp.zeros(spike_input.shape[1])
            
            # Hidden layer
            hidden_current = inp @ self.W_in.value
            h_spikes = self.hidden_neurons(hidden_current)
            hidden_spikes.append(h_spikes)
            
            # Output layer
            output_current = h_spikes @ self.W_out.value
            o_spikes = self.output_neurons(output_current)
            output_spikes.append(o_spikes)
        
        return jnp.stack(hidden_spikes), jnp.stack(output_spikes)

# Create spiking network
snn = SpikingNetwork(input_size=50, hidden_size=100, output_size=10)

# Generate random spike input
spike_input = (bst.random.rand(50, 50) < 0.1).astype(jnp.float32)  # Sparse spikes

# Simulate
hidden_activity, output_activity = snn(spike_input, num_steps=50)

print(f"Hidden activity shape: {hidden_activity.shape}")
print(f"Output activity shape: {output_activity.shape}")
print(f"Total output spikes: {jnp.sum(output_activity)}")

# Visualize spike raster
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6))

# Hidden layer spikes
spike_times, spike_neurons = jnp.where(hidden_activity > 0.5)
ax1.scatter(spike_times, spike_neurons, s=1, c='black', alpha=0.5)
ax1.set_ylabel('Neuron Index')
ax1.set_title('Hidden Layer Spike Raster')
ax1.set_xlim(0, 50)

# Output layer spikes
spike_times, spike_neurons = jnp.where(output_activity > 0.5)
ax2.scatter(spike_times, spike_neurons, s=5, c='red', alpha=0.7)
ax2.set_xlabel('Time Step')
ax2.set_ylabel('Neuron Index')
ax2.set_title('Output Layer Spike Raster')
ax2.set_xlim(0, 50)

plt.tight_layout()
plt.show()

### 5.3 Scenario 3: Hybrid Model (Both)

In [None]:
print("Use Case 3: Neuromorphic Vision System")
print("=" * 60)
print("Framework: Both BrainState + BrainPy")
print("Reason: CNN feature extraction + spiking decision making\n")

class NeuromorphicVision(bst.graph.Node):
    """Vision system with CNN features and spiking decision layer."""
    
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Feature extraction (BrainState - standard DL)
        self.features = bst.nn.Sequential(
            bst.nn.Conv2d(1, 32, kernel_size=(3, 3), padding='SAME'),
            bst.nn.Conv2d(32, 64, kernel_size=(3, 3), padding='SAME'),
        )
        
        # Feature to current converter
        self.fc_current = bst.nn.Linear(64 * 7 * 7, 100)
        
        # Decision layer (BrainPy-style - spiking)
        self.decision = LIFNeuron(size=num_classes, tau=10.0)
        
        # Current weights to decision neurons
        self.W_decision = bst.ParamState(bst.random.randn(100, num_classes) * 0.1)
    
    def extract_features(self, x: jnp.ndarray) -> jnp.ndarray:
        """Extract features using CNN."""
        for layer in self.features.children().values():
            x = jax.nn.relu(layer(x))
            x = bst.functional.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        return x.reshape(x.shape[0], -1)
    
    def decide(self, features: jnp.ndarray, num_steps: int = 20) -> jnp.ndarray:
        """Make decision using spiking neurons."""
        # Convert features to current
        current_base = self.fc_current(features)
        
        # Accumulate spikes over time
        spike_counts = jnp.zeros(10)
        
        for t in range(num_steps):
            # Add noise to current
            current = current_base[0] @ self.W_decision.value
            spikes = self.decision(current)
            spike_counts = spike_counts + spikes
        
        return spike_counts
    
    def __call__(self, x: jnp.ndarray, mode: str = 'spike') -> jnp.ndarray:
        """
        Forward pass.
        
        Args:
            x: Input image
            mode: 'spike' for spike-based, 'rate' for rate-based
        """
        features = self.extract_features(x)
        
        if mode == 'spike':
            return self.decide(features, num_steps=20)
        else:
            # Rate-based (standard)
            current = self.fc_current(features)
            return current @ self.W_decision.value

# Test neuromorphic vision
neuro_vision = NeuromorphicVision(num_classes=10)
test_img = bst.random.randn(1, 1, 28, 28)

# Spike-based decision
spike_output = neuro_vision(test_img, mode='spike')
print(f"Spike-based output (counts): {spike_output}")
print(f"Predicted class: {jnp.argmax(spike_output)}")

# Rate-based decision
rate_output = neuro_vision(test_img, mode='rate')
print(f"\nRate-based output: {rate_output[0]}")
print(f"Predicted class: {jnp.argmax(rate_output[0])}")

## 6. Migration Between Frameworks

### 6.1 Converting BrainState to BrainPy-style

In [None]:
print("Converting BrainState Model to BrainPy-style")
print("=" * 60)

# Original BrainState model
class OriginalBrainState(bst.graph.Node):
    def __init__(self, dim):
        super().__init__()
        self.fc = bst.nn.Linear(dim, dim)
    
    def __call__(self, x):
        return jax.nn.relu(self.fc(x))

# Converted to BrainPy-style with dynamics
class ConvertedBrainPyStyle(bst.nn.Dynamics):
    def __init__(self, dim, tau=10.0, dt=0.1):
        super().__init__()
        self.fc = bst.nn.Linear(dim, dim)
        self.tau = tau
        self.dt = dt
        
        # Add membrane potential dynamics
        self.V = bst.ShortTermState(jnp.zeros(dim))
    
    def __call__(self, x):
        # Compute target activity
        target = jax.nn.relu(self.fc(x))
        
        # Integrate with first-order dynamics
        dV = (target - self.V.value) / self.tau
        self.V.value = self.V.value + dV * self.dt
        
        return self.V.value

# Compare
dim = 10
x = bst.random.randn(5, dim)

orig = OriginalBrainState(dim)
conv = ConvertedBrainPyStyle(dim)

# Initialize both
_ = orig(x)
_ = conv(x)

# Copy weights
orig_params = orig.states(bst.ParamState)
conv_params = conv.states(bst.ParamState)
for k in conv_params:
    if k in orig_params:
        conv_params[k].value = orig_params[k].value

# Run both
out_orig = orig(x)
out_conv = conv(x)

print(f"Original output (instant): {out_orig[0, :5]}")
print(f"Converted output (dynamics): {out_conv[0, :5]}")
print("\nNote: Converted version has temporal dynamics!")

## 7. Best Practices

### 7.1 Decision Guide

In [None]:
print("Framework Selection Decision Tree")
print("=" * 80)

decision_tree = """
START: What are you building?
│
├─ Standard ML/DL model (CNN, Transformer, etc.)
│  └─> Use BrainState ✓
│
├─ Brain-inspired but rate-based neural network
│  └─> Use BrainState ✓
│
├─ Spiking neural network
│  ├─ Simple LIF neurons
│  │  └─> Either (BrainState can do it)
│  └─ Complex neuron models (HH, Izhikevich)
│     └─> Use BrainPy ✓✓
│
├─ Computational neuroscience simulation
│  └─> Use BrainPy ✓✓
│
├─ Need synaptic plasticity (STDP, BCM)
│  └─> Use BrainPy ✓✓
│
└─ Hybrid: DL features + brain dynamics
   └─> Use Both BrainState + BrainPy ✓✓
"""

print(decision_tree)

print("\nKey Questions to Ask:")
questions = [
    "1. Do I need biologically realistic neuron models? → BrainPy",
    "2. Am I doing standard deep learning? → BrainState",
    "3. Do I need spike timing dependent plasticity? → BrainPy",
    "4. Am I building a simple feedforward network? → BrainState",
    "5. Do I need differential equation solvers? → BrainPy",
    "6. Am I prototyping quickly? → BrainState (simpler)",
    "7. Is this for neuroscience research? → BrainPy",
]

for q in questions:
    print(f"  {q}")

### 7.2 Performance Tips

In [None]:
print("\nPerformance Tips for Both Frameworks")
print("=" * 60)

tips = [
    ("JIT Compilation", "Always use @bst.transform.jit for production", "Both"),
    ("Vectorization", "Use vmap instead of Python loops", "Both"),
    ("State Management", "Use appropriate state types (Param vs ShortTerm)", "Both"),
    ("Memory", "Reset ShortTermState between episodes", "Both"),
    ("Gradient Computation", "Specify grad_states to avoid unnecessary grads", "Both"),
    ("Batch Processing", "Process multiple samples simultaneously", "Both"),
    ("Time Steps", "For spiking: Balance accuracy vs speed", "BrainPy"),
    ("Connectivity", "Use sparse matrices for large networks", "BrainPy"),
]

print(f"{'Aspect':<20} {'Tip':<40} {'Framework':<10}")
print("-" * 70)
for aspect, tip, framework in tips:
    print(f"{aspect:<20} {tip:<40} {framework:<10}")

## Summary

### Key Takeaways:

1. **Complementary Frameworks**:
   - BrainState: Foundation for state management and basic DL
   - BrainPy: Advanced brain modeling built on BrainState

2. **Shared Foundation**:
   - Same state management system
   - Same base class (bst.graph.Node)
   - Compatible JAX transformations

3. **When to Use Each**:
   - BrainState: Standard ML/DL, simple models, rapid prototyping
   - BrainPy: Brain simulation, SNNs, neuroscience research
   - Both: Hybrid models combining DL with brain dynamics

4. **Interoperability**:
   - Models from both can be combined seamlessly
   - Share parameters and states
   - Use same training infrastructure

5. **Best Practices**:
   - Start with BrainState for simplicity
   - Add BrainPy when you need biological realism
   - Use JIT and vmap for performance in both

## Next Steps

- Explore BrainPy documentation for advanced features
- Experiment with hybrid models
- Try converting existing models between frameworks
- Build neuromorphic applications

For more information:
- [BrainState Documentation](https://brainstate.readthedocs.io/)
- [BrainPy Documentation](https://brainpy.readthedocs.io/)