Top-K retrieval module

In [1]:
# --- Imports for Top-K Retrieval ---
# heapq   → efficient min-heap for O(n log k) top-k retrieval
# typing  → clean type annotations (good for readability during interviews)

import heapq
from typing import List, Tuple


Dependency on Cosine Similarity

In [6]:
# --- Cosine Similarity (required dependency for Top-K) ---
import math
from typing import List

def cosine_similarity(a: List[float], b: List[float]) -> float:
    """
    Compute cosine similarity between vectors a and b.
    Defensive: handles zero-length inputs and zero-norm vectors.
    """
    if not isinstance(a, list) or not isinstance(b, list):
        raise TypeError("Vectors must be lists.")

    if len(a) == 0 or len(b) == 0:
        return 0.0

    dot = sum(x * y for x, y in zip(a, b))
    norm_a = math.sqrt(sum(x * x for x in a))
    norm_b = math.sqrt(sum(y * y for y in b))

    if norm_a == 0 or norm_b == 0:
        return 0.0

    return dot / (norm_a * norm_b)


Top-K Retrieval (sort + heap)

In [7]:
def top_k_sort(query: List[float], 
               docs: List[Tuple[str, List[float]]], 
               k: int = 3,
               similarity_fn=None):
    """
    Sort-based Top-K retrieval.
    Complexity: O(n log n)
    """
    if similarity_fn is None:
        raise ValueError("similarity_fn must be provided")

    scores = []
    for doc_id, vec in docs:
        score = similarity_fn(query, vec)
        scores.append((score, doc_id))

    scores.sort(key=lambda x: x[0], reverse=True)
    return scores[:k]


def top_k_heap(query: List[float],
               docs: List[Tuple[str, List[float]]],
               k: int = 3,
               similarity_fn=None):
    """
    Heap-based Top-K retrieval.
    Complexity: O(n log k)
    """
    if similarity_fn is None:
        raise ValueError("similarity_fn must be provided")
    if k <= 0:
        return []

    heap: List[Tuple[float, str]] = []

    for doc_id, vec in docs:
        score = similarity_fn(query, vec)

        if len(heap) < k:
            heapq.heappush(heap, (score, doc_id))
        else:
            if score > heap[0][0]:
                heapq.heapreplace(heap, (score, doc_id))

    # Sort final results descending
    return sorted(heap, key=lambda x: x[0], reverse=True)


In [8]:
# --- Unit Tests for Top-K Retrieval ---

def run_topk_tests():
    print("Running Top-K tests...")
    eps = 1e-6

    from math import isclose

    # basic test vectors
    query = [1.0, 0.0, 1.0]
    docs = [
        ("d1", [1.0, 0.0, 1.0]),
        ("d2", [0.0, 0.0, 0.0]),
        ("d3", [1.0, 0.0, 0.5]),
        ("d4", [0.5, 0.1, 0.2]),
    ]

    # --- Sort-based retrieval ---
    res_sort = top_k_sort(query, docs, k=2, similarity_fn=cosine_similarity)

    assert len(res_sort) == 2
    assert res_sort[0][1] == "d1"  # highest match
    assert res_sort[1][1] in ("d3", "d4")

    # --- Heap-based retrieval ---
    res_heap = top_k_heap(query, docs, k=2, similarity_fn=cosine_similarity)

    assert len(res_heap) == 2
    assert res_heap[0][1] == "d1"

    # --- Non-positive k ---
    assert top_k_heap(query, docs, k=0, similarity_fn=cosine_similarity) == []
    assert top_k_sort(query, docs, k=0, similarity_fn=cosine_similarity) == []

    # --- Zero-vector / edge cases ---
    res_zero = top_k_sort(query, [("dz", [0, 0, 0])], k=1, similarity_fn=cosine_similarity)
    assert abs(res_zero[0][0] - 0.0) < eps

    print("All Top-K tests passed ✔️")


# Run tests
run_topk_tests()


Running Top-K tests...
All Top-K tests passed ✔️
