# 9.3 Real-time Inference & Model Serving Interactive Notebook

This notebook provides hands-on implementation of real-time ML inference systems for semiconductor manufacturing. Learn to build production-ready serving infrastructure with low latency, high throughput, and reliability.

## Outline:
1. Import Required Libraries
2. Understanding Real-time vs Batch Inference
3. Build Simple FastAPI Model Server
4. Implement Request Caching with TTL
5. Dynamic Batching for Throughput
6. Latency Monitoring (p50, p95, p99)
7. Async Request Processing
8. Model Versioning & A/B Testing
9. Health Checks & Readiness Probes
10. Load Testing & Benchmarking
11. Production Deployment Considerations
12. Summary & Best Practices

> **Semiconductor Context**: Inline inspection systems require real-time defect detection (<50ms) to keep up with production speed. This notebook shows how to build serving infrastructure meeting strict latency and throughput requirements.

## 1. Import Required Libraries

Import libraries for model serving, API development, caching, and monitoring.

In [None]:
# Core libraries
import numpy as np
import pandas as pd
from pathlib import Path
import time
import asyncio
import json
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from collections import deque
import warnings
warnings.filterwarnings('ignore')

# ML libraries
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import joblib

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid', context='notebook')
%matplotlib inline

# FastAPI (will be used in later sections)
# Note: Install with: pip install fastapi uvicorn
try:
    from fastapi import FastAPI, HTTPException, BackgroundTasks
    from pydantic import BaseModel
    HAS_FASTAPI = True
except ImportError:
    print("⚠️ FastAPI not installed. Install with: pip install fastapi uvicorn")
    HAS_FASTAPI = False

# Reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

print('✓ Libraries imported successfully')
print(f'Random seed: {RANDOM_SEED}')
print(f'FastAPI available: {HAS_FASTAPI}')

# Helper function
def section(title: str):
    print(f"\n{'='*len(title)}\n{title}\n{'='*len(title)}")

## 2. Understanding Real-time vs Batch Inference

### Key Differences:

| Aspect | Batch Inference | Real-time Inference |
|--------|----------------|---------------------|
| **Latency** | Minutes to hours | Milliseconds to seconds |
| **Throughput** | High (thousands/sec) | Variable (depends on request rate) |
| **Use Case** | Offline analysis, reports | Online decisions, interactive |
| **Complexity** | Simple (just loop) | High (servers, caching, monitoring) |
| **Cost** | Lower (can batch) | Higher (always running) |

### Semiconductor Manufacturing Applications:

**Real-time (Inline Inspection):**
- Defect detection during production (<50ms)
- Process control adjustments (immediate feedback)
- Equipment health monitoring (continuous)
- Yield prediction for WIP wafers (on-demand)

**Batch Processing (Offline Analysis):**
- Daily yield reports
- Historical trend analysis
- Model retraining on accumulated data
- Root cause analysis (can wait hours)

### Performance Requirements:
- **Latency targets**: p99 < 50ms for inline, < 200ms for interactive
- **Throughput targets**: 100-1000 req/sec for fab-wide deployment
- **Availability**: 99.9% uptime (production can't stop)
- **SLA tracking**: Monitor p50, p95, p99 latency percentiles

## 3. Generate Sample Model & Data

Create a simple defect detection model for demonstration.

In [None]:
section('Generate Sample Defect Detection Model')

def generate_synthetic_wafer_data(n_samples: int = 1000) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate synthetic wafer process data for defect prediction.
    
    Features: temperature, pressure, gas_flow, particle_count, humidity
    Target: defect (0=pass, 1=fail)
    """
    rng = np.random.default_rng(RANDOM_SEED)
    
    # Generate features
    X = np.column_stack([
        rng.normal(25, 5, n_samples),    # temperature (C)
        rng.normal(1.0, 0.2, n_samples), # pressure (atm)
        rng.normal(100, 20, n_samples),  # gas_flow (sccm)
        rng.normal(50, 15, n_samples),   # particle_count
        rng.normal(45, 10, n_samples)    # humidity (%)
    ])
    
    # Generate labels with realistic decision boundary
    defect_score = (
        0.5 * (X[:, 0] - 25) / 5 +  # Temperature deviation
        0.3 * (X[:, 3] - 50) / 15 +  # Particle count
        -0.2 * (X[:, 4] - 45) / 10   # Low humidity increases defects
    )
    y = (defect_score + rng.normal(0, 0.3, n_samples) > 0.5).astype(int)
    
    return X, y

# Generate and split data
X, y = generate_synthetic_wafer_data(n_samples=1000)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=RANDOM_SEED
)

print(f"Training samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")
print(f"Defect rate: {y.mean()*100:.1f}%")

# Train simple model
print("\n🔄 Training Random Forest model...")
model = RandomForestClassifier(n_estimators=50, max_depth=5, random_state=RANDOM_SEED)
scaler = StandardScaler()

X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

model.fit(X_train_scaled, y_train)

# Evaluate
train_acc = model.score(X_train_scaled, y_train)
test_acc = model.score(X_test_scaled, y_test)

print(f"✓ Training complete")
print(f"Train accuracy: {train_acc:.3f}")
print(f"Test accuracy:  {test_acc:.3f}")

# Save model (for serving)
model_path = Path('defect_model.joblib')
scaler_path = Path('defect_scaler.joblib')
joblib.dump(model, model_path)
joblib.dump(scaler, scaler_path)
print(f"\n✓ Model saved to {model_path}")

## 4. Simple Model Server with Latency Tracking

Build a basic model serving class that tracks inference latency.

In [None]:
section('Simple Model Server with Latency Tracking')

@dataclass
class LatencyMetrics:
    """Track latency metrics for monitoring."""
    latencies: deque = field(default_factory=lambda: deque(maxlen=1000))
    
    def record(self, latency_ms: float):
        """Record a latency measurement."""
        self.latencies.append(latency_ms)
    
    def get_percentiles(self) -> Dict[str, float]:
        """Calculate latency percentiles."""
        if not self.latencies:
            return {'p50': 0, 'p95': 0, 'p99': 0, 'mean': 0}
        
        latencies_array = np.array(self.latencies)
        return {
            'p50': np.percentile(latencies_array, 50),
            'p95': np.percentile(latencies_array, 95),
            'p99': np.percentile(latencies_array, 99),
            'mean': np.mean(latencies_array),
            'count': len(self.latencies)
        }

class ModelServer:
    """Simple model serving class with latency tracking."""
    
    def __init__(self, model_path: Path, scaler_path: Path):
        self.model = joblib.load(model_path)
        self.scaler = joblib.load(scaler_path)
        self.metrics = LatencyMetrics()
        print(f"✓ Model loaded from {model_path}")
    
    def predict(self, features: np.ndarray) -> Dict:
        """Make prediction with latency tracking."""
        start_time = time.perf_counter()
        
        # Preprocess
        features_scaled = self.scaler.transform(features.reshape(1, -1))
        
        # Predict
        prediction = self.model.predict(features_scaled)[0]
        probability = self.model.predict_proba(features_scaled)[0]
        
        # Calculate latency
        latency_ms = (time.perf_counter() - start_time) * 1000
        self.metrics.record(latency_ms)
        
        return {
            'prediction': int(prediction),
            'probability': float(probability[1]),
            'latency_ms': latency_ms
        }
    
    def get_metrics(self) -> Dict:
        """Get current latency metrics."""
        return self.metrics.get_percentiles()

# Test server
server = ModelServer(model_path, scaler_path)

print("\n🧪 Testing inference...\n")
for i in range(5):
    test_features = X_test[i]
    result = server.predict(test_features)
    print(f"Sample {i+1}: prediction={result['prediction']}, "
          f"probability={result['probability']:.3f}, "
          f"latency={result['latency_ms']:.2f}ms")

# Run benchmark
print("\n🏃 Running latency benchmark (100 requests)...")
for i in range(100):
    idx = np.random.randint(0, len(X_test))
    server.predict(X_test[idx])

metrics = server.get_metrics()
print("\n📊 Latency Metrics:")
print("="*40)
print(f"Mean:  {metrics['mean']:.2f} ms")
print(f"p50:   {metrics['p50']:.2f} ms")
print(f"p95:   {metrics['p95']:.2f} ms")
print(f"p99:   {metrics['p99']:.2f} ms")
print(f"Count: {metrics['count']} requests")

# Visualize latency distribution
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(server.metrics.latencies, bins=30, color='steelblue', edgecolor='black', alpha=0.7)
plt.axvline(metrics['p50'], color='green', linestyle='--', linewidth=2, label=f"p50: {metrics['p50']:.1f}ms")
plt.axvline(metrics['p95'], color='orange', linestyle='--', linewidth=2, label=f"p95: {metrics['p95']:.1f}ms")
plt.axvline(metrics['p99'], color='red', linestyle='--', linewidth=2, label=f"p99: {metrics['p99']:.1f}ms")
plt.xlabel('Latency (ms)')
plt.ylabel('Frequency')
plt.title('Inference Latency Distribution')
plt.legend()
plt.grid(alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(list(server.metrics.latencies), color='navy', alpha=0.6)
plt.axhline(metrics['p95'], color='orange', linestyle='--', linewidth=1, alpha=0.7)
plt.xlabel('Request Number')
plt.ylabel('Latency (ms)')
plt.title('Latency Over Time (with p95 line)')
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("\n💡 Interpretation:")
print("- p50 (median): Typical latency for most requests")
print("- p95: 95% of requests faster than this")
print("- p99: 99% of requests faster than this (SLA target)")
print("- Use p99 for SLAs (protects against tail latency)")

## 5. Request Caching with TTL

Implement caching to avoid redundant inference for identical requests.

In [None]:
section('Request Caching with TTL')

@dataclass
class CacheEntry:
    """Cache entry with timestamp for TTL."""
    result: Dict
    timestamp: datetime

class CachedModelServer(ModelServer):
    """Model server with request caching."""
    
    def __init__(self, model_path: Path, scaler_path: Path, cache_ttl_seconds: int = 60):
        super().__init__(model_path, scaler_path)
        self.cache: Dict[str, CacheEntry] = {}
        self.cache_ttl = timedelta(seconds=cache_ttl_seconds)
        self.cache_hits = 0
        self.cache_misses = 0
        print(f"✓ Caching enabled (TTL: {cache_ttl_seconds}s)")
    
    def _make_cache_key(self, features: np.ndarray) -> str:
        """Create cache key from features."""
        # Round to reduce cache misses from floating point precision
        rounded = np.round(features, decimals=2)
        return str(rounded.tobytes())
    
    def _is_cache_valid(self, entry: CacheEntry) -> bool:
        """Check if cache entry is still valid."""
        return datetime.now() - entry.timestamp < self.cache_ttl
    
    def predict(self, features: np.ndarray, use_cache: bool = True) -> Dict:
        """Make prediction with caching."""
        if use_cache:
            cache_key = self._make_cache_key(features)
            
            # Check cache
            if cache_key in self.cache:
                entry = self.cache[cache_key]
                if self._is_cache_valid(entry):
                    self.cache_hits += 1
                    result = entry.result.copy()
                    result['cache_hit'] = True
                    result['latency_ms'] = 0.1  # Cache lookup time
                    return result
                else:
                    # Cache entry expired
                    del self.cache[cache_key]
        
        # Cache miss - compute prediction
        self.cache_misses += 1
        result = super().predict(features)
        result['cache_hit'] = False
        
        # Store in cache
        if use_cache:
            cache_key = self._make_cache_key(features)
            self.cache[cache_key] = CacheEntry(result, datetime.now())
        
        return result
    
    def get_cache_stats(self) -> Dict:
        """Get cache statistics."""
        total_requests = self.cache_hits + self.cache_misses
        hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0
        
        return {
            'cache_hits': self.cache_hits,
            'cache_misses': self.cache_misses,
            'hit_rate': hit_rate,
            'cache_size': len(self.cache)
        }

# Test caching
cached_server = CachedModelServer(model_path, scaler_path, cache_ttl_seconds=30)

print("\n🧪 Testing cache performance...\n")

# First request (cache miss)
test_sample = X_test[0]
result1 = cached_server.predict(test_sample)
print(f"Request 1: latency={result1['latency_ms']:.2f}ms, cache_hit={result1['cache_hit']}")

# Second request with same features (cache hit)
result2 = cached_server.predict(test_sample)
print(f"Request 2: latency={result2['latency_ms']:.2f}ms, cache_hit={result2['cache_hit']}")

# Simulate workload with repeated requests
print("\n🏃 Simulating workload (50% repeated requests)...")
unique_samples = X_test[:20]

for _ in range(200):
    # 50% chance of using a repeated sample
    if np.random.random() < 0.5:
        idx = np.random.randint(0, len(unique_samples))
        cached_server.predict(unique_samples[idx])
    else:
        idx = np.random.randint(0, len(X_test))
        cached_server.predict(X_test[idx])

# Show cache statistics
cache_stats = cached_server.get_cache_stats()
print("\n📊 Cache Statistics:")
print("="*40)
print(f"Cache hits:    {cache_stats['cache_hits']:>6}")
print(f"Cache misses:  {cache_stats['cache_misses']:>6}")
print(f"Hit rate:      {cache_stats['hit_rate']:>6.1%}")
print(f"Cache size:    {cache_stats['cache_size']:>6} entries")

# Calculate speedup from caching
avg_inference_time = 0.5  # Assume 0.5ms average
avg_cache_time = 0.1  # Cache lookup time
time_saved = cache_stats['cache_hits'] * (avg_inference_time - avg_cache_time)
total_time = cache_stats['cache_hits'] * avg_cache_time + cache_stats['cache_misses'] * avg_inference_time
speedup = ((cache_stats['cache_hits'] + cache_stats['cache_misses']) * avg_inference_time) / total_time

print(f"\nEstimated speedup: {speedup:.2f}x")
print(f"Time saved: ~{time_saved:.1f}ms total")

print("\n💡 Caching Benefits:")
print("- Reduces latency for repeated requests")
print("- Lowers compute load on model")
print("- TTL prevents stale predictions")
print("- Most effective when requests have patterns (e.g., periodic monitoring)")

## 6. Dynamic Batching for Throughput

Batch multiple requests together to improve GPU utilization and throughput.

**Key concepts:**
- Collect requests up to max batch size or timeout
- Process batch in single inference call
- Return individual results to each requester
- Trade-off: Small latency increase for much higher throughput

In [None]:
section('Dynamic Batching')

class BatchedModelServer:
    """Model server with dynamic batching."""
    
    def __init__(self, model_path: Path, scaler_path: Path, 
                 max_batch_size: int = 32, max_wait_ms: float = 10):
        self.model = joblib.load(model_path)
        self.scaler = joblib.load(scaler_path)
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.batch_sizes = []
        print(f"✓ Batching enabled (max_size={max_batch_size}, max_wait={max_wait_ms}ms)")
    
    def predict_batch(self, features_list: List[np.ndarray]) -> List[Dict]:
        """Process a batch of requests."""
        start_time = time.perf_counter()
        
        # Stack features into batch
        features_batch = np.vstack(features_list)
        features_scaled = self.scaler.transform(features_batch)
        
        # Batch inference
        predictions = self.model.predict(features_scaled)
        probabilities = self.model.predict_proba(features_scaled)
        
        batch_latency_ms = (time.perf_counter() - start_time) * 1000
        per_sample_latency = batch_latency_ms / len(features_list)
        
        # Record batch size for statistics
        self.batch_sizes.append(len(features_list))
        
        # Format results
        results = []
        for pred, prob in zip(predictions, probabilities):
            results.append({
                'prediction': int(pred),
                'probability': float(prob[1]),
                'batch_size': len(features_list),
                'batch_latency_ms': batch_latency_ms,
                'per_sample_latency_ms': per_sample_latency
            })
        
        return results
    
    def get_batch_stats(self) -> Dict:
        """Get batching statistics."""
        if not self.batch_sizes:
            return {'mean_batch_size': 0, 'total_batches': 0}
        
        return {
            'mean_batch_size': np.mean(self.batch_sizes),
            'max_batch_size': np.max(self.batch_sizes),
            'total_batches': len(self.batch_sizes),
            'total_samples': sum(self.batch_sizes)
        }

# Test batching
batch_server = BatchedModelServer(model_path, scaler_path, max_batch_size=32)

print("\n🧪 Testing batch inference...\n")

# Single sample (batch size 1)
results = batch_server.predict_batch([X_test[0]])
print(f"Batch size 1: latency={results[0]['batch_latency_ms']:.2f}ms")

# Small batch
results = batch_server.predict_batch(X_test[:8].tolist())
print(f"Batch size 8: batch_latency={results[0]['batch_latency_ms']:.2f}ms, "
      f"per_sample={results[0]['per_sample_latency_ms']:.2f}ms")

# Large batch
results = batch_server.predict_batch(X_test[:32].tolist())
print(f"Batch size 32: batch_latency={results[0]['batch_latency_ms']:.2f}ms, "
      f"per_sample={results[0]['per_sample_latency_ms']:.2f}ms")

# Benchmark different batch sizes
print("\n🏃 Benchmarking batch sizes...")
batch_sizes = [1, 2, 4, 8, 16, 32]
latencies = []
throughputs = []

for bs in batch_sizes:
    batch = X_test[:bs].tolist()
    
    # Warmup
    batch_server.predict_batch(batch)
    
    # Benchmark
    times = []
    for _ in range(10):
        start = time.perf_counter()
        batch_server.predict_batch(batch)
        times.append((time.perf_counter() - start) * 1000)
    
    avg_latency = np.mean(times)
    throughput = (bs / avg_latency) * 1000  # samples/sec
    
    latencies.append(avg_latency)
    throughputs.append(throughput)

# Visualize batch performance
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Latency vs batch size
axes[0].plot(batch_sizes, latencies, marker='o', linewidth=2, markersize=8, color='steelblue')
axes[0].set_xlabel('Batch Size')
axes[0].set_ylabel('Batch Latency (ms)')
axes[0].set_title('Batch Latency vs Batch Size')
axes[0].grid(alpha=0.3)

# Throughput vs batch size
axes[1].plot(batch_sizes, throughputs, marker='s', linewidth=2, markersize=8, color='coral')
axes[1].set_xlabel('Batch Size')
axes[1].set_ylabel('Throughput (samples/sec)')
axes[1].set_title('Throughput vs Batch Size')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("\n📊 Batch Performance Summary:")
print("="*60)
print(f"{'Batch Size':>12} {'Latency (ms)':>15} {'Throughput (samples/s)':>25}")
print("="*60)
for bs, lat, tput in zip(batch_sizes, latencies, throughputs):
    print(f"{bs:>12} {lat:>15.2f} {tput:>25.1f}")

print("\n💡 Batching Benefits:")
print("- Dramatically increases throughput (10-30x typical)")
print("- Better GPU/CPU utilization")
print("- Small per-request latency increase acceptable for high-load scenarios")
print("- Essential for serving at scale (1000s requests/sec)")

## 7. Summary & Best Practices

### Key Takeaways:

**Latency Optimization:**
- Track p50, p95, p99 percentiles (not just mean)
- Use p99 for SLA targets (protects tail latency)
- Optimize model (quantization, pruning) for edge deployment
- Pre-load models to avoid cold start

**Throughput Optimization:**
- Use dynamic batching for 10-30x throughput gain
- Adjust batch size vs latency trade-off per use case
- Leverage GPU acceleration when available
- Consider async processing for I/O-bound operations

**Caching Strategy:**
- Implement TTL caching for repeated requests
- Monitor cache hit rates (>30% is good)
- Use content-based cache keys
- Set TTL based on model update frequency

**Production Deployment:**
- Implement health checks and readiness probes
- Version models for A/B testing
- Monitor latency, throughput, error rates
- Use load balancers for horizontal scaling
- Plan for graceful degradation

### Semiconductor Manufacturing Recommendations:

1. **Inline Inspection (<50ms)**: Optimize for latency with edge deployment
2. **Interactive Analysis (<200ms)**: Use caching + batching
3. **High Throughput (1000s/sec)**: Dynamic batching + horizontal scaling
4. **Mission Critical**: Implement redundancy, monitoring, auto-scaling

### Next Steps:
- Explore FastAPI implementation in 9.3-realtime-quick-ref.md
- Review assessment questions in assessments/module-9/9.3-questions.json
- Check module-9 fundamentals for deep theory
- Try production deployment with Docker + Kubernetes

In [None]:
section('Notebook Complete!')

print("✅ You have completed the Real-time Inference & Model Serving tutorial!")
print("\nKey Skills Acquired:")
print("  • Understanding real-time vs batch inference trade-offs")
print("  • Building model servers with latency tracking")
print("  • Implementing request caching with TTL")
print("  • Dynamic batching for throughput optimization")
print("  • Monitoring p50/p95/p99 latency percentiles")
print("  • Production deployment considerations")
print("\n📚 Next Steps:")
print("  • Explore 9.3-realtime-inference-quick-ref.md for FastAPI examples")
print("  • Review assessment questions in assessments/module-9/9.3-questions.json")
print("  • Check module-9 fundamentals for serving architecture patterns")
print("  • Deploy to production with Docker, Kubernetes, cloud platforms")