# Tutorial 17: Utility Functions and Tools

In this final advanced tutorial, we'll explore BrainState's utility functions and helper tools that make working with neural networks more convenient and productive.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Use util.filter for selective operations
- Work with util.struct for structure manipulation
- Create readable outputs with PrettyObject
- Manage dictionaries with DictManager
- Use specialized dict types (FrozenDict, DotDict, etc.)
- Apply utility functions for common tasks
- Leverage helper functions for debugging

## Why Utilities Matter

Utility functions provide:
- **Convenience**: Common operations made easy
- **Readability**: Clean, expressive code
- **Debugging**: Better introspection and visualization
- **Productivity**: Less boilerplate code

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Any

# Set random seed
bst.random.seed(42)

## 1. Filter Utilities

The `filter` module provides functions for filtering and selecting states.

In [None]:
from brainstate.util import filter

# Create a model
class SampleModel(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.layer1 = bst.nn.Linear(10, 20)
        self.layer2 = bst.nn.Linear(20, 10)
        self.dropout = bst.nn.Dropout(0.3)
        self.counter = bst.ShortTermState(jnp.array(0))
    
    def __call__(self, x):
        self.counter.value += 1
        x = self.layer1(x)
        x = jax.nn.relu(x)
        x = self.dropout(x)
        x = self.layer2(x)
        return x

model = SampleModel()

# Get all states
all_states = model.states()
print(f"Total states: {len(all_states)}")
print(f"State names: {list(all_states.keys())[:5]}...")  # Show first 5

# Filter only parameters
params = model.states(bst.ParamState)
print(f"\nParameters: {len(params)}")

# Filter only hidden states
hidden = model.states(bst.ShortTermState)
print(f"Hidden states: {len(hidden)}")
print(f"Hidden state names: {list(hidden.keys())}")

## 2. Structure Utilities

The `struct` module helps with manipulating nested structures.

In [None]:
from brainstate.util import flatten_dict, unflatten_dict, merge_dicts

# Nested dictionary
nested_dict = {
    'model': {
        'layer1': {
            'weight': jnp.ones((10, 20)),
            'bias': jnp.zeros(20)
        },
        'layer2': {
            'weight': jnp.ones((20, 10)),
            'bias': jnp.zeros(10)
        }
    },
    'config': {
        'learning_rate': 0.01,
        'batch_size': 32
    }
}

# Flatten nested structure
flat_dict = flatten_dict(nested_dict, sep='/')
print("Flattened dictionary:")
for key in flat_dict.keys():
    print(f"  {key}")

# Unflatten back
unflat_dict = unflatten_dict(flat_dict, sep='/')
print("\nUnflattened keys:")
print(f"  {list(unflat_dict.keys())}")
print(f"  Match original: {list(unflat_dict.keys()) == list(nested_dict.keys())}")

### Merging Dictionaries

In [None]:
# Merge multiple dictionaries
dict1 = {'a': 1, 'b': 2}
dict2 = {'b': 3, 'c': 4}
dict3 = {'c': 5, 'd': 6}

merged = merge_dicts([dict1, dict2, dict3])
print("Merged dictionary:")
print(f"  {merged}")
print("\nNote: Later values override earlier ones")

## 3. PrettyObject: Readable Output

`PrettyObject` provides clean string representations of objects.

In [None]:
from brainstate.util import PrettyObject, PrettyAttr, PrettyType

class PrettyModel(PrettyObject):
    """Model with pretty printing."""
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.layers = [
            {'type': 'Linear', 'in': input_dim, 'out': hidden_dim},
            {'type': 'ReLU'},
            {'type': 'Linear', 'in': hidden_dim, 'out': output_dim}
        ]
        self.n_parameters = input_dim * hidden_dim + hidden_dim * output_dim

# Create and print
model = PrettyModel(10, 20, 5)
print("Pretty printed model:")
print(model)

# Custom representation
print("\nManual formatting:")
print(f"  Dimensions: {model.input_dim} → {model.hidden_dim} → {model.output_dim}")
print(f"  Parameters: {model.n_parameters:,}")

### PrettyAttr for Custom Formatting

In [None]:
# Custom pretty formatting
class Network(PrettyObject):
    """Network with custom pretty attributes."""
    
    def __init__(self):
        self.name = "MyNetwork"
        self.version = "1.0.0"
        self.config = {
            'learning_rate': 0.001,
            'batch_size': 32,
            'epochs': 100
        }
        self.weights = jnp.ones((100, 100))  # Large array
    
    def __repr__(self):
        lines = [
            f"{self.__class__.__name__}(",
            f"  name={self.name!r},",
            f"  version={self.version!r},",
            f"  config={self.config},",
            f"  weights=<{self.weights.shape} array>",
            ")"
        ]
        return "\n".join(lines)

net = Network()
print(net)

## 4. Specialized Dictionary Types

### FrozenDict: Immutable Dictionary

In [None]:
from brainstate.util import FrozenDict, freeze, unfreeze

# Create frozen dict
config = {'learning_rate': 0.01, 'batch_size': 32}
frozen_config = freeze(config)

print(f"Frozen config: {frozen_config}")
print(f"Type: {type(frozen_config)}")

# Try to modify (will fail)
try:
    frozen_config['learning_rate'] = 0.001
    print("Modification succeeded")
except (TypeError, AttributeError) as e:
    print(f"Cannot modify frozen dict (expected)")

# Unfreeze to modify
mutable_config = unfreeze(frozen_config)
mutable_config['learning_rate'] = 0.001
print(f"\nModified unfrozen dict: {mutable_config}")

### DotDict: Attribute Access

In [None]:
from brainstate.util import DotDict

# Create dot-accessible dict
config = DotDict({
    'model': {
        'type': 'MLP',
        'hidden_dim': 256
    },
    'training': {
        'learning_rate': 0.001,
        'epochs': 100
    }
})

# Access with dot notation
print("Dot notation access:")
print(f"  Model type: {config.model.type}")
print(f"  Hidden dim: {config.model.hidden_dim}")
print(f"  Learning rate: {config.training.learning_rate}")

# Still works as dict
print(f"\nDict access: {config['model']['type']}")

# Modify
config.training.learning_rate = 0.01
print(f"Modified LR: {config.training.learning_rate}")

### FlattedDict: Flat Key Access

In [None]:
from brainstate.util import FlattedDict

# Create flatted dict
nested = {
    'a': {
        'b': {
            'c': 1,
            'd': 2
        },
        'e': 3
    },
    'f': 4
}

flatted = FlattedDict(nested, sep='.')

print("Flatted dict keys:")
for key in flatted.keys():
    print(f"  {key}: {flatted[key]}")

# Access with flat keys
print(f"\nAccess 'a.b.c': {flatted['a.b.c']}")
print(f"Access 'a.e': {flatted['a.e']}")

## 5. DictManager: Advanced Dictionary Management

In [None]:
from brainstate.util import DictManager

# DictManager for managing model states
class ManagedModel(bst.graph.Node):
    """Model with managed state dictionaries."""
    
    def __init__(self):
        super().__init__()
        self.linear = bst.nn.Linear(10, 5)
        
        # Create dict manager
        self.manager = DictManager()
    
    def save_checkpoint(self, name: str):
        """Save model state."""
        states = {k: v.value for k, v in self.states(bst.ParamState).items()}
        self.manager[name] = states
        print(f"Saved checkpoint '{name}' with {len(states)} states")
    
    def load_checkpoint(self, name: str):
        """Load model state."""
        if name not in self.manager:
            raise KeyError(f"Checkpoint '{name}' not found")
        
        states = self.manager[name]
        for key, value in states.items():
            if key in self.states():
                self.states()[key].value = value
        
        print(f"Loaded checkpoint '{name}'")
    
    def list_checkpoints(self):
        """List saved checkpoints."""
        return list(self.manager.keys())

# Test checkpoint management
model = ManagedModel()

# Save initial state
model.save_checkpoint('init')

# Modify parameters
for state in model.states(bst.ParamState).values():
    state.value = state.value * 2

# Save modified state
model.save_checkpoint('modified')

# List checkpoints
print(f"\nCheckpoints: {model.list_checkpoints()}")

# Restore initial
model.load_checkpoint('init')
print("Restored to initial state")

## 6. Unique Name Generation

In [None]:
from brainstate.util import get_unique_name

# Generate unique names
existing_names = set()

print("Generating unique names:")
for i in range(10):
    name = get_unique_name('layer', existing_names)
    existing_names.add(name)
    print(f"  {i}: {name}")

# Try to generate more with same prefix
print("\nMore names with existing set:")
for i in range(3):
    name = get_unique_name('layer', existing_names)
    existing_names.add(name)
    print(f"  {name}")

## 7. Memory Management

In [None]:
from brainstate.util import clear_buffer_memory

# Memory management utility
class BufferedModel(bst.graph.Node):
    """Model with internal buffers."""
    
    def __init__(self):
        super().__init__()
        self.layer = bst.nn.Linear(100, 100)
        # Buffers for intermediate computations
        self.buffers = []
    
    def __call__(self, x):
        # Store intermediate values
        intermediate = self.layer(x)
        self.buffers.append(intermediate)
        
        # Keep only last 10
        if len(self.buffers) > 10:
            self.buffers.pop(0)
        
        return jax.nn.relu(intermediate)
    
    def clear_buffers(self):
        """Clear internal buffers to free memory."""
        self.buffers.clear()
        clear_buffer_memory()  # Clear JAX buffers
        print("Buffers cleared")

# Test buffering
model = BufferedModel()

# Process many batches
for i in range(15):
    x = bst.random.randn(32, 100)
    _ = model(x)

print(f"Accumulated {len(model.buffers)} buffers")
model.clear_buffers()
print(f"Buffers after clear: {len(model.buffers)}")

## 8. Practical Example: Model Registry

In [None]:
# Complete example using multiple utilities
from brainstate.util import (
    DictManager, PrettyObject, DotDict, 
    flatten_dict, get_unique_name
)

class ModelRegistry(PrettyObject):
    """Registry for managing multiple models."""
    
    def __init__(self):
        self.models = DictManager()
        self.configs = DictManager()
        self.model_names = set()
    
    def register(self, model: bst.graph.Node, config: dict, name: str = None):
        """Register a model with configuration.
        
        Args:
            model: Model to register
            config: Model configuration
            name: Optional name (auto-generated if None)
        """
        if name is None:
            name = get_unique_name('model', self.model_names)
        
        self.model_names.add(name)
        self.models[name] = model
        self.configs[name] = DotDict(config)
        
        print(f"Registered model '{name}'")
        return name
    
    def get_model(self, name: str):
        """Get registered model."""
        if name not in self.models:
            raise KeyError(f"Model '{name}' not found")
        return self.models[name]
    
    def get_config(self, name: str):
        """Get model configuration."""
        if name not in self.configs:
            raise KeyError(f"Config for '{name}' not found")
        return self.configs[name]
    
    def list_models(self):
        """List all registered models."""
        return list(self.models.keys())
    
    def summary(self):
        """Print registry summary."""
        print("=" * 60)
        print("Model Registry Summary")
        print("=" * 60)
        
        for name in self.list_models():
            model = self.models[name]
            config = self.configs[name]
            n_params = sum(
                p.value.size for p in model.states(bst.ParamState).values()
            )
            
            print(f"\n{name}:")
            print(f"  Type: {model.__class__.__name__}")
            print(f"  Parameters: {n_params:,}")
            print(f"  Config: {dict(flatten_dict(dict(config), sep='.'))}") 
        
        print("\n" + "=" * 60)

# Create registry
registry = ModelRegistry()

# Register models
model1 = bst.nn.Linear(10, 5)
registry.register(
    model1,
    config={'type': 'linear', 'input_dim': 10, 'output_dim': 5},
    name='small_linear'
)

model2 = bst.nn.Linear(100, 50)
registry.register(
    model2,
    config={'type': 'linear', 'input_dim': 100, 'output_dim': 50},
    name='large_linear'
)

# Auto-generated name
model3 = bst.nn.Linear(20, 20)
auto_name = registry.register(
    model3,
    config={'type': 'linear', 'input_dim': 20, 'output_dim': 20}
)

# Show summary
registry.summary()

# Access specific model
print(f"\nRetrieved model config:")
config = registry.get_config('small_linear')
print(f"  Input dim: {config.input_dim}")
print(f"  Output dim: {config.output_dim}")

## 9. Debugging Utilities

In [None]:
from brainstate.util import pretty_repr

# Pretty representation for debugging
class DebugModel(bst.graph.Node):
    """Model with debug utilities."""
    
    def __init__(self):
        super().__init__()
        self.layer1 = bst.nn.Linear(10, 20)
        self.layer2 = bst.nn.Linear(20, 5)
        self.activations = []
    
    def __call__(self, x, debug=False):
        if debug:
            print(f"Input shape: {x.shape}")
            print(f"Input stats: mean={jnp.mean(x):.3f}, std={jnp.std(x):.3f}")
        
        x = self.layer1(x)
        if debug:
            print(f"After layer1: shape={x.shape}, mean={jnp.mean(x):.3f}")
        
        x = jax.nn.relu(x)
        self.activations.append(x)
        
        x = self.layer2(x)
        if debug:
            print(f"After layer2: shape={x.shape}, mean={jnp.mean(x):.3f}")
        
        return x
    
    def debug_info(self):
        """Get debug information."""
        info = {
            'n_activations': len(self.activations),
            'layer1_params': self.layer1.weight.value.size,
            'layer2_params': self.layer2.weight.value.size,
        }
        
        if self.activations:
            latest = self.activations[-1]
            info['latest_activation'] = {
                'shape': latest.shape,
                'mean': float(jnp.mean(latest)),
                'std': float(jnp.std(latest)),
                'min': float(jnp.min(latest)),
                'max': float(jnp.max(latest)),
            }
        
        return info

# Test debug model
model = DebugModel()
x = bst.random.randn(4, 10)

print("Debug mode:")
print("=" * 60)
output = model(x, debug=True)

print("\n" + "=" * 60)
print("Debug info:")
info = model.debug_info()
print(pretty_repr(info))

## Summary

In this tutorial, we covered:

1. **Filter Utilities**: Selective state operations
2. **Structure Utilities**: flatten_dict, unflatten_dict, merge_dicts
3. **PrettyObject**: Readable object representations
4. **Specialized Dicts**: FrozenDict, DotDict, FlattedDict
5. **DictManager**: Advanced dictionary management
6. **Unique Names**: Name generation utilities
7. **Memory Management**: Buffer clearing
8. **Model Registry**: Practical multi-utility example
9. **Debugging**: pretty_repr and debug utilities

## Key Takeaways

- **Utilities save time** and reduce boilerplate
- **PrettyObject** makes debugging easier
- **Specialized dicts** provide convenient access patterns
- **Structure utilities** simplify nested data manipulation
- **DictManager** enables sophisticated state management
- Utilities are **composable** - use multiple together

## Best Practices

1. Use FrozenDict for immutable configurations
2. Use DotDict for clean config access
3. Flatten dicts for serialization/logging
4. Leverage PrettyObject for debugging
5. Use DictManager for checkpoint management
6. Clear buffers in long-running processes
7. Generate unique names programmatically

## Congratulations!

You've completed all advanced tutorials! You now have comprehensive knowledge of:
- Graph operations and node system
- Mixin system and computation modes
- Type system and annotations
- Utility functions and tools

These advanced features enable you to build sophisticated, production-ready neural networks with BrainState!