<a href="https://colab.research.google.com/github/dimitarpg13/rag_architectures_and_concepts/blob/main/src/examples/ann/comparisons/approximate_nearest_neighbor_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Approximate Nearest Neighbor (ANN) Algorithms: Implementation and Analysis

This notebook provides a comprehensive analysis of Approximate Nearest Neighbor algorithms, comparing:
- **Exact KNN** (baseline using brute force)
- **Annoy** (Approximate Nearest Neighbors Oh Yeah - Spotify's library)
- **FAISS** (Facebook AI Similarity Search)
- **KD-Tree** (scikit-learn's tree-based approach)

We'll measure:
- Recall@K accuracy
- Query latency
- Index build time
- Memory usage patterns

In [None]:
# Install required packages
!pip install annoy faiss-cpu scikit-learn matplotlib numpy pandas seaborn -q

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from time import time
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')

# ANN libraries
from annoy import AnnoyIndex
import faiss
from sklearn.neighbors import NearestNeighbors, KDTree

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ All imports successful")

## 1. Data Generation

Generate synthetic high-dimensional data to simulate vector embeddings (common in NLP, computer vision, recommendation systems).

In [None]:
def generate_data(n_samples: int = 10000,
                  n_dimensions: int = 128,
                  n_queries: int = 100,
                  random_state: int = 42) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate synthetic high-dimensional data for ANN testing.

    Args:
        n_samples: Number of vectors in the database
        n_dimensions: Dimensionality of each vector
        n_queries: Number of query vectors
        random_state: Random seed for reproducibility

    Returns:
        Tuple of (database_vectors, query_vectors)
    """
    np.random.seed(random_state)

    # Generate database vectors (normalized)
    database = np.random.randn(n_samples, n_dimensions).astype('float32')
    database = database / np.linalg.norm(database, axis=1, keepdims=True)

    # Generate query vectors (normalized)
    queries = np.random.randn(n_queries, n_dimensions).astype('float32')
    queries = queries / np.linalg.norm(queries, axis=1, keepdims=True)

    return database, queries

# Generate data
N_SAMPLES = 50000
N_DIMENSIONS = 128
N_QUERIES = 200
K = 10  # Number of neighbors to retrieve

database_vectors, query_vectors = generate_data(
    n_samples=N_SAMPLES,
    n_dimensions=N_DIMENSIONS,
    n_queries=N_QUERIES
)

print(f"Database shape: {database_vectors.shape}")
print(f"Query shape: {query_vectors.shape}")
print(f"Memory usage: {(database_vectors.nbytes + query_vectors.nbytes) / 1024**2:.2f} MB")

## 2. Exact Nearest Neighbor (Baseline)

Brute-force search to establish ground truth for accuracy measurements.

In [None]:
def exact_knn(database: np.ndarray,
              queries: np.ndarray,
              k: int) -> Tuple[np.ndarray, float]:
    """
    Exact K-Nearest Neighbors using brute force.

    Returns:
        Tuple of (indices of k nearest neighbors, query time)
    """
    start_time = time()

    # Using sklearn's brute force implementation
    nn = NearestNeighbors(n_neighbors=k, algorithm='brute', metric='euclidean')
    nn.fit(database)

    build_time = time() - start_time

    query_start = time()
    distances, indices = nn.kneighbors(queries)
    query_time = time() - query_start

    return indices, build_time, query_time

# Compute ground truth
print("Computing exact KNN (ground truth)...")
ground_truth, exact_build_time, exact_query_time = exact_knn(database_vectors, query_vectors, K)

print(f"Build time: {exact_build_time:.4f}s")
print(f"Query time: {exact_query_time:.4f}s")
print(f"Avg time per query: {exact_query_time/N_QUERIES*1000:.2f}ms")

## 3. Annoy (Approximate Nearest Neighbors Oh Yeah)

Spotify's library using random projection trees. Fast and memory-efficient.

In [None]:
def build_annoy_index(database: np.ndarray,
                      n_trees: int = 10,
                      metric: str = 'euclidean') -> Tuple[AnnoyIndex, float]:
    """
    Build Annoy index.

    Args:
        database: Database vectors
        n_trees: Number of trees (more trees = better accuracy, slower build)
        metric: Distance metric ('euclidean' or 'angular')

    Returns:
        Tuple of (Annoy index, build time)
    """
    start_time = time()

    # Initialize index
    index = AnnoyIndex(database.shape[1], metric)

    # Add all vectors
    for i, vector in enumerate(database):
        index.add_item(i, vector)

    # Build index
    index.build(n_trees)

    build_time = time() - start_time
    return index, build_time

def query_annoy(index: AnnoyIndex,
                queries: np.ndarray,
                k: int,
                search_k: int = -1) -> Tuple[np.ndarray, float]:
    """
    Query Annoy index.

    Args:
        search_k: Number of nodes to search (-1 = n_trees * k)

    Returns:
        Tuple of (indices, query time)
    """
    start_time = time()

    results = []
    for query in queries:
        indices = index.get_nns_by_vector(query, k, search_k=search_k)
        results.append(indices)

    query_time = time() - start_time
    return np.array(results), query_time

# Test with different numbers of trees
print("Building Annoy indices...")
annoy_results = {}

for n_trees in [5, 10, 20, 50]:
    index, build_time = build_annoy_index(database_vectors, n_trees=n_trees)
    indices, query_time = query_annoy(index, query_vectors, K)

    annoy_results[n_trees] = {
        'index': index,
        'indices': indices,
        'build_time': build_time,
        'query_time': query_time
    }

    print(f"n_trees={n_trees}: Build={build_time:.4f}s, Query={query_time:.4f}s")

## 4. FAISS (Facebook AI Similarity Search)

Industry-standard library with multiple index types. We'll test IVF (Inverted File) and HNSW.

In [None]:
def build_faiss_flat(database: np.ndarray) -> Tuple[faiss.Index, float]:
    """
    Build FAISS Flat (exact) index for comparison.
    """
    start_time = time()
    index = faiss.IndexFlatL2(database.shape[1])
    index.add(database)
    build_time = time() - start_time
    return index, build_time

def build_faiss_ivf(database: np.ndarray,
                    n_lists: int = 100) -> Tuple[faiss.Index, float]:
    """
    Build FAISS IVF (Inverted File) index.

    Args:
        n_lists: Number of Voronoi cells (clusters)
    """
    start_time = time()

    # Create quantizer and IVF index
    quantizer = faiss.IndexFlatL2(database.shape[1])
    index = faiss.IndexIVFFlat(quantizer, database.shape[1], n_lists)

    # Train on database
    index.train(database)
    index.add(database)

    build_time = time() - start_time
    return index, build_time

def build_faiss_hnsw(database: np.ndarray,
                     M: int = 32) -> Tuple[faiss.Index, float]:
    """
    Build FAISS HNSW (Hierarchical Navigable Small World) index.

    Args:
        M: Number of connections per node
    """
    start_time = time()
    index = faiss.IndexHNSWFlat(database.shape[1], M)
    index.add(database)
    build_time = time() - start_time
    return index, build_time

def query_faiss(index: faiss.Index,
                queries: np.ndarray,
                k: int,
                n_probe: int = 1) -> Tuple[np.ndarray, float]:
    """
    Query FAISS index.

    Args:
        n_probe: Number of clusters to visit (for IVF indices)
    """
    # Set n_probe for IVF indices
    if hasattr(index, 'nprobe'):
        index.nprobe = n_probe

    start_time = time()
    distances, indices = index.search(queries, k)
    query_time = time() - start_time

    return indices, query_time

# Build FAISS indices
print("Building FAISS indices...")
faiss_results = {}

# Flat (exact)
index_flat, build_time = build_faiss_flat(database_vectors)
indices, query_time = query_faiss(index_flat, query_vectors, K)
faiss_results['Flat'] = {
    'indices': indices,
    'build_time': build_time,
    'query_time': query_time
}
print(f"Flat: Build={build_time:.4f}s, Query={query_time:.4f}s")

# IVF with different n_probe values
for n_probe in [1, 5, 10, 20]:
    index_ivf, build_time = build_faiss_ivf(database_vectors, n_lists=100)
    indices, query_time = query_faiss(index_ivf, query_vectors, K, n_probe=n_probe)
    faiss_results[f'IVF_probe{n_probe}'] = {
        'indices': indices,
        'build_time': build_time,
        'query_time': query_time
    }
    print(f"IVF (n_probe={n_probe}): Build={build_time:.4f}s, Query={query_time:.4f}s")

# HNSW
index_hnsw, build_time = build_faiss_hnsw(database_vectors, M=32)
indices, query_time = query_faiss(index_hnsw, query_vectors, K)
faiss_results['HNSW'] = {
    'indices': indices,
    'build_time': build_time,
    'query_time': query_time
}
print(f"HNSW: Build={build_time:.4f}s, Query={query_time:.4f}s")

## 5. KD-Tree (sklearn)

Tree-based approach from scikit-learn. Works well in lower dimensions but degrades in high dimensions.

In [None]:
def build_kdtree(database: np.ndarray,
                 leaf_size: int = 30) -> Tuple[KDTree, float]:
    """
    Build KD-Tree index.
    """
    start_time = time()
    tree = KDTree(database, leaf_size=leaf_size)
    build_time = time() - start_time
    return tree, build_time

def query_kdtree(tree: KDTree,
                 queries: np.ndarray,
                 k: int) -> Tuple[np.ndarray, float]:
    """
    Query KD-Tree.
    """
    start_time = time()
    distances, indices = tree.query(queries, k=k)
    query_time = time() - start_time
    return indices, query_time

# Build KD-Tree
print("Building KD-Tree...")
kdtree, build_time = build_kdtree(database_vectors)
kdtree_indices, query_time = query_kdtree(kdtree, query_vectors, K)

print(f"Build time: {build_time:.4f}s")
print(f"Query time: {query_time:.4f}s")

## 6. Accuracy Metrics

Calculate Recall@K: the fraction of true nearest neighbors found by the approximate method.

In [None]:
def calculate_recall_at_k(ground_truth: np.ndarray,
                          predictions: np.ndarray) -> float:
    """
    Calculate Recall@K.

    Args:
        ground_truth: True nearest neighbors [n_queries, k]
        predictions: Predicted nearest neighbors [n_queries, k]

    Returns:
        Average recall across all queries
    """
    recalls = []

    for true_neighbors, pred_neighbors in zip(ground_truth, predictions):
        # Count how many predicted neighbors are in the true set
        intersection = len(set(true_neighbors) & set(pred_neighbors))
        recall = intersection / len(true_neighbors)
        recalls.append(recall)

    return np.mean(recalls)

# Calculate recall for all methods
results_df = []

# Exact KNN (should be 1.0)
results_df.append({
    'Method': 'Exact KNN',
    'Recall@K': 1.0,
    'Build Time (s)': exact_build_time,
    'Query Time (s)': exact_query_time,
    'Avg Query (ms)': exact_query_time / N_QUERIES * 1000
})

# Annoy
for n_trees, data in annoy_results.items():
    recall = calculate_recall_at_k(ground_truth, data['indices'])
    results_df.append({
        'Method': f'Annoy (trees={n_trees})',
        'Recall@K': recall,
        'Build Time (s)': data['build_time'],
        'Query Time (s)': data['query_time'],
        'Avg Query (ms)': data['query_time'] / N_QUERIES * 1000
    })

# FAISS
for name, data in faiss_results.items():
    recall = calculate_recall_at_k(ground_truth, data['indices'])
    results_df.append({
        'Method': f'FAISS {name}',
        'Recall@K': recall,
        'Build Time (s)': data['build_time'],
        'Query Time (s)': data['query_time'],
        'Avg Query (ms)': data['query_time'] / N_QUERIES * 1000
    })

# KD-Tree
kdtree_recall = calculate_recall_at_k(ground_truth, kdtree_indices)
results_df.append({
    'Method': 'KD-Tree',
    'Recall@K': kdtree_recall,
    'Build Time (s)': build_time,
    'Query Time (s)': query_time,
    'Avg Query (ms)': query_time / N_QUERIES * 1000
})

results_df = pd.DataFrame(results_df)
results_df = results_df.sort_values('Recall@K', ascending=False)

print("\n" + "="*80)
print("PERFORMANCE COMPARISON")
print("="*80)
print(results_df.to_string(index=False))
print("="*80)

## 7. Visualizations

In [None]:
# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Recall vs Query Time (Speed-Accuracy Tradeoff)
ax1 = axes[0, 0]
for _, row in results_df.iterrows():
    color = 'red' if 'Exact' in row['Method'] else 'blue' if 'Annoy' in row['Method'] else 'green' if 'FAISS' in row['Method'] else 'orange'
    marker = 'o' if 'Exact' in row['Method'] else 's' if 'Annoy' in row['Method'] else '^' if 'FAISS' in row['Method'] else 'd'
    ax1.scatter(row['Avg Query (ms)'], row['Recall@K'],
               s=150, alpha=0.7, color=color, marker=marker,
               label=row['Method'])

ax1.set_xlabel('Average Query Time (ms)', fontsize=12, fontweight='bold')
ax1.set_ylabel('Recall@K', fontsize=12, fontweight='bold')
ax1.set_title('Speed-Accuracy Tradeoff', fontsize=14, fontweight='bold')
ax1.set_xscale('log')
ax1.grid(True, alpha=0.3)
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)

# 2. Build Time Comparison
ax2 = axes[0, 1]
methods = results_df['Method'].str.replace('FAISS ', '').str.replace('Annoy ', 'A-')
colors = ['red' if 'Exact' in m else 'blue' if 'A-' in m else 'green' if any(x in m for x in ['Flat', 'IVF', 'HNSW']) else 'orange'
          for m in results_df['Method']]
bars = ax2.barh(range(len(methods)), results_df['Build Time (s)'], color=colors, alpha=0.7)
ax2.set_yticks(range(len(methods)))
ax2.set_yticklabels(methods, fontsize=9)
ax2.set_xlabel('Build Time (seconds)', fontsize=12, fontweight='bold')
ax2.set_title('Index Build Time Comparison', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='x')

# Add value labels
for i, (bar, val) in enumerate(zip(bars, results_df['Build Time (s)'])):
    ax2.text(val, bar.get_y() + bar.get_height()/2, f'{val:.3f}s',
            va='center', ha='left', fontsize=8, fontweight='bold')

# 3. Query Time Comparison
ax3 = axes[1, 0]
bars = ax3.barh(range(len(methods)), results_df['Avg Query (ms)'], color=colors, alpha=0.7)
ax3.set_yticks(range(len(methods)))
ax3.set_yticklabels(methods, fontsize=9)
ax3.set_xlabel('Average Query Time (ms)', fontsize=12, fontweight='bold')
ax3.set_title('Query Performance Comparison', fontsize=14, fontweight='bold')
ax3.set_xscale('log')
ax3.grid(True, alpha=0.3, axis='x')

# Add value labels
for i, (bar, val) in enumerate(zip(bars, results_df['Avg Query (ms)'])):
    ax3.text(val * 1.1, bar.get_y() + bar.get_height()/2, f'{val:.2f}ms',
            va='center', ha='left', fontsize=8, fontweight='bold')

# 4. Recall Comparison
ax4 = axes[1, 1]
bars = ax4.barh(range(len(methods)), results_df['Recall@K'], color=colors, alpha=0.7)
ax4.set_yticks(range(len(methods)))
ax4.set_yticklabels(methods, fontsize=9)
ax4.set_xlabel('Recall@K', fontsize=12, fontweight='bold')
ax4.set_title('Accuracy Comparison (Recall@10)', fontsize=14, fontweight='bold')
ax4.set_xlim([0, 1.05])
ax4.grid(True, alpha=0.3, axis='x')

# Add value labels
for i, (bar, val) in enumerate(zip(bars, results_df['Recall@K'])):
    ax4.text(val, bar.get_y() + bar.get_height()/2, f'{val:.3f}',
            va='center', ha='right' if val > 0.5 else 'left',
            fontsize=8, fontweight='bold',
            color='white' if val > 0.5 else 'black')

plt.tight_layout()
plt.savefig('/home/claude/ann_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Visualizations saved to 'ann_comparison.png'")

## 8. Detailed Analysis: Pareto Frontier

Identify methods on the Pareto frontier (best speed-accuracy tradeoffs).

In [None]:
def find_pareto_frontier(df: pd.DataFrame) -> pd.DataFrame:
    """
    Find Pareto-optimal points (maximize recall, minimize query time).
    """
    pareto_points = []

    for i, row in df.iterrows():
        is_pareto = True
        for j, other_row in df.iterrows():
            if i != j:
                # Other point dominates if it has better or equal recall AND better query time
                if (other_row['Recall@K'] >= row['Recall@K'] and
                    other_row['Avg Query (ms)'] < row['Avg Query (ms)']):
                    is_pareto = False
                    break

        if is_pareto:
            pareto_points.append(row)

    return pd.DataFrame(pareto_points)

# Find and plot Pareto frontier
pareto_df = find_pareto_frontier(results_df)
pareto_df = pareto_df.sort_values('Avg Query (ms)')

plt.figure(figsize=(12, 7))

# Plot all points
for _, row in results_df.iterrows():
    is_pareto = row['Method'] in pareto_df['Method'].values
    color = 'red' if 'Exact' in row['Method'] else 'blue' if 'Annoy' in row['Method'] else 'green' if 'FAISS' in row['Method'] else 'orange'
    marker = 'o' if 'Exact' in row['Method'] else 's' if 'Annoy' in row['Method'] else '^' if 'FAISS' in row['Method'] else 'd'

    plt.scatter(row['Avg Query (ms)'], row['Recall@K'],
               s=300 if is_pareto else 150,
               alpha=1.0 if is_pareto else 0.4,
               color=color,
               marker=marker,
               edgecolors='black' if is_pareto else 'none',
               linewidths=2 if is_pareto else 0,
               label=row['Method'] if is_pareto else '',
               zorder=10 if is_pareto else 5)

    # Add labels for Pareto points
    if is_pareto:
        plt.annotate(row['Method'],
                    (row['Avg Query (ms)'], row['Recall@K']),
                    xytext=(10, 10), textcoords='offset points',
                    fontsize=9, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.5', facecolor=color, alpha=0.3))

# Draw Pareto frontier line
plt.plot(pareto_df['Avg Query (ms)'], pareto_df['Recall@K'],
         'k--', linewidth=2, alpha=0.5, label='Pareto Frontier', zorder=8)

plt.xlabel('Average Query Time (ms)', fontsize=14, fontweight='bold')
plt.ylabel('Recall@K', fontsize=14, fontweight='bold')
plt.title('Pareto Frontier: Speed-Accuracy Tradeoff\n(Larger points with black borders are Pareto-optimal)',
         fontsize=15, fontweight='bold')
plt.xscale('log')
plt.grid(True, alpha=0.3)
plt.legend(fontsize=10, loc='lower left')
plt.tight_layout()
plt.savefig('/home/claude/pareto_frontier.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nPareto-Optimal Methods:")
print(pareto_df[['Method', 'Recall@K', 'Avg Query (ms)', 'Build Time (s)']].to_string(index=False))

## 9. Distance Distribution Analysis

Visualize how approximate methods compare to exact search in terms of actual distances.

In [None]:
def compute_distances(database: np.ndarray,
                     queries: np.ndarray,
                     indices: np.ndarray) -> np.ndarray:
    """
    Compute actual L2 distances for retrieved neighbors.
    """
    distances = []
    for query_idx, neighbor_indices in enumerate(indices):
        query = queries[query_idx]
        neighbors = database[neighbor_indices]
        dists = np.linalg.norm(neighbors - query, axis=1)
        distances.append(dists)
    return np.array(distances)

# Compute distances for select methods
exact_distances = compute_distances(database_vectors, query_vectors, ground_truth)
annoy_10_distances = compute_distances(database_vectors, query_vectors, annoy_results[10]['indices'])
hnsw_distances = compute_distances(database_vectors, query_vectors, faiss_results['HNSW']['indices'])

# Plot distance distributions
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

methods_to_plot = [
    ('Exact KNN', exact_distances, 'red'),
    ('Annoy (10 trees)', annoy_10_distances, 'blue'),
    ('FAISS HNSW', hnsw_distances, 'green')
]

for ax, (method, distances, color) in zip(axes, methods_to_plot):
    # Flatten all distances
    flat_distances = distances.flatten()

    ax.hist(flat_distances, bins=50, alpha=0.7, color=color, edgecolor='black')
    ax.axvline(flat_distances.mean(), color='darkred', linestyle='--', linewidth=2,
              label=f'Mean: {flat_distances.mean():.4f}')
    ax.set_xlabel('L2 Distance', fontsize=12, fontweight='bold')
    ax.set_ylabel('Frequency', fontsize=12, fontweight='bold')
    ax.set_title(f'{method}\nDistance Distribution', fontsize=13, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('/home/claude/distance_distributions.png', dpi=300, bbox_inches='tight')
plt.show()

## 10. Summary and Recommendations

### Key Findings:

1. **Speed vs Accuracy Tradeoff**: Approximate methods can achieve 10-100x speedup with 90-99% recall

2. **Best Overall**: FAISS HNSW typically offers the best balance (high recall, fast queries)

3. **Memory Constrained**: Annoy is excellent for memory-limited scenarios

4. **High Dimensions**: KD-Tree degrades significantly in high dimensions (curse of dimensionality)

### Production Recommendations:

- **<100k vectors**: FAISS HNSW or Annoy
- **100k-10M vectors**: FAISS IVF or HNSW
- **>10M vectors**: FAISS with GPU support or distributed solutions
- **Real-time requirements**: HNSW (best query latency)
- **Batch processing**: IVF (good throughput, acceptable latency)

### Tuning Guidelines:

- **Annoy**: Increase `n_trees` for better recall (linear build time increase)
- **FAISS IVF**: Increase `n_probe` for better recall (linear query time increase)
- **FAISS HNSW**: Increase `M` for better recall (higher memory usage)

### Further Reading:

- [FAISS Documentation](https://github.com/facebookresearch/faiss/wiki)
- [Annoy GitHub](https://github.com/spotify/annoy)
- [ANN Benchmarks](http://ann-benchmarks.com/)

In [None]:
# Export results to CSV
results_df.to_csv('/home/claude/ann_benchmark_results.csv', index=False)
print("\n✓ Results exported to 'ann_benchmark_results.csv'")
print("\n✓ Notebook execution complete!")
print("\n" + "="*80)
print("FINAL SUMMARY")
print("="*80)
print(f"Database size: {N_SAMPLES:,} vectors × {N_DIMENSIONS} dimensions")
print(f"Number of queries: {N_QUERIES}")
print(f"K (neighbors to retrieve): {K}")
print(f"\nTop 3 methods by recall:")
print(results_df.nlargest(3, 'Recall@K')[['Method', 'Recall@K', 'Avg Query (ms)']].to_string(index=False))
print(f"\nTop 3 methods by speed:")
print(results_df.nsmallest(3, 'Avg Query (ms)')[['Method', 'Recall@K', 'Avg Query (ms)']].to_string(index=False))
print("="*80)