<a href="https://colab.research.google.com/github/nackerboss/SCANNN/blob/main/Scann_Benchmark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A. Program Initialization

This section installs the libraries, declares and implements the functions needed to run all of the following cells.

1. Installs the relevant libraries.
2. Removes TensorFlow
3. Imports libraries and implements the common functions.

To start, run all the cells in this section. The later sections' cells is not required to be run in an order and are for demos, benchmarking and searching.

In [None]:
!pip install scann sentence-transformers datasets


Collecting scann
  Downloading scann-1.4.2-cp312-cp312-manylinux_2_27_x86_64.whl.metadata (5.8 kB)
Downloading scann-1.4.2-cp312-cp312-manylinux_2_27_x86_64.whl (11.6 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m11.6/11.6 MB[0m [31m133.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: scann
Successfully installed scann-1.4.2


In [None]:
!pip uninstall tensorflow

Found existing installation: tensorflow 2.19.0
Uninstalling tensorflow-2.19.0:
  Would remove:
    /usr/local/bin/import_pb_to_tensorboard
    /usr/local/bin/saved_model_cli
    /usr/local/bin/tensorboard
    /usr/local/bin/tf_upgrade_v2
    /usr/local/bin/tflite_convert
    /usr/local/bin/toco
    /usr/local/lib/python3.12/dist-packages/tensorflow-2.19.0.dist-info/*
    /usr/local/lib/python3.12/dist-packages/tensorflow/*
Proceed (Y/n)? Y
  Successfully uninstalled tensorflow-2.19.0


In [None]:
import torch
import numpy as np

try:
    import scann
    from sentence_transformers import SentenceTransformer
    from datasets import load_dataset # New import for public dataset
except ImportError:
    print("----------------------------------------------------------------------")
    print("üö® ERROR: Please run the following command in a separate Colab cell ")
    print("and restart the runtime before running this code:")
    print("!pip install scann sentence-transformers datasets")
    print("----------------------------------------------------------------------")
    exit()

# --------------- Model used and device used (cpu/gpu/tpu) ----------------------

MODEL_NAME = 'all-MiniLM-L6-v2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model = SentenceTransformer(MODEL_NAME, device=device)

# --------------- Vector converter and normalization ----------------------------

def generate_and_normalize(data):
    """Generates embeddings and performs L2 normalization."""
    print(f"Generating embeddings for {len(data)} items...")

    # Generate embeddings (returns a numpy array)
    embeddings = embedding_model.encode(
        data,
        convert_to_tensor=False,
        show_progress_bar=True
    )

    # L2 Normalization (Needed for ScaNN dot product or angular similarity)

    print("Normalizing embeddings...")
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

    return normalized_embeddings, embeddings.shape[1]

# --------------------- Runs one text query -------------------------------------

def run_query(query, search_index, original_dataset):
    """Embeds a query, normalizes it, and searches the ScaNN index."""
    print(f"\nSearching with query: '{query}'")

    # 6.1 Embed and Normalize the query
    query_embedding = embedding_model.encode([query])[0]
    normalized_query = query_embedding / np.linalg.norm(query_embedding)

    # 6.2 Perform the search
    # The 'k' parameter is configured during the builder step, so we omit it here.
    indices, distances = search_index.search(normalized_query)

    print(f"\nTop {len(indices)} results found:")
    for rank, (idx, distance) in enumerate(zip(indices, distances)):
        print(f"  Rank {rank+1}:")
        print(idx)
        print(f"    Text: {original_dataset[idx.item() ]}")
        # Dot product distance is 1.0 for perfect match, 0.0 for orthogonal
        print(f"    Similarity (Dot Product): {distance:.4f}")
        print(f"    Dataset Index: {idx}")

# --------------- Computes recall (correctness of the ScaNN result) -------------

def compute_recall(neighbors, true_neighbors):
    """
    Computes recall @k by comparing the results of the approximate search
    (neighbors) against the exact search (true_neighbors).
    """
    total = 0
    # Iterate through query results, comparing the approximate set against the true set
    for gt_row, row in zip(true_neighbors, neighbors):
        # Count the number of common elements (true positives)
        total += np.intersect1d(gt_row, row).shape[0]

    # Recall is (True Positives) / (Total True Neighbors)
    return total / true_neighbors.size

print("Importing libraries and implementing functions successful.")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Importing libraries and implementing functions successful.


# B. Demonstration of ScaNN Embedding Search Index

This script demonstrates the general workings of ScaNN. This allows querying in a 2,000-dataset of news headlines, using predetermined texts.

In [None]:

# Load a public dataset (ag_news) and take a manageable subset for demonstration
print("Loading public dataset (ag_news) subset...")
try:
    # Load the training split (used for building the ScaNN index)
    ag_news_dataset_train = load_dataset('ag_news', split='train[:5000]')
    dataset = ag_news_dataset_train['text']

    # Load the test split (used for generating test queries for recall calculation)
    ag_news_dataset_test = load_dataset('ag_news', split='test[:300]')
    test_dataset_text = ag_news_dataset_test['text']

except Exception as e:
    print(f"Error loading ag_news dataset: {e}")
    # Fallback to the original small dataset if loading fails
    dataset = [
        "The sun rises in the east every morning.",
        "A computer uses a central processing unit for core tasks.",
        "Cats and dogs are common household pets.",
        "A feline companion enjoying a nap on the sofa.",
        "The central processing unit is the brain of any modern machine.",
        "Tomorrow's forecast predicts clear skies and warm weather."
    ]
    test_dataset_text = dataset # Use the same small data for queries if primary fails


# The queries we will use to search the dataset
query_text_1 = "The main component of a PC is the CPU."
query_text_2 = "What is the weather like at dawn?"
query_text_3 = "Football match results from the weekend."



normalized_dataset_embeddings, embedding_dim = generate_and_normalize(dataset)

normalized_test_embeddings, _ = generate_and_normalize(test_dataset_text)

print(f"\nDataset Ready. Shape: {normalized_dataset_embeddings.shape}")
print(f"Test Query Set Shape: {normalized_test_embeddings.shape}")
print(f"First dataset entry (Index Training Data): {dataset[0]}")


# --- 4. Building the ScaNN Index (Optimized for 2000 vectors) ---

print("\n--- 4. Building ScaNN Optimized Searcher (Trained on 2000 examples) ---")

# The maximum number of neighbors to retrieve (top-k)
K_NEIGHBORS = 5
REORDER_NEIGHBORS = 100 # Reduced reorder candidates for speedier demo

# 4.1. Initialize the ScaNN builder
# Arguments: (dataset, k, distance_metric)
builder = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    K_NEIGHBORS,
    "dot_product"
)

# 4.2. Configure the Tree (Partitioning) stage
tree_configured = builder.tree(
    num_leaves=500,
    num_leaves_to_search=50,
    training_sample_size=4000
)

# 4.3. Configure Asymmetric Hashing (AH) for scoring
ah_configured = tree_configured.score_ah(
    2, # Number of dimensions per subvector
    anisotropic_quantization_threshold=0.2
)

# 4.4. Configure the Reordering (Refinement) stage
reorder_configured = ah_configured.reorder(REORDER_NEIGHBORS)

# 4.5. Finalize and build the searcher
searcher = reorder_configured.build()

print("ScaNN optimized index built successfully.")




In [None]:
# --- Computing Recall (ScaNN vs. Brute Force) ---

print("\n--- 5. Computing Recall (ScaNN vs. Brute Force) ---")

# Create a Brute-Force ScaNN searcher (no tree, no quantization)
# This will find the mathematically exact nearest neighbors.
bruteforce_searcher = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    K_NEIGHBORS,
    "dot_product"
).score_brute_force().build()

# 5.2. Define Test Queries (using a subset of the official test split as queries)
# Limit the number of test queries for faster recall computation
MAX_TEST_QUERIES = 500
NUM_RECALL_QUERIES = min(MAX_TEST_QUERIES, len(normalized_test_embeddings))

# Use the dedicated test set embeddings for recall calculation
recall_test_queries = normalized_test_embeddings[:NUM_RECALL_QUERIES]

print(f"1. Running Brute-Force search on {NUM_RECALL_QUERIES} test queries...")
# .search_batched() is much faster for multiple queries
true_neighbors, _ = bruteforce_searcher.search_batched(recall_test_queries)

print("2. Running Optimized ScaNN search...")
scann_neighbors, _ = searcher.search_batched(recall_test_queries)

# 5.3. Calculate and Print Recall
recall_value = compute_recall(scann_neighbors, true_neighbors)
print(f"\n‚úÖ Recall @{K_NEIGHBORS} for {NUM_RECALL_QUERIES} queries from the TEST split: {recall_value * 100:.2f}%")
print("This value indicates the percentage of exact nearest neighbors found by the approximate searcher.")

# Run Query 1: Find sentences about computers
run_query(query_text_1, searcher, dataset)

# Run Query 2: Find sentences about weather/time
run_query(query_text_2, searcher, dataset)

# Run Query 3: Find relevant news articles
run_query(query_text_3, searcher, dataset)

# C. Benchmarking Section

This second section attempts to run both the built-in brute force algorithm of ScaNN and the actual algorithm in a larger scale.

The first cell is for pre-generating a subset (or all) of the headlines in ag_news for quicker use(that it does not need to regenerate the vectors again).

In [None]:
# -------------- development script used for generating an embedded! -----------

print("This is a Colab script to load the dataset, embed, then save it into agnews_embeddings.h5.")

num_headlines = int(input("Enter the number of news headlines to convert, normalize, and be saved to the embeddings file."))

if (num_headlines > 120000) or (num_headlines < 1):
  print('Invalid input. num_headlines is set back to 5000.')
  num_headlines = 5000

num_tests = int(input("Enter the number of text queries to be converted, normalized, and be saved."))

if (num_tests > 1200) or (num_tests < 1):
  print('Invalid input. num_tests is set back to 1000.')
  num_tests = 1000

# 2.

MODEL_NAME = 'all-MiniLM-L6-v2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model = SentenceTransformer(MODEL_NAME, device=device)

print("Loading public dataset (ag_news) subset...")
try:
    ag_news_dataset_train = load_dataset('ag_news', split=f'train[:{num_headlines}]') # Loads dataset with num_training entries
    dataset_train = ag_news_dataset_train['text']

    ag_news_dataset_test = load_dataset('ag_news', split=f'test[:{num_tests}]') # Loads dataset with num_test entries
    dataset_test = ag_news_dataset_test['text']

except Exception as e:
    print(f"Error loading ag_news dataset. {e}. Since this is used for benchmarking, the hardcoded dataset is ignored and the program is cancelled.")
    exit()

normalized_dataset_train_embeddings, embedding_dim = generate_and_normalize(dataset_train)
normalized_dataset_test_embeddings, embedding_dim = generate_and_normalize(dataset_test)

print(f"\nDataset generated: {normalized_dataset_train_embeddings.shape}")
print(f"First dataset training entry (Index Training Data): {dataset_train[0]}")

# This script has finished generating and normalizing. The next cell saves them.



This is a Colab script to load the dataset, embed, then save it into agnews_embeddings.h5.
Enter the number of news headlines to convert, normalize, and be saved to the embeddings file.5000
Enter the number of text queries to be converted, normalized, and be saved.300
Loading public dataset (ag_news) subset...
Generating embeddings for 5000 items...


Batches:   0%|          | 0/157 [00:00<?, ?it/s]

Normalizing embeddings...
Generating embeddings for 300 items...


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Normalizing embeddings...

Dataset generated: (5000, 384)
First dataset training entry (Index Training Data): Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.


In [None]:
import h5py

filename = 'agnews_embeddings.h5'
dataset_train = 'train'
dataset_test = 'test'

# 2. Save the embeddings to the H5 file
try:
    with h5py.File(filename, 'w') as f:
        # Create a dataset to hold the embedding array
        dtrset = f.create_dataset(dataset_train, data=normalized_dataset_train_embeddings)
        dteset = f.create_dataset(dataset_test, data=normalized_dataset_test_embeddings)

    print(f"Embeddings successfully saved to {filename} under the data train set '{dataset_train}' and data test set '{dataset_test}'.")

except Exception as e:
    print(f"An error occurred: {e}")


Embeddings successfully saved to agnews_embeddings.h5 under the data train set 'train' and data test set 'test.


This used a pregen'd embedded dataset from a file. Run this snippet first to fetch the vectors into the variables.

The first cell downloads an external test file for benchmarking.

If you wish to do this with any arbitrary .hdf5 (.h5) files, modify the variables in the second cell. Default is GloVe-200 dimensions, 1.1m train size and 10k test size, available [here](https://github.com/erikbern/ann-benchmarks?tab=readme-ov-file). As the file has 'train' and 'test' objects, we have hardcoded this into the cell. This remains the standard for the other datasets in the other files on the link mentioned.

In [None]:
!wget http://ann-benchmarks.com/glove-200-angular.hdf5

In [None]:
# To load the embeddings back
import h5py

filename = 'glove-200-angular.hdf5'
dataset_train = 'train'
dataset_test = 'test'

try:
    with h5py.File(filename, 'r') as f:
        # Access the dataset
        loaded_train_embeddings_raw = f[dataset_train][:]
        loaded_train_embeddings = loaded_train_embeddings_raw / np.linalg.norm(loaded_train_embeddings_raw, axis=1, keepdims=True)
        loaded_test_embeddings_raw = f[dataset_test][:]
        loaded_test_embeddings = loaded_test_embeddings_raw / np.linalg.norm(loaded_test_embeddings_raw, axis=1, keepdims=True)

        print("\nEmbeddings loaded successfully.")
        print("Train shape:", loaded_train_embeddings.shape)
        print("Test shape:", loaded_test_embeddings.shape)

 #       print("Metadata description:", f[dataset_name].attrs['description'])

except Exception as e:
    print(f"An error occurred during loading: {e}")


Embeddings loaded successfully.
Train shape: (1183514, 200)
Test shape: (10000, 200)


This next cell starts running the relevant code to benchmark runtimes of the ScaNN algorithm and the built-in brute-force algorithm. The K_NEIGHBORS and TEST_SIZES are hardcoded, with TEXT_SIZES being an array for multiple test and allowing to see trends.

In [None]:
import os
import sys
import time
import tempfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset
import scann

K_NEIGHBORS = 10

TEST_SIZES = [100, 300, 1000, 3000, 10000]

def sizeof(obj):
    """Rough estimate of Python in-memory object size (MB)."""
    try:
        return sys.getsizeof(obj) / (1024 * 1024)
    except:
        return -1


def save_and_measure_index(searcher):
    """Serialize ScaNN index into a temporary directory and measure size."""
    import tempfile
    import shutil

    temp_dir = tempfile.mkdtemp()   # <-- directory, not file

    try:
        searcher.serialize(temp_dir)

        # compute total directory size
        total_size = 0
        for root, dirs, files in os.walk(temp_dir):
            for f in files:
                fp = os.path.join(root, f)
                total_size += os.path.getsize(fp)

        return total_size / (1024 * 1024)  # MB

    finally:
        # clean up directory
        shutil.rmtree(temp_dir)



def benchmark_instance(num_queries, dataset_embeddings, searcher, test_embeddings):

  print(f"\n=== Benchmarking with {num_queries} Queries ===")

  if num_queries > test_embeddings.shape[0]:
    raise ValueError(f"Not enough queries in the test embeddings! {num_queries} > {test_embeddings.shape[0]}")

  query_batch = test_embeddings[:num_queries]

  brute = scann.scann_ops_pybind.builder(
      dataset_embeddings,
      K_NEIGHBORS,
      "dot_product"
  ).score_brute_force().build()

  # ---- Brute force ----
  bf_start = time.perf_counter()
  true_neighbors, _ = brute.search_batched(query_batch)
  bf_end = time.perf_counter()
  bf_time = bf_end - bf_start

  # ---- ScaNN ----
  sc_start = time.perf_counter()
  scann_neighbors, _ = searcher.search_batched(query_batch)
  sc_end = time.perf_counter()
  sc_time = sc_end - sc_start

  # ---- Recall ----
  recall_value = compute_recall(scann_neighbors, true_neighbors)

  return {
    "num_queries": num_queries,
    "recall@{}".format(K_NEIGHBORS): recall_value,
    "brute_force_time_sec": bf_time,
    "scann_time_sec": sc_time,
    "speedup": bf_time / sc_time if sc_time > 0 else np.inf
  }

def sub_benchmark():


  dataset_embeddings = loaded_train_embeddings
  num_headlines = loaded_train_embeddings.shape[0]

  print(f"\nDataset size: {num_headlines}")

  print("Loaded dataset embeddings.")
  raw_embedding_size_mb = sizeof(dataset_embeddings)
  print(f"Raw embedding memory usage: {raw_embedding_size_mb:.2f} MB")

  # ===================================
  # Build ScaNN index
  # ===================================
  print("Building ScaNN index...")

  num_leaves = max(int(np.sqrt(num_headlines) * 4), 100)  # rcm val: 4 * sqrt(num_hl);
  num_leaves_to_search = max(int(num_leaves * 0.05), 10)  # requires tuning in order to get the desired recall value
  training_sample_size = min(int(num_headlines * 0.3), num_headlines - 1)  # ~30% of dataset

  searcher = scann.scann_ops_pybind.builder(
      dataset_embeddings,
      K_NEIGHBORS,
      "dot_product"
  ).tree(
      num_leaves=num_leaves,
      num_leaves_to_search=num_leaves_to_search,
      training_sample_size=training_sample_size
  ).score_ah(
      dimensions_per_block=2,
      anisotropic_quantization_threshold=0.2
  ).reorder(K_NEIGHBORS * 10).build()

  # Measure index size
  scann_index_size_mb = save_and_measure_index(searcher)
  print(f"ScaNN index size: {scann_index_size_mb:.2f} MB")
  compression_ratio = raw_embedding_size_mb / scann_index_size_mb

  # ===================================
  # Run benchmarks
  # ===================================
  results = []

  for size in TEST_SIZES:
      stats = benchmark_instance(size, dataset_embeddings, searcher, loaded_test_embeddings)
      results.append(stats)

  # Convert to table
  df = pd.DataFrame(results)

  print("\n=== Index Statistics ===")
  print(f"Raw embedding memory usage: {raw_embedding_size_mb:.2f} MB")
  print(f"ScaNN index size: {scann_index_size_mb:.2f} MB")
  print(f"Compression ratio: {compression_ratio:.3f}x")

  print("\n=== Final Benchmark Table (per query count only) ===")
  print(df)

  df.to_csv("benchmark_results.csv", index=False)
  print("Saved CSV -> benchmark_results.csv")

  # ===================================
  # Plotting
  # ===================================

  # ---- Speedup Plot ----
  plt.figure()
  plt.plot(df["num_queries"], df["brute_force_time_sec"], label="Brute-force", marker='o')
  plt.plot(df["num_queries"], df["scann_time_sec"], label="ScaNN", marker='s')
  plt.xlabel("num. queries")
  plt.ylabel("Time (seconds)")
  plt.title("Searching time comparison")
  plt.legend()
  plt.grid(True)
  plt.savefig("time_comparison_plot.png")
  print("Saved plot -> time_comparison_plot.png")

  return df

sub_benchmark()



Dataset size: 1183514
Loaded dataset embeddings.
Raw embedding memory usage: 902.95 MB
Building ScaNN index...
ScaNN index size: 1027.08 MB

=== Benchmarking with 100 Queries ===

=== Benchmarking with 300 Queries ===

=== Benchmarking with 1000 Queries ===

=== Benchmarking with 3000 Queries ===

=== Benchmarking with 10000 Queries ===
