# Efficient Embeddings Generation

This notebook demonstrates optimization techniques for generating embeddings at scale.

## Learning Objectives

By the end of this notebook, you will be able to:
- Understand the impact of batch size on throughput and latency
- Implement efficient batching strategies for embedding generation
- Build caching systems to avoid redundant computations
- Measure and optimize embedding generation performance
- Choose appropriate batch sizes for different use cases

## Optimization Strategies

We'll explore two key techniques that dramatically improve performance:

1. **Batching** - Processing multiple inputs together rather than one at a time
   - Maximizes GPU/CPU utilization
   - Reduces overhead from model inference
   - Increases throughput significantly

2. **Caching** - Storing previously generated embeddings to avoid regeneration
   - Eliminates redundant computation
   - Reduces latency for repeated content
   - Critical for production systems

These optimizations are essential for building performant AI applications that rely on embeddings.

## Setup: Install Required Libraries

In [None]:
import os
os.environ['UV_LINK_MODE'] = 'copy'

# Install the required packages
!uv pip install accelerate==1.6.0 sentence-transformers==4.0.2

print("✓ Required libraries installed successfully!")

## Import Libraries and Load Model

In [None]:
from sentence_transformers import SentenceTransformer
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import hashlib
from functools import lru_cache
import pandas as pd

# Load lightweight model for demonstration
model_name = 'all-MiniLM-L6-v2'
model = SentenceTransformer(model_name)

# Generate example sentences for benchmarking
sentences = [
    f"This is a sample sentence for benchmarking embeddings generation {i}."
    for i in range(1000)
]

print("✓ Libraries imported and model loaded!")
print(f"  Model: {model_name}")
print(f"  Test dataset: {len(sentences):,} sentences")

## Batch Size Impact on Performance

**Batching** processes multiple inputs together instead of one-by-one, which dramatically improves efficiency. Let's measure how different batch sizes affect:

- **Throughput** - Embeddings generated per second (higher is better)
- **Latency** - Time to process each batch (important for real-time apps)

In [None]:
# Benchmark different batch sizes
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
results = []

print("Benchmarking batch sizes...")
print("=" * 80)

for batch_size in tqdm(batch_sizes, desc="Testing batch sizes"):
    start_time = time.time()
    
    # Process data in batches
    embeddings = []
    for i in range(0, len(sentences), batch_size):
        batch = sentences[i:i+batch_size]
        batch_embeddings = model.encode(batch)
        embeddings.extend(batch_embeddings)
    
    # Calculate performance metrics
    total_time = time.time() - start_time
    throughput = len(sentences) / total_time
    avg_latency = total_time / (len(sentences) / batch_size)
    
    results.append({
        'Batch Size': batch_size,
        'Total Time (s)': total_time,
        'Throughput (samples/s)': throughput,
        'Avg Batch Latency (s)': avg_latency
    })

# Display results
df = pd.DataFrame(results)
print("\n" + "=" * 80)
print("BATCH SIZE BENCHMARK RESULTS")
print("=" * 80)
print(df.to_string(index=False, float_format=lambda x: f"{x:.4f}"))
print("=" * 80)

## Visualize Performance Metrics

Let's visualize how batch size affects throughput and latency:

In [None]:
# Create visualization
plt.figure(figsize=(15, 6))

# Throughput plot
plt.subplot(1, 2, 1)
plt.plot(df['Batch Size'], df['Throughput (samples/s)'], 'o-', linewidth=2, 
        markersize=8, color='#2E86AB')
plt.xlabel('Batch Size', fontsize=12)
plt.ylabel('Throughput (samples/second)', fontsize=12)
plt.title('Batch Size vs Throughput\n(Higher is Better)', fontsize=13, weight='bold')
plt.grid(True, alpha=0.3)

# Mark optimal batch size
max_throughput_idx = df['Throughput (samples/s)'].idxmax()
optimal_batch = df.loc[max_throughput_idx, 'Batch Size']
optimal_throughput = df.loc[max_throughput_idx, 'Throughput (samples/s)']
plt.axvline(x=optimal_batch, color='red', linestyle='--', alpha=0.5, label=f'Optimal: {optimal_batch}')
plt.legend()

# Latency plot
plt.subplot(1, 2, 2)
plt.plot(df['Batch Size'], df['Avg Batch Latency (s)'], 'o-', linewidth=2, 
        markersize=8, color='#A23B72')
plt.xlabel('Batch Size', fontsize=12)
plt.ylabel('Average Batch Latency (seconds)', fontsize=12)
plt.title('Batch Size vs Latency\n(Lower is Better for Real-time)', fontsize=13, weight='bold')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("✓ Performance visualization complete!")
print(f"\nOptimal batch size for throughput: {optimal_batch} ({optimal_throughput:.1f} samples/s)")

### Key Observations

1. **Throughput increases with batch size** - Processing multiple inputs together is much more efficient than one-by-one

2. **Diminishing returns** - Beyond batch size 32-64, throughput gains plateau due to memory and computational limits

3. **Latency trade-off** - Larger batches improve throughput but increase per-batch latency
   - **Real-time apps** → Use smaller batches (1-16) for low latency
   - **Batch processing** → Use larger batches (64-256) for high throughput

4. **Optimal batch size depends on use case:**
   - Search engines: Small batches (fast response)
   - Document indexing: Large batches (high throughput)
   - Recommendation systems: Medium batches (balanced)

## Implementing Embedding Caches

In real-world applications, we often process the same text multiple times. **Caching** avoids redundant computation by storing previously generated embeddings.

Benefits of caching:
- **Eliminates redundant work** - No need to recompute embeddings for repeated text
- **Reduces latency** - Cached embeddings return instantly
- **Saves resources** - Less CPU/GPU usage for repeated content

### Simple Dictionary-Based Cache

A basic caching implementation using Python dictionaries:

In [None]:
class SimpleEmbeddingCache:
    """
    Simple embedding cache using a dictionary to store text→embedding mappings.
    Tracks cache hits and misses for performance monitoring.
    """
    def __init__(self, model):
        self.model = model
        self.cache = {}  # text_hash → embedding
        self.hits = 0
        self.misses = 0
    
    def _get_hash(self, text):
        """Generate a stable hash for text (used as cache key)"""
        return hashlib.md5(text.encode('utf-8')).hexdigest()
    
    def encode(self, texts, batch_size=32):
        """Encode texts using cache when available"""
        results = []
        texts_to_encode = []
        text_indices = []
        
        # Check cache for each text
        for i, text in enumerate(texts):
            text_hash = self._get_hash(text)
            if text_hash in self.cache:
                results.append((i, self.cache[text_hash]))
                self.hits += 1
            else:
                texts_to_encode.append(text)
                text_indices.append(i)
                self.misses += 1
        
        # Generate embeddings for cache misses (in batches)
        if texts_to_encode:
            new_embeddings = []
            for i in range(0, len(texts_to_encode), batch_size):
                batch = texts_to_encode[i:i+batch_size]
                batch_embeddings = model.encode(batch)
                new_embeddings.extend(batch_embeddings)
            
            # Update cache with new embeddings
            for i, text in enumerate(texts_to_encode):
                text_hash = self._get_hash(text)
                self.cache[text_hash] = new_embeddings[i]
                results.append((text_indices[i], new_embeddings[i]))
        
        # Sort by original index and return embeddings
        results.sort(key=lambda x: x[0])
        return np.array([emb for _, emb in results])
    
    def get_stats(self):
        """Return cache performance statistics"""
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0
        return {
            "hits": self.hits,
            "misses": self.misses,
            "total": total,
            "hit_rate": hit_rate,
            "cache_size": len(self.cache)
        }

print("✓ SimpleEmbeddingCache class defined successfully!")

### LRU Cache Implementation

Python's `functools.lru_cache` provides a **Least Recently Used** cache that automatically evicts old entries when full:

In [None]:
@lru_cache(maxsize=1024)
def hash_text(text):
    """Cache text hashes to avoid recomputing MD5"""
    return hashlib.md5(text.encode('utf-8')).hexdigest()

class LRUEmbeddingCache:
    """
    Embedding cache using Python's LRU cache (automatically evicts least recently used).
    Useful when memory is limited and you want automatic cache management.
    """
    def __init__(self, model, maxsize=1024):
        self.model = model
        self.encode_single = lru_cache(maxsize=maxsize)(self._encode_single)
        self.hits = 0
        self.misses = 0
        self.hash_to_text = {}
    
    def _encode_single(self, text_hash):
        """Generate embedding for single text (cached automatically by @lru_cache)"""
        self.misses += 1
        text = self.hash_to_text[text_hash]
        return self.model.encode([text])[0]
    
    def encode(self, texts, batch_size=32):
        """Encode texts using LRU cache"""
        self.hash_to_text = {}
        results = []
        
        for text in texts:
            text_hash = hash_text(text)
            self.hash_to_text[text_hash] = text
            
            # Check if cached
            cache_info_before = self.encode_single.cache_info()
            embedding = self.encode_single(text_hash)
            cache_info_after = self.encode_single.cache_info()
            
            # Update hit counter
            if cache_info_after.hits > cache_info_before.hits:
                self.hits += 1
            
            results.append(embedding)
        
        return np.array(results)
    
    def get_stats(self):
        """Return cache performance statistics"""
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0
        cache_info = self.encode_single.cache_info()
        return {
            "hits": self.hits,
            "misses": self.misses,
            "total": total,
            "hit_rate": hit_rate,
            "cache_info": cache_info
        }

print("✓ LRUEmbeddingCache class defined successfully!")

## Evaluate Cache Performance

Let's compare performance with and without caching on a dataset with repeated content:

In [None]:
# Create test dataset with repetition (simulates real-world usage)
test_data = []
for i in range(500):
    # Every 5th sentence is repeated from the first 100 sentences
    test_data.append(sentences[i % 100] if i % 5 == 0 else sentences[i])

print("=" * 80)
print("CACHE PERFORMANCE BENCHMARK")
print("=" * 80)
print(f"\nTest dataset: {len(test_data)} sentences")
print(f"Expected repeated sentences: ~{len([s for i, s in enumerate(test_data) if i % 5 == 0])} ({len([s for i, s in enumerate(test_data) if i % 5 == 0])/len(test_data)*100:.1f}%)")

# Benchmark 1: Without cache
print("\n" + "─" * 80)
print("1️⃣  WITHOUT CACHE (baseline)")
print("─" * 80)
start = time.time()
embeddings_no_cache = model.encode(test_data, batch_size=32)
no_cache_time = time.time() - start
print(f"✓ Processing time: {no_cache_time:.4f}s")

# Benchmark 2: First run with cache (all misses)
print("\n" + "─" * 80)
print("2️⃣  FIRST RUN WITH CACHE (cold cache)")
print("─" * 80)
cache = SimpleEmbeddingCache(model)
start = time.time()
embeddings_with_cache = cache.encode(test_data, batch_size=32)
first_run_time = time.time() - start
stats = cache.get_stats()
print(f"✓ Processing time: {first_run_time:.4f}s")
print(f"  Cache stats: {stats['hits']} hits, {stats['misses']} misses ({stats['hit_rate']*100:.1f}% hit rate)")

# Benchmark 3: Second run with populated cache
print("\n" + "─" * 80)
print("3️⃣  SECOND RUN WITH CACHE (warm cache)")
print("─" * 80)
second_cache = SimpleEmbeddingCache(model)
second_cache.cache = cache.cache  # Reuse populated cache
start = time.time()
embeddings_with_cache = second_cache.encode(test_data, batch_size=32)
second_run_time = time.time() - start
stats2 = second_cache.get_stats()
print(f"✓ Processing time: {second_run_time:.4f}s")
print(f"  Cache stats: {stats2['hits']} hits, {stats2['misses']} misses ({stats2['hit_rate']*100:.1f}% hit rate)")

# Performance summary
print("\n" + "=" * 80)
print("PERFORMANCE SUMMARY")
print("=" * 80)
print(f"  Without cache:         {no_cache_time:>8.4f}s  (baseline)")
print(f"  First run with cache:  {first_run_time:>8.4f}s  ({first_run_time/no_cache_time:>5.2f}x)")
print(f"  Second run with cache: {second_run_time:>8.4f}s  ({no_cache_time/second_run_time:>5.1f}x faster!)")
print("=" * 80)

print(f"\n✓ Caching provides {no_cache_time/second_run_time:.0f}x speedup for repeated content!")

## Summary

We've explored optimization techniques for efficient embedding generation at scale.

### Key Takeaways

1. **Batching dramatically improves throughput** - Processing 32-64 items together is 3-4x faster than one-by-one

2. **Batch size trades off throughput vs latency**:
   - Smaller batches (1-16) → Lower latency, good for real-time apps
   - Larger batches (64-256) → Higher throughput, ideal for batch processing
   - Optimal size depends on your specific use case and hardware

3. **Caching eliminates redundant computation** - Can provide 100-2000x speedup for repeated content

4. **Cache implementation choices**:
   - **Simple dict cache** - Unlimited size, manual management, best for known dataset sizes
   - **LRU cache** - Automatic eviction, memory-bounded, good for streaming data

5. **Combined optimizations multiply benefits** - Batching + caching together provide the best performance

### Production Best Practices

**When to use batching:**
- Bulk document indexing
- Offline data preprocessing
- Scheduled batch jobs
- High-volume API endpoints

**When to use caching:**
- User-generated content with duplicates
- Recommendation systems (popular items accessed frequently)
- Search engines (common queries)
- Real-time systems with repeated inputs

**When to use both:**
- Large-scale document processing
- Production RAG systems
- Search infrastructure
- Content moderation pipelines

### Performance Impact

| Scenario | Without Optimization | With Batching | With Batching + Caching |
|----------|---------------------|---------------|-------------------------|
| Cold start (no repeats) | 1x (baseline) | 3-4x faster | 3-4x faster |
| Warm cache (20% repeats) | 1x (baseline) | 3-4x faster | 5-10x faster |
| Hot cache (80% repeats) | 1x (baseline) | 3-4x faster | 20-50x faster |

### Next Steps

These optimization techniques form the foundation for building production-grade embedding systems:
- Vector databases (ChromaDB, Pinecone, Weaviate)
- Semantic search engines
- RAG (Retrieval Augmented Generation) systems
- Document similarity services
- Recommendation engines