<a href="https://colab.research.google.com/github/nackerboss/SCANN/blob/main/ScaNN_Embedding_Search_Index_Demo_and_Benchmarking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A. Program Initialization

This section installs the libraries, declares and implements the functions needed to run all of the following cells.

1. Installs the relevant libraries.
2. Removes TensorFlow
3. Imports libraries and implements the common functions.

To start, run all the cells in this section. The later sections' cells is not required to be run in an order and are for demos, benchmarking and searching.

In [None]:
!pip install scann sentence-transformers datasets


In [None]:
!pip uninstall --yes tensorflow

In [None]:
import torch
import numpy as np

try:
    import scann
    from sentence_transformers import SentenceTransformer
    from datasets import load_dataset # New import for public dataset
except ImportError:
    print("----------------------------------------------------------------------")
    print("ðŸš¨ ERROR: Please run the following command in a separate Colab cell ")
    print("and rerun this cell before running this code:")
    print("!pip install scann sentence-transformers datasets")
    print("----------------------------------------------------------------------")
    exit()

# --------------- Model used and device used (cpu/gpu/tpu) ----------------------

MODEL_NAME = 'all-MiniLM-L6-v2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model = SentenceTransformer(MODEL_NAME, device=device)

# --------------- Vector converter and normalization ----------------------------

def generate_and_normalize(data):
    """Generates embeddings and performs L2 normalization."""
    print(f"Generating embeddings for {len(data)} items...")

    # Generate embeddings (returns a numpy array)
    embeddings = embedding_model.encode(
        data,
        convert_to_tensor=False,
        show_progress_bar=True
    )

    # L2 Normalization (Needed for ScaNN dot product or angular similarity)

    print("Normalizing embeddings...")
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

    return normalized_embeddings, embeddings.shape[1]

# --------------------- Runs one text query -------------------------------------

def run_query(query, search_index, original_dataset):
    """Embeds a query, normalizes it, and searches the ScaNN index."""
    print(f"\nSearching with query: '{query}'")

    # 6.1 Embed and Normalize the query
    query_embedding = embedding_model.encode([query])[0]
    normalized_query = query_embedding / np.linalg.norm(query_embedding)

    # 6.2 Perform the search
    # The 'k' parameter is configured during the builder step, so we omit it here.
    indices, distances = search_index.search(normalized_query)

    print(f"\nTop {len(indices)} results found:")
    for rank, (idx, distance) in enumerate(zip(indices, distances)):
        print(f"  Rank {rank+1}:")
        print(idx)
        print(f"    Text: {original_dataset[idx.item() ]}")
        # Dot product distance is 1.0 for perfect match, 0.0 for orthogonal
        print(f"    Similarity (Dot Product): {distance:.4f}")
        print(f"    Dataset Index: {idx}")

# --------------- Computes recall (correctness of the ScaNN result) -------------

def compute_recall(neighbors, true_neighbors):
    """
    Computes recall @k by comparing the results of the approximate search
    (neighbors) against the exact search (true_neighbors).
    """
    total = 0
    # Iterate through query results, comparing the approximate set against the true set
    for gt_row, row in zip(true_neighbors, neighbors):
        # Count the number of common elements (true positives)
        total += np.intersect1d(gt_row, row).shape[0]

    # Recall is (True Positives) / (Total True Neighbors)
    return total / true_neighbors.size

print("Importing libraries and implementing functions successful.")

# B. Demonstration of ScaNN Embedding Search Index

This script demonstrates the general workings of ScaNN. This allows querying in a 2,000-dataset of news headlines, using predetermined texts.

In [None]:

# Load a public dataset (ag_news) and take a manageable subset for demonstration
print("Loading public dataset (ag_news) subset...")
try:
    # Load the training split (used for building the ScaNN index)
    ag_news_dataset_train = load_dataset('ag_news', split='train[:2000]')
    dataset = ag_news_dataset_train['text']

    # Load the test split (used for generating test queries for recall calculation)
    ag_news_dataset_test = load_dataset('ag_news', split='test[:300]')
    test_dataset_text = ag_news_dataset_test['text']

except Exception as e:
    print(f"Error loading ag_news dataset: {e}")
    # Fallback to the original small dataset if loading fails
    dataset = [
        "The sun rises in the east every morning.",
        "A computer uses a central processing unit for core tasks.",
        "Cats and dogs are common household pets.",
        "A feline companion enjoying a nap on the sofa.",
        "The central processing unit is the brain of any modern machine.",
        "Tomorrow's forecast predicts clear skies and warm weather."
    ]
    test_dataset_text = dataset # Use the same small data for queries if primary fails


# The queries we will use to search the dataset
query_text_1 = "The main component of a PC is the CPU."
query_text_2 = "What is the weather like at dawn?"
query_text_3 = "Football match results from the weekend."



normalized_dataset_embeddings, embedding_dim = generate_and_normalize(dataset)

normalized_test_embeddings, _ = generate_and_normalize(test_dataset_text)

print(f"\nDataset Ready. Shape: {normalized_dataset_embeddings.shape}")
print(f"Test Query Set Shape: {normalized_test_embeddings.shape}")
print(f"First dataset entry (Index Training Data): {dataset[0]}")


# --- 4. Building the ScaNN Index (Optimized for 2000 vectors) ---

print("\n--- 4. Building ScaNN Optimized Searcher (Trained on 2000 examples) ---")

# The maximum number of neighbors to retrieve (top-k)
K_NEIGHBORS = 5
REORDER_NEIGHBORS = 100 # Reduced reorder candidates for speedier demo

# 4.1. Initialize the ScaNN builder
# Arguments: (dataset, k, distance_metric)
builder = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    K_NEIGHBORS,
    "dot_product"
)

# 4.2. Configure the Tree (Partitioning) stage
tree_configured = builder.tree(
    num_leaves=500,
    num_leaves_to_search=50,
    training_sample_size=4000
)

# 4.3. Configure Asymmetric Hashing (AH) for scoring
ah_configured = tree_configured.score_ah(
    2, # Number of dimensions per subvector
    anisotropic_quantization_threshold=0.2
)

# 4.4. Configure the Reordering (Refinement) stage
reorder_configured = ah_configured.reorder(REORDER_NEIGHBORS)

# 4.5. Finalize and build the searcher
searcher = reorder_configured.build()

print("ScaNN optimized index built successfully.")




In [None]:
# --- Computing Recall (ScaNN vs. Brute Force) ---

print("\n--- 5. Computing Recall (ScaNN vs. Brute Force) ---")

# Create a Brute-Force ScaNN searcher (no tree, no quantization)
# This will find the mathematically exact nearest neighbors.
bruteforce_searcher = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    K_NEIGHBORS,
    "dot_product"
).score_brute_force().build()

# 5.2. Define Test Queries (using a subset of the official test split as queries)
# Limit the number of test queries for faster recall computation
MAX_TEST_QUERIES = 500
NUM_RECALL_QUERIES = min(MAX_TEST_QUERIES, len(normalized_test_embeddings))

# Use the dedicated test set embeddings for recall calculation
recall_test_queries = normalized_test_embeddings[:NUM_RECALL_QUERIES]

print(f"1. Running Brute-Force search on {NUM_RECALL_QUERIES} test queries...")
# .search_batched() is much faster for multiple queries
true_neighbors, _ = bruteforce_searcher.search_batched(recall_test_queries)

print("2. Running Optimized ScaNN search...")
scann_neighbors, _ = searcher.search_batched(recall_test_queries)

# 5.3. Calculate and Print Recall
recall_value = compute_recall(scann_neighbors, true_neighbors)
print(f"\nâœ… Recall @{K_NEIGHBORS} for {NUM_RECALL_QUERIES} queries from the TEST split: {recall_value * 100:.2f}%")
print("This value indicates the percentage of exact nearest neighbors found by the approximate searcher.")

# Run Query 1: Find sentences about computers
run_query(query_text_1, searcher, dataset)

# Run Query 2: Find sentences about weather/time
run_query(query_text_2, searcher, dataset)

# Run Query 3: Find relevant news articles
run_query(query_text_3, searcher, dataset)

print("-------------- Arbitrary Query ---------------")

while (True):
  query_text = input("Enter a query... ('quit' to exit) ")
  k = input("Enter how much neighbors is needed (0 to exit) ")

  if (query_text == "quit" or k == 0):
    break

  builder = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    k,
    "dot_product"
  )

  tree_configured = builder.tree(
    num_leaves=500,
    num_leaves_to_search=50,
    training_sample_size=4000
  )

  ah_configured = tree_configured.score_ah(
    2, # Number of dimensions per subvector
    anisotropic_quantization_threshold=0.2
  )

  reorder_configured = ah_configured.reorder(REORDER_NEIGHBORS)

  arbitrary_searcher = reorder_configured.build()

  run_query(query_text, searcher, dataset)

# C. Benchmarking Section

This second section attempts to run both the built-in brute force algorithm of ScaNN and the actual algorithm in a larger scale.

The first cell is for pre-generating a subset (or all) of the headlines in ag_news for quicker use(that it does not need to regenerate the vectors again). If you want to use an external file (or a file in ann-benchmark, which we have a link to), skip the next two cells.

In [None]:
# -------------- development script used for generating an embedded! -----------

print("This is a Colab script to load the dataset, embed, then save it into agnews_embeddings.h5.")

num_headlines = int(input("Enter the number of news headlines to convert, normalize, and be saved to the embeddings file. "))

if (num_headlines > 120000) or (num_headlines < 1):
  print('Invalid input. num_headlines is set back to 5000.')
  num_headlines = 5000

num_tests = int(input("Enter the number of text queries to be converted, normalized, and be saved. "))

if (num_tests > 7600) or (num_tests < 1):
  print('Invalid input. num_tests is set back to 1000.')
  num_tests = 1000

# 2.

MODEL_NAME = 'all-MiniLM-L6-v2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model = SentenceTransformer(MODEL_NAME, device=device)

print("Loading public dataset (ag_news) subset...")
try:
    ag_news_dataset_train = load_dataset('ag_news', split=f'train[:{num_headlines}]') # Loads dataset with num_training entries
    dataset_train = ag_news_dataset_train['text']

    ag_news_dataset_test = load_dataset('ag_news', split=f'test[:{num_tests}]') # Loads dataset with num_test entries
    dataset_test = ag_news_dataset_test['text']

except Exception as e:
    print(f"Error loading ag_news dataset. {e}. Since this is used for benchmarking, the hardcoded dataset is ignored and the program is cancelled.")
    exit()

normalized_dataset_train_embeddings, embedding_dim = generate_and_normalize(dataset_train)
normalized_dataset_test_embeddings, embedding_dim = generate_and_normalize(dataset_test)

print(f"\nDataset generated: {normalized_dataset_train_embeddings.shape}")
print(f"First dataset training entry (Index Training Data): {dataset_train[0]}")

GT_K_NEIGHBORS = 100

print("Generating ground truth neighbors via brute-force search...")

# Create a Brute-Force ScaNN searcher to find the mathematically exact nearest neighbors.
bruteforce_searcher_gt = scann.scann_ops_pybind.builder(
    normalized_dataset_train_embeddings,
    GT_K_NEIGHBORS,
    "dot_product"
).score_brute_force().build()

# Perform the brute-force search on the test embeddings to get true neighbors
ground_truth_neighbors, _ = bruteforce_searcher_gt.search_batched(normalized_dataset_test_embeddings)

print(f"Ground truth neighbors generated. Shape: {ground_truth_neighbors.shape}")

# This script has finished generating and normalizing. The next cell saves them.




In [None]:
import h5py

filename = 'agnews_embeddings.h5'
dataset_train = 'train'
dataset_test = 'test'
dataset_truth = 'neighbors'

# 2. Save the embeddings to the H5 file
try:
    with h5py.File(filename, 'w') as f:
        dtrset = f.create_dataset(dataset_train, data=normalized_dataset_train_embeddings)
        dteset = f.create_dataset(dataset_test, data=normalized_dataset_test_embeddings)
        dgtset = f.create_dataset(dataset_truth, data=ground_truth_neighbors) # Save ground truth

    print(f"Embeddings successfully saved to {filename} under the data train set '{dataset_train}', data test set '{dataset_test}', and ground truth neighbors '{dataset_truth}'.")

except Exception as e:
    print(f"An error occurred: {e}")

This used a pregen'd embedded dataset from a file. Run this snippet first to fetch the vectors into the variables.

The first cell downloads an external test file for benchmarking. You can replace the link by any other link to get other .hdf5 files.

If you wish to do this with any arbitrary .hdf5 (.h5) files, upload that file first, ignore the first cell, modify the 'filename' variable in the second cell. Default is GloVe-50, 50 dimensions, ~1.2m train size, 10k test size, available [here](https://github.com/erikbern/ann-benchmarks?tab=readme-ov-file). As the file has 'train', 'test' and 'neighbors' objects, we have hardcoded this into the cell. This remains the standard for the other datasets in the other files on the link mentioned.

It is recommended that you use http://ann-benchmarks.com/glove-200-angular.hdf5 this data set for just brute-force and ScaNN comparison. ScaNN's tree builder is infeasable to be repeated a lot with this dataset to properly benchmark k-values, and thus, is not recommended for k-value benchmarking.

Be warned: We are using ScaNN by dot products. Remember to use angular or dot product datasets of ground truths!

In [None]:
!wget http://ann-benchmarks.com/glove-50-angular.hdf5

Remember to edit the filename string value for the script to work!

In [None]:
# To load the embeddings back
import h5py

filename = 'glove-50-angular.hdf5'
dataset_train = 'train'
dataset_test = 'test'
dataset_truth = 'neighbors'

try:
    with h5py.File(filename, 'r') as f:
        # Access the dataset, then normalizes the embedding vectors. Be warned, this can only work for Angular/Dot products datasets!
        loaded_train_embeddings_raw = f[dataset_train][:]
        loaded_train_embeddings = loaded_train_embeddings_raw / np.linalg.norm(loaded_train_embeddings_raw, axis=1, keepdims=True)
        loaded_test_embeddings_raw = f[dataset_test][:]
        loaded_test_embeddings = loaded_test_embeddings_raw / np.linalg.norm(loaded_test_embeddings_raw, axis=1, keepdims=True)
        loaded_truth_embedding_raw = f[dataset_truth][:]
        loaded_truth_embedding = loaded_truth_embedding_raw / np.linalg.norm(loaded_truth_embedding_raw, axis=1, keepdims=True)

        print("\nEmbeddings loaded successfully.")
        print("Train shape:", loaded_train_embeddings.shape)
        print("Test shape:", loaded_test_embeddings.shape)
        print("Truth shape:", loaded_truth_embedding.shape)

 #       print("Metadata description:", f[dataset_name].attrs['description'])

except Exception as e:
    print(f"An error occurred during loading: {e}")

This next cell starts running the relevant code to benchmark runtimes of the ScaNN algorithm and the built-in brute-force algorithm. The K_NEIGHBORS and TEST_SIZES are hardcoded, with TEXT_SIZES being an array for multiple test and allowing to see trends.

In [None]:
import os
import sys
import time
import tempfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset
import scann
import psutil

K_NEIGHBORS = 10

TEST_SIZES = [100, 300, 1000, 3000, 10000]

def get_process_memory_mb():
    """Get current process memory usage in MB."""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / (1024 * 1024)


def sizeof(obj):
    """Accurate memory estimate (MB)."""
    if isinstance(obj, np.ndarray):
        return obj.nbytes / (1024 * 1024)
    try:
        return sys.getsizeof(obj) / (1024 * 1024)
    except:
        return -1



def benchmark_instance_speed(num_queries, dataset_embeddings, searcher, test_embeddings):

  print(f"\n=== Benchmarking with {num_queries} Queries ===")

  if num_queries > test_embeddings.shape[0]:
    raise ValueError(f"Not enough queries in the test embeddings! {num_queries} > {test_embeddings.shape[0]}")

  query_batch = test_embeddings[:num_queries]

  brute = scann.scann_ops_pybind.builder(
      dataset_embeddings,
      K_NEIGHBORS,
      "dot_product"
  ).score_brute_force().build()

  # ---- Brute force ----
  bf_start = time.perf_counter()
  true_neighbors, _ = brute.search_batched(query_batch)
  bf_end = time.perf_counter()
  bf_time = bf_end - bf_start

  # ---- ScaNN ----
  sc_start = time.perf_counter()
  scann_neighbors, _ = searcher.search_batched(query_batch)
  sc_end = time.perf_counter()
  sc_time = sc_end - sc_start

  # ---- Recall ----
  recall_value = compute_recall(scann_neighbors, true_neighbors)

  return {
    "num_queries": num_queries,
    "recall@{}".format(K_NEIGHBORS): recall_value,
    "brute_force_time_sec": bf_time,
    "scann_time_sec": sc_time,
    "speedup": bf_time / sc_time if sc_time > 0 else np.inf
  }

def sub_benchmark_speed():


  loaded_train_embeddings
  num_headlines = loaded_train_embeddings.shape[0]

  print(f"\nDataset size: {num_headlines}")

  print("Loaded dataset embeddings.")
  raw_embedding_size_mb = sizeof(loaded_train_embeddings)
  print(f"Raw embedding memory usage: {raw_embedding_size_mb:.2f} MB")

  print(f"\nDataset size: {num_headlines}")

  print("Loaded dataset embeddings.")
  raw_embedding_size_mb = sizeof(loaded_train_embeddings)
  print(f"Raw embedding memory usage: {raw_embedding_size_mb:.2f} MB")

  # ===================================
  # Build ScaNN index
  # ===================================
  print("Building ScaNN index...")

  mem_before_index = get_process_memory_mb()

  num_leaves = max(int(np.sqrt(num_headlines)), 100)  # rcm val: sqrt(num_hl);
  num_leaves_to_search = max(int(num_leaves * 0.05), 10)  # requires tuning in order to get the desired recall value
  training_sample_size = min(int(num_headlines * 0.3), num_headlines - 1)  # ~30% of dataset
  reordered = K_NEIGHBORS * 10

  searcher = scann.scann_ops_pybind.builder(
      loaded_train_embeddings,
      K_NEIGHBORS,
      "dot_product"
  ).tree(
      num_leaves=num_leaves,
      num_leaves_to_search=num_leaves_to_search,
      training_sample_size=training_sample_size
  ).score_ah(
      dimensions_per_block=2,
      anisotropic_quantization_threshold=0.2
  ).reorder(K_NEIGHBORS * 30).build()

  mem_after_index = get_process_memory_mb()
  scann_index_memory_mb = mem_after_index - mem_before_index

  print("ScaNN index built successfully.")
  print(f"ScaNN index memory usage: {scann_index_memory_mb:.2f} MB")
  compression_ratio = raw_embedding_size_mb / scann_index_memory_mb if scann_index_memory_mb > 0 else 0
  print(f"Compression ratio: {compression_ratio:.3f}x")

  # ===================================
  # Run benchmarks
  # ===================================
  results = []

  for size in TEST_SIZES:
      stats = benchmark_instance_speed(size, loaded_train_embeddings, searcher, loaded_test_embeddings)
      results.append(stats)

  # Convert to table
  df = pd.DataFrame(results)

  # print("\n=== Index Statistics ===")
  # print(f"Raw embedding memory usage: {raw_embedding_size_mb:.2f} MB")
  # print(f"ScaNN index size: {scann_index_size_mb:.2f} MB")
  # print(f"Compression ratio: {compression_ratio:.3f}x")

  print("\n=== Final Benchmark Table (per query count only) ===")
  print(df)

  df.to_csv("benchmark_results.csv", index=False)
  print("Saved CSV -> benchmark_results.csv")

  # ===================================
  # Plotting
  # ===================================

  # ---- Speedup Plot ----
  plt.figure()
  plt.plot(df["num_queries"], df["brute_force_time_sec"], label="Brute-force", marker='o')
  plt.plot(df["num_queries"], df["scann_time_sec"], label="ScaNN", marker='s')
  plt.xlabel("num. queries")
  plt.ylabel("Time (seconds)")
  plt.title("Searching time comparison")
  plt.legend()
  plt.grid(True)
  plt.savefig("time_comparison_plot.png")
  print("Saved plot -> time_comparison_plot.png")

  return df

sub_benchmark_speed()


This is the part of the code to benchmark k-values vs ScaNN times. The num_queries are hardcoded and K_NEIGHBORS are a set. As said, if you have used GloVe-200, it is recommended you switch back to a lighter dataset to handle an insane amount of tree rebuilding.

In [None]:
import os
import sys
import time
import tempfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset
import scann

K_NEIGHBORS = [5, 10, 100, 500, 1000, 2000, 4000, 8000, 10000]
TEST_QUERIES = 100

def truth_recall(neighbors, true_neighbors):
    """
    Computes recall @k by comparing the results of the approximate search
    (neighbors) against the dataset's provided ground truth.
    """
    total = 0
    # Iterate through query results, comparing the approximate set against the true set
    for gt_row, row in zip(true_neighbors, neighbors):
        # Count the number of common elements (true positives)
        total += np.intersect1d(gt_row, row).shape[0]

    # Recall is (True Positives) / (Total True Neighbors)
    return total / true_neighbors.size


def benchmark_instance_speed_k(num_queries, k_neighbors, dataset_embeddings, test_embeddings, dataset_truth):
  print(f"Running benchmark with {num_queries} queries...")

  print(f"Building ScaNN index for {k_neighbors} neighbors...")

  num_train = dataset_embeddings.shape[0]

  num_leaves = max(int(np.sqrt(num_train) * 4), 100)  # rcm val: 4 * sqrt(num_hl);
  num_leaves_to_search = max(int(num_leaves * 0.05), 10)  # requires tuning in order to get the desired recall value
  training_sample_size = min(int(num_train * 0.3), num_train - 1)  # ~30% of dataset

  searcher = scann.scann_ops_pybind.builder(
      dataset_embeddings,
      k_neighbors,
      "dot_product"
  ).tree(
      num_leaves=num_leaves,
      num_leaves_to_search=num_leaves_to_search,
      training_sample_size=training_sample_size
  ).score_ah(
      dimensions_per_block=2,
      anisotropic_quantization_threshold=0.2
  ).reorder(k_neighbors * 10).build()

  if num_queries > test_embeddings.shape[0]:
    raise ValueError(f"Not enough queries in the test embeddings! {num_queries} > {test_embeddings.shape[0]}")

  query_batch = test_embeddings[:num_queries]

  sc_start = time.perf_counter()
  scann_neighbors, _ = searcher.search_batched(query_batch)
  sc_end = time.perf_counter()
  sc_time = sc_end - sc_start

  brute = scann.scann_ops_pybind.builder(
      dataset_embeddings,
      k_neighbors,
      "dot_product"
  ).score_brute_force().build()

  # ---- Brute force ----
  bf_start = time.perf_counter()
  true_neighbors, _ = brute.search_batched(query_batch)
  bf_end = time.perf_counter()
  bf_time = bf_end - bf_start

  # ---- Recall ----
  recall_value = truth_recall(scann_neighbors, true_neighbors)

  return {
    "recall@k": recall_value,
    "scann_time_sec": sc_time,
    "brute_force_time_sec": bf_time,
    "speedup": bf_time / sc_time if sc_time > 0 else np.inf
  }

def benchmark_k():

  dataset_train_local = loaded_train_embeddings
  dataset_test_local = loaded_test_embeddings
  dataset_truth_local = loaded_truth_embedding

  num_train = dataset_train_local.shape[0]

  print("Dataset size: {num_train}")
  print("Loaded dataset embeddings.")

  # ===================================
  # Run benchmarks
  # ===================================
  results = []

  for k in K_NEIGHBORS:
    stats = benchmark_instance_speed_k(TEST_QUERIES, k, dataset_train_local, dataset_test_local, dataset_truth_local)
    results.append(stats)

  df = pd.DataFrame(results)

  print("\n=== Benchmark Results ===")
  print(df)

  # Save to CSV
  df.to_csv("benchmark_k_results.csv", index=False)
  print("Saved CSV -> benchmark_k_results.csv")

  # ===================================
  # Plotting
  # ===================================

  # ---- Query Time vs K ----
  plt.figure(figsize=(10, 6))
  plt.plot(K_NEIGHBORS, df["scann_time_sec"], marker='o', linewidth=2)
  plt.plot(K_NEIGHBORS, df["brute_force_time_sec"], label="Brute-force", marker='o')
  plt.xlabel("Number of neighbors (K)")
  plt.ylabel("Query time (seconds)")
  plt.title(f"ScaNN query time vs K neighbors ({TEST_QUERIES} queries)")
  plt.grid(True)
  plt.savefig("query_time_vs_k.png")
  print("Saved plot -> query_time_vs_k.png")
  plt.show()

  return df

benchmark_k()
