<a href="https://colab.research.google.com/github/nackerboss/SCANN/blob/main/ScaNN_Embedding_Search_Index.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A. Program Initialization

This section installs the libraries, declares and implements the functions needed to run all of the following cells.

1. Installs the relevant libraries.
2. Removes TensorFlow
3. Imports libraries and implements the common functions.

To start, run all the cells in this section. The later sections' cells is not required to be run in an order and are for demos, benchmarking and searching.

Take note; you should use a GPU/TPU runtime type!

In [None]:
!pip install scann sentence-transformers datasets


In [None]:
!pip uninstall --yes tensorflow

In [None]:
import torch
import numpy as np

try:
    import scann
    from sentence_transformers import SentenceTransformer
    from datasets import load_dataset # New import for public dataset
except ImportError:
    print("----------------------------------------------------------------------")
    print("ðŸš¨ ERROR: Please run the following command in a separate Colab cell ")
    print("and rerun this cell before running this code:")
    print("!pip install scann sentence-transformers datasets")
    print("----------------------------------------------------------------------")
    exit()

# --------------- Model used and device used (cpu/gpu/tpu) ----------------------

MODEL_NAME = 'all-MiniLM-L6-v2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model = SentenceTransformer(MODEL_NAME, device=device)

# --------------- Vector converter and normalization ----------------------------

def generate_and_normalize(data):
    """Generates embeddings and performs L2 normalization."""
    print(f"Generating embeddings for {len(data)} items...")

    # Generate embeddings (returns a numpy array)
    embeddings = embedding_model.encode(
        data,
        convert_to_tensor=False,
        show_progress_bar=True
    )

    # L2 Normalization (Needed for ScaNN dot product or angular similarity)

    print("Normalizing embeddings...")
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

    return normalized_embeddings, embeddings.shape[1]

# --------------------- Runs one text query -------------------------------------

def run_query(query, search_index, original_dataset):
    """Embeds a query, normalizes it, and searches the ScaNN index."""
    print(f"\nSearching with query: '{query}'")

    # 6.1 Embed and Normalize the query
    query_embedding = embedding_model.encode([query])[0]
    normalized_query = query_embedding / np.linalg.norm(query_embedding)

    # 6.2 Perform the search
    # The 'k' parameter is configured during the builder step, so we omit it here.
    indices, distances = search_index.search(normalized_query)

    print(f"\nTop {len(indices)} results found:")
    for rank, (idx, distance) in enumerate(zip(indices, distances)):
        print(f"  Rank {rank+1}:")
        print(idx)
        print(f"    Text: {original_dataset[idx.item() ]}")
        # Dot product distance is 1.0 for perfect match, 0.0 for orthogonal
        print(f"    Similarity (Dot Product): {distance:.4f}")
        print(f"    Dataset Index: {idx}")

# --------------- Computes recall (correctness of the ScaNN result) -------------

def compute_recall(neighbors, true_neighbors):
    """
    Computes recall @k by comparing the results of the approximate search
    (neighbors) against the exact search (true_neighbors).
    """
    total = 0
    # Iterate through query results, comparing the approximate set against the true set
    for gt_row, row in zip(true_neighbors, neighbors):
        # Count the number of common elements (true positives)
        total += np.intersect1d(gt_row, row).shape[0]

    # Recall is (True Positives) / (Total True Neighbors)
    return total / true_neighbors.size

print("Importing libraries and implementing functions successful.")

# B. Demonstration of ScaNN Embedding Search Index

This script demonstrates the general workings of ScaNN. This allows querying in a 2,000-dataset of news headlines, using predetermined texts.

In [None]:

# Load a public dataset (ag_news) and take a manageable subset for demonstration
print("Loading public dataset (ag_news) subset...")
try:
    # Load the training split (used for building the ScaNN index)
    ag_news_dataset_train = load_dataset('ag_news', split='train[:2000]')
    dataset = ag_news_dataset_train['text']

    # Load the test split (used for generating test queries for recall calculation)
    ag_news_dataset_test = load_dataset('ag_news', split='test[:300]')
    test_dataset_text = ag_news_dataset_test['text']

except Exception as e:
    print(f"Error loading ag_news dataset: {e}")
    # Fallback to the original small dataset if loading fails
    dataset = [
        "The sun rises in the east every morning.",
        "A computer uses a central processing unit for core tasks.",
        "Cats and dogs are common household pets.",
        "A feline companion enjoying a nap on the sofa.",
        "The central processing unit is the brain of any modern machine.",
        "Tomorrow's forecast predicts clear skies and warm weather."
    ]
    test_dataset_text = dataset # Use the same small data for queries if primary fails


# The queries we will use to search the dataset
query_text_1 = "The main component of a PC is the CPU."
query_text_2 = "What is the weather like at dawn?"
query_text_3 = "Football match results from the weekend."



normalized_dataset_embeddings, embedding_dim = generate_and_normalize(dataset)

normalized_test_embeddings, _ = generate_and_normalize(test_dataset_text)

print(f"\nDataset Ready. Shape: {normalized_dataset_embeddings.shape}")
print(f"Test Query Set Shape: {normalized_test_embeddings.shape}")
print(f"First dataset entry (Index Training Data): {dataset[0]}")


# --- 4. Building the ScaNN Index (Optimized for 2000 vectors) ---

print("\n--- 4. Building ScaNN Optimized Searcher (Trained on 2000 examples) ---")

# The maximum number of neighbors to retrieve (top-k)
K_NEIGHBORS = 5
REORDER_NEIGHBORS = 100 # Reduced reorder candidates for speedier demo

# 4.1. Initialize the ScaNN builder
# Arguments: (dataset, k, distance_metric)
builder = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    K_NEIGHBORS,
    "dot_product"
)

# 4.2. Configure the Tree (Partitioning) stage
tree_configured = builder.tree(
    num_leaves=500,
    num_leaves_to_search=50,
    training_sample_size=4000
)

# 4.3. Configure Asymmetric Hashing (AH) for scoring
ah_configured = tree_configured.score_ah(
    2, # Number of dimensions per subvector
    anisotropic_quantization_threshold=0.2
)

# 4.4. Configure the Reordering (Refinement) stage
reorder_configured = ah_configured.reorder(REORDER_NEIGHBORS)

# 4.5. Finalize and build the searcher
searcher = reorder_configured.build()

print("ScaNN optimized index built successfully.")




In [None]:
# --- Computing Recall (ScaNN vs. Brute Force) ---

print("\n--- 5. Computing Recall (ScaNN vs. Brute Force) ---")

# Create a Brute-Force ScaNN searcher (no tree, no quantization)
# This will find the mathematically exact nearest neighbors.
bruteforce_searcher = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    K_NEIGHBORS,
    "dot_product"
).score_brute_force().build()

# 5.2. Define Test Queries (using a subset of the official test split as queries)
# Limit the number of test queries for faster recall computation
MAX_TEST_QUERIES = 500
NUM_RECALL_QUERIES = min(MAX_TEST_QUERIES, len(normalized_test_embeddings))

# Use the dedicated test set embeddings for recall calculation
recall_test_queries = normalized_test_embeddings[:NUM_RECALL_QUERIES]

print(f"1. Running Brute-Force search on {NUM_RECALL_QUERIES} test queries...")
# .search_batched() is much faster for multiple queries
true_neighbors, _ = bruteforce_searcher.search_batched(recall_test_queries)

print("2. Running Optimized ScaNN search...")
scann_neighbors, _ = searcher.search_batched(recall_test_queries)

# 5.3. Calculate and Print Recall
recall_value = compute_recall(scann_neighbors, true_neighbors)
print(f"\nâœ… Recall @{K_NEIGHBORS} for {NUM_RECALL_QUERIES} queries from the TEST split: {recall_value * 100:.2f}%")
print("This value indicates the percentage of exact nearest neighbors found by the approximate searcher.")

# Run Query 1: Find sentences about computers
run_query(query_text_1, searcher, dataset)

# Run Query 2: Find sentences about weather/time
run_query(query_text_2, searcher, dataset)

# Run Query 3: Find relevant news articles
run_query(query_text_3, searcher, dataset)

In [None]:
def run_arbitrary_query(query, search_index, original_dataset, k_neighbors):
    """Embeds a query, normalizes it, and searches the ScaNN index."""
    print(f"\nSearching with query: '{query}'")

    query_embedding = embedding_model.encode([query])[0]
    normalized_query = query_embedding / np.linalg.norm(query_embedding)

    indices, distances = search_index.search(normalized_query, leaves_to_search = k_neighbors * 20, final_num_neighbors = k_neighbors, pre_reorder_num_neighbors = k_neighbors * 5)

    print(f"\nTop {len(indices)} results found:")
    for rank, (idx, distance) in enumerate(zip(indices, distances)):
        print(f"  Rank {rank+1}:")
        print(idx)
        print(f"    Text: {original_dataset[idx.item() ]}")
        # Dot product distance is 1.0 for perfect match, 0.0 for orthogonal
        print(f"    Similarity (Dot Product): {distance:.4f}")
        print(f"    Dataset Index: {idx}")

print("-------------- Arbitrary Query ---------------")

while (True):
  query_text = input("Enter a query... ('quit' to exit) ")

  if (query_text == "quit"):
    break

  k = input("Enter how much neighbors is needed (0 to exit) ")

  if (k == 0):
    break

  arbitrary_searcher = reorder_configured.build()

  run_arbitrary_query(query_text, arbitrary_searcher, dataset, int(k))

  print("\n\n-------------------------------------------\n\n")