# Chapter 39: API Gateway & Load Balancing

Run this notebook directly in Google Colab - no local Python needed!

**Full code**: [GitHub](https://github.com/eduardd76/AI_for_networking_and_security_engineers/tree/main/CODE/Volume-3-Production-Systems/Chapter-39-API-Gateway-Load-Balancing)

## Setup

Install dependencies and configure API keys.

In [None]:
# Install dependencies
!pip install -q requests flask anthropic pybreaker

# Import and configure API key
import os
from getpass import getpass

# Check for Colab secrets first
try:
    from google.colab import userdata
    os.environ['ANTHROPIC_API_KEY'] = userdata.get('ANTHROPIC_API_KEY')
    try:
        os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
    except:
        pass
    print('✓ Using API keys from Colab secrets')
except:
    # Fall back to manual entry
    if 'ANTHROPIC_API_KEY' not in os.environ:
        os.environ['ANTHROPIC_API_KEY'] = getpass('Enter ANTHROPIC_API_KEY: ')
    print('✓ API keys configured')

print('\n✅ Setup complete! Ready to run examples.')

## Example 1: Simple Load Balancer

Implement round-robin load balancing for API requests.

In [None]:
import requests
from typing import List, Dict, Any
import time

class SimpleLoadBalancer:
    """
    Round-robin load balancer for API requests.
    Distributes requests evenly across multiple backends.
    """
    
    def __init__(self, backends: List[str]):
        self.backends = backends
        self.current_index = 0
        self.request_count = {backend: 0 for backend in backends}
    
    def get_next_backend(self) -> str:
        """Get next backend using round-robin"""
        backend = self.backends[self.current_index]
        self.current_index = (self.current_index + 1) % len(self.backends)
        self.request_count[backend] += 1
        return backend
    
    def send_request(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
        """Send request to next backend"""
        backend = self.get_next_backend()
        url = f"{backend}{endpoint}"
        
        start_time = time.time()
        
        try:
            # Simulate backend call
            time.sleep(0.1)  # Simulate network latency
            
            return {
                "backend": backend,
                "status": "success",
                "data": data,
                "latency_ms": (time.time() - start_time) * 1000
            }
        except Exception as e:
            return {
                "backend": backend,
                "status": "error",
                "error": str(e),
                "latency_ms": (time.time() - start_time) * 1000
            }
    
    def get_stats(self) -> Dict[str, Any]:
        """Get load balancer statistics"""
        total_requests = sum(self.request_count.values())
        return {
            "total_requests": total_requests,
            "backends": len(self.backends),
            "distribution": {
                backend: {
                    "requests": count,
                    "percentage": (count / total_requests * 100) if total_requests > 0 else 0
                }
                for backend, count in self.request_count.items()
            }
        }

# Test load balancer
backends = [
    "http://backend-1:8000",
    "http://backend-2:8000",
    "http://backend-3:8000"
]

lb = SimpleLoadBalancer(backends)

print("Testing round-robin load balancing...\n")

# Send 9 requests
for i in range(9):
    result = lb.send_request("/v1/chat/completions", {"query": f"request_{i}"})
    print(f"Request {i+1}: routed to {result['backend']} ({result['latency_ms']:.1f}ms)")

# Show statistics
print("\nLoad Balancer Statistics:")
stats = lb.get_stats()
print(f"Total requests: {stats['total_requests']}")
print(f"Backends: {stats['backends']}")
print("\nDistribution:")
for backend, info in stats['distribution'].items():
    print(f"  {backend}: {info['requests']} requests ({info['percentage']:.1f}%)")

## Example 2: Circuit Breaker Pattern

Implement circuit breaker to prevent cascading failures.

In [None]:
from pybreaker import CircuitBreaker
import time
import random

# Configure circuit breaker
breaker = CircuitBreaker(
    fail_max=5,           # Open after 5 failures
    timeout_duration=30,  # Stay open for 30 seconds
    name='api_breaker'
)

@breaker
def call_api_with_breaker(backend_id: int, request_id: int) -> Dict[str, Any]:
    """
    Call API with circuit breaker protection.
    Simulates failures to demonstrate circuit breaker behavior.
    """
    # Simulate failures for first 5 requests
    if request_id < 5:
        raise Exception(f"Backend {backend_id} error: Service unavailable")
    
    # Simulate success after that
    return {
        "backend_id": backend_id,
        "status": "success",
        "response": f"Request {request_id} processed successfully"
    }

print("Testing circuit breaker pattern...\n")

# Test circuit breaker
for i in range(10):
    try:
        result = call_api_with_breaker(1, i)
        print(f"Request {i}: SUCCESS - {result['response']}")
    except Exception as e:
        error_type = type(e).__name__
        if error_type == "CircuitBreakerError":
            print(f"Request {i}: CIRCUIT BREAKER OPEN (fast fail)")
        else:
            print(f"Request {i}: FAILED - {str(e)}")
    
    time.sleep(0.2)

print("\nCircuit Breaker State:")
print(f"  State: {breaker.current_state}")
print(f"  Failure count: {breaker.fail_counter}")
print(f"  Expected close time: {breaker.opened_at + breaker.timeout_duration if breaker.opened_at else 'N/A'}")

## Example 3: Retry Logic with Exponential Backoff

Implement intelligent retry strategy for transient failures.

In [None]:
import time
import random
from typing import Dict, Any, Optional

def call_api_with_retry(
    url: str,
    data: Dict[str, Any],
    max_retries: int = 5,
    base_delay: float = 1.0
) -> Dict[str, Any]:
    """
    Call API with exponential backoff retry logic.
    
    Args:
        url: API endpoint
        data: Request data
        max_retries: Maximum retry attempts
        base_delay: Initial delay in seconds
    """
    for attempt in range(max_retries):
        try:
            # Simulate API call
            # Fail first 3 attempts, succeed on 4th
            if attempt < 3:
                # Simulate different error types
                error_type = random.choice([503, 429, "timeout"])
                
                if error_type == 429:
                    # Rate limited - use Retry-After header
                    retry_after = 2.0
                    wait_time = retry_after
                    print(f"Attempt {attempt + 1}: Rate limited (429)")
                    print(f"  Waiting {wait_time:.2f}s before retry")
                elif error_type == 503:
                    # Server error - exponential backoff
                    wait_time = base_delay * (2 ** attempt)
                    jitter = random.uniform(0, wait_time * 0.1)
                    total_wait = wait_time + jitter
                    print(f"Attempt {attempt + 1}: Server error (503)")
                    print(f"  Waiting {total_wait:.2f}s before retry")
                    wait_time = total_wait
                else:
                    # Timeout - exponential backoff
                    wait_time = base_delay * (2 ** attempt)
                    jitter = random.uniform(0, wait_time * 0.1)
                    total_wait = wait_time + jitter
                    print(f"Attempt {attempt + 1}: Timeout")
                    print(f"  Waiting {total_wait:.2f}s before retry")
                    wait_time = total_wait
                
                time.sleep(wait_time)
                continue
            
            # Success on 4th attempt
            print(f"Attempt {attempt + 1}: Success")
            return {
                "status": "success",
                "data": data,
                "attempts": attempt + 1
            }
        
        except Exception as e:
            print(f"Attempt {attempt + 1}: Exception - {str(e)}")
            if attempt < max_retries - 1:
                wait_time = base_delay * (2 ** attempt)
                print(f"  Waiting {wait_time:.2f}s before retry")
                time.sleep(wait_time)
                continue
    
    # All retries exhausted
    return {
        "status": "failed",
        "error": f"Failed after {max_retries} attempts",
        "attempts": max_retries
    }

print("Testing exponential backoff retry logic...\n")

start_time = time.time()
result = call_api_with_retry(
    url="https://api.example.com/v1/chat",
    data={"query": "test"},
    max_retries=5,
    base_delay=1.0
)
total_time = time.time() - start_time

print(f"\nFinal Result:")
print(f"  Status: {result['status']}")
print(f"  Attempts: {result['attempts']}")
print(f"  Total time: {total_time:.2f}s")

if result['status'] == 'success':
    print(f"  Data: {result['data']}")

## Example 4: Request Routing and Health Checks

Implement backend health monitoring with automatic failover.

In [None]:
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from datetime import datetime, timedelta
import time
import random

@dataclass
class Backend:
    """Backend server configuration"""
    url: str
    weight: int = 1
    healthy: bool = True
    fail_count: int = 0
    max_fails: int = 3
    last_check: Optional[datetime] = None
    response_time_ms: float = 0.0

class HealthAwareLoadBalancer:
    """
    Load balancer with health checks and automatic failover.
    Marks backends as unhealthy after consecutive failures.
    """
    
    def __init__(self, backends: List[Backend], health_check_interval: int = 10):
        self.backends = backends
        self.health_check_interval = health_check_interval
        self.current_index = 0
        self.request_count = {backend.url: 0 for backend in backends}
    
    def check_backend_health(self, backend: Backend) -> bool:
        """Check if backend is healthy"""
        try:
            # Simulate health check
            # Randomly mark some backends as unhealthy for demo
            if random.random() < 0.2:  # 20% chance of failure
                backend.fail_count += 1
                if backend.fail_count >= backend.max_fails:
                    backend.healthy = False
                return False
            else:
                backend.fail_count = 0
                backend.healthy = True
                backend.last_check = datetime.now()
                return True
        except Exception as e:
            backend.fail_count += 1
            if backend.fail_count >= backend.max_fails:
                backend.healthy = False
            return False
    
    def get_healthy_backends(self) -> List[Backend]:
        """Get list of healthy backends"""
        return [b for b in self.backends if b.healthy]
    
    def get_next_backend(self) -> Optional[Backend]:
        """Get next healthy backend using weighted round-robin"""
        healthy_backends = self.get_healthy_backends()
        
        if not healthy_backends:
            return None
        
        # Simple round-robin (can be extended to weighted)
        backend = healthy_backends[self.current_index % len(healthy_backends)]
        self.current_index += 1
        self.request_count[backend.url] += 1
        
        return backend
    
    def send_request(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
        """Send request to healthy backend with automatic failover"""
        max_attempts = len(self.backends)
        
        for attempt in range(max_attempts):
            backend = self.get_next_backend()
            
            if not backend:
                return {
                    "status": "error",
                    "error": "No healthy backends available"
                }
            
            start_time = time.time()
            
            try:
                # Simulate backend call
                time.sleep(0.05)  # Simulate network latency
                
                # Simulate occasional failures
                if random.random() < 0.1:  # 10% failure rate
                    raise Exception("Backend error")
                
                response_time = (time.time() - start_time) * 1000
                backend.response_time_ms = response_time
                
                return {
                    "status": "success",
                    "backend": backend.url,
                    "data": data,
                    "response_time_ms": response_time,
                    "attempt": attempt + 1
                }
            
            except Exception as e:
                backend.fail_count += 1
                if backend.fail_count >= backend.max_fails:
                    backend.healthy = False
                    print(f"  Backend {backend.url} marked unhealthy")
                
                if attempt < max_attempts - 1:
                    print(f"  Attempt {attempt + 1} failed, trying next backend...")
                    continue
        
        return {
            "status": "error",
            "error": "All backends failed"
        }
    
    def get_stats(self) -> Dict[str, Any]:
        """Get load balancer statistics"""
        healthy = self.get_healthy_backends()
        
        return {
            "total_backends": len(self.backends),
            "healthy_backends": len(healthy),
            "unhealthy_backends": len(self.backends) - len(healthy),
            "backends": [
                {
                    "url": b.url,
                    "healthy": b.healthy,
                    "fail_count": b.fail_count,
                    "requests": self.request_count[b.url],
                    "avg_response_ms": b.response_time_ms
                }
                for b in self.backends
            ]
        }

# Test health-aware load balancer
backends = [
    Backend(url="http://backend-1:8000", weight=3),
    Backend(url="http://backend-2:8000", weight=2),
    Backend(url="http://backend-3:8000", weight=1)
]

lb = HealthAwareLoadBalancer(backends)

print("Testing health-aware load balancing...\n")

# Send 15 requests
success_count = 0
for i in range(15):
    result = lb.send_request("/v1/chat", {"query": f"request_{i}"})
    
    if result['status'] == 'success':
        print(f"Request {i+1}: SUCCESS via {result['backend']} ({result['response_time_ms']:.1f}ms)")
        success_count += 1
    else:
        print(f"Request {i+1}: FAILED - {result['error']}")

# Show statistics
print("\nLoad Balancer Statistics:")
stats = lb.get_stats()
print(f"Total backends: {stats['total_backends']}")
print(f"Healthy backends: {stats['healthy_backends']}")
print(f"Unhealthy backends: {stats['unhealthy_backends']}")
print(f"Success rate: {success_count}/15 ({success_count/15*100:.1f}%)")
print("\nBackend Details:")
for backend in stats['backends']:
    status = "✓ HEALTHY" if backend['healthy'] else "✗ UNHEALTHY"
    print(f"  {backend['url']}: {status}")
    print(f"    Requests: {backend['requests']}")
    print(f"    Fail count: {backend['fail_count']}")
    print(f"    Avg response: {backend['avg_response_ms']:.1f}ms")

## Next Steps

- Full code: [Chapter 39 on GitHub](https://github.com/eduardd76/AI_for_networking_and_security_engineers/tree/main/CODE/Volume-3-Production-Systems/Chapter-39-API-Gateway-Load-Balancing)
- Learn more: [vExpertAI.com](https://vexpertai.com)
- Author: Eduard Dulharu ([@eduardd76](https://github.com/eduardd76))

**Production Deployment:**
- Use Nginx for production load balancing
- Implement SSL/TLS termination
- Configure health check probes
- Set up monitoring with Prometheus
- Deploy with Docker Swarm or Kubernetes