# Module 10.3: Cost Optimization

**Goal**: Optimize deployment costs through batching, quantization, and caching

**Time**: 75 minutes

**Concepts Covered**:
- Cost breakdown analysis
- Batch inference strategies
- Dynamic batching implementation
- Model quantization for cost reduction
- Response caching
- Auto-scaling strategies

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
# Cost Breakdown Analysis
def calculate_inference_cost(
    model_size_gb,
    requests_per_hour,
    avg_tokens_per_request,
    gpu_cost_per_hour=0.50,
    tokens_per_second=50
):
    """Calculate inference costs"""
    # Compute time needed
    total_tokens = requests_per_hour * avg_tokens_per_request
    compute_hours = total_tokens / (tokens_per_second * 3600)
    
    # GPU costs
    gpu_cost = compute_hours * gpu_cost_per_hour
    
    # Memory costs (if using cloud storage)
    memory_cost_per_hour = model_size_gb * 0.01  # $0.01/GB/hour
    
    total_cost_per_hour = gpu_cost + memory_cost_per_hour
    total_cost_per_month = total_cost_per_hour * 24 * 30
    cost_per_request = total_cost_per_hour / requests_per_hour
    
    return {
        "gpu_cost_per_hour": gpu_cost,
        "memory_cost_per_hour": memory_cost_per_hour,
        "total_cost_per_hour": total_cost_per_hour,
        "total_cost_per_month": total_cost_per_month,
        "cost_per_request": cost_per_request,
    }

# Example calculation
costs = calculate_inference_cost(
    model_size_gb=4.5,  # SmolLM-1.7B in FP16
    requests_per_hour=1000,
    avg_tokens_per_request=200,
)

print("Cost Breakdown:")
for key, value in costs.items():
    if "cost" in key:
        print(f"  {key}: ${value:.4f}")

In [None]:
# Dynamic Batching
import asyncio
from collections import deque
from typing import List, Callable

class DynamicBatcher:
    def __init__(self, batch_size: int, max_wait_ms: int = 50):
        self.batch_size = batch_size
        self.max_wait_ms = max_wait_ms
        self.queue = deque()
        self.processing = False
    
    async def add_request(self, request, callback: Callable):
        """Add request to batch queue"""
        self.queue.append((request, callback))
        
        if len(self.queue) >= self.batch_size:
            await self._process_batch()
        elif not self.processing:
            asyncio.create_task(self._wait_and_process())
    
    async def _wait_and_process(self):
        """Wait for max_wait_ms then process batch"""
        self.processing = True
        await asyncio.sleep(self.max_wait_ms / 1000)
        
        if self.queue:
            await self._process_batch()
        
        self.processing = False
    
    async def _process_batch(self):
        """Process current batch"""
        if not self.queue:
            return
        
        batch = []
        callbacks = []
        
        while self.queue and len(batch) < self.batch_size:
            request, callback = self.queue.popleft()
            batch.append(request)
            callbacks.append(callback)
        
        # Process batch (simulated)
        results = await self._inference_batch(batch)
        
        # Call callbacks
        for callback, result in zip(callbacks, results):
            callback(result)
    
    async def _inference_batch(self, batch):
        """Run inference on batch"""
        # Simulated batch inference
        await asyncio.sleep(0.1)
        return [f"result_{i}" for i in range(len(batch))]

print("Dynamic batching:")
print("- Collects requests over time window")
print("- Processes when batch is full or timeout")
print("- Improves throughput and reduces costs")

In [None]:
# Response Caching
from functools import lru_cache
import hashlib
import json

class ResponseCache:
    def __init__(self, max_size=1000):
        self.cache = {}
        self.max_size = max_size
        self.hits = 0
        self.misses = 0
    
    def _cache_key(self, prompt, **kwargs):
        """Generate cache key from prompt and parameters"""
        key_data = {"prompt": prompt, **kwargs}
        key_str = json.dumps(key_data, sort_keys=True)
        return hashlib.md5(key_str.encode()).hexdigest()
    
    def get(self, prompt, **kwargs):
        """Get cached response"""
        key = self._cache_key(prompt, **kwargs)
        
        if key in self.cache:
            self.hits += 1
            return self.cache[key]
        
        self.misses += 1
        return None
    
    def set(self, prompt, response, **kwargs):
        """Cache response"""
        key = self._cache_key(prompt, **kwargs)
        
        if len(self.cache) >= self.max_size:
            # Remove oldest entry (simple FIFO)
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        
        self.cache[key] = response
    
    def hit_rate(self):
        """Calculate cache hit rate"""
        total = self.hits + self.misses
        if total == 0:
            return 0.0
        return self.hits / total

# Example usage
cache = ResponseCache()

# Cache a response
cache.set("What is AI?", "AI is artificial intelligence...", temperature=0.7)

# Retrieve from cache
cached = cache.get("What is AI?", temperature=0.7)
print(f"Cache hit: {cached is not None}")
print(f"Hit rate: {cache.hit_rate():.2%}")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.