In [1]:
import numpy as np
from collections import defaultdict
from typing import List, Tuple, Dict
import random
import heapq

class LSH:
    """
    Locality Sensitive Hashing implementation
    """
    def __init__(self, num_hash_tables: int, num_hash_functions: int, vector_dimension: int):
        self.num_hash_tables = num_hash_tables
        self.num_hash_functions = num_hash_functions
        self.vector_dimension = vector_dimension
        self.hash_tables = [defaultdict(list) for _ in range(num_hash_tables)]
        # Random projection vectors for each hash function
        self.random_vectors = [
            [np.random.randn(vector_dimension) for _ in range(num_hash_functions)]
            for _ in range(num_hash_tables)
        ]

    def _hash_vector(self, vector: np.ndarray, random_vectors: List[np.ndarray]) -> str:
        """Generate hash for a vector using random projections"""
        hash_bits = []
        for rv in random_vectors:
            # If projection is positive, hash bit is 1; otherwise 0
            hash_bit = 1 if np.dot(vector, rv) >= 0 else 0
            hash_bits.append(str(hash_bit))
        return ''.join(hash_bits)

    def insert(self, vector: np.ndarray, label: str):
        """Insert a vector into the hash tables"""
        for i, hash_table in enumerate(self.hash_tables):
            hash_value = self._hash_vector(vector, self.random_vectors[i])
            hash_table[hash_value].append((vector, label))

    def query(self, query_vector: np.ndarray, k: int = 1) -> List[Tuple[str, float]]:
        """Find k approximate nearest neighbors"""
        candidates = set()
        
        # Collect candidates from all hash tables
        for i, hash_table in enumerate(self.hash_tables):
            hash_value = self._hash_vector(query_vector, self.random_vectors[i])
            candidates.update((label, np.linalg.norm(query_vector - vector)) 
                            for vector, label in hash_table[hash_value])
        
        # Return k nearest neighbors from candidates
        return sorted(candidates, key=lambda x: x[1])[:k]


class KDTree:
    """
    KD-Tree implementation for ANN search
    """
    class Node:
        def __init__(self, point: np.ndarray, label: str, axis: int):
            self.point = point
            self.label = label
            self.axis = axis
            self.left = None
            self.right = None

    def __init__(self, points: List[np.ndarray], labels: List[str]):
        self.dimension = len(points[0])
        self.root = self._build_tree(points, labels, 0)

    def _build_tree(self, points: List[np.ndarray], labels: List[str], depth: int) -> Node:
        if not points:
            return None

        axis = depth % self.dimension
        # Sort points by the current axis
        sorted_points_labels = sorted(zip(points, labels), key=lambda x: x[0][axis])
        median_idx = len(points) // 2
        
        node = self.Node(sorted_points_labels[median_idx][0], 
                        sorted_points_labels[median_idx][1], 
                        axis)
        
        node.left = self._build_tree([p for p, _ in sorted_points_labels[:median_idx]],
                                   [l for _, l in sorted_points_labels[:median_idx]],
                                   depth + 1)
        
        node.right = self._build_tree([p for p, _ in sorted_points_labels[median_idx + 1:]],
                                    [l for _, l in sorted_points_labels[median_idx + 1:]],
                                    depth + 1)
        
        return node

    def query(self, query_point: np.ndarray, k: int = 1) -> List[Tuple[str, float]]:
        """Find k approximate nearest neighbors using KD-tree"""
        heap = []  # min heap to store k nearest neighbors
        self._search(self.root, query_point, k, heap)
        
        # Convert heap to sorted list of (label, distance) pairs
        return sorted([(label, -dist) for dist, label in heap])

    def _search(self, node: Node, query_point: np.ndarray, k: int, heap: List[Tuple[float, str]]):
        if not node:
            return

        distance = np.linalg.norm(query_point - node.point)
        
        # If heap has less than k elements, add current point
        if len(heap) < k:
            heapq.heappush(heap, (-distance, node.label))
        # If current point is closer than the furthest point in heap
        elif -distance > heap[0][0]:
            heapq.heapreplace(heap, (-distance, node.label))

        axis = node.axis
        diff = query_point[axis] - node.point[axis]
        
        # Recursively search the closer subtree
        if diff <= 0:
            self._search(node.left, query_point, k, heap)
            # Search the other subtree if it might contain closer points
            if len(heap) < k or abs(diff) < -heap[0][0]:
                self._search(node.right, query_point, k, heap)
        else:
            self._search(node.right, query_point, k, heap)
            if len(heap) < k or abs(diff) < -heap[0][0]:
                self._search(node.left, query_point, k, heap)


# Example usage
def demo_ann_algorithms():
    # Generate sample data
    np.random.seed(42)
    num_points = 1000
    dimension = 10
    
    # Generate random points and labels
    points = [np.random.randn(dimension) for _ in range(num_points)]
    labels = [f"point_{i}" for i in range(num_points)]
    
    # Test LSH
    print("Testing LSH...")
    lsh = LSH(num_hash_tables=5, num_hash_functions=4, vector_dimension=dimension)
    for point, label in zip(points, labels):
        lsh.insert(point, label)
    
    query_point = np.random.randn(dimension)
    nearest_neighbors = lsh.query(query_point, k=5)
    print(f"LSH nearest neighbors to query point: {nearest_neighbors}\n")
    
    # Test KD-Tree
    print("Testing KD-Tree...")
    kdtree = KDTree(points, labels)
    nearest_neighbors = kdtree.query(query_point, k=5)
    print(f"KD-Tree nearest neighbors to query point: {nearest_neighbors}")

if __name__ == "__main__":
    demo_ann_algorithms()

Testing LSH...
LSH nearest neighbors to query point: [('point_869', 1.05852149715971), ('point_777', 1.7024391174604139), ('point_251', 1.9113638572205693), ('point_693', 2.1405210115580124), ('point_788', 2.153742934654818)]

Testing KD-Tree...
KD-Tree nearest neighbors to query point: [('point_113', 2.0106317102460323), ('point_251', 1.9113638572205693), ('point_393', 1.8499312365970935), ('point_777', 1.7024391174604139), ('point_869', 1.05852149715971)]


In [2]:
import numpy as np
import faiss
import scann
import time
from typing import Tuple, List
from tqdm import tqdm  # for progress bars

class FAISSIndexer:
    """
    Wrapper class for FAISS indexing and search
    """
    def __init__(self, dimension: int, index_type: str = 'flat'):
        self.dimension = dimension
        self.index_type = index_type
        self.index = self._create_index()
        
    def _create_index(self) -> faiss.Index:
        if self.index_type == 'flat':
            return faiss.IndexFlatL2(self.dimension)
        elif self.index_type == 'ivf':
            nlist = 100
            quantizer = faiss.IndexFlatL2(self.dimension)
            return faiss.IndexIVFFlat(quantizer, self.dimension, nlist)
        elif self.index_type == 'hnsw':
            return faiss.IndexHNSWFlat(self.dimension, 32)
        else:
            raise ValueError(f"Unknown index type: {self.index_type}")

    def train(self, vectors: np.ndarray):
        if self.index_type == 'ivf':
            self.index.train(vectors)
    
    def add(self, vectors: np.ndarray):
        self.index.add(vectors.astype(np.float32))
    
    def search(self, query: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
        return self.index.search(query.astype(np.float32), k)

class ScaNNIndexer:
    """
    Wrapper class for ScaNN indexing and search with proper batch handling
    """
    def __init__(self, 
                 dimension: int,
                 num_leaves: int = 100,
                 num_leaves_to_search: int = 10,
                 training_sample_size: int = 10000):
        self.dimension = dimension
        self.num_leaves = num_leaves
        self.num_leaves_to_search = num_leaves_to_search
        self.training_sample_size = training_sample_size
        self.index = None
        
    def train(self, vectors: np.ndarray):
        """Build and train the ScaNN index"""
        self.index = (
            scann.scann_ops_pybind.builder(vectors, 10, "dot_product")
            .tree(
                num_leaves=self.num_leaves,
                num_leaves_to_search=self.num_leaves_to_search,
                training_sample_size=self.training_sample_size,
            )
            .score_ah(2, anisotropic_quantization_threshold=0.2)
            .reorder(100)
            .build()
        )
    
    def search_batch(self, 
                    queries: np.ndarray, 
                    k: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Search for k nearest neighbors for a batch of queries
        Handles queries one by one as ScaNN expects 1D queries
        """
        all_indices = []
        all_distances = []
        
        for query in queries:
            indices, distances = self.index.search(query, k)
            all_indices.append(indices)
            all_distances.append(distances)
            
        return (np.array(all_indices), np.array(all_distances))

def benchmark_ann_methods():
    """
    Benchmark different ANN methods with sample data
    """
    # Generate sample data
    num_vectors = 100000
    dimension = 128
    num_queries = 1000
    k = 10
    
    print(f"Generating {num_vectors} vectors of dimension {dimension}...")
    vectors = np.random.random((num_vectors, dimension)).astype(np.float32)
    queries = np.random.random((num_queries, dimension)).astype(np.float32)
    
    # Test FAISS with different index types
    faiss_types = ['flat', 'ivf', 'hnsw']
    for index_type in faiss_types:
        print(f"\nTesting FAISS with {index_type.upper()} index:")
        indexer = FAISSIndexer(dimension, index_type)
        
        # Train if needed
        if index_type in ['ivf']:
            start_time = time.time()
            indexer.train(vectors)
            print(f"Training time: {time.time() - start_time:.2f}s")
        
        # Add vectors
        start_time = time.time()
        indexer.add(vectors)
        print(f"Index building time: {time.time() - start_time:.2f}s")
        
        # Search
        start_time = time.time()
        distances, indices = indexer.search(queries, k)
        search_time = time.time() - start_time
        print(f"Search time for {num_queries} queries: {search_time:.2f}s")
        print(f"Average time per query: {(search_time/num_queries)*1000:.2f}ms")
        
    # Test ScaNN
    print("\nTesting ScaNN:")
    indexer = ScaNNIndexer(dimension)
    
    # Train and build index
    start_time = time.time()
    indexer.train(vectors)
    print(f"Index building time: {time.time() - start_time:.2f}s")
    
    # Search
    start_time = time.time()
    indices, distances = indexer.search_batch(queries, k)  # Using the new batch search method
    search_time = time.time() - start_time
    print(f"Search time for {num_queries} queries: {search_time:.2f}s")
    print(f"Average time per query: {(search_time/num_queries)*1000:.2f}ms")

def example_usage():
    """
    Example of how to use FAISS and ScaNN for a simple search task
    """
    # Generate sample data
    num_vectors = 10000
    dimension = 64
    vectors = np.random.random((num_vectors, dimension)).astype(np.float32)
    query = np.random.random((1, dimension)).astype(np.float32)
    
    print("\nSimple Example Usage:")
    
    # FAISS example
    print("\nFAISS (HNSW) Example:")
    faiss_indexer = FAISSIndexer(dimension, 'hnsw')
    faiss_indexer.add(vectors)
    distances, indices = faiss_indexer.search(query, k=5)
    print(f"Query results - indices: {indices[0]}")
    print(f"Distances: {distances[0]}")
    
    # ScaNN example
    print("\nScaNN Example:")
    scann_indexer = ScaNNIndexer(dimension)
    scann_indexer.train(vectors)
    # For single query, reshape to ensure it's 2D
    indices, distances = scann_indexer.search_batch(query, k=5)
    print(f"Query results - indices: {indices[0]}")
    print(f"Distances: {distances[0]}")

if __name__ == "__main__":
    print("Running ANN benchmarks...")
    benchmark_ann_methods()
    print("\nRunning example usage...")
    example_usage()

Running ANN benchmarks...
Generating 100000 vectors of dimension 128...

Testing FAISS with FLAT index:
Index building time: 0.02s
Search time for 1000 queries: 2.38s
Average time per query: 2.38ms

Testing FAISS with IVF index:
Training time: 0.71s
Index building time: 0.21s
Search time for 1000 queries: 0.02s
Average time per query: 0.02ms

Testing FAISS with HNSW index:
Index building time: 31.30s
Search time for 1000 queries: 0.07s
Average time per query: 0.07ms

Testing ScaNN:


2024-11-24 16:27:09.765365: I scann/partitioning/partitioner_factory_base.cc:59] Size of sampled dataset for training partition: 10008
2024-11-24 16:27:09.799479: I ./scann/partitioning/kmeans_tree_partitioner_utils.h:89] PartitionerFactory ran in 33.899161ms.


Index building time: 3.72s
Search time for 1000 queries: 0.12s
Average time per query: 0.12ms

Running example usage...

Simple Example Usage:

FAISS (HNSW) Example:
Query results - indices: [7958 6102  890 1396 8362]
Distances: [5.577335  5.834095  5.921546  5.9616146 6.050023 ]

ScaNN Example:


2024-11-24 16:27:14.058140: I scann/partitioning/partitioner_factory_base.cc:59] Size of sampled dataset for training partition: 10000
2024-11-24 16:27:14.075021: I ./scann/partitioning/kmeans_tree_partitioner_utils.h:89] PartitionerFactory ran in 16.790695ms.


Query results - indices: [4100 5471 7383 4148 3109]
Distances: [20.139353 19.811321 19.656528 19.62415  19.563314]
