Cosine Similarity

In [1]:
# --- Imports ---
# math      → used for computing vector norms
# typing    → helps annotate vectors for clarity 

import math
from typing import List


In [2]:
def cosine_similarity(a: List[float], b: List[float]) -> float:
    """
    Compute cosine similarity between vectors a and b.
    
    This implementation is intentionally:
    - Defensive: handles bad inputs and zero vectors
    - Minimal: easy to explain
    - Correct: matches standard cosine similarity definition

    Returns:
        float similarity value in [-1, 1]
    """

    # Validate inputs
    if not isinstance(a, list) or not isinstance(b, list):
        raise TypeError("Vectors must be lists.")

    if len(a) == 0 or len(b) == 0:
        # No comparison possible
        return 0.0

    # Compute dot product (zip handles mismatched lengths safely)
    dot = sum(x * y for x, y in zip(a, b))

    # Compute magnitudes
    norm_a = math.sqrt(sum(x * x for x in a))
    norm_b = math.sqrt(sum(y * y for y in b))

    # Avoid divide-by-zero
    if norm_a == 0 or norm_b == 0:
        return 0.0

    # Return cosine similarity
    return dot / (norm_a * norm_b)


In [None]:
# --- Unit Tests for Cosine Similarity ---

def run_similarity_tests():
    print("Running tests...")

    # 1) Basic identical vectors
    assert cosine_similarity([1, 0, 1], [1, 0, 1]) == 1.0

    # 2) Orthogonal vectors → similarity = 0
    assert cosine_similarity([1, 0], [0, 1]) == 0.0

    # 3) Opposite vectors → similarity = -1
    assert cosine_similarity([1, 0], [-1, 0]) == -1.0

    # 4) Zero vector on one side
    assert cosine_similarity([0, 0, 0], [1, 2, 3]) == 0.0

    # 5) Empty vector
    assert cosine_similarity([], [1, 2, 3]) == 0.0

    # 6) Mismatched lengths (zip truncates) — valid for interviews
    sim = cosine_similarity([1, 2, 3], [1, 2])
    assert isinstance(sim, float)

    # 7) Typical near-match comparison
    sim2 = cosine_similarity([1, 1, 1], [0.9, 1.0, 1.1])
    assert sim2 > 0.9  # crude check

    print("All tests passed ✔️")

# Run tests
run_similarity_tests()


Running tests...


AssertionError: 