# Tutorial 27: Frequently Asked Questions (FAQ)

In this tutorial, we'll address common questions and issues when working with BrainState.

## Learning Objectives

By the end of this tutorial, you will:
- Understand common pitfalls and their solutions
- Learn how to interpret error messages
- Know where to find help and resources
- Understand best practices
- Get answers to frequently asked questions

## Introduction

This FAQ covers:
- **General Questions**: About BrainState and JAX
- **Installation Issues**: Setup and dependencies
- **API Usage**: Common usage patterns
- **Error Messages**: Understanding and fixing errors
- **Performance**: Optimization tips
- **Community Resources**: Where to get help

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np

bst.random.seed(42)

## 1. General Questions

### Q1.1: What is BrainState?

**A:** BrainState is a state management framework built on JAX for building neural models. It provides:
- Explicit state management (ParamState, ShortTermState, etc.)
- Neural network layers and utilities
- Graph-based model composition
- JAX transformation wrappers

In [None]:
print("Q1.1: What is BrainState?")
print("=" * 60)

print("""
BrainState is designed for:
  ✓ Brain-inspired neural network models
  ✓ Spiking neural networks
  ✓ Dynamical systems
  ✓ Standard deep learning models

Built on JAX, it provides:
  ✓ Automatic differentiation
  ✓ JIT compilation
  ✓ Vectorization (vmap)
  ✓ GPU/TPU acceleration
""")

# Simple example
class SimpleNN(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.fc = bst.nn.Linear(10, 5)
    
    def __call__(self, x):
        return self.fc(x)

model = SimpleNN()
x = bst.random.randn(2, 10)
output = model(x)
print(f"Example output shape: {output.shape}")

### Q1.2: How is BrainState different from PyTorch/TensorFlow?

**A:** Key differences:

In [None]:
print("Q1.2: BrainState vs PyTorch/TensorFlow")
print("=" * 60)

comparison = [
    ("Feature", "PyTorch/TF", "BrainState/JAX"),
    ("-" * 20, "-" * 20, "-" * 20),
    ("Paradigm", "Imperative", "Functional"),
    ("State", "Mutable", "Explicit management"),
    ("Arrays", "Mutable tensors", "Immutable arrays"),
    ("Gradients", ".backward()", "jax.grad()"),
    ("JIT", "TorchScript/XLA", "Built-in JAX JIT"),
    ("Random Numbers", "Global state", "Explicit PRNG keys"),
    ("Device", "Manual .to()", "Automatic"),
    ("Focus", "General DL", "Brain modeling + DL"),
]

for row in comparison:
    print(f"{row[0]:<20} {row[1]:<20} {row[2]:<20}")

print("""
Choose BrainState when:
  ✓ Building brain-inspired models
  ✓ Need functional programming
  ✓ Want composable transformations
  ✓ Require explicit state control
  ✓ Working with JAX ecosystem
""")

### Q1.3: Do I need to know JAX to use BrainState?

**A:** Basic JAX knowledge is helpful but not required initially.

In [None]:
print("Q1.3: Do I need to know JAX?")
print("=" * 60)

print("""
Required JAX Knowledge:
  ✓ Basic: jnp arrays, basic operations
  ✓ Intermediate: jax.jit, jax.grad
  ○ Advanced: vmap, pmap, lax operations

BrainState abstracts many JAX details:
  ✓ State management (no manual PRNG keys)
  ✓ Neural network layers
  ✓ Gradient computation helpers

Start with BrainState, learn JAX as needed!
""")

# Example: You can use BrainState without deep JAX knowledge
model = bst.nn.Sequential(
    bst.nn.Linear(10, 20),
    bst.nn.Linear(20, 5)
)

x = bst.random.randn(3, 10)
output = model(x)
print(f"\nSimple usage without JAX details: {output.shape}")

## 2. Installation and Setup

### Q2.1: How do I install BrainState?

In [None]:
print("Q2.1: Installation")
print("=" * 60)

print("""
Installation Methods:

1. From PyPI (stable):
   pip install brainstate

2. From source (development):
   git clone https://github.com/chaobrain/brainstate
   cd brainstate
   pip install -e .

3. With GPU support:
   # First install JAX with GPU
   pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
   # Then install brainstate
   pip install brainstate

Verify installation:
""")

import brainstate as bst
print(f"BrainState version: {bst.__version__}")
print(f"JAX devices: {jax.devices()}")

### Q2.2: Common installation issues

In [None]:
print("Q2.2: Installation Issues")
print("=" * 60)

issues = [
    (
        "ImportError: No module named 'jax'",
        "Solution: pip install jax jaxlib"
    ),
    (
        "CUDA version mismatch",
        "Solution: Install JAX version matching your CUDA version"
    ),
    (
        "'module' object has no attribute '__version__'",
        "Solution: Reinstall brainstate: pip install --upgrade brainstate"
    ),
    (
        "Slow performance on CPU",
        "Solution: Use GPU/TPU or enable XLA optimizations"
    ),
]

for issue, solution in issues:
    print(f"\nIssue: {issue}")
    print(f"  {solution}")

## 3. API Usage Questions

### Q3.1: How do I define a custom layer?

In [None]:
print("Q3.1: Custom Layer")
print("=" * 60)

# Option 1: Inherit from brainstate.graph.Node
class CustomLayer(bst.graph.Node):
    def __init__(self, in_features, out_features):
        super().__init__()
        # Define parameters
        self.weight = bst.ParamState(
            bst.random.randn(in_features, out_features) * 0.01
        )
        self.bias = bst.ParamState(jnp.zeros(out_features))
    
    def __call__(self, x):
        return jnp.matmul(x, self.weight.value) + self.bias.value

# Test custom layer
layer = CustomLayer(10, 5)
x = bst.random.randn(3, 10)
output = layer(x)
print(f"Custom layer output shape: {output.shape}")

# Option 2: Use brainstate.nn.Dynamics for stateful layers
class StatefulLayer(bst.nn.Dynamics):
    def __init__(self, size):
        super().__init__()
        self.weight = bst.ParamState(bst.random.randn(size, size) * 0.01)
        self.state = bst.ShortTermState(jnp.zeros(size))
    
    def __call__(self, x):
        new_state = jnp.tanh(jnp.matmul(x, self.weight.value) + self.state.value)
        self.state.value = new_state
        return new_state

stateful = StatefulLayer(5)
x = bst.random.randn(3, 5)
output = stateful(x)
print(f"Stateful layer output shape: {output.shape}")

### Q3.2: How do I handle different state types?

In [None]:
print("Q3.2: State Types")
print("=" * 60)

print("""
State Types and Usage:

1. ParamState - Trainable parameters
   Use for: Weights, biases
   Example: self.weight = brainstate.ParamState(jnp.zeros((10, 5)))

2. ShortTermState - Temporary state
   Use for: Hidden states, activations
   Example: self.hidden = brainstate.ShortTermState(jnp.zeros(128))
   Reset between episodes/sequences

3. LongTermState - Accumulated state
   Use for: Running statistics, counters
   Example: self.running_mean = brainstate.LongTermState(jnp.zeros(10))
   Persists across training

4. HiddenState - Internal state (not exposed)
   Use for: Implementation details
   Example: self.cache = brainstate.HiddenState(None)
""")

# Example using all state types
class MultiStateModel(bst.graph.Node):
    def __init__(self, size):
        super().__init__()
        self.weight = bst.ParamState(bst.random.randn(size, size) * 0.01)
        self.hidden = bst.ShortTermState(jnp.zeros(size))
        self.counter = bst.LongTermState(jnp.array(0))
    
    def __call__(self, x):
        self.counter.value = self.counter.value + 1
        new_hidden = jnp.tanh(jnp.matmul(x, self.weight.value) + self.hidden.value)
        self.hidden.value = new_hidden
        return new_hidden

model = MultiStateModel(5)
x = bst.random.randn(1, 5)
for i in range(3):
    _ = model(x)
    print(f"Step {i+1}: counter = {model.counter.value}")

### Q3.3: How do I save and load models?

In [None]:
print("Q3.3: Save/Load Models")
print("=" * 60)

# Method 1: Save/load state dict
def save_model(model, filepath):
    """Save model parameters."""
    states = model.states(bst.ParamState)
    state_dict = {k: np.array(v.value) for k, v in states.items()}
    np.savez(filepath, **state_dict)
    print(f"Model saved to {filepath}")

def load_model(model, filepath):
    """Load model parameters."""
    loaded = np.load(filepath)
    states = model.states(bst.ParamState)
    for k in loaded.files:
        if k in states:
            states[k].value = jnp.array(loaded[k])
    print(f"Model loaded from {filepath}")

# Example
model = CustomLayer(10, 5)
x = bst.random.randn(1, 10)
_ = model(x)

# Save
save_model(model, "model.npz")

# Load into new model
new_model = CustomLayer(10, 5)
_ = new_model(x)
load_model(new_model, "model.npz")

print("\nSee Tutorial 22 for complete save/load examples")

## 4. Common Error Messages

### Q4.1: "ConcretizationTypeError: Abstract tracer value"

In [None]:
print("Q4.1: Tracer Error")
print("=" * 60)

print("""
Error: ConcretizationTypeError: Abstract tracer value encountered

Cause: Using traced values in Python control flow inside JIT

Example that causes error:
  @jax.jit
  def bad_function(x):
      if x > 0:  # ERROR: Can't use traced x in if!
          return x * 2
      return x * 3

Solutions:
  1. Use jax.lax.cond:
     return jax.lax.cond(x > 0, lambda x: x*2, lambda x: x*3, x)
  
  2. Use jnp.where for element-wise:
     return jnp.where(x > 0, x*2, x*3)
  
  3. Use static_argnums:
     @jax.jit(static_argnums=(1,))
     def func(x, condition):
         if condition:  # OK: condition is static
             return x * 2
         return x * 3
""")

### Q4.2: "TypeError: Shapes must match"

In [None]:
print("Q4.2: Shape Mismatch")
print("=" * 60)

print("""
Error: TypeError: Shapes must match but got...

Common causes:
  1. Matrix multiplication dimension mismatch
  2. Broadcasting incompatibility
  3. Wrong input shape to layer

Debugging steps:
  1. Print shapes:
     print(f"x.shape = {x.shape}")
     print(f"y.shape = {y.shape}")
  
  2. Check expected vs actual:
     expected_shape = (batch_size, features)
     assert x.shape == expected_shape
  
  3. Use reshape if needed:
     x = x.reshape(batch_size, -1)  # Flatten
  
  4. Check broadcasting rules:
     # (3, 1) + (4,) -> Error!
     # (3, 4) + (4,) -> OK, broadcasts to (3, 4)
""")

# Example fix
x = jnp.array([[1, 2, 3]])  # Shape: (1, 3)
y = jnp.array([10, 20, 30])  # Shape: (3,)

print(f"\nx.shape = {x.shape}")
print(f"y.shape = {y.shape}")

result = x + y  # Broadcasting works
print(f"result.shape = {result.shape}")
print(f"result = {result}")

### Q4.3: "RuntimeError: Invalid argument"

In [None]:
print("Q4.3: Invalid Argument Error")
print("=" * 60)

print("""
Error: RuntimeError: Invalid argument

Common causes:
  1. NaN or Inf in computation
  2. Division by zero
  3. Invalid operation (e.g., sqrt of negative)
  4. dtype mismatch

Solutions:
  1. Check for NaN/Inf:
     assert not jnp.any(jnp.isnan(x))
     assert not jnp.any(jnp.isinf(x))
  
  2. Enable NaN debugging:
     jax.config.update('jax_debug_nans', True)
  
  3. Add numerical stability:
     # Bad: x / y
     # Good: x / (y + 1e-8)
  
  4. Clip values:
     x = jnp.clip(x, min_val, max_val)
""")

## 5. Performance Questions

### Q5.1: Why is my code slow?

In [None]:
print("Q5.1: Performance Issues")
print("=" * 60)

print("""
Common performance issues:

1. Not using JIT:
   ❌ def slow(x): return x * 2
   ✓ @jax.jit
     def fast(x): return x * 2

2. Unnecessary recompilation:
   - Changing input shapes
   - Non-static control flow arguments
   Solution: Use static_argnums

3. Host-device transfers:
   ❌ for i in range(1000):
       x = np.array(jax_array)  # Slow!
   ✓ Keep data on device

4. Python loops:
   ❌ for i in range(n):
       result = func(result)
   ✓ jax.lax.fori_loop or jax.lax.scan

5. Small batch sizes:
   - GPU/TPU need large batches
   - Benchmark different batch sizes

See Tutorial 25 for detailed optimization guide.
""")

### Q5.2: How do I use GPU/TPU?

In [None]:
print("Q5.2: GPU/TPU Usage")
print("=" * 60)

print("""
GPU/TPU Setup:

1. Check available devices:
""")
print(f"   Devices: {jax.devices()}")

print("""
2. JAX automatically uses best device:
   # No need to manually move tensors!
   x = jnp.array([1, 2, 3])  # Automatically on GPU if available

3. Force specific device:
   with jax.default_device(jax.devices('gpu')[0]):
       x = jnp.array([1, 2, 3])

4. Multi-device with pmap:
   @jax.pmap
   def parallel_func(x):
       return x * 2

5. For GPU, install CUDA-enabled JAX:
   pip install --upgrade "jax[cuda12_pip]"

6. Verify GPU usage:
   - Check nvidia-smi during execution
   - Use JAX profiler
""")

## 6. Best Practices

### Q6.1: What are BrainState best practices?

In [None]:
print("Q6.1: Best Practices")
print("=" * 60)

best_practices = [
    ("1. State Management", [
        "Use appropriate state types (Param, ShortTerm, LongTerm)",
        "Reset ShortTermState between episodes",
        "Initialize models with dummy input",
    ]),
    ("2. Performance", [
        "Always use @jax.jit for production",
        "Use vmap instead of Python loops",
        "Profile before optimizing",
        "Batch operations when possible",
    ]),
    ("3. Debugging", [
        "Print shapes early and often",
        "Use jax.debug.print inside JIT",
        "Enable NaN checking during development",
        "Disable JIT temporarily for debugging",
    ]),
    ("4. Code Organization", [
        "Inherit from brainstate.graph.Node for models",
        "Use brainstate.nn.Dynamics for stateful layers",
        "Keep models modular and composable",
        "Document expected shapes",
    ]),
    ("5. Gradients", [
        "Use brainstate.transform.grad with grad_states",
        "Monitor gradient magnitudes",
        "Implement gradient clipping if needed",
        "Check for NaN/Inf in gradients",
    ]),
]

for category, practices in best_practices:
    print(f"\n{category}:")
    for practice in practices:
        print(f"  ✓ {practice}")

## 7. Community and Resources

### Q7.1: Where can I get help?

In [None]:
print("Q7.1: Community Resources")
print("=" * 60)

resources = [
    ("Documentation", [
        ("BrainState Docs", "https://brainstate.readthedocs.io/"),
        ("BrainPy Docs", "https://brainpy.readthedocs.io/"),
        ("JAX Docs", "https://jax.readthedocs.io/"),
    ]),
    ("Code & Issues", [
        ("BrainState GitHub", "https://github.com/chaobrain/brainstate"),
        ("Report Issues", "https://github.com/chaobrain/brainstate/issues"),
        ("Contribute", "https://github.com/chaobrain/brainstate/pulls"),
    ]),
    ("Learning", [
        ("BrainState Tutorials", "docs/tutorials/"),
        ("Examples", "examples/"),
        ("JAX Tutorial", "https://jax.readthedocs.io/en/latest/notebooks/quickstart.html"),
    ]),
    ("Community", [
        ("GitHub Discussions", "https://github.com/chaobrain/brainstate/discussions"),
        ("Google Group", "Search for BrainPy/BrainState groups"),
    ]),
]

for category, links in resources:
    print(f"\n{category}:")
    for name, url in links:
        print(f"  • {name}: {url}")

### Q7.2: How do I report a bug?

In [None]:
print("Q7.2: Reporting Bugs")
print("=" * 60)

print("""
When reporting bugs, include:

1. Environment Information:
   - BrainState version: brainstate.__version__
   - JAX version: jax.__version__
   - Python version
   - Operating system
   - GPU/TPU info (if applicable)

2. Minimal Reproducible Example:
   import brainstate as brainstate
   import jax.numpy as jnp
   
   # Minimal code that reproduces the bug
   model = brainstate.nn.Linear(10, 5)
   x = brainstate.random.randn(2, 10)
   # ... error happens here

3. Error Message:
   - Full traceback
   - Error type and message

4. Expected vs Actual Behavior:
   - What you expected to happen
   - What actually happened

5. Additional Context:
   - What you've tried
   - Related issues (if any)

Submit to: https://github.com/chaobrain/brainstate/issues
""")

# Example environment info
print("\nExample environment info:")
print(f"BrainState: {bst.__version__}")
print(f"JAX: {jax.__version__}")
print(f"Devices: {jax.devices()}")

## 8. Advanced Topics

### Q8.1: How do I integrate with other JAX libraries?

In [None]:
print("Q8.1: Integration with JAX Ecosystem")
print("=" * 60)

print("""
BrainState works with JAX ecosystem:

1. Optax (Optimizers):
   import optax
   
   optimizer = optax.adam(learning_rate=1e-3)
   opt_state = optimizer.init(params)
   
   # In training loop:
   updates, opt_state = optimizer.update(grads, opt_state)
   params = optax.apply_updates(params, updates)

2. Flax (Compatible patterns):
   - BrainState models can be used with Flax utilities
   - State management is explicit in BrainState

3. JAX-MD (Molecular dynamics):
   - Use JAX-MD for physics
   - Use BrainState for neural control

4. Haiku (Neural networks):
   - Can mix Haiku and BrainState layers
   - State management differs

5. Custom JAX code:
   - Any JAX function works with BrainState
   - Just follow functional programming principles
""")

### Q8.2: Can I use BrainState for production?

In [None]:
print("Q8.2: Production Deployment")
print("=" * 60)

print("""
Production Readiness:

✓ Yes, BrainState can be used in production:

Advantages:
  • Built on stable JAX backend
  • JIT compilation for performance
  • GPU/TPU acceleration
  • Reproducible (explicit PRNG)
  • Save/load model states

Considerations:
  • Ensure thorough testing
  • Monitor for numerical issues
  • Version pin dependencies
  • Benchmark performance
  • Have rollback strategy

Production Checklist:
  1. ✓ Model serialization implemented
  2. ✓ Inference optimized (JIT, batching)
  3. ✓ Error handling in place
  4. ✓ Monitoring and logging
  5. ✓ Performance benchmarks met
  6. ✓ Unit and integration tests
  7. ✓ Documentation complete

See Tutorial 22 for deployment guide.
""")

## 9. Quick Reference

### Common Commands Cheat Sheet

In [None]:
print("Quick Reference")
print("=" * 80)

cheatsheet = [
    ("Imports", [
        "import brainstate as brainstate",
        "import jax.numpy as jnp",
        "import jax",
    ]),
    ("Model Definition", [
        "class Model(brainstate.graph.Node): ...",
        "class Dynamics(brainstate.nn.Dynamics): ...",
    ]),
    ("State Types", [
        "brainstate.ParamState(init_value)",
        "brainstate.ShortTermState(init_value)",
        "brainstate.LongTermState(init_value)",
    ]),
    ("Layers", [
        "brainstate.nn.Linear(in_features, out_features)",
        "brainstate.nn.Conv2d(in_ch, out_ch, kernel_size)",
        "brainstate.nn.BatchNorm2d(num_features)",
    ]),
    ("Activations", [
        "jax.nn.relu(x)",
        "jax.nn.sigmoid(x)",
        "jax.nn.softmax(x, axis=-1)",
    ]),
    ("Transformations", [
        "@brainstate.transform.jit",
        "brainstate.transform.grad(fn, grad_states=params)",
        "jax.vmap(fn, in_axes=0)",
    ]),
    ("Random Numbers", [
        "brainstate.random.seed(42)",
        "brainstate.random.randn(shape)",
        "brainstate.random.rand(shape)",
    ]),
    ("Debugging", [
        "jax.debug.print('x = {}', x)",
        "with jax.disable_jit(): ...",
        "jax.config.update('jax_debug_nans', True)",
    ]),
]

for category, commands in cheatsheet:
    print(f"\n{category}:")
    for cmd in commands:
        print(f"  {cmd}")

## Summary

This FAQ covered:

1. **General Questions**: What is BrainState, differences from PyTorch
2. **Installation**: Setup and common issues
3. **API Usage**: Custom layers, state types, save/load
4. **Error Messages**: Common errors and solutions
5. **Performance**: Optimization tips, GPU usage
6. **Best Practices**: Code organization, debugging, gradients
7. **Community**: Where to get help, reporting bugs
8. **Advanced Topics**: JAX ecosystem, production deployment
9. **Quick Reference**: Command cheat sheet

### Key Resources:

- **Documentation**: https://brainstate.readthedocs.io/
- **GitHub**: https://github.com/chaobrain/brainstate
- **Issues**: https://github.com/chaobrain/brainstate/issues
- **Tutorials**: Complete tutorial series in docs/tutorials/

### Getting Help:

1. Check this FAQ first
2. Search existing GitHub issues
3. Read relevant tutorials
4. Ask in GitHub Discussions
5. Report bugs with reproducible examples

Happy coding with BrainState! 🧠✨