# Lecture 83 – Serving Models with FastAPI

## Learning Objectives
- Build a production-ready REST API for model serving
- Use FastAPI for high-performance inference endpoints
- Implement request validation with Pydantic
- Add health checks and metadata endpoints
- Test API with cURL and Python requests

## Expected Runtime
~3-5 minutes (API server runs in background)

## Prerequisites
- Completed notebook 01 (saved model artifacts)
- FastAPI, Uvicorn, Pydantic
- Python requests library for testing

---

## Setup and Environment Check

In [None]:
# Install required packages
# !pip install fastapi==0.104.1 uvicorn[standard]==0.24.0 pydantic==2.5.0 python-multipart requests

In [None]:
import sys
import json
import numpy as np
from pathlib import Path

try:
    import fastapi
    import uvicorn
    import pydantic
    import requests
    print(f"✓ FastAPI version: {fastapi.__version__}")
    print(f"✓ Uvicorn version: {uvicorn.__version__}")
    print(f"✓ Pydantic version: {pydantic.__version__}")
except ImportError as e:
    print(f"Missing package: {e}")
    print("Run: pip install fastapi uvicorn[standard] pydantic requests")

## 1. Review the FastAPI Application Structure

We've created a complete FastAPI application in `apps/fastapi_app/`. Let's examine it.

In [None]:
# Check if the FastAPI app exists
app_path = Path('../apps/fastapi_app/app.py')

if app_path.exists():
    print(f"✓ FastAPI app found at: {app_path}")
    print(f"\nApp structure (first 50 lines):")
    with open(app_path, 'r') as f:
        lines = f.readlines()[:50]
        print(''.join(lines))
else:
    print(f"✗ App not found. Please run notebook 01 first or create the app.")

## 2. Understanding the API Endpoints

Our FastAPI application provides:

1. **GET /ping** - Health check endpoint
2. **POST /predict** - Main inference endpoint
3. **GET /metadata** - Model information
4. **GET /** - Root endpoint with API documentation

### Key Features:
- **Pydantic Models**: Type-safe request/response validation
- **Async Support**: Non-blocking I/O for better performance
- **Error Handling**: Proper HTTP status codes and error messages
- **Auto Documentation**: Swagger UI at `/docs`

## 3. Starting the FastAPI Server

**Note**: We'll start the server in a separate terminal or background process.

In [None]:
# Create a simple test to check if server is running
import requests
import time

def check_server_running(url="http://localhost:8000", max_retries=3):
    """Check if the FastAPI server is running."""
    for i in range(max_retries):
        try:
            response = requests.get(f"{url}/ping", timeout=2)
            if response.status_code == 200:
                print(f"✓ Server is running at {url}")
                return True
        except requests.exceptions.RequestException:
            if i < max_retries - 1:
                print(f"Waiting for server... (attempt {i+1}/{max_retries})")
                time.sleep(2)
    
    print(f"✗ Server not running at {url}")
    print("\nTo start the server, run in a terminal:")
    print("  cd apps/fastapi_app")
    print("  uvicorn app:app --host 0.0.0.0 --port 8000 --reload")
    return False

SERVER_URL = "http://localhost:8000"
server_running = check_server_running(SERVER_URL)

## 4. Testing API Endpoints

### 4.1 Health Check

In [None]:
if server_running:
    response = requests.get(f"{SERVER_URL}/ping")
    print(f"Status Code: {response.status_code}")
    print(f"Response: {response.json()}")
else:
    print("Server not running. Example response:")
    print('{"status": "healthy", "model_loaded": true}')

### 4.2 Get Model Metadata

In [None]:
if server_running:
    response = requests.get(f"{SERVER_URL}/metadata")
    print(f"Status Code: {response.status_code}")
    print(f"\nModel Metadata:")
    print(json.dumps(response.json(), indent=2))
else:
    print("Server not running. Example metadata:")
    example_metadata = {
        "model_name": "fashion_mnist_cnn",
        "version": "20241102_120000",
        "input_shape": [28, 28, 1],
        "output_classes": 10,
        "class_names": ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
                       "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
    }
    print(json.dumps(example_metadata, indent=2))

### 4.3 Make Predictions

Let's create a test image and send it to the API.

In [None]:
# Load a test image from Fashion-MNIST
import tensorflow as tf

(_, _), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

# Get a few test images (raw, not normalized)
test_images = x_test[:3].tolist()  # Convert to list for JSON serialization

print(f"Loaded {len(test_images)} test images")
print(f"Image shape: {np.array(test_images[0]).shape}")
print(f"Actual labels: {y_test[:3]}")

In [None]:
# Prepare the request payload
payload = {
    "instances": test_images
}

print(f"Payload contains {len(payload['instances'])} images")
print(f"Payload size: ~{len(json.dumps(payload)) / 1024:.2f} KB")

In [None]:
if server_running:
    # Send prediction request
    response = requests.post(
        f"{SERVER_URL}/predict",
        json=payload,
        headers={"Content-Type": "application/json"}
    )
    
    print(f"Status Code: {response.status_code}")
    
    if response.status_code == 200:
        result = response.json()
        print(f"\nPredictions:")
        for i, pred in enumerate(result['predictions']):
            print(f"  Image {i+1}: {pred['class_name']} (confidence: {pred['confidence']:.4f})")
            print(f"           Actual: {['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'][y_test[i]]}")
    else:
        print(f"Error: {response.text}")
else:
    print("Server not running. Example prediction response:")
    example_response = {
        "predictions": [
            {"class_id": 9, "class_name": "Ankle boot", "confidence": 0.9876},
            {"class_id": 2, "class_name": "Pullover", "confidence": 0.8543},
            {"class_id": 1, "class_name": "Trouser", "confidence": 0.9234}
        ],
        "model_version": "20241102_120000"
    }
    print(json.dumps(example_response, indent=2))

## 5. Testing with cURL

You can also test the API using cURL commands from the terminal:

In [None]:
# Generate example cURL commands
curl_commands = f"""
# Health check
curl -X GET "{SERVER_URL}/ping"

# Get metadata
curl -X GET "{SERVER_URL}/metadata"

# Make prediction (with single image)
curl -X POST "{SERVER_URL}/predict" \\
  -H "Content-Type: application/json" \\
  -d '{{
    "instances": [[0, 0, 0, ..., 0]]  # 28x28 = 784 values
  }}'

# View interactive documentation
open {SERVER_URL}/docs
"""

print("Example cURL commands:")
print(curl_commands)

## 6. Performance Testing

Let's measure API latency and throughput.

In [None]:
import time

def benchmark_api(url, num_requests=10, batch_size=1):
    """Benchmark API performance."""
    if not server_running:
        print("Server not running. Skipping benchmark.")
        return
    
    latencies = []
    
    print(f"Running {num_requests} requests with batch_size={batch_size}...")
    
    for i in range(num_requests):
        # Prepare batch
        batch = x_test[:batch_size].tolist()
        payload = {"instances": batch}
        
        # Measure latency
        start = time.time()
        response = requests.post(f"{url}/predict", json=payload)
        latency = (time.time() - start) * 1000  # Convert to ms
        
        if response.status_code == 200:
            latencies.append(latency)
        else:
            print(f"Request {i+1} failed: {response.status_code}")
    
    if latencies:
        print(f"\nPerformance Metrics:")
        print(f"  Total requests: {len(latencies)}")
        print(f"  Mean latency: {np.mean(latencies):.2f} ms")
        print(f"  Median latency: {np.median(latencies):.2f} ms")
        print(f"  95th percentile: {np.percentile(latencies, 95):.2f} ms")
        print(f"  Min latency: {np.min(latencies):.2f} ms")
        print(f"  Max latency: {np.max(latencies):.2f} ms")
        print(f"  Throughput: {1000 / np.mean(latencies):.2f} requests/sec")

benchmark_api(SERVER_URL, num_requests=20, batch_size=1)

## 7. Error Handling and Validation

Let's test how the API handles invalid requests.

In [None]:
if server_running:
    # Test 1: Empty request
    print("Test 1: Empty request")
    response = requests.post(f"{SERVER_URL}/predict", json={})
    print(f"  Status: {response.status_code}")
    print(f"  Response: {response.json()}\n")
    
    # Test 2: Wrong shape
    print("Test 2: Wrong image shape")
    response = requests.post(f"{SERVER_URL}/predict", 
                            json={"instances": [[1, 2, 3]]})  # Too small
    print(f"  Status: {response.status_code}")
    print(f"  Response: {response.json()}\n")
    
    # Test 3: Invalid data type
    print("Test 3: Invalid data type")
    response = requests.post(f"{SERVER_URL}/predict", 
                            json={"instances": "not a list"})
    print(f"  Status: {response.status_code}")
    print(f"  Response: {response.json()}")
else:
    print("Server not running. Example error responses:")
    print("422 Unprocessable Entity: {\"detail\": \"Invalid input format\"}")

## 8. Production Deployment Notes

### 8.1 Running with Gunicorn (Production Server)

```bash
# Install gunicorn
pip install gunicorn

# Run with multiple workers
gunicorn app:app \
  --workers 4 \
  --worker-class uvicorn.workers.UvicornWorker \
  --bind 0.0.0.0:8000 \
  --timeout 60 \
  --access-logfile - \
  --error-logfile -
```

### 8.2 Environment Variables

```bash
export MODEL_PATH=/path/to/saved/model
export PREPROCESSING_PATH=/path/to/preprocessing.pkl
export LOG_LEVEL=INFO
export MAX_BATCH_SIZE=32
```

### 8.3 Monitoring and Logging

- Use structured logging (JSON format)
- Add request IDs for tracing
- Implement metrics endpoint for Prometheus
- Set up alerts for error rates and latency

### 8.4 Security Considerations

- Add API key authentication
- Implement rate limiting
- Use HTTPS in production
- Validate and sanitize all inputs
- Set maximum request size limits

## 9. Writing Unit Tests

Here's an example test suite using pytest:

In [None]:
# Create a test file
test_code = '''
"""Unit tests for FastAPI model serving."""
import pytest
from fastapi.testclient import TestClient
import numpy as np

# Assuming app.py is in the same directory
from app import app

client = TestClient(app)


def test_ping():
    """Test health check endpoint."""
    response = client.get("/ping")
    assert response.status_code == 200
    assert response.json()["status"] == "healthy"


def test_metadata():
    """Test metadata endpoint."""
    response = client.get("/metadata")
    assert response.status_code == 200
    data = response.json()
    assert "model_name" in data
    assert "input_shape" in data


def test_predict_valid():
    """Test prediction with valid input."""
    # Create a random 28x28 image
    image = np.random.randint(0, 255, (28, 28)).tolist()
    
    response = client.post("/predict", json={"instances": [image]})
    assert response.status_code == 200
    
    data = response.json()
    assert "predictions" in data
    assert len(data["predictions"]) == 1
    assert "class_name" in data["predictions"][0]


def test_predict_invalid_shape():
    """Test prediction with invalid input shape."""
    response = client.post("/predict", json={"instances": [[1, 2, 3]]})
    assert response.status_code == 400


def test_predict_empty():
    """Test prediction with empty input."""
    response = client.post("/predict", json={"instances": []})
    assert response.status_code == 400
'''

print("Example pytest test suite:")
print(test_code)

# Save to file
test_file = Path('../apps/fastapi_app/test_app.py')
with open(test_file, 'w') as f:
    f.write(test_code)
print(f"\n✓ Test file saved to: {test_file}")
print("\nRun tests with: pytest test_app.py -v")

## Summary

In this notebook, we:

1. ✓ Created a production-ready FastAPI application
2. ✓ Implemented health, metadata, and prediction endpoints
3. ✓ Added request validation with Pydantic
4. ✓ Tested the API with Python and cURL
5. ✓ Benchmarked API performance
6. ✓ Handled errors gracefully
7. ✓ Wrote unit tests with pytest

### Production Checklist:

- [ ] Use Gunicorn with multiple workers
- [ ] Add authentication (API keys, OAuth)
- [ ] Implement rate limiting
- [ ] Set up monitoring (Prometheus + Grafana)
- [ ] Add logging (structured JSON logs)
- [ ] Use HTTPS with valid certificates
- [ ] Set up CI/CD pipeline
- [ ] Implement load balancing
- [ ] Add caching for repeated requests
- [ ] Document API with OpenAPI/Swagger

---

## Extension Ideas

1. **Batch Optimization**: Add batching logic to process multiple requests together
2. **Model Versioning**: Support multiple model versions with A/B testing
3. **Async Processing**: Use background tasks for long-running predictions
4. **Caching**: Implement Redis caching for frequent predictions
5. **Metrics**: Add Prometheus metrics endpoint
6. **WebSocket**: Real-time streaming predictions

---

**Next**: `03_rag_langchain_gradio.ipynb` - Build a RAG system with Gradio UI