# Lecture 83 – Real-Time Inference

## Learning Objectives
- Understand batch vs real-time inference
- Implement WebSocket for streaming predictions
- Use FastAPI background tasks
- Monitor inference performance
- Optimize for low latency

## Expected Runtime
~5 minutes

---

In [None]:
# !pip install fastapi uvicorn websockets

In [None]:
import asyncio
import time
from fastapi import FastAPI, WebSocket, BackgroundTasks
from typing import List
import numpy as np

print("✓ Imports successful")

## 1. Batch vs Real-Time Inference

| Aspect | Batch | Real-Time |
|--------|-------|----------|
| Latency | High (minutes-hours) | Low (ms-seconds) |
| Throughput | Very high | Moderate |
| Cost | Lower | Higher |
| Use Cases | Reports, analytics | User-facing apps |
| Complexity | Simpler | More complex |

## 2. FastAPI Background Tasks

In [None]:
# Example: Async prediction with background logging
from fastapi import FastAPI

app = FastAPI()

def log_prediction(request_id: str, result: dict):
    """Background task for logging."""
    # Simulate logging to database
    time.sleep(0.1)
    print(f"Logged prediction {request_id}: {result}")

@app.post("/predict-async")
async def predict_async(data: dict, background_tasks: BackgroundTasks):
    """Predict with background logging."""
    request_id = str(time.time())
    
    # Make prediction (fast)
    result = {"prediction": "class_1", "confidence": 0.95}
    
    # Schedule background task (don't wait)
    background_tasks.add_task(log_prediction, request_id, result)
    
    return result

print("✓ Background task example created")

## 3. WebSocket for Streaming Predictions

In [None]:
# WebSocket endpoint for streaming
@app.websocket("/ws/predict")
async def websocket_predict(websocket: WebSocket):
    """Stream predictions via WebSocket."""
    await websocket.accept()
    
    try:
        while True:
            # Receive data
            data = await websocket.receive_json()
            
            # Simulate prediction
            prediction = {
                "class": np.random.randint(0, 10),
                "confidence": float(np.random.random()),
                "timestamp": time.time()
            }
            
            # Send result
            await websocket.send_json(prediction)
            
    except Exception as e:
        print(f"WebSocket error: {e}")
    finally:
        await websocket.close()

print("✓ WebSocket endpoint created")

## 4. WebSocket Client Example

In [None]:
# Client code example
client_code = '''
import asyncio
import websockets
import json

async def stream_predictions():
    uri = "ws://localhost:8000/ws/predict"
    
    async with websockets.connect(uri) as websocket:
        # Send requests
        for i in range(10):
            data = {"image": [[0]*28]*28}
            await websocket.send(json.dumps(data))
            
            # Receive prediction
            response = await websocket.recv()
            print(f"Prediction {i}: {response}")
            
            await asyncio.sleep(0.1)

# Run
asyncio.run(stream_predictions())
'''

print("WebSocket Client Example:")
print(client_code)

## 5. Performance Monitoring

In [None]:
from fastapi import Request
from prometheus_client import Counter, Histogram, generate_latest
import time

# Prometheus metrics
REQUEST_COUNT = Counter('http_requests_total', 'Total HTTP requests', ['method', 'endpoint'])
REQUEST_LATENCY = Histogram('http_request_duration_seconds', 'HTTP request latency')

@app.middleware("http")
async def metrics_middleware(request: Request, call_next):
    """Middleware to collect metrics."""
    start_time = time.time()
    
    # Process request
    response = await call_next(request)
    
    # Record metrics
    duration = time.time() - start_time
    REQUEST_COUNT.labels(method=request.method, endpoint=request.url.path).inc()
    REQUEST_LATENCY.observe(duration)
    
    return response

@app.get("/metrics")
async def metrics():
    """Prometheus metrics endpoint."""
    return generate_latest()

print("✓ Monitoring middleware created")

## 6. Latency Optimization Techniques

In [None]:
optimization_tips = '''
### Model Optimization
1. **Quantization**: Convert FP32 to INT8
   - TensorFlow Lite, ONNX Runtime
   - 4x smaller, 2-4x faster

2. **Pruning**: Remove unnecessary weights
   - Can reduce model size by 90%

3. **Knowledge Distillation**: Train smaller model
   - Teacher-student approach

### Serving Optimization
1. **Batch predictions**: Group requests
   - Higher throughput
   - Trade latency for efficiency

2. **Model caching**: Keep model in memory
   - Load once at startup
   - Use global variables

3. **GPU acceleration**: Use CUDA
   - 10-100x speedup for large models

4. **TensorRT / ONNX**: Optimize graph
   - Fuse operations
   - Hardware-specific optimization

### Infrastructure
1. **Load balancing**: Distribute requests
   - Multiple replicas
   - Auto-scaling

2. **Caching**: Redis for frequent requests
   - Sub-millisecond retrieval

3. **CDN**: Edge deployment
   - Reduce network latency
'''

print(optimization_tips)

## 7. Batching for Throughput

In [None]:
import asyncio
from collections import deque

class BatchPredictor:
    """Batch predictions for higher throughput."""
    
    def __init__(self, model, max_batch_size=32, max_wait_time=0.1):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_wait_time = max_wait_time
        self.queue = deque()
        self.lock = asyncio.Lock()
    
    async def predict(self, data):
        """Add to batch and wait for result."""
        future = asyncio.Future()
        
        async with self.lock:
            self.queue.append((data, future))
        
        # Trigger batch if full
        if len(self.queue) >= self.max_batch_size:
            asyncio.create_task(self._process_batch())
        
        return await future
    
    async def _process_batch(self):
        """Process accumulated batch."""
        async with self.lock:
            if not self.queue:
                return
            
            batch = list(self.queue)
            self.queue.clear()
        
        # Extract data
        data_batch = [d for d, _ in batch]
        futures = [f for _, f in batch]
        
        # Batch predict
        results = self.model.predict(np.array(data_batch))
        
        # Set results
        for future, result in zip(futures, results):
            future.set_result(result)

print("✓ Batch predictor created")

## Summary

✓ Compared batch and real-time inference  
✓ Implemented background tasks  
✓ Created WebSocket streaming  
✓ Added performance monitoring  
✓ Learned optimization techniques  

### Production Checklist:
- [ ] Model quantization
- [ ] Request batching
- [ ] GPU optimization
- [ ] Metrics collection
- [ ] Auto-scaling
- [ ] Circuit breakers
- [ ] Rate limiting

**Next**: `06_hands_on_lab_deploy_sentiment_or_cnn.ipynb`