# Tutorial 22: Model Deployment

In this tutorial, we'll learn how to save, load, and deploy BrainState models for production use.

## Learning Objectives

By the end of this tutorial, you will:
- Understand how to save and load model states
- Learn model serialization techniques
- Implement checkpointing strategies
- Optimize models for inference
- Create batch processing pipelines
- Implement model versioning
- Deploy models for production use

## Introduction

Model deployment is a critical step in the machine learning workflow. BrainState provides several mechanisms for saving, loading, and deploying models efficiently.

Key concepts:
- **State Management**: Saving and loading model parameters and states
- **Serialization**: Converting model states to disk-friendly formats
- **Checkpointing**: Periodic saving during training
- **Inference Optimization**: Making models faster for production
- **Batch Processing**: Handling multiple inputs efficiently
- **Versioning**: Managing different model versions

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import pickle
import json
from pathlib import Path
from typing import Dict, Any, Optional, List
from datetime import datetime
import time

# Set random seed for reproducibility
bst.random.seed(42)

## 1. Basic Model Saving and Loading

The simplest way to save and load models is using BrainState's state dictionary system.

In [None]:
# Create a simple model
class SimpleClassifier(bst.graph.Node):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.fc1 = bst.nn.Linear(input_dim, hidden_dim)
        self.fc2 = bst.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = bst.nn.Linear(hidden_dim, output_dim)
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = jax.nn.relu(self.fc1(x))
        x = jax.nn.relu(self.fc2(x))
        return self.fc3(x)

# Create and initialize model
model = SimpleClassifier(input_dim=784, hidden_dim=128, output_dim=10)

# Initialize with dummy input
dummy_input = bst.random.randn(1, 784)
_ = model(dummy_input)

print(f"Model created with {sum(p.value.size for p in model.states(bst.ParamState).values())} parameters")

### 1.1 Saving Model States

We can save the model's parameters using the state dictionary.

In [None]:
def save_model_states(model: bst.graph.Node, filepath: str):
    """
    Save model states to a file.
    
    Args:
        model: The model to save
        filepath: Path to save the model
    """
    # Get all parameter states
    states = model.states(bst.ParamState)
    
    # Convert to numpy arrays for serialization
    state_dict = {k: np.array(v.value) for k, v in states.items()}
    
    # Save using numpy's compressed format
    np.savez_compressed(filepath, **state_dict)
    
    print(f"Model saved to {filepath}")
    print(f"File size: {Path(filepath).stat().st_size / 1024:.2f} KB")

# Save the model
save_model_states(model, "simple_classifier.npz")

### 1.2 Loading Model States

We can load the saved states back into a model.

In [None]:
def load_model_states(model: bst.graph.Node, filepath: str):
    """
    Load model states from a file.
    
    Args:
        model: The model to load states into
        filepath: Path to load the model from
    """
    # Load the state dictionary
    loaded_data = np.load(filepath)
    state_dict = {k: jnp.array(loaded_data[k]) for k in loaded_data.files}
    
    # Get model states
    model_states = model.states(bst.ParamState)
    
    # Verify compatibility
    if set(state_dict.keys()) != set(model_states.keys()):
        raise ValueError("State dictionary keys don't match model states")
    
    # Load states
    for k, v in state_dict.items():
        if model_states[k].value.shape != v.shape:
            raise ValueError(f"Shape mismatch for {k}: expected {model_states[k].value.shape}, got {v.shape}")
        model_states[k].value = v
    
    print(f"Model loaded from {filepath}")

# Create a new model and load states
new_model = SimpleClassifier(input_dim=784, hidden_dim=128, output_dim=10)
_ = new_model(dummy_input)  # Initialize

# Verify models are different before loading
original_output = model(dummy_input)
new_output_before = new_model(dummy_input)
print(f"Outputs differ before loading: {not jnp.allclose(original_output, new_output_before)}")

# Load states
load_model_states(new_model, "simple_classifier.npz")

# Verify models are identical after loading
new_output_after = new_model(dummy_input)
print(f"Outputs match after loading: {jnp.allclose(original_output, new_output_after)}")

## 2. Complete Model Serialization

For complete model serialization, we need to save both the model architecture and its states.

In [None]:
class ModelSerializer:
    """
    Complete model serialization with metadata.
    """
    
    @staticmethod
    def save(model: bst.graph.Node, 
             filepath: str, 
             metadata: Optional[Dict[str, Any]] = None):
        """
        Save model with metadata.
        
        Args:
            model: Model to save
            filepath: Path to save to
            metadata: Additional metadata (training info, hyperparameters, etc.)
        """
        # Get model states
        states = model.states(bst.ParamState)
        state_dict = {k: np.array(v.value) for k, v in states.items()}
        
        # Prepare metadata
        meta = {
            'timestamp': datetime.now().isoformat(),
            'model_class': model.__class__.__name__,
            'num_parameters': sum(v.size for v in state_dict.values()),
            'state_keys': list(state_dict.keys()),
        }
        
        if metadata:
            meta.update(metadata)
        
        # Save everything
        save_data = {
            'states': state_dict,
            'metadata': meta
        }
        
        with open(filepath, 'wb') as f:
            pickle.dump(save_data, f)
        
        print(f"Model serialized to {filepath}")
        print(f"Metadata: {json.dumps(meta, indent=2)}")
    
    @staticmethod
    def load(model: bst.graph.Node, filepath: str) -> Dict[str, Any]:
        """
        Load model and return metadata.
        
        Args:
            model: Model to load states into
            filepath: Path to load from
            
        Returns:
            Metadata dictionary
        """
        with open(filepath, 'rb') as f:
            save_data = pickle.load(f)
        
        state_dict = {k: jnp.array(v) for k, v in save_data['states'].items()}
        metadata = save_data['metadata']
        
        # Load states
        model_states = model.states(bst.ParamState)
        for k, v in state_dict.items():
            if k in model_states:
                model_states[k].value = v
        
        print(f"Model deserialized from {filepath}")
        print(f"Saved: {metadata['timestamp']}")
        
        return metadata

# Example usage
metadata = {
    'accuracy': 0.95,
    'loss': 0.15,
    'epochs': 10,
    'learning_rate': 0.001,
    'dataset': 'MNIST'
}

ModelSerializer.save(model, "model_with_metadata.pkl", metadata)

In [None]:
# Load and verify
loaded_model = SimpleClassifier(input_dim=784, hidden_dim=128, output_dim=10)
_ = loaded_model(dummy_input)

loaded_metadata = ModelSerializer.load(loaded_model, "model_with_metadata.pkl")
print(f"\nLoaded metadata: {json.dumps(loaded_metadata, indent=2)}")

## 3. Checkpoint Management

During training, we want to save checkpoints periodically and keep the best models.

In [None]:
class CheckpointManager:
    """
    Manages model checkpoints during training.
    """
    
    def __init__(self, 
                 checkpoint_dir: str,
                 max_to_keep: int = 5,
                 keep_best: bool = True):
        """
        Args:
            checkpoint_dir: Directory to save checkpoints
            max_to_keep: Maximum number of checkpoints to keep
            keep_best: Whether to keep the best checkpoint separately
        """
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.max_to_keep = max_to_keep
        self.keep_best = keep_best
        self.checkpoints = []  # List of (filepath, metric) tuples
        self.best_metric = None
        self.best_checkpoint = None
    
    def save_checkpoint(self,
                       model: bst.graph.Node,
                       epoch: int,
                       metric: float,
                       metadata: Optional[Dict[str, Any]] = None) -> str:
        """
        Save a checkpoint.
        
        Args:
            model: Model to save
            epoch: Current epoch
            metric: Metric value (lower is better)
            metadata: Additional metadata
            
        Returns:
            Path to saved checkpoint
        """
        # Create checkpoint filename
        checkpoint_name = f"checkpoint_epoch_{epoch:04d}_metric_{metric:.4f}.pkl"
        checkpoint_path = self.checkpoint_dir / checkpoint_name
        
        # Prepare metadata
        meta = {
            'epoch': epoch,
            'metric': metric,
        }
        if metadata:
            meta.update(metadata)
        
        # Save checkpoint
        ModelSerializer.save(model, str(checkpoint_path), meta)
        
        # Add to checkpoint list
        self.checkpoints.append((str(checkpoint_path), metric))
        
        # Update best checkpoint
        if self.keep_best and (self.best_metric is None or metric < self.best_metric):
            self.best_metric = metric
            best_path = self.checkpoint_dir / "best_model.pkl"
            ModelSerializer.save(model, str(best_path), meta)
            self.best_checkpoint = str(best_path)
            print(f"New best model saved with metric: {metric:.4f}")
        
        # Remove old checkpoints if necessary
        if len(self.checkpoints) > self.max_to_keep:
            # Sort by metric (ascending)
            self.checkpoints.sort(key=lambda x: x[1])
            # Remove worst checkpoint
            to_remove = self.checkpoints.pop()
            Path(to_remove[0]).unlink(missing_ok=True)
            print(f"Removed old checkpoint: {Path(to_remove[0]).name}")
        
        return str(checkpoint_path)
    
    def load_best(self, model: bst.graph.Node) -> Optional[Dict[str, Any]]:
        """
        Load the best checkpoint.
        
        Args:
            model: Model to load into
            
        Returns:
            Metadata if best checkpoint exists, None otherwise
        """
        if self.best_checkpoint and Path(self.best_checkpoint).exists():
            return ModelSerializer.load(model, self.best_checkpoint)
        return None
    
    def load_latest(self, model: bst.graph.Node) -> Optional[Dict[str, Any]]:
        """
        Load the latest checkpoint.
        
        Args:
            model: Model to load into
            
        Returns:
            Metadata if checkpoint exists, None otherwise
        """
        if self.checkpoints:
            latest_checkpoint = self.checkpoints[-1][0]
            return ModelSerializer.load(model, latest_checkpoint)
        return None

# Example usage
checkpoint_manager = CheckpointManager("checkpoints", max_to_keep=3)

# Simulate training with checkpoints
for epoch in range(10):
    # Simulate metric improvement
    metric = 1.0 - epoch * 0.08 + np.random.randn() * 0.05
    
    metadata = {
        'train_loss': metric,
        'learning_rate': 0.001 * (0.95 ** epoch)
    }
    
    checkpoint_manager.save_checkpoint(model, epoch, metric, metadata)
    print(f"Epoch {epoch}: metric={metric:.4f}")
    print("-" * 50)

In [None]:
# Load best model
print("\nLoading best model:")
best_metadata = checkpoint_manager.load_best(model)
if best_metadata:
    print(f"Best model from epoch {best_metadata['epoch']} with metric {best_metadata['metric']:.4f}")

## 4. Inference Optimization

For production deployment, we want to optimize models for fast inference.

In [None]:
class InferenceModel:
    """
    Optimized model wrapper for inference.
    """
    
    def __init__(self, model: bst.graph.Node):
        self.model = model
        # Create JIT-compiled inference function
        self._predict_fn = bst.transform.jit(self._predict)
        # Warmup
        self._warmup()
    
    def _predict(self, x: jnp.ndarray) -> jnp.ndarray:
        """Single prediction."""
        return self.model(x)
    
    def _warmup(self):
        """Warmup JIT compilation."""
        dummy = jnp.zeros((1, 784))
        _ = self._predict_fn(dummy)
    
    def predict(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Make predictions.
        
        Args:
            x: Input array
            
        Returns:
            Predictions
        """
        return self._predict_fn(x)
    
    def predict_proba(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Predict class probabilities.
        
        Args:
            x: Input array
            
        Returns:
            Class probabilities
        """
        logits = self.predict(x)
        return jax.nn.softmax(logits, axis=-1)
    
    def predict_class(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Predict class labels.
        
        Args:
            x: Input array
            
        Returns:
            Class labels
        """
        logits = self.predict(x)
        return jnp.argmax(logits, axis=-1)

# Create inference model
inference_model = InferenceModel(model)

# Test predictions
test_input = bst.random.randn(5, 784)
predictions = inference_model.predict_class(test_input)
probabilities = inference_model.predict_proba(test_input)

print("Predictions:", predictions)
print("\nTop-3 probabilities for first sample:")
top_3_idx = jnp.argsort(probabilities[0])[-3:][::-1]
for idx in top_3_idx:
    print(f"  Class {idx}: {probabilities[0, idx]:.4f}")

### 4.1 Benchmarking Inference Speed

In [None]:
def benchmark_inference(model_fn, input_shape, num_runs=100, batch_sizes=[1, 10, 50, 100]):
    """
    Benchmark inference speed.
    
    Args:
        model_fn: Function to call for inference
        input_shape: Shape of input (excluding batch)
        num_runs: Number of runs per batch size
        batch_sizes: List of batch sizes to test
    """
    results = []
    
    for batch_size in batch_sizes:
        # Create test data
        test_data = jnp.ones((batch_size,) + input_shape)
        
        # Warmup
        _ = model_fn(test_data)
        
        # Benchmark
        start_time = time.time()
        for _ in range(num_runs):
            _ = model_fn(test_data)
        end_time = time.time()
        
        total_time = end_time - start_time
        time_per_sample = (total_time / num_runs / batch_size) * 1000  # ms
        throughput = (num_runs * batch_size) / total_time
        
        results.append({
            'batch_size': batch_size,
            'time_per_sample_ms': time_per_sample,
            'throughput_samples_per_sec': throughput
        })
        
        print(f"Batch size {batch_size:3d}: {time_per_sample:.3f} ms/sample, "
              f"{throughput:.1f} samples/sec")
    
    return results

print("Benchmarking inference performance:\n")
benchmark_results = benchmark_inference(
    inference_model.predict,
    input_shape=(784,),
    num_runs=100,
    batch_sizes=[1, 10, 50, 100]
)

In [None]:
# Visualize benchmark results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

batch_sizes = [r['batch_size'] for r in benchmark_results]
latencies = [r['time_per_sample_ms'] for r in benchmark_results]
throughputs = [r['throughput_samples_per_sec'] for r in benchmark_results]

ax1.plot(batch_sizes, latencies, 'o-', linewidth=2, markersize=8)
ax1.set_xlabel('Batch Size')
ax1.set_ylabel('Latency (ms/sample)')
ax1.set_title('Inference Latency vs Batch Size')
ax1.grid(True, alpha=0.3)

ax2.plot(batch_sizes, throughputs, 's-', linewidth=2, markersize=8, color='green')
ax2.set_xlabel('Batch Size')
ax2.set_ylabel('Throughput (samples/sec)')
ax2.set_title('Inference Throughput vs Batch Size')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 5. Batch Processing Pipeline

For production systems, we often need to process large datasets efficiently.

In [None]:
class BatchProcessor:
    """
    Efficient batch processing for inference.
    """
    
    def __init__(self, 
                 model: InferenceModel,
                 batch_size: int = 32,
                 show_progress: bool = True):
        """
        Args:
            model: Inference model
            batch_size: Batch size for processing
            show_progress: Whether to show progress
        """
        self.model = model
        self.batch_size = batch_size
        self.show_progress = show_progress
    
    def process(self, data: np.ndarray) -> np.ndarray:
        """
        Process data in batches.
        
        Args:
            data: Input data array
            
        Returns:
            Predictions for all data
        """
        num_samples = len(data)
        num_batches = (num_samples + self.batch_size - 1) // self.batch_size
        
        predictions = []
        
        for i in range(num_batches):
            start_idx = i * self.batch_size
            end_idx = min(start_idx + self.batch_size, num_samples)
            
            batch = jnp.array(data[start_idx:end_idx])
            batch_predictions = self.model.predict_class(batch)
            predictions.append(np.array(batch_predictions))
            
            if self.show_progress:
                print(f"\rProcessed {end_idx}/{num_samples} samples", end="")
        
        if self.show_progress:
            print()  # New line
        
        return np.concatenate(predictions, axis=0)
    
    def process_with_probabilities(self, data: np.ndarray) -> tuple:
        """
        Process data and return both predictions and probabilities.
        
        Args:
            data: Input data array
            
        Returns:
            Tuple of (predictions, probabilities)
        """
        num_samples = len(data)
        num_batches = (num_samples + self.batch_size - 1) // self.batch_size
        
        all_predictions = []
        all_probabilities = []
        
        for i in range(num_batches):
            start_idx = i * self.batch_size
            end_idx = min(start_idx + self.batch_size, num_samples)
            
            batch = jnp.array(data[start_idx:end_idx])
            batch_probabilities = self.model.predict_proba(batch)
            batch_predictions = jnp.argmax(batch_probabilities, axis=-1)
            
            all_predictions.append(np.array(batch_predictions))
            all_probabilities.append(np.array(batch_probabilities))
            
            if self.show_progress:
                print(f"\rProcessed {end_idx}/{num_samples} samples", end="")
        
        if self.show_progress:
            print()
        
        predictions = np.concatenate(all_predictions, axis=0)
        probabilities = np.concatenate(all_probabilities, axis=0)
        
        return predictions, probabilities

# Example usage
processor = BatchProcessor(inference_model, batch_size=32)

# Generate test dataset
test_dataset = np.random.randn(1000, 784)

# Process all data
start_time = time.time()
predictions, probabilities = processor.process_with_probabilities(test_dataset)
end_time = time.time()

print(f"\nProcessed {len(test_dataset)} samples in {end_time - start_time:.2f} seconds")
print(f"Throughput: {len(test_dataset) / (end_time - start_time):.1f} samples/sec")
print(f"\nPrediction distribution:")
for class_id in range(10):
    count = np.sum(predictions == class_id)
    print(f"  Class {class_id}: {count} samples ({count/len(predictions)*100:.1f}%)")

## 6. Model Versioning

Managing multiple versions of models is crucial for production systems.

In [None]:
class ModelRegistry:
    """
    Model version registry for managing multiple model versions.
    """
    
    def __init__(self, registry_dir: str):
        self.registry_dir = Path(registry_dir)
        self.registry_dir.mkdir(parents=True, exist_ok=True)
        self.index_file = self.registry_dir / "registry.json"
        self.versions = self._load_index()
    
    def _load_index(self) -> Dict:
        """Load version index."""
        if self.index_file.exists():
            with open(self.index_file, 'r') as f:
                return json.load(f)
        return {}
    
    def _save_index(self):
        """Save version index."""
        with open(self.index_file, 'w') as f:
            json.dump(self.versions, f, indent=2)
    
    def register_model(self,
                      model: bst.graph.Node,
                      version: str,
                      description: str = "",
                      metrics: Optional[Dict[str, float]] = None,
                      tags: Optional[List[str]] = None) -> str:
        """
        Register a new model version.
        
        Args:
            model: Model to register
            version: Version string (e.g., "1.0.0")
            description: Description of this version
            metrics: Performance metrics
            tags: Tags for categorization
            
        Returns:
            Path to saved model
        """
        if version in self.versions:
            raise ValueError(f"Version {version} already exists")
        
        # Create version directory
        version_dir = self.registry_dir / f"v_{version}"
        version_dir.mkdir(exist_ok=True)
        
        # Save model
        model_path = version_dir / "model.pkl"
        metadata = {
            'version': version,
            'description': description,
            'metrics': metrics or {},
            'tags': tags or [],
            'registered_at': datetime.now().isoformat()
        }
        ModelSerializer.save(model, str(model_path), metadata)
        
        # Update registry
        self.versions[version] = {
            'path': str(model_path),
            'description': description,
            'metrics': metrics or {},
            'tags': tags or [],
            'registered_at': metadata['registered_at']
        }
        self._save_index()
        
        print(f"Model version {version} registered")
        return str(model_path)
    
    def load_version(self, model: bst.graph.Node, version: str) -> Dict:
        """
        Load a specific model version.
        
        Args:
            model: Model to load into
            version: Version to load
            
        Returns:
            Version metadata
        """
        if version not in self.versions:
            raise ValueError(f"Version {version} not found")
        
        model_path = self.versions[version]['path']
        metadata = ModelSerializer.load(model, model_path)
        
        return metadata
    
    def list_versions(self, tags: Optional[List[str]] = None) -> List[Dict]:
        """
        List all registered versions.
        
        Args:
            tags: Filter by tags (optional)
            
        Returns:
            List of version info dictionaries
        """
        versions = []
        for version, info in self.versions.items():
            if tags is None or any(tag in info['tags'] for tag in tags):
                versions.append({
                    'version': version,
                    **info
                })
        return versions
    
    def get_best_version(self, metric: str = 'accuracy') -> Optional[str]:
        """
        Get the version with the best metric.
        
        Args:
            metric: Metric name to optimize
            
        Returns:
            Best version string or None
        """
        best_version = None
        best_value = -float('inf')
        
        for version, info in self.versions.items():
            if metric in info['metrics']:
                value = info['metrics'][metric]
                if value > best_value:
                    best_value = value
                    best_version = version
        
        return best_version

# Example usage
registry = ModelRegistry("model_registry")

# Register multiple versions
versions_to_register = [
    ("1.0.0", "Initial release", {'accuracy': 0.85, 'f1': 0.83}, ['baseline']),
    ("1.1.0", "Improved architecture", {'accuracy': 0.89, 'f1': 0.87}, ['production']),
    ("1.2.0", "Fine-tuned hyperparameters", {'accuracy': 0.92, 'f1': 0.90}, ['production', 'best']),
    ("2.0.0", "Major redesign", {'accuracy': 0.91, 'f1': 0.89}, ['experimental']),
]

for version, desc, metrics, tags in versions_to_register:
    try:
        registry.register_model(model, version, desc, metrics, tags)
    except ValueError as e:
        print(f"Skipping {version}: {e}")

In [None]:
# List all versions
print("\nAll registered versions:")
all_versions = registry.list_versions()
for v in all_versions:
    print(f"\nVersion: {v['version']}")
    print(f"  Description: {v['description']}")
    print(f"  Metrics: {v['metrics']}")
    print(f"  Tags: {v['tags']}")
    print(f"  Registered: {v['registered_at']}")

In [None]:
# Get best version
best_version = registry.get_best_version('accuracy')
print(f"\nBest version by accuracy: {best_version}")

# List production versions
print("\nProduction versions:")
prod_versions = registry.list_versions(tags=['production'])
for v in prod_versions:
    print(f"  {v['version']}: accuracy={v['metrics']['accuracy']:.2f}")

## 7. Production Deployment Example

Let's put it all together in a complete deployment scenario.

In [None]:
class ProductionModel:
    """
    Production-ready model wrapper.
    """
    
    def __init__(self, 
                 model_class,
                 model_kwargs: Dict[str, Any],
                 registry: ModelRegistry,
                 version: Optional[str] = None):
        """
        Args:
            model_class: Model class to instantiate
            model_kwargs: Keyword arguments for model initialization
            registry: Model registry
            version: Version to load (if None, loads best)
        """
        # Create model
        self.model = model_class(**model_kwargs)
        
        # Initialize model
        dummy_input = jnp.zeros((1,) + (model_kwargs['input_dim'],))
        _ = self.model(dummy_input)
        
        # Load version
        if version is None:
            version = registry.get_best_version('accuracy')
            print(f"Loading best version: {version}")
        
        self.metadata = registry.load_version(self.model, version)
        self.version = version
        
        # Create inference model
        self.inference_model = InferenceModel(self.model)
        
        # Create batch processor
        self.processor = BatchProcessor(self.inference_model, batch_size=32, show_progress=False)
        
        print(f"Production model ready (version {self.version})")
        print(f"Metrics: {self.metadata.get('metrics', {})}")
    
    def predict(self, x: np.ndarray, return_probabilities: bool = False):
        """
        Make predictions.
        
        Args:
            x: Input data
            return_probabilities: Whether to return probabilities
            
        Returns:
            Predictions (and probabilities if requested)
        """
        if len(x) == 1:
            # Single prediction
            x_jnp = jnp.array(x)
            if return_probabilities:
                probs = self.inference_model.predict_proba(x_jnp)
                preds = jnp.argmax(probs, axis=-1)
                return np.array(preds), np.array(probs)
            else:
                preds = self.inference_model.predict_class(x_jnp)
                return np.array(preds)
        else:
            # Batch prediction
            if return_probabilities:
                return self.processor.process_with_probabilities(x)
            else:
                return self.processor.process(x)
    
    def get_info(self) -> Dict[str, Any]:
        """
        Get model information.
        
        Returns:
            Model info dictionary
        """
        return {
            'version': self.version,
            'metadata': self.metadata,
            'num_parameters': sum(p.value.size for p in self.model.states(bst.ParamState).values())
        }

# Deploy production model
production_model = ProductionModel(
    model_class=SimpleClassifier,
    model_kwargs={'input_dim': 784, 'hidden_dim': 128, 'output_dim': 10},
    registry=registry,
    version='1.2.0'  # Or None to load best
)

In [None]:
# Test production model
test_samples = np.random.randn(100, 784)

# Single prediction
single_pred, single_prob = production_model.predict(test_samples[:1], return_probabilities=True)
print(f"Single prediction: class {single_pred[0]}")
print(f"Confidence: {single_prob[0, single_pred[0]]:.4f}")

# Batch predictions
print("\nBatch predictions:")
batch_preds = production_model.predict(test_samples)
print(f"Processed {len(batch_preds)} samples")

# Get model info
info = production_model.get_info()
print(f"\nModel info:")
print(f"  Version: {info['version']}")
print(f"  Parameters: {info['num_parameters']:,}")
print(f"  Accuracy: {info['metadata'].get('metrics', {}).get('accuracy', 'N/A')}")

## 8. Monitoring and Logging

Production models need monitoring and logging capabilities.

In [None]:
class PredictionLogger:
    """
    Logger for tracking predictions and performance.
    """
    
    def __init__(self, log_dir: str):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.predictions = []
        self.latencies = []
        self.start_time = time.time()
    
    def log_prediction(self, 
                      input_data: np.ndarray,
                      prediction: int,
                      probability: float,
                      latency_ms: float):
        """
        Log a single prediction.
        
        Args:
            input_data: Input that was predicted on
            prediction: Predicted class
            probability: Prediction confidence
            latency_ms: Inference latency in milliseconds
        """
        entry = {
            'timestamp': datetime.now().isoformat(),
            'prediction': int(prediction),
            'probability': float(probability),
            'latency_ms': float(latency_ms)
        }
        self.predictions.append(entry)
        self.latencies.append(latency_ms)
    
    def get_stats(self) -> Dict[str, Any]:
        """
        Get statistics about logged predictions.
        
        Returns:
            Statistics dictionary
        """
        if not self.predictions:
            return {}
        
        latencies = np.array(self.latencies)
        predictions = [p['prediction'] for p in self.predictions]
        probabilities = [p['probability'] for p in self.predictions]
        
        return {
            'total_predictions': len(self.predictions),
            'uptime_seconds': time.time() - self.start_time,
            'latency_mean_ms': float(np.mean(latencies)),
            'latency_std_ms': float(np.std(latencies)),
            'latency_p50_ms': float(np.percentile(latencies, 50)),
            'latency_p95_ms': float(np.percentile(latencies, 95)),
            'latency_p99_ms': float(np.percentile(latencies, 99)),
            'avg_confidence': float(np.mean(probabilities)),
            'prediction_distribution': {int(i): predictions.count(i) for i in set(predictions)}
        }
    
    def save_logs(self):
        """
        Save logs to disk.
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = self.log_dir / f"predictions_{timestamp}.json"
        
        log_data = {
            'predictions': self.predictions,
            'statistics': self.get_stats()
        }
        
        with open(log_file, 'w') as f:
            json.dump(log_data, f, indent=2)
        
        print(f"Logs saved to {log_file}")

# Create logger
logger = PredictionLogger("prediction_logs")

# Simulate predictions with logging
print("Simulating logged predictions...\n")
for i in range(50):
    # Generate test input
    test_input = np.random.randn(1, 784)
    
    # Time prediction
    start = time.time()
    pred, prob = production_model.predict(test_input, return_probabilities=True)
    latency = (time.time() - start) * 1000  # Convert to ms
    
    # Log
    logger.log_prediction(
        test_input,
        pred[0],
        prob[0, pred[0]],
        latency
    )

# Get statistics
stats = logger.get_stats()
print("Prediction Statistics:")
print(f"  Total predictions: {stats['total_predictions']}")
print(f"  Average latency: {stats['latency_mean_ms']:.2f} ms")
print(f"  P95 latency: {stats['latency_p95_ms']:.2f} ms")
print(f"  P99 latency: {stats['latency_p99_ms']:.2f} ms")
print(f"  Average confidence: {stats['avg_confidence']:.4f}")
print(f"\n  Prediction distribution:")
for class_id, count in sorted(stats['prediction_distribution'].items()):
    print(f"    Class {class_id}: {count}")

# Save logs
logger.save_logs()

## 9. Advanced: Model Export

For deployment to other platforms, we can export model components.

In [None]:
class ModelExporter:
    """
    Export models to various formats.
    """
    
    @staticmethod
    def export_to_numpy(model: bst.graph.Node, filepath: str):
        """
        Export model weights to numpy format.
        
        Args:
            model: Model to export
            filepath: Output file path
        """
        states = model.states(bst.ParamState)
        
        # Organize weights by layer
        weights = {}
        for key, state in states.items():
            # Extract layer name and parameter type from key
            parts = key.split('/')
            layer_name = '/'.join(parts[:-1])
            param_name = parts[-1]
            
            if layer_name not in weights:
                weights[layer_name] = {}
            
            weights[layer_name][param_name] = np.array(state.value)
        
        # Save
        np.savez_compressed(filepath, **{k: v for layer in weights.values() for k, v in layer.items()})
        print(f"Model exported to {filepath}")
        
        return weights
    
    @staticmethod
    def export_architecture(model: bst.graph.Node, filepath: str):
        """
        Export model architecture description.
        
        Args:
            model: Model to export
            filepath: Output file path
        """
        # Get model structure
        nodes = model.nodes().values()
        
        architecture = {
            'model_class': model.__class__.__name__,
            'num_parameters': sum(p.value.size for p in model.states(bst.ParamState).values()),
            'layers': []
        }
        
        for node in nodes:
            layer_info = {
                'name': node.__class__.__name__,
                'type': str(type(node)),
            }
            
            # Add layer-specific info
            if hasattr(node, '__dict__'):
                for attr, value in node.__dict__.items():
                    if not attr.startswith('_') and isinstance(value, (int, float, str, bool)):
                        layer_info[attr] = value
            
            architecture['layers'].append(layer_info)
        
        # Save
        with open(filepath, 'w') as f:
            json.dump(architecture, f, indent=2)
        
        print(f"Architecture exported to {filepath}")
        return architecture

# Export model
exporter = ModelExporter()

# Export weights
weights = exporter.export_to_numpy(model, "exported_weights.npz")

# Export architecture
arch = exporter.export_architecture(model, "exported_architecture.json")

print("\nExported architecture:")
print(json.dumps(arch, indent=2))

## Summary

In this tutorial, we covered:

1. **Basic Saving/Loading**: Simple state serialization with NumPy
2. **Complete Serialization**: Saving models with metadata using pickle
3. **Checkpoint Management**: Periodic saving during training with best model tracking
4. **Inference Optimization**: JIT-compiled inference models with benchmarking
5. **Batch Processing**: Efficient processing of large datasets
6. **Model Versioning**: Registry system for managing multiple model versions
7. **Production Deployment**: Complete production-ready model wrapper
8. **Monitoring**: Prediction logging and performance tracking
9. **Model Export**: Exporting weights and architecture for other platforms

### Key Takeaways:

- Use `model.states(bst.ParamState)` to access all model parameters
- Save checkpoints periodically during training
- Use JIT compilation for faster inference
- Implement batch processing for large datasets
- Maintain a model registry for version control
- Log predictions for monitoring and debugging
- Benchmark inference performance for optimization

## Next Steps

Now that you've completed all BrainState tutorials, you can:

1. Build production ML systems with BrainState
2. Implement custom deployment pipelines
3. Integrate with serving frameworks (Flask, FastAPI, etc.)
4. Deploy to cloud platforms (AWS, GCP, Azure)
5. Contribute to the BrainState project

For more information, visit the [BrainState documentation](https://brainstate.readthedocs.io/).