# Tutorial 15: Mixin System and Computation Modes

In this tutorial, we'll explore BrainState's mixin system, which provides powerful mechanisms for controlling computation behavior through modes.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Understand the Mixin base class and its purpose
- Use Mode for computation behavior control
- Work with JointMode for combining modes
- Apply Batching mode for batch operations
- Use Training mode for train/eval switching
- Create custom mixins for specialized behavior
- Design mode-aware neural network components

## What are Mixins?

Mixins are a design pattern that allows classes to share behavior without traditional inheritance. In BrainState, mixins control:
- **Computation modes**: Training vs evaluation, batching behavior
- **Dynamic behavior**: Change functionality based on context
- **Composability**: Combine multiple behaviors

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# Set random seed
bst.random.seed(42)

## 1. The Mode System

`Mode` is the foundation for controlling computation behavior.

In [None]:
# Understanding Mode basics
from brainstate.mixin import Mode

# Mode controls behavior through flags
print("Mode System Basics:")
print("=" * 60)

# Check current mode
print(f"Current fit mode: {bst.environ.get('fit', default=True)}")

# Change mode temporarily
with bst.environ.context(fit=False):
    print(f"Inside context - fit mode: {bst.environ.get('fit')}")

print(f"After context - fit mode: {bst.environ.get('fit', default=True)}")

### Mode-Aware Components

In [None]:
# Create mode-aware dropout
class ModeAwareDropout(bst.graph.Node):
    """Dropout that respects training mode."""
    
    def __init__(self, rate=0.5):
        super().__init__()
        self.rate = rate
    
    def __call__(self, x):
        # Check if we're in training mode
        is_training = bst.environ.get('fit', default=True)
        
        if is_training:
            # Apply dropout during training
            keep_prob = 1 - self.rate
            mask = bst.random.bernoulli(keep_prob, x.shape)
            return jnp.where(mask, x / keep_prob, 0)
        else:
            # No dropout during evaluation
            return x

# Test mode-aware dropout
dropout = ModeAwareDropout(rate=0.5)
x = jnp.ones((5, 10))

print("Training mode (dropout active):")
with bst.environ.context(fit=True):
    out_train = dropout(x)
    print(f"  Non-zero elements: {jnp.sum(out_train != 0)} / {out_train.size}")
    print(f"  Mean: {jnp.mean(out_train):.3f} (should be ~1.0)")

print("\nEvaluation mode (dropout disabled):")
with bst.environ.context(fit=False):
    out_eval = dropout(x)
    print(f"  Non-zero elements: {jnp.sum(out_eval != 0)} / {out_eval.size}")
    print(f"  Mean: {jnp.mean(out_eval):.3f}")

## 2. Training Mode

The `Training` mixin provides train/eval mode switching functionality.

In [None]:
# Training mode in practice
class TrainingAwareNetwork(bst.graph.Node):
    """Network that behaves differently in train vs eval mode."""
    
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.3):
        super().__init__()
        self.linear1 = bst.nn.Linear(input_dim, hidden_dim)
        self.dropout = bst.nn.Dropout(dropout_rate)
        self.batchnorm = bst.nn.BatchNorm1d(hidden_dim)
        self.linear2 = bst.nn.Linear(hidden_dim, output_dim)
    
    def __call__(self, x):
        x = self.linear1(x)
        x = self.batchnorm(x)  # BatchNorm behaves differently in train/eval
        x = jax.nn.relu(x)
        x = self.dropout(x)    # Dropout only active in training
        x = self.linear2(x)
        return x

# Create network
net = TrainingAwareNetwork(input_dim=10, hidden_dim=20, output_dim=5)
x = bst.random.randn(32, 10)

# Training mode
print("Training Mode:")
with bst.environ.context(fit=True):
    out_train1 = net(x)
    out_train2 = net(x)
    # Outputs differ due to dropout randomness
    diff_train = jnp.mean(jnp.abs(out_train1 - out_train2))
    print(f"  Output variance (dropout): {diff_train:.4f}")

# Evaluation mode
print("\nEvaluation Mode:")
with bst.environ.context(fit=False):
    out_eval1 = net(x)
    out_eval2 = net(x)
    # Outputs identical (deterministic)
    diff_eval = jnp.mean(jnp.abs(out_eval1 - out_eval2))
    print(f"  Output variance (no dropout): {diff_eval:.4f}")
    print(f"  Outputs are deterministic: {jnp.allclose(out_eval1, out_eval2)}")

### Training Context Manager

In [None]:
# Training loop with mode switching
def train_epoch(model, train_data, learning_rate=0.01):
    """Train for one epoch."""
    x_train, y_train = train_data
    
    # Ensure training mode
    with bst.environ.context(fit=True):
        def loss_fn():
            pred = model(x_train)
            return jnp.mean((pred - y_train) ** 2)
        
        loss, grads = bst.augment.grad(
            loss_fn,
            model.states(bst.ParamState),
            return_value=True
        )()
        
        # Update parameters
        for name, grad in grads.items():
            model.states()[name].value -= learning_rate * grad
        
        return float(loss)

def evaluate(model, test_data):
    """Evaluate model."""
    x_test, y_test = test_data
    
    # Ensure evaluation mode
    with bst.environ.context(fit=False):
        pred = model(x_test)
        loss = jnp.mean((pred - y_test) ** 2)
        return float(loss)

# Generate data
x_train = bst.random.randn(100, 10)
y_train = bst.random.randn(100, 5)
x_test = bst.random.randn(20, 10)
y_test = bst.random.randn(20, 5)

# Train
model = TrainingAwareNetwork(10, 20, 5, dropout_rate=0.2)
train_losses = []
test_losses = []

for epoch in range(50):
    train_loss = train_epoch(model, (x_train, y_train))
    test_loss = evaluate(model, (x_test, y_test))
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}")

# Plot results
plt.figure(figsize=(10, 4))
plt.plot(train_losses, label='Train Loss', alpha=0.7)
plt.plot(test_losses, label='Test Loss', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training with Mode Switching')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3. Batching Mode

The `Batching` mixin controls how layers handle batched inputs.

In [None]:
# Batching mode example
class BatchAwareLayer(bst.graph.Node):
    """Layer that adapts to batching mode."""
    
    def __init__(self, features):
        super().__init__()
        self.features = features
        self.weight = bst.ParamState(bst.random.randn(features, features) * 0.1)
    
    def __call__(self, x):
        # Check batching mode
        batch_mode = bst.environ.get('batch', default=True)
        
        if batch_mode:
            # Batched computation: (batch, features) @ (features, features)
            return x @ self.weight.value
        else:
            # Single sample: (features,) @ (features, features)
            return x @ self.weight.value

layer = BatchAwareLayer(features=5)

# Test with batch
x_batch = bst.random.randn(10, 5)
print("Batched input:")
with bst.environ.context(batch=True):
    out_batch = layer(x_batch)
    print(f"  Input shape: {x_batch.shape}")
    print(f"  Output shape: {out_batch.shape}")

# Test with single sample
x_single = bst.random.randn(5)
print("\nSingle sample:")
with bst.environ.context(batch=False):
    out_single = layer(x_single)
    print(f"  Input shape: {x_single.shape}")
    print(f"  Output shape: {out_single.shape}")

## 4. JointMode: Combining Multiple Modes

`JointMode` allows combining multiple behavioral modes simultaneously.

In [None]:
# Multiple modes simultaneously
class MultiModalLayer(bst.graph.Node):
    """Layer that responds to multiple mode flags."""
    
    def __init__(self, features, dropout_rate=0.3):
        super().__init__()
        self.features = features
        self.dropout_rate = dropout_rate
        self.weight = bst.ParamState(bst.random.randn(features, features) * 0.1)
        self.call_count = bst.ShortTermState(jnp.array(0))
    
    def __call__(self, x):
        # Get mode flags
        is_training = bst.environ.get('fit', default=True)
        is_batched = bst.environ.get('batch', default=True)
        verbose = bst.environ.get('verbose', default=False)
        
        if verbose:
            print(f"  Mode: training={is_training}, batched={is_batched}")
        
        # Increment counter
        self.call_count.value += 1
        
        # Apply transformation
        output = x @ self.weight.value
        
        # Apply dropout if training
        if is_training:
            keep_prob = 1 - self.dropout_rate
            mask = bst.random.bernoulli(keep_prob, output.shape)
            output = jnp.where(mask, output / keep_prob, 0)
        
        return output

layer = MultiModalLayer(features=8)
x = bst.random.randn(5, 8)

# Test different mode combinations
print("Mode Combinations:")
print("=" * 60)

# Training + Batched
print("\n1. Training + Batched:")
with bst.environ.context(fit=True, batch=True, verbose=True):
    out = layer(x)
    print(f"   Output shape: {out.shape}")

# Evaluation + Batched
print("\n2. Evaluation + Batched:")
with bst.environ.context(fit=False, batch=True, verbose=True):
    out = layer(x)
    print(f"   Output shape: {out.shape}")

# Evaluation + Single
print("\n3. Evaluation + Single sample:")
with bst.environ.context(fit=False, batch=False, verbose=True):
    out = layer(x[0])
    print(f"   Output shape: {out.shape}")

print(f"\nTotal calls to layer: {layer.call_count.value}")

## 5. Custom Mixins

Create custom mixins for specialized behavior.

In [None]:
# Custom mixin for debugging
class DebugMode:
    """Mixin for debug-aware components."""
    
    @staticmethod
    def is_debug():
        return bst.environ.get('debug', default=False)
    
    @staticmethod
    def debug_print(msg):
        if DebugMode.is_debug():
            print(f"[DEBUG] {msg}")

class DebugLayer(bst.graph.Node, DebugMode):
    """Layer with debug capabilities."""
    
    def __init__(self, features):
        super().__init__()
        self.features = features
        self.weight = bst.ParamState(bst.random.randn(features, features) * 0.1)
        self.debug_print(f"Initialized DebugLayer with {features} features")
    
    def __call__(self, x):
        self.debug_print(f"Input shape: {x.shape}")
        self.debug_print(f"Weight stats: mean={jnp.mean(self.weight.value):.3f}, "
                        f"std={jnp.std(self.weight.value):.3f}")
        
        output = x @ self.weight.value
        
        self.debug_print(f"Output shape: {output.shape}")
        self.debug_print(f"Output stats: mean={jnp.mean(output):.3f}, "
                        f"std={jnp.std(output):.3f}")
        
        return output

# Test without debug
print("Normal mode (no debug output):")
layer = DebugLayer(features=5)
x = bst.random.randn(3, 5)
out = layer(x)

print("\n" + "=" * 60)
print("Debug mode enabled:")
with bst.environ.context(debug=True):
    layer2 = DebugLayer(features=5)
    out = layer2(x)

### Custom Precision Mode

In [None]:
# Precision mode mixin
class PrecisionMode:
    """Control computation precision."""
    
    @staticmethod
    def get_dtype():
        precision = bst.environ.get('precision', default='float32')
        if precision == 'float16':
            return jnp.float16
        elif precision == 'float64':
            return jnp.float64
        else:
            return jnp.float32

class PrecisionAwareLayer(bst.graph.Node, PrecisionMode):
    """Layer that adapts to precision mode."""
    
    def __init__(self, features):
        super().__init__()
        dtype = self.get_dtype()
        self.weight = bst.ParamState(
            bst.random.randn(features, features).astype(dtype) * 0.1
        )
    
    def __call__(self, x):
        dtype = self.get_dtype()
        x = x.astype(dtype)
        return x @ self.weight.value

# Test different precisions
x = bst.random.randn(3, 4)

print("Precision Modes:")
print("=" * 60)

for precision in ['float16', 'float32', 'float64']:
    with bst.environ.context(precision=precision):
        layer = PrecisionAwareLayer(features=4)
        out = layer(x)
        print(f"{precision}: output dtype = {out.dtype}, "
              f"memory = {out.nbytes} bytes")

## 6. Practical Example: Multi-Mode Network

In [None]:
# Complete example with multiple modes
class ProductionNetwork(bst.graph.Node):
    """Production-ready network with multiple modes."""
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer1 = bst.nn.Linear(input_dim, hidden_dim)
        self.bn1 = bst.nn.BatchNorm1d(hidden_dim)
        self.dropout1 = bst.nn.Dropout(0.3)
        
        self.layer2 = bst.nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = bst.nn.BatchNorm1d(hidden_dim)
        self.dropout2 = bst.nn.Dropout(0.3)
        
        self.layer3 = bst.nn.Linear(hidden_dim, output_dim)
        
        # Statistics
        self.forward_count = bst.ShortTermState(jnp.array(0))
    
    def __call__(self, x):
        # Track calls
        self.forward_count.value += 1
        
        # Debug info
        if bst.environ.get('debug', default=False):
            is_training = bst.environ.get('fit', default=True)
            mode_str = "TRAIN" if is_training else "EVAL"
            print(f"[{mode_str}] Forward pass #{self.forward_count.value}")
        
        # Layer 1
        x = self.layer1(x)
        x = self.bn1(x)
        x = jax.nn.relu(x)
        x = self.dropout1(x)
        
        # Layer 2
        x = self.layer2(x)
        x = self.bn2(x)
        x = jax.nn.relu(x)
        x = self.dropout2(x)
        
        # Output
        x = self.layer3(x)
        
        return x
    
    def reset_stats(self):
        """Reset statistics."""
        self.forward_count.value = jnp.array(0)

# Create network
net = ProductionNetwork(input_dim=20, hidden_dim=50, output_dim=10)
x_train = bst.random.randn(32, 20)
x_test = bst.random.randn(8, 20)

print("Production Network Demo:")
print("=" * 60)

# Training mode
print("\n1. Training (with debug):")
with bst.environ.context(fit=True, debug=True):
    out_train = net(x_train)

# Evaluation mode
print("\n2. Evaluation (with debug):")
with bst.environ.context(fit=False, debug=True):
    out_test = net(x_test)

# Silent evaluation
print("\n3. Silent evaluation:")
with bst.environ.context(fit=False, debug=False):
    for _ in range(5):
        _ = net(x_test)
    print(f"   Completed 5 forward passes silently")

print(f"\nTotal forward passes: {net.forward_count.value}")

## 7. Mode Inheritance and Composition

In [None]:
# Composing multiple custom modes
class ProfilingMode:
    """Profiling mode for performance analysis."""
    
    @staticmethod
    def should_profile():
        return bst.environ.get('profile', default=False)

class LoggingMode:
    """Logging mode for tracking operations."""
    
    @staticmethod
    def should_log():
        return bst.environ.get('log', default=False)

class AdvancedLayer(bst.graph.Node, ProfilingMode, LoggingMode):
    """Layer with profiling and logging."""
    
    def __init__(self, features):
        super().__init__()
        self.features = features
        self.weight = bst.ParamState(bst.random.randn(features, features) * 0.1)
    
    def __call__(self, x):
        import time
        
        # Logging
        if self.should_log():
            print(f"[LOG] Processing input of shape {x.shape}")
        
        # Profiling
        if self.should_profile():
            start = time.time()
        
        # Computation
        output = x @ self.weight.value
        
        # Profiling
        if self.should_profile():
            elapsed = (time.time() - start) * 1000
            print(f"[PROFILE] Computation took {elapsed:.3f} ms")
        
        # Logging
        if self.should_log():
            print(f"[LOG] Output shape: {output.shape}")
        
        return output

layer = AdvancedLayer(features=100)
x = bst.random.randn(50, 100)

print("Normal mode:")
_ = layer(x)
print("  (no output)")

print("\nWith logging:")
with bst.environ.context(log=True):
    _ = layer(x)

print("\nWith profiling:")
with bst.environ.context(profile=True):
    _ = layer(x)

print("\nWith both:")
with bst.environ.context(log=True, profile=True):
    _ = layer(x)

## 8. Best Practices and Patterns

In [None]:
# Best practice: Mode-aware configuration
class ConfigurableModel(bst.graph.Node):
    """Model with mode-based configuration."""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Build layers based on config
        self.layers = []
        for i, (in_dim, out_dim) in enumerate(zip(
            config['layer_dims'][:-1],
            config['layer_dims'][1:]
        )):
            layer = bst.nn.Linear(in_dim, out_dim)
            self.layers.append(layer)
            setattr(self, f'layer_{i}', layer)
        
        # Optional components based on config
        if config.get('use_dropout', False):
            self.dropout = bst.nn.Dropout(config['dropout_rate'])
        else:
            self.dropout = None
        
        if config.get('use_batchnorm', False):
            self.batchnorms = [
                bst.nn.BatchNorm1d(dim) 
                for dim in config['layer_dims'][1:-1]
            ]
            for i, bn in enumerate(self.batchnorms):
                setattr(self, f'bn_{i}', bn)
        else:
            self.batchnorms = None
    
    def __call__(self, x):
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            
            # Optional batch norm
            if self.batchnorms is not None:
                x = self.batchnorms[i](x)
            
            x = jax.nn.relu(x)
            
            # Optional dropout
            if self.dropout is not None:
                x = self.dropout(x)
        
        # Final layer
        x = self.layers[-1](x)
        return x

# Different configurations
configs = {
    'simple': {
        'layer_dims': [10, 20, 5],
        'use_dropout': False,
        'use_batchnorm': False
    },
    'regularized': {
        'layer_dims': [10, 20, 15, 5],
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batchnorm': True
    }
}

print("Model Configurations:")
print("=" * 60)

for name, config in configs.items():
    model = ConfigurableModel(config)
    n_params = sum(p.value.size for p in model.states(bst.ParamState).values())
    print(f"\n{name.capitalize()} model:")
    print(f"  Layers: {config['layer_dims']}")
    print(f"  Dropout: {config.get('use_dropout', False)}")
    print(f"  BatchNorm: {config.get('use_batchnorm', False)}")
    print(f"  Parameters: {n_params:,}")

## Summary

In this tutorial, we covered:

1. **Mode System**: Foundation for controlling computation behavior
2. **Training Mode**: Train/eval switching for dropout and batch norm
3. **Batching Mode**: Handling batched vs single inputs
4. **JointMode**: Combining multiple modes simultaneously
5. **Custom Mixins**: Creating specialized behavior (debug, precision, profiling)
6. **Production Patterns**: Multi-mode networks
7. **Mode Composition**: Combining multiple mixins
8. **Best Practices**: Configuration-based models

## Key Takeaways

- **Modes control behavior** without changing code structure
- Use **context managers** to temporarily change modes
- **Training mode** affects dropout, batch normalization
- **Custom mixins** enable specialized functionality
- Modes are **composable** - combine multiple behaviors
- **environ.context** is the primary interface for mode control

## Best Practices

1. Always use context managers for temporary mode changes
2. Check mode flags explicitly when behavior differs
3. Document which modes your components respond to
4. Create custom mixins for reusable behavior patterns
5. Use meaningful mode names (fit, debug, profile, etc.)
6. Design components to work correctly in all modes

## Next Steps

In the next tutorial, we'll explore:
- **Type System**: OneOfTypes, JointTypes, type annotations
- Type checking and validation
- Best practices for type hints in neural networks