# Tutorial 26: Debugging Tips and Techniques

In this tutorial, we'll explore debugging strategies for BrainState and JAX code.

## Learning Objectives

By the end of this tutorial, you will:
- Understand common JAX/BrainState errors
- Use JAX debugging tools effectively
- Debug shape mismatches and type errors
- Handle tracer errors
- Identify and fix NaN/Inf issues
- Debug JIT compilation problems
- Use visualization for debugging

## Introduction

Debugging functional code with transformations (JIT, vmap, grad) requires different approaches than traditional imperative debugging.

Common challenges:
- **Abstract tracers**: Values during JIT compilation
- **Shape errors**: Dimension mismatches
- **Type errors**: Incompatible dtypes
- **NaN/Inf propagation**: Numerical instability
- **State management**: Tracking state updates

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from typing import Any

bst.random.seed(42)

## 1. Common Errors and Solutions

### 1.1 Shape Mismatch Errors

In [None]:
print("Error 1: Shape Mismatch")
print("=" * 60)

# ❌ WRONG: Shape mismatch
print("\n❌ Common mistake:")
print("""
x = jnp.array([[1, 2, 3]])  # Shape: (1, 3)
y = jnp.array([1, 2])       # Shape: (2,)
result = x + y  # ERROR: Incompatible shapes!
""")

# ✓ SOLUTION 1: Check shapes before operations
print("✓ Solution 1: Check shapes")
x = jnp.array([[1, 2, 3]])
y = jnp.array([1, 2])

print(f"x.shape = {x.shape}")
print(f"y.shape = {y.shape}")
print("Shapes incompatible for addition!")

# ✓ SOLUTION 2: Reshape to match
print("\n✓ Solution 2: Reshape")
y_reshaped = jnp.array([1, 2, 3])  # Match x's second dimension
result = x + y_reshaped
print(f"x.shape = {x.shape}")
print(f"y_reshaped.shape = {y_reshaped.shape}")
print(f"result.shape = {result.shape}")
print(f"result = {result}")

# ✓ SOLUTION 3: Use broadcasting correctly
print("\n✓ Solution 3: Broadcasting")
x = jnp.array([[1, 2, 3], [4, 5, 6]])  # (2, 3)
y = jnp.array([10, 20, 30])            # (3,)
result = x + y  # Broadcasting: (2, 3) + (3,) -> (2, 3)
print(f"x.shape = {x.shape}")
print(f"y.shape = {y.shape}")
print(f"result.shape = {result.shape}")
print(f"result = \n{result}")

### 1.2 Tracer Errors

In [None]:
print("\nError 2: Tracer Errors")
print("=" * 60)

# ❌ WRONG: Using traced values in Python control flow
print("\n❌ Common mistake:")
print("""
@jax.jit
def wrong_function(x):
    if x > 0:  # ERROR: Can't use traced value in if!
        return x * 2
    else:
        return x * 3
""")

# ✓ SOLUTION 1: Use jax.lax.cond
print("✓ Solution 1: Use jax.lax.cond")
@jax.jit
def correct_function(x):
    return jax.lax.cond(
        x > 0,
        lambda x: x * 2,
        lambda x: x * 3,
        x
    )

result = correct_function(5.0)
print(f"correct_function(5.0) = {result}")

result = correct_function(-3.0)
print(f"correct_function(-3.0) = {result}")

# ✓ SOLUTION 2: Use jnp.where for element-wise conditionals
print("\n✓ Solution 2: Use jnp.where")
@jax.jit
def vectorized_function(x):
    return jnp.where(x > 0, x * 2, x * 3)

x = jnp.array([-2, -1, 0, 1, 2])
result = vectorized_function(x)
print(f"Input: {x}")
print(f"Output: {result}")

# ✓ SOLUTION 3: Use static_argnums for control flow arguments
print("\n✓ Solution 3: static_argnums")
@jax.jit(static_argnums=(1,))
def function_with_static(x, mode):
    if mode == "double":
        return x * 2
    else:
        return x * 3

print(f"function_with_static(5.0, 'double') = {function_with_static(5.0, 'double')}")
print(f"function_with_static(5.0, 'triple') = {function_with_static(5.0, 'triple')}")

### 1.3 NaN and Inf Detection

In [None]:
print("Error 3: NaN and Inf Values")
print("=" * 60)

# Common sources of NaN/Inf
print("\nCommon sources:")
print("1. Division by zero")
print("2. Logarithm of zero or negative numbers")
print("3. Numerical overflow")
print("4. Invalid mathematical operations")

# Detection utilities
def check_nan_inf(x: jnp.ndarray, name: str = "tensor") -> bool:
    """
    Check for NaN or Inf values.
    
    Args:
        x: Array to check
        name: Name for error message
        
    Returns:
        True if any NaN or Inf found
    """
    has_nan = jnp.any(jnp.isnan(x))
    has_inf = jnp.any(jnp.isinf(x))
    
    if has_nan:
        print(f"⚠️  WARNING: {name} contains NaN values!")
        return True
    if has_inf:
        print(f"⚠️  WARNING: {name} contains Inf values!")
        return True
    
    return False

# Example: Division by zero
print("\nExample 1: Division by zero")
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([2.0, 0.0, 1.0])
result = x / y
print(f"x / y = {result}")
check_nan_inf(result, "division result")

# Solution: Add epsilon
print("\n✓ Solution: Add small epsilon")
epsilon = 1e-8
result_safe = x / (y + epsilon)
print(f"x / (y + epsilon) = {result_safe}")
check_nan_inf(result_safe, "safe division")

# Example: Log of zero
print("\nExample 2: Logarithm of zero")
x = jnp.array([1.0, 0.0, 2.0])
result = jnp.log(x)
print(f"log(x) = {result}")
check_nan_inf(result, "log result")

# Solution: Clipping
print("\n✓ Solution: Clip values")
result_safe = jnp.log(jnp.maximum(x, epsilon))
print(f"log(max(x, epsilon)) = {result_safe}")
check_nan_inf(result_safe, "safe log")

# Enable NaN debugging
print("\n✓ Enable JAX NaN checking (for development):")
print("""
# At the start of your script:
jax.config.update('jax_debug_nans', True)

# This will raise an error immediately when NaN is created
""")

### 1.4 Type Errors

In [None]:
print("Error 4: Type Mismatches")
print("=" * 60)

# ❌ Common mistake: Mixing dtypes
print("\n❌ Common mistake:")
x_int = jnp.array([1, 2, 3], dtype=jnp.int32)
y_float = jnp.array([1.5, 2.5, 3.5], dtype=jnp.float32)
print(f"x dtype: {x_int.dtype}")
print(f"y dtype: {y_float.dtype}")

result = x_int + y_float  # Implicit cast
print(f"result dtype: {result.dtype}")
print(f"result: {result}")

# ✓ Best practice: Explicit casting
print("\n✓ Solution: Explicit casting")
x_float = x_int.astype(jnp.float32)
result = x_float + y_float
print(f"result dtype: {result.dtype}")
print(f"result: {result}")

# Check dtypes
def check_dtype(x: jnp.ndarray, expected_dtype: jnp.dtype, name: str = "tensor"):
    """Check if array has expected dtype."""
    if x.dtype != expected_dtype:
        print(f"⚠️  WARNING: {name} has dtype {x.dtype}, expected {expected_dtype}")
        return False
    return True

print("\nType checking:")
check_dtype(x_int, jnp.float32, "x_int")
check_dtype(y_float, jnp.float32, "y_float")

## 2. JAX Debugging Tools

### 2.1 Debug Print

In [None]:
print("JAX Debug Printing")
print("=" * 60)

# Problem: Regular print doesn't work inside JIT
print("\n❌ Regular print in JIT:")
@jax.jit
def function_with_print(x):
    print(f"x = {x}")  # This prints during compilation, not execution!
    return x * 2

print("First call (compiles):")
result = function_with_print(5.0)
print("\nSecond call (cached):")
result = function_with_print(10.0)
print("Notice: print only happened during compilation!\n")

# ✓ Solution: jax.debug.print
print("✓ Solution: jax.debug.print")
@jax.jit
def function_with_debug_print(x):
    jax.debug.print("x = {}", x)
    return x * 2

print("First call:")
result = function_with_debug_print(5.0)
print("\nSecond call:")
result = function_with_debug_print(10.0)
print("\njax.debug.print works during execution!")

# Advanced: Conditional debug printing
print("\n✓ Conditional debug printing:")
@jax.jit
def debug_conditional(x):
    jax.debug.print("Input: x = {}", x)
    result = jnp.sqrt(x)
    jax.debug.print("Output: sqrt(x) = {}", result)
    
    # Print only if result is large
    jax.lax.cond(
        result > 5.0,
        lambda: jax.debug.print("⚠️  Large result: {}", result),
        lambda: None,
    )
    return result

print("\nTesting with small input:")
_ = debug_conditional(4.0)
print("\nTesting with large input:")
_ = debug_conditional(36.0)

### 2.2 Disabling JIT for Debugging

In [None]:
print("Disabling JIT for Debugging")
print("=" * 60)

@jax.jit
def complex_function(x):
    # Complex computation
    y = jnp.sin(x)
    z = jnp.exp(y)
    return jnp.sum(z)

# Option 1: Temporarily disable JIT globally
print("\nOption 1: Disable JIT globally")
print("with jax.disable_jit():")
print("    result = complex_function(x)")
print("    # Now you can use print(), debugger, etc.")

x = jnp.array([1.0, 2.0, 3.0])

with jax.disable_jit():
    result = complex_function(x)
    print(f"Result (no JIT): {result}")

# Option 2: Comment out @jax.jit decorator
print("\nOption 2: Remove decorator temporarily")
print("# @jax.jit  # Commented out")
print("def complex_function(x):")
print("    ...")

# Option 3: Conditional JIT
print("\nOption 3: Conditional JIT")
DEBUG = True

def maybe_jit(func):
    if DEBUG:
        return func
    else:
        return jax.jit(func)

@maybe_jit
def debuggable_function(x):
    print(f"Debug: x = {x}")  # Works when DEBUG=True
    return x * 2

result = debuggable_function(5.0)
print(f"Result: {result}")

### 2.3 Inspecting Computation Graphs

In [None]:
print("Inspecting JAX Computations")
print("=" * 60)

def simple_function(x, y):
    return jnp.sin(x) + jnp.cos(y)

# Get the jaxpr (JAX expression)
print("\nJAXPR (JAX intermediate representation):")
x = jnp.array(1.0)
y = jnp.array(2.0)

jaxpr = jax.make_jaxpr(simple_function)(x, y)
print(jaxpr)

print("\n✓ Use make_jaxpr to understand what JAX is doing internally")
print("  Useful for:")
print("  - Understanding transformations")
print("  - Debugging shape inference")
print("  - Verifying optimizations")

## 3. Debugging Model Code

### 3.1 Shape Debugging

In [None]:
print("Model Shape Debugging")
print("=" * 60)

class DebuggableModel(bst.graph.Node):
    """Model with shape debugging."""
    
    def __init__(self, input_dim, hidden_dim, output_dim, debug=True):
        super().__init__()
        self.fc1 = bst.nn.Linear(input_dim, hidden_dim)
        self.fc2 = bst.nn.Linear(hidden_dim, output_dim)
        self.debug = debug
    
    def __call__(self, x):
        if self.debug:
            print(f"Input shape: {x.shape}")
        
        x = self.fc1(x)
        if self.debug:
            print(f"After fc1: {x.shape}")
        
        x = jax.nn.relu(x)
        if self.debug:
            print(f"After relu: {x.shape}")
        
        x = self.fc2(x)
        if self.debug:
            print(f"Output shape: {x.shape}")
        
        return x

# Test with debug mode
print("\nWith debug=True:")
model = DebuggableModel(10, 20, 5, debug=True)
x = bst.random.randn(3, 10)
output = model(x)

# Turn off debug for production
print("\nWith debug=False:")
model.debug = False
output = model(x)
print(f"Final output shape: {output.shape}")

### 3.2 Gradient Debugging

In [None]:
print("Gradient Debugging")
print("=" * 60)

class SimpleModel(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.fc = bst.nn.Linear(10, 5)
    
    def __call__(self, x):
        return self.fc(x)

model = SimpleModel()
x = bst.random.randn(2, 10)
_ = model(x)

# Define loss
def loss_fn(x):
    output = model(x)
    return jnp.sum(output ** 2)

# Compute gradients
params = model.states(bst.ParamState)
loss, grads = bst.transform.grad(loss_fn, grad_states=params, return_value=True)(x)

print("\nGradient Statistics:")
for name, grad in grads.items():
    print(f"\n{name}:")
    print(f"  Shape: {grad.shape}")
    print(f"  Mean: {jnp.mean(grad):.6f}")
    print(f"  Std: {jnp.std(grad):.6f}")
    print(f"  Max: {jnp.max(jnp.abs(grad)):.6f}")
    
    # Check for problems
    if jnp.any(jnp.isnan(grad)):
        print("  ⚠️  Contains NaN!")
    if jnp.any(jnp.isinf(grad)):
        print("  ⚠️  Contains Inf!")
    if jnp.max(jnp.abs(grad)) > 10.0:
        print("  ⚠️  Potentially exploding gradient!")
    if jnp.max(jnp.abs(grad)) < 1e-7:
        print("  ⚠️  Potentially vanishing gradient!")

### 3.3 State Debugging

In [None]:
print("State Debugging")
print("=" * 60)

class StatefulModel(bst.graph.Node):
    def __init__(self, size):
        super().__init__()
        self.fc = bst.nn.Linear(size, size)
        self.hidden = bst.ShortTermState(jnp.zeros(size))
        self.counter = bst.LongTermState(jnp.array(0))
    
    def __call__(self, x):
        # Update counter
        self.counter.value = self.counter.value + 1
        
        # Update hidden state
        new_hidden = jax.nn.tanh(self.fc(x) + self.hidden.value)
        self.hidden.value = new_hidden
        
        return new_hidden
    
    def debug_states(self):
        """Print state information."""
        print("\nState Debug Info:")
        print(f"Counter: {self.counter.value}")
        print(f"Hidden state:")
        print(f"  Shape: {self.hidden.value.shape}")
        print(f"  Mean: {jnp.mean(self.hidden.value):.6f}")
        print(f"  Std: {jnp.std(self.hidden.value):.6f}")
        print(f"  Range: [{jnp.min(self.hidden.value):.6f}, {jnp.max(self.hidden.value):.6f}]")

# Test stateful model
model = StatefulModel(size=10)
x = bst.random.randn(1, 10)

print("Initial state:")
model.debug_states()

for i in range(3):
    _ = model(x)
    print(f"\nAfter step {i+1}:")
    model.debug_states()

## 4. Visualization for Debugging

### 4.1 Weight Distribution

In [None]:
print("Visualizing Weight Distributions")
print("=" * 60)

def visualize_weights(model: bst.graph.Node):
    """Visualize weight distributions."""
    params = model.states(bst.ParamState)
    
    num_params = len(params)
    fig, axes = plt.subplots(1, min(num_params, 4), figsize=(15, 3))
    if num_params == 1:
        axes = [axes]
    
    for idx, (name, param) in enumerate(list(params.items())[:4]):
        ax = axes[idx]
        weights = param.value.flatten()
        
        ax.hist(np.array(weights), bins=50, alpha=0.7, edgecolor='black')
        ax.set_title(f"{name}\nMean: {jnp.mean(weights):.4f}")
        ax.set_xlabel('Value')
        ax.set_ylabel('Count')
        ax.grid(True, alpha=0.3)
        
        # Add statistics
        ax.axvline(jnp.mean(weights), color='red', linestyle='--', 
                   linewidth=2, label='Mean')
        ax.legend()
    
    plt.tight_layout()
    plt.show()

# Create and visualize model
model = SimpleModel()
x = bst.random.randn(1, 10)
_ = model(x)

print("\nWeight distributions:")
visualize_weights(model)

### 4.2 Activation Visualization

In [None]:
print("Visualizing Activations")
print("=" * 60)

class ActivationTracker(bst.graph.Node):
    """Model that tracks activations."""
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = bst.nn.Linear(input_dim, hidden_dim)
        self.fc2 = bst.nn.Linear(hidden_dim, output_dim)
        self.activations = []
    
    def __call__(self, x):
        self.activations = []
        
        self.activations.append(('input', x))
        
        x = self.fc1(x)
        self.activations.append(('fc1_pre', x))
        
        x = jax.nn.relu(x)
        self.activations.append(('fc1_post', x))
        
        x = self.fc2(x)
        self.activations.append(('output', x))
        
        return x
    
    def plot_activations(self):
        """Plot activation statistics."""
        names = [name for name, _ in self.activations]
        means = [float(jnp.mean(act)) for _, act in self.activations]
        stds = [float(jnp.std(act)) for _, act in self.activations]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        x_pos = range(len(names))
        
        ax1.bar(x_pos, means, alpha=0.7, edgecolor='black')
        ax1.set_xticks(x_pos)
        ax1.set_xticklabels(names, rotation=45, ha='right')
        ax1.set_ylabel('Mean Activation')
        ax1.set_title('Mean Activations')
        ax1.grid(True, alpha=0.3, axis='y')
        
        ax2.bar(x_pos, stds, alpha=0.7, color='orange', edgecolor='black')
        ax2.set_xticks(x_pos)
        ax2.set_xticklabels(names, rotation=45, ha='right')
        ax2.set_ylabel('Std Activation')
        ax2.set_title('Activation Standard Deviation')
        ax2.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        plt.show()

# Test activation tracking
model = ActivationTracker(10, 20, 5)
x = bst.random.randn(32, 10)
output = model(x)

print("\nActivation statistics:")
for name, act in model.activations:
    print(f"{name:12s}: shape={str(act.shape):12s} "
          f"mean={jnp.mean(act):7.4f} std={jnp.std(act):7.4f}")

model.plot_activations()

## 5. Debugging Checklist

In [None]:
print("Debugging Checklist")
print("=" * 80)

checklist = [
    ("Shape Errors", [
        "Print shapes at each step",
        "Check input/output dimensions",
        "Verify broadcasting behavior",
        "Use .shape attribute liberally",
    ]),
    ("Type Errors", [
        "Check dtypes (float32, int32, etc.)",
        "Explicit casting when needed",
        "Consistent types throughout",
    ]),
    ("NaN/Inf Issues", [
        "Enable jax_debug_nans during development",
        "Check for division by zero",
        "Validate log inputs (> 0)",
        "Monitor gradient magnitudes",
        "Use gradient clipping",
    ]),
    ("JIT/Tracer Errors", [
        "Use jax.lax.cond instead of if",
        "Use jax.lax.fori_loop instead of for",
        "Mark control flow args as static",
        "Use jax.debug.print for JIT",
        "Disable JIT temporarily with disable_jit()",
    ]),
    ("State Management", [
        "Verify state updates happen",
        "Check state types (Param vs ShortTerm)",
        "Reset states when needed",
        "Track state statistics",
    ]),
    ("Performance Issues", [
        "Profile code to find bottlenecks",
        "Check for unnecessary recompilation",
        "Verify JIT is being used",
        "Monitor memory usage",
    ]),
]

for category, items in checklist:
    print(f"\n{category}:")
    for item in items:
        print(f"  ✓ {item}")

print("\n" + "=" * 80)
print("Quick Debug Commands:")
print("=" * 80)
commands = [
    ("Check shape", "print(f'Shape: {x.shape}')"),
    ("Check dtype", "print(f'Dtype: {x.dtype}')"),
    ("Check for NaN", "assert not jnp.any(jnp.isnan(x))"),
    ("Check for Inf", "assert not jnp.any(jnp.isinf(x))"),
    ("Check range", "print(f'Range: [{jnp.min(x)}, {jnp.max(x)}]')"),
    ("Debug in JIT", "jax.debug.print('Value: {}', x)"),
    ("Disable JIT", "with jax.disable_jit(): ..."),
    ("Enable NaN check", "jax.config.update('jax_debug_nans', True)"),
]

for desc, cmd in commands:
    print(f"{desc:20s}: {cmd}")

## Summary

In this tutorial, we covered:

1. **Common Errors**:
   - Shape mismatches and broadcasting
   - Tracer errors in JIT
   - NaN and Inf detection
   - Type mismatches

2. **JAX Debugging Tools**:
   - jax.debug.print for JIT
   - Disabling JIT temporarily
   - Inspecting JAXPRs

3. **Model Debugging**:
   - Shape debugging
   - Gradient checking
   - State monitoring

4. **Visualization**:
   - Weight distributions
   - Activation patterns
   - Statistical analysis

5. **Debugging Checklist**:
   - Systematic approach
   - Quick commands
   - Best practices

### Key Takeaways:

- **Use jax.debug.print** inside JIT functions
- **Check shapes early and often**
- **Enable NaN debugging** during development
- **Disable JIT temporarily** for detailed debugging
- **Visualize** weights and activations
- **Monitor gradients** for training issues

## Next Steps

- Practice debugging on your own code
- Build debugging utilities for your workflow
- Learn JAX internals for advanced debugging

For more information:
- [JAX Debugging Guide](https://jax.readthedocs.io/en/latest/debugging/index.html)
- [BrainState Documentation](https://brainstate.readthedocs.io/)