# Chapter 40: Caching Strategies for AI Systems

Run this notebook directly in Google Colab - no local Python needed!

**Full code**: [GitHub](https://github.com/eduardd76/AI_for_networking_and_security_engineers/tree/main/CODE/Volume-3-Production-Systems/Chapter-40-Caching-Strategies)

## Setup

Install dependencies and configure API keys.

In [None]:
# Install dependencies
!pip install -q redis anthropic numpy

# Import and configure API key
import os
from getpass import getpass

# Check for Colab secrets first
try:
    from google.colab import userdata
    os.environ['ANTHROPIC_API_KEY'] = userdata.get('ANTHROPIC_API_KEY')
    print('✓ Using API keys from Colab secrets')
except:
    # Fall back to manual entry
    if 'ANTHROPIC_API_KEY' not in os.environ:
        os.environ['ANTHROPIC_API_KEY'] = getpass('Enter ANTHROPIC_API_KEY: ')
    print('✓ API keys configured')

print('\n✅ Setup complete! Ready to run examples.')
print('\n⚠️  Note: Redis examples require a running Redis instance.')
print('   For testing without Redis, examples will use in-memory storage.')

## Example 1: Basic Semantic Cache

Implement semantic caching for LLM responses using similarity matching.

In [None]:
import hashlib
import json
import time
from typing import Optional, Dict, Any, Tuple
import numpy as np

class SimpleSemanticCache:
    """
    In-memory semantic cache for LLM responses.
    Uses simple hash-based embeddings for demonstration.
    """
    
    def __init__(self, similarity_threshold: float = 0.95):
        self.cache = {}  # Store: {cache_key: {embedding, response, timestamp}}
        self.similarity_threshold = similarity_threshold
        self.hits = 0
        self.misses = 0
        self.total_latency_saved = 0.0
    
    def _get_embedding(self, text: str) -> np.ndarray:
        """Generate simple embedding from text (for demo purposes)"""
        # Normalize text
        normalized = text.lower().strip()
        
        # Create hash-based embedding
        hash_obj = hashlib.sha256(normalized.encode())
        hash_bytes = hash_obj.digest()
        
        # Convert to vector
        embedding = np.frombuffer(hash_bytes, dtype=np.uint8).astype(float)
        embedding = embedding / np.linalg.norm(embedding)
        
        return embedding
    
    def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
        """Calculate cosine similarity"""
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
        
        return dot_product / (norm1 * norm2)
    
    def _search_cache(self, embedding: np.ndarray) -> Optional[Dict[str, Any]]:
        """Search for similar cached entries"""
        best_match = None
        best_similarity = 0.0
        
        for cache_key, cached_data in self.cache.items():
            cached_embedding = np.array(cached_data["embedding"])
            similarity = self._cosine_similarity(embedding, cached_embedding)
            
            if similarity > best_similarity and similarity >= self.similarity_threshold:
                best_similarity = similarity
                best_match = cached_data.copy()
                best_match["similarity"] = similarity
        
        return best_match
    
    def get(self, prompt: str) -> Tuple[str, Dict[str, Any]]:
        """Get response from cache or simulate LLM call"""
        start_time = time.time()
        
        # Generate embedding
        query_embedding = self._get_embedding(prompt)
        
        # Search cache
        cached = self._search_cache(query_embedding)
        
        if cached:
            # Cache hit
            latency = time.time() - start_time
            self.hits += 1
            
            # Estimate latency saved (typical LLM call = 2-4s)
            latency_saved = 3.0 - latency
            self.total_latency_saved += latency_saved
            
            metadata = {
                "cache_hit": True,
                "latency": latency,
                "similarity": cached["similarity"],
                "cached_at": cached["timestamp"],
                "latency_saved": latency_saved
            }
            
            return cached["response"], metadata
        
        # Cache miss - simulate LLM call
        self.misses += 1
        time.sleep(0.2)  # Simulate LLM latency
        
        # Generate response
        response = f"Simulated response for: {prompt[:50]}..."
        
        # Store in cache
        cache_key = hashlib.md5(prompt.encode()).hexdigest()
        self.cache[cache_key] = {
            "prompt": prompt,
            "embedding": query_embedding.tolist(),
            "response": response,
            "timestamp": time.time()
        }
        
        latency = time.time() - start_time
        metadata = {
            "cache_hit": False,
            "latency": latency,
            "similarity": 0.0,
            "latency_saved": 0.0
        }
        
        return response, metadata
    
    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics"""
        total_requests = self.hits + self.misses
        hit_rate = (self.hits / total_requests * 100) if total_requests > 0 else 0
        
        return {
            "hits": self.hits,
            "misses": self.misses,
            "total_requests": total_requests,
            "hit_rate_percent": round(hit_rate, 2),
            "total_latency_saved_seconds": round(self.total_latency_saved, 2),
            "avg_latency_saved_per_hit": round(
                self.total_latency_saved / self.hits if self.hits > 0 else 0, 3
            ),
            "cache_size": len(self.cache)
        }

# Test semantic cache
cache = SimpleSemanticCache(similarity_threshold=0.95)

print("Testing Semantic Cache\n")
print("=" * 60)

# First query - cache miss
print("\n1. First query (cache miss):")
prompt1 = "What are BGP best practices?"
response1, meta1 = cache.get(prompt1)
print(f"   Prompt: {prompt1}")
print(f"   Cache Hit: {meta1['cache_hit']}")
print(f"   Latency: {meta1['latency']:.3f}s")

# Similar query - should hit cache
print("\n2. Similar query (should hit cache):")
prompt2 = "What are BGP best practices?"
response2, meta2 = cache.get(prompt2)
print(f"   Prompt: {prompt2}")
print(f"   Cache Hit: {meta2['cache_hit']}")
print(f"   Latency: {meta2['latency']:.3f}s")
print(f"   Similarity: {meta2.get('similarity', 0):.3f}")
print(f"   Latency Saved: {meta2.get('latency_saved', 0):.3f}s")

# Different query - cache miss
print("\n3. Different query (cache miss):")
prompt3 = "How to troubleshoot OSPF adjacencies?"
response3, meta3 = cache.get(prompt3)
print(f"   Prompt: {prompt3}")
print(f"   Cache Hit: {meta3['cache_hit']}")
print(f"   Latency: {meta3['latency']:.3f}s")

# Statistics
print("\n" + "=" * 60)
print("Cache Statistics:")
stats = cache.get_stats()
for key, value in stats.items():
    print(f"  {key}: {value}")

## Example 2: Multi-Tier Cache Architecture

Implement hot/warm/cold tiers with automatic promotion.

In [None]:
from enum import Enum
from dataclasses import dataclass
from typing import Dict, Any, Optional
import time

class CacheTier(Enum):
    """Cache tiers with different TTLs"""
    HOT = "hot"      # 5 minutes - active troubleshooting
    WARM = "warm"    # 1 hour - common queries
    COLD = "cold"    # 24 hours - reference data

@dataclass
class CacheEntry:
    """Cache entry with metadata"""
    key: str
    value: Any
    tier: CacheTier
    access_count: int = 0
    created_at: float = 0.0
    expires_at: float = 0.0

class MultiTierCache:
    """
    Multi-tier in-memory cache with automatic promotion.
    """
    
    TTL_MAP = {
        CacheTier.HOT: 300,      # 5 minutes
        CacheTier.WARM: 3600,    # 1 hour
        CacheTier.COLD: 86400    # 24 hours
    }
    
    def __init__(self):
        self.cache: Dict[str, CacheEntry] = {}
    
    def get(self, key: str) -> Optional[Dict[str, Any]]:
        """Get value and promote if frequently accessed"""
        entry = self.cache.get(key)
        
        if not entry:
            return None
        
        # Check if expired
        if time.time() > entry.expires_at:
            del self.cache[key]
            return None
        
        # Increment access count
        entry.access_count += 1
        
        # Check for promotion
        if entry.tier == CacheTier.COLD and entry.access_count >= 5:
            self._promote(entry, CacheTier.WARM)
            print(f"  → Promoted {key} to WARM tier")
        elif entry.tier == CacheTier.WARM and entry.access_count >= 10:
            self._promote(entry, CacheTier.HOT)
            print(f"  → Promoted {key} to HOT tier")
        
        return {
            "value": entry.value,
            "tier": entry.tier.value,
            "access_count": entry.access_count,
            "ttl_remaining": entry.expires_at - time.time()
        }
    
    def set(self, key: str, value: Any, tier: CacheTier = CacheTier.WARM):
        """Store value in specified tier"""
        ttl = self.TTL_MAP[tier]
        now = time.time()
        
        entry = CacheEntry(
            key=key,
            value=value,
            tier=tier,
            access_count=0,
            created_at=now,
            expires_at=now + ttl
        )
        
        self.cache[key] = entry
    
    def _promote(self, entry: CacheEntry, new_tier: CacheTier):
        """Promote entry to higher tier"""
        entry.tier = new_tier
        ttl = self.TTL_MAP[new_tier]
        entry.expires_at = time.time() + ttl
    
    def get_tier_stats(self) -> Dict[str, Any]:
        """Get statistics for each tier"""
        stats = {tier.value: {"count": 0, "total_accesses": 0} for tier in CacheTier}
        
        for entry in self.cache.values():
            tier_stats = stats[entry.tier.value]
            tier_stats["count"] += 1
            tier_stats["total_accesses"] += entry.access_count
        
        return stats

# Test multi-tier cache
cache = MultiTierCache()

print("Testing Multi-Tier Cache\n")
print("=" * 60)

# Store entry in warm tier
print("\n1. Storing entry in WARM tier:")
cache.set("bgp_best_practices", "BGP best practices content...", CacheTier.WARM)
print("   Stored: bgp_best_practices")

# Access multiple times to trigger promotion
print("\n2. Accessing entry multiple times:")
for i in range(12):
    result = cache.get("bgp_best_practices")
    if result:
        print(f"   Access {i+1}: tier={result['tier']}, count={result['access_count']}")

# Show tier statistics
print("\n" + "=" * 60)
print("Tier Statistics:")
stats = cache.get_tier_stats()
for tier, tier_stats in stats.items():
    if tier_stats['count'] > 0:
        avg_accesses = tier_stats['total_accesses'] / tier_stats['count']
        print(f"  {tier.upper()}: {tier_stats['count']} entries, "
              f"avg {avg_accesses:.1f} accesses")

## Example 3: Cache Key Design with Normalization

Design cache keys to maximize hit rate through parameter normalization.

In [None]:
import hashlib
import json
from typing import Dict, Any
from datetime import datetime

class CacheKeyGenerator:
    """
    Generate optimized cache keys for network queries.
    """
    
    def __init__(self, version: str = "v1"):
        self.version = version
    
    def generate_key(
        self,
        query_type: str,
        device_id: str,
        parameters: Dict[str, Any],
        time_sensitivity: str = "static"
    ) -> str:
        """
        Generate cache key with normalization.
        
        Args:
            query_type: Type of query (e.g., 'interface_status')
            device_id: Device identifier
            parameters: Query parameters
            time_sensitivity: 'real-time', 'near-real-time', or 'static'
        """
        # Normalize parameters
        normalized_params = self._normalize_parameters(parameters)
        
        # Generate parameter hash
        param_hash = self._hash_parameters(normalized_params)
        
        # Generate time bucket
        time_bucket = self._get_time_bucket(time_sensitivity)
        
        # Build key
        key = f"{self.version}:{query_type}:{device_id}:{param_hash}:{time_bucket}"
        
        return key
    
    def _normalize_parameters(self, params: Dict[str, Any]) -> Dict[str, Any]:
        """Normalize parameters to maximize cache hits"""
        normalized = {}
        
        for key in sorted(params.keys()):
            value = params[key]
            
            if isinstance(value, str):
                normalized[key] = value.lower().strip()
            elif isinstance(value, (int, float)):
                normalized[key] = round(value, 2) if isinstance(value, float) else value
            elif isinstance(value, list):
                normalized[key] = sorted(value)
            else:
                normalized[key] = value
        
        return normalized
    
    def _hash_parameters(self, params: Dict[str, Any]) -> str:
        """Generate short hash from parameters"""
        param_str = json.dumps(params, sort_keys=True)
        hash_obj = hashlib.md5(param_str.encode())
        return hash_obj.hexdigest()[:8]
    
    def _get_time_bucket(self, sensitivity: str) -> str:
        """Generate time bucket based on sensitivity"""
        now = datetime.now()
        
        if sensitivity == "real-time":
            return now.strftime("%Y-%m-%d-%H-%M")  # Minute-level
        elif sensitivity == "near-real-time":
            return now.strftime("%Y-%m-%d-%H")     # Hour-level
        else:  # static
            return "static"

# Test cache key generation
key_gen = CacheKeyGenerator(version="v1")

print("Testing Cache Key Generation\n")
print("=" * 60)

# Example 1: Real-time interface status
print("\n1. Real-time interface status:")
key1 = key_gen.generate_key(
    query_type="interface_status",
    device_id="rtr-001",
    parameters={
        "interfaces": ["GigabitEthernet0/0", "GigabitEthernet0/1"],
        "include_stats": True
    },
    time_sensitivity="real-time"
)
print(f"   Key: {key1}")

# Example 2: Near-real-time BGP summary
print("\n2. Near-real-time BGP summary:")
key2 = key_gen.generate_key(
    query_type="bgp_summary",
    device_id="site-nyc",
    parameters={
        "peer_type": "ebgp",
        "state": "established"
    },
    time_sensitivity="near-real-time"
)
print(f"   Key: {key2}")

# Example 3: Static reference query
print("\n3. Static reference query:")
key3 = key_gen.generate_key(
    query_type="best_practices",
    device_id="none",
    parameters={
        "topic": "OSPF Design",
        "protocol": "ospf"
    },
    time_sensitivity="static"
)
print(f"   Key: {key3}")

# Test parameter normalization
print("\n" + "=" * 60)
print("Parameter Normalization Test:")

# These should generate the SAME key
params_a = {"Interface": "Gi0/0", "Status": "UP"}
params_b = {"status": "up", "interface": "gi0/0"}  # Different order/case

key_a = key_gen.generate_key("query", "rtr-001", params_a, "static")
key_b = key_gen.generate_key("query", "rtr-001", params_b, "static")

print(f"\nKey A (original): {key_a}")
print(f"Key B (normalized): {key_b}")
print(f"\nKeys match: {key_a == key_b} ✓" if key_a == key_b else "Keys don't match ✗")

## Example 4: Cache Performance Metrics

Measure cache effectiveness with comprehensive metrics.

In [None]:
from dataclasses import dataclass
from typing import Dict, Any, List
import time
import random

@dataclass
class CacheMetrics:
    """Cache performance metrics"""
    total_requests: int
    cache_hits: int
    cache_misses: int
    hit_rate_percent: float
    avg_latency_hit_ms: float
    avg_latency_miss_ms: float
    total_latency_saved_seconds: float
    estimated_cost_saved_dollars: float

class CacheMonitor:
    """
    Monitor cache performance and calculate savings.
    """
    
    def __init__(
        self,
        cost_per_llm_call: float = 0.10,
        cost_per_cache_hit: float = 0.0001
    ):
        self.cost_per_llm_call = cost_per_llm_call
        self.cost_per_cache_hit = cost_per_cache_hit
        
        # Tracking
        self.hits = 0
        self.misses = 0
        self.hit_latencies = []
        self.miss_latencies = []
        self.latency_saved = 0.0
    
    def record_hit(self, latency_ms: float, latency_saved_ms: float):
        """Record cache hit"""
        self.hits += 1
        self.hit_latencies.append(latency_ms)
        self.latency_saved += latency_saved_ms / 1000
    
    def record_miss(self, latency_ms: float):
        """Record cache miss"""
        self.misses += 1
        self.miss_latencies.append(latency_ms)
    
    def get_metrics(self) -> CacheMetrics:
        """Get current cache metrics"""
        total = self.hits + self.misses
        hit_rate = (self.hits / total * 100) if total > 0 else 0
        
        avg_hit_latency = sum(self.hit_latencies) / len(self.hit_latencies) if self.hit_latencies else 0
        avg_miss_latency = sum(self.miss_latencies) / len(self.miss_latencies) if self.miss_latencies else 0
        
        # Calculate cost savings
        cost_saved = self.hits * (self.cost_per_llm_call - self.cost_per_cache_hit)
        
        return CacheMetrics(
            total_requests=total,
            cache_hits=self.hits,
            cache_misses=self.misses,
            hit_rate_percent=round(hit_rate, 2),
            avg_latency_hit_ms=round(avg_hit_latency, 2),
            avg_latency_miss_ms=round(avg_miss_latency, 2),
            total_latency_saved_seconds=round(self.latency_saved, 2),
            estimated_cost_saved_dollars=round(cost_saved, 2)
        )
    
    def generate_report(self) -> str:
        """Generate human-readable report"""
        metrics = self.get_metrics()
        
        report = []
        report.append("=" * 60)
        report.append("CACHE PERFORMANCE REPORT")
        report.append("=" * 60)
        report.append("")
        report.append("Overall Metrics:")
        report.append(f"  Total Requests: {metrics.total_requests:,}")
        report.append(f"  Cache Hits: {metrics.cache_hits:,}")
        report.append(f"  Cache Misses: {metrics.cache_misses:,}")
        report.append(f"  Hit Rate: {metrics.hit_rate_percent}%")
        report.append("")
        report.append("Performance:")
        report.append(f"  Avg Hit Latency: {metrics.avg_latency_hit_ms}ms")
        report.append(f"  Avg Miss Latency: {metrics.avg_latency_miss_ms}ms")
        report.append(f"  Latency Saved: {metrics.total_latency_saved_seconds}s")
        report.append("")
        report.append("Cost Savings:")
        report.append(f"  Estimated Savings: ${metrics.estimated_cost_saved_dollars}")
        report.append(f"  Cost per LLM call: ${self.cost_per_llm_call}")
        report.append(f"  Cost per cache hit: ${self.cost_per_cache_hit}")
        report.append("")
        report.append("=" * 60)
        
        return "\n".join(report)

# Simulate cache traffic
monitor = CacheMonitor(cost_per_llm_call=0.10, cost_per_cache_hit=0.0001)

print("Simulating cache traffic...\n")

# Simulate 100 requests with 70% hit rate
for i in range(100):
    if random.random() < 0.70:  # 70% cache hit
        monitor.record_hit(
            latency_ms=random.uniform(20, 40),
            latency_saved_ms=random.uniform(2900, 3100)
        )
    else:  # 30% cache miss
        monitor.record_miss(
            latency_ms=random.uniform(2800, 3200)
        )

# Generate report
print(monitor.generate_report())

# Calculate monthly projections
metrics = monitor.get_metrics()
monthly_requests = 100000
projected_hit_rate = metrics.hit_rate_percent / 100

cost_without_cache = monthly_requests * 0.10
cost_with_cache = (
    (monthly_requests * projected_hit_rate * 0.0001) +
    (monthly_requests * (1 - projected_hit_rate) * 0.10)
)
monthly_savings = cost_without_cache - cost_with_cache

print("\nProjected Monthly Savings (100k requests):")
print(f"  Without cache: ${cost_without_cache:,.2f}")
print(f"  With cache: ${cost_with_cache:,.2f}")
print(f"  Monthly savings: ${monthly_savings:,.2f}")
print(f"  Annual savings: ${monthly_savings * 12:,.2f}")

## Next Steps

- Full code: [Chapter 40 on GitHub](https://github.com/eduardd76/AI_for_networking_and_security_engineers/tree/main/CODE/Volume-3-Production-Systems/Chapter-40-Caching-Strategies)
- Learn more: [vExpertAI.com](https://vexpertai.com)
- Author: Eduard Dulharu ([@eduardd76](https://github.com/eduardd76))

**Production Implementation:**
- Deploy Redis for distributed caching
- Use pgvector for semantic similarity search
- Implement cache warming strategies
- Monitor hit rates and cost savings
- Configure TTLs based on data volatility