<a href="https://colab.research.google.com/github/nackerboss/SCANNN/blob/main/REAL_MAIN_ScaNN_Embedding_Search_Index.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install scann sentence-transformers datasets


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [3]:
!pip uninstall tensorflow

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[0m

In [4]:
# This script demonstrates the full workflow using a public dataset:
# 1. Install necessary libraries (scann, sentence-transformers, datasets).
# 2. Load a public text dataset (Hugging Face ag_news) for both training (index) and testing (queries).
# 3. Generate and normalize vector embeddings for both sets.
# 4. Build a high-performance ScaNN index on the training set.
# 5. Build a Brute-Force searcher and compute recall using the test set queries.
# 6. Run a sample similarity query.

import torch
import numpy as np

# --- 1. Installation (Run this in a separate Colab cell first!) ---
# Note: You now need 'datasets' installed.
# !pip install scann sentence-transformers datasets

try:
    import scann
    from sentence_transformers import SentenceTransformer
    from datasets import load_dataset # New import for public dataset
except ImportError:
    print("----------------------------------------------------------------------")
    print("üö® ERROR: Please run the following command in a separate Colab cell ")
    print("and restart the runtime before running this code:")
    print("!pip install scann sentence-transformers datasets")
    print("----------------------------------------------------------------------")
    exit()

# --- Utility Function for Recall Calculation (Provided by user) ---

def compute_recall(neighbors, true_neighbors):
    """
    Computes recall @k by comparing the results of the approximate search
    (neighbors) against the exact search (true_neighbors).
    """
    total = 0
    # Iterate through query results, comparing the approximate set against the true set
    for gt_row, row in zip(true_neighbors, neighbors):
        # Count the number of common elements (true positives)
        total += np.intersect1d(gt_row, row).shape[0]

    # Recall is (True Positives) / (Total True Neighbors)
    return total / true_neighbors.size

# --- 2. Setup and Data Loading ---

MODEL_NAME = 'all-MiniLM-L6-v2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model = SentenceTransformer(MODEL_NAME, device=device)

# Load a public dataset (ag_news) and take a manageable subset for demonstration
print("Loading public dataset (ag_news) subset...")
try:
    # Load the training split (used for building the ScaNN index)
    ag_news_dataset_train = load_dataset('ag_news', split='train[:5000]')
    dataset = ag_news_dataset_train['text']

    # Load the test split (used for generating test queries for recall calculation)
    ag_news_dataset_test = load_dataset('ag_news', split='test[:20]')
    test_dataset_text = ag_news_dataset_test['text']

except Exception as e:
    print(f"Error loading ag_news dataset: {e}")
    # Fallback to the original small dataset if loading fails
    dataset = [
        "The sun rises in the east every morning.",
        "A computer uses a central processing unit for core tasks.",
        "Cats and dogs are common household pets.",
        "A feline companion enjoying a nap on the sofa.",
        "The central processing unit is the brain of any modern machine.",
        "Tomorrow's forecast predicts clear skies and warm weather."
    ]
    test_dataset_text = dataset # Use the same small data for queries if primary fails


# The queries we will use to search the dataset
query_text_1 = "The main component of a PC is the CPU."
query_text_2 = "What is the weather like at dawn?"
query_text_3 = "Football match results from the weekend."

def generate_and_normalize(data):
    """Generates embeddings and performs L2 normalization."""
    print(f"Generating embeddings for {len(data)} items...")

    # 3.1 Generate embeddings (returns a numpy array)
    embeddings = embedding_model.encode(
        data,
        convert_to_tensor=False,
        show_progress_bar=True
    )

    # 3.2 L2 Normalization (Crucial for ScaNN dot product or angular similarity)
    print("Normalizing embeddings...")
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

    return normalized_embeddings, embeddings.shape[1]

normalized_dataset_embeddings, embedding_dim = generate_and_normalize(dataset)

normalized_test_embeddings, _ = generate_and_normalize(test_dataset_text)

print(f"\nDataset Ready. Shape: {normalized_dataset_embeddings.shape}")
print(f"Test Query Set Shape: {normalized_test_embeddings.shape}")
print(f"First dataset entry (Index Training Data): {dataset[0]}")


# --- 4. Building the ScaNN Index (Optimized for 5000 vectors) ---

print("\n--- 4. Building ScaNN Optimized Searcher (Trained on 5000 examples) ---")

# The maximum number of neighbors to retrieve (top-k)
K_NEIGHBORS = 5
REORDER_NEIGHBORS = 50 # Reduced reorder candidates for speedier demo

# 4.1. Initialize the ScaNN builder
# Arguments: (dataset, k, distance_metric)
builder = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    K_NEIGHBORS,
    "dot_product"
)

# 4.2. Configure the Tree (Partitioning) stage
tree_configured = builder.tree(
    num_leaves=500,
    num_leaves_to_search=50,
    training_sample_size=4000
)

# 4.3. Configure Asymmetric Hashing (AH) for scoring
ah_configured = tree_configured.score_ah(
    8, # Number of dimensions per subvector
    anisotropic_quantization_threshold=0.2
)

# 4.4. Configure the Reordering (Refinement) stage
reorder_configured = ah_configured.reorder(REORDER_NEIGHBORS)

# 4.5. Finalize and build the searcher
searcher = reorder_configured.build()

print("ScaNN optimized index built successfully.")




Loading public dataset (ag_news) subset...
Generating embeddings for 5000 items...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 157/157 [00:33<00:00,  4.65it/s]


Normalizing embeddings...
Generating embeddings for 20 items...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00,  2.57it/s]
I0000 00:00:1762237923.870735   13532 partitioner_factory_base.cc:58] Size of sampled dataset for training partition: 4000
I0000 00:00:1762237923.951447   13532 kmeans_tree_partitioner_utils.h:90] PartitionerFactory ran in 80.666928ms.


Normalizing embeddings...

Dataset Ready. Shape: (5000, 384)
Test Query Set Shape: (20, 384)
First dataset entry (Index Training Data): Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.

--- 4. Building ScaNN Optimized Searcher (Trained on 5000 examples) ---
ScaNN optimized index built successfully.


In [5]:
# --- 5. Computing Recall (ScaNN vs. Brute Force) ---

print("\n--- 5. Computing Recall (ScaNN vs. Brute Force) ---")

# 5.1. Create a Brute-Force ScaNN searcher (no tree, no quantization)
# This will find the mathematically exact nearest neighbors.
bruteforce_searcher = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    K_NEIGHBORS,
    "dot_product"
).score_brute_force().build()

# 5.2. Define Test Queries (using a subset of the official test split as queries)
# Limit the number of test queries for faster recall computation
MAX_TEST_QUERIES = 500
NUM_RECALL_QUERIES = min(MAX_TEST_QUERIES, len(normalized_test_embeddings))

# Use the dedicated test set embeddings for recall calculation
recall_test_queries = normalized_test_embeddings[:NUM_RECALL_QUERIES]

print(f"1. Running Brute-Force search on {NUM_RECALL_QUERIES} test queries...")
# .search_batched() is much faster for multiple queries
true_neighbors, _ = bruteforce_searcher.search_batched(recall_test_queries)

print("2. Running Optimized ScaNN search...")
scann_neighbors, _ = searcher.search_batched(recall_test_queries)

# 5.3. Calculate and Print Recall
recall_value = compute_recall(scann_neighbors, true_neighbors)
print(f"\n‚úÖ Recall @{K_NEIGHBORS} for {NUM_RECALL_QUERIES} queries from the TEST split: {recall_value * 100:.2f}%")
print("This value indicates the percentage of exact nearest neighbors found by the approximate searcher.")


# --- 6. Running a Sample Query ---

def run_query(query, search_index, original_dataset):
    """Embeds a query, normalizes it, and searches the ScaNN index."""
    print(f"\nSearching with query: '{query}'")

    # 6.1 Embed and Normalize the query
    query_embedding = embedding_model.encode([query])[0]
    normalized_query = query_embedding / np.linalg.norm(query_embedding)

    # 6.2 Perform the search
    # The 'k' parameter is configured during the builder step, so we omit it here.
    indices, distances = search_index.search(normalized_query)

    print(f"\nTop {len(indices)} results found:")
    for rank, (idx, distance) in enumerate(zip(indices, distances)):
        print(f"  Rank {rank+1}:")
        print(idx)
        print(f"    Text: {original_dataset[idx.item() ]}")
        # Dot product distance is 1.0 for perfect match, 0.0 for orthogonal
        print(f"    Similarity (Dot Product): {distance:.4f}")
        print(f"    Dataset Index: {idx}")

# Run Query 1: Find sentences about computers
run_query(query_text_1, searcher, dataset)

# Run Query 2: Find sentences about weather/time
run_query(query_text_2, searcher, dataset)

# Run Query 3: Find relevant news articles
run_query(query_text_3, searcher, dataset)


--- 5. Computing Recall (ScaNN vs. Brute Force) ---
1. Running Brute-Force search on 20 test queries...
2. Running Optimized ScaNN search...

‚úÖ Recall @5 for 20 queries from the TEST split: 93.00%
This value indicates the percentage of exact nearest neighbors found by the approximate searcher.

Searching with query: 'The main component of a PC is the CPU.'

Top 5 results found:
  Rank 1:
3926
    Text: Intel Chips In for New Gateway PCs Desktops will be available at several retailers, including CompUSA.
    Similarity (Dot Product): 0.4512
    Dataset Index: 3926
  Rank 2:
4704
    Text: AMD #39;s new budget processors AMD #39;s new Sempron range of desktop and notebook CPUs is targeted squarely at Intel #39;s competing Celeron family. 
    Similarity (Dot Product): 0.4472
    Dataset Index: 4704
  Rank 3:
2698
    Text: New PC Is Created Just for Teenagers This isn't your typical, humdrum, slate-colored computer. Not only is the PC known as the hip-e almost all white, but its scree

# Benchmarking Section

This second section attempts to run both the built-in brute force algorithm of ScaNN and the actual algorithm in a larger scale.

In [None]:
# -------------- development script used for generating an embedded! -----------
import torch
import numpy as np

# --- 1. Installation (Run this in a separate Colab cell first!) ---
# Note: You now need 'datasets' installed.
# !pip install scann sentence-transformers datasets

try:
    import scann
    from sentence_transformers import SentenceTransformer
    from datasets import load_dataset # New import for public dataset
except ImportError:
    print("----------------------------------------------------------------------")
    print("üö® ERROR: Please run the following command in a separate Colab cell ")
    print("and restart the runtime before running this code:")
    print("!pip install scann sentence-transformers datasets")
    print("----------------------------------------------------------------------")
    exit()

print("This is a Colab script to load the dataset, embed, then save it into agnews_embeddings.h5.")

num_headlines = int(input("Enter the number of news headlines to convert, normalize, and be saved to the embeddings file."))

if (num_headlines > 120000) or (num_headlines < 1):
  print('Invalid input. num_headlines is set back to 5000.')
  num_headlines = 5000

# 2.

MODEL_NAME = 'all-MiniLM-L6-v2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model = SentenceTransformer(MODEL_NAME, device=device)

print("Loading public dataset (ag_news) subset...")
try:
    ag_news_dataset_train = load_dataset('ag_news', split=f'train[:{num_headlines}]') # Loads dataset with num_training entries
    dataset = ag_news_dataset_train['text']

except Exception as e:
    print(f"Error loading ag_news dataset. {e}. Since this is only used for benchmarking, the hardcoded dataset is ignored and the program is cancelled.")

#------- Inner function, declared right before use -------
def generate_and_normalize(data):
    """Generates embeddings and performs L2 normalization."""
    print(f"Generating embeddings for {len(data)} items...")

    embeddings = embedding_model.encode(
        data,
        convert_to_tensor=False,
        show_progress_bar=True
    )

    print("Normalizing embeddings...")
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

    return normalized_embeddings, embeddings.shape[1]
#------- Inner function, declared right before use -------

normalized_dataset_embeddings, embedding_dim = generate_and_normalize(dataset)

print(f"\nDataset generated: {normalized_dataset_embeddings.shape}")
print(f"First dataset entry (Index Training Data): {dataset[0]}")

# This script has finished generating and normalizing. The next cell saves them.



In [None]:
import h5py

filename = 'agnews_embeddings.h5'
dataset_name = 'agnews'

# 2. Save the embeddings to the H5 file
try:
    with h5py.File(filename, 'w') as f:
        # Create a dataset to hold the embedding array
        dset = f.create_dataset(dataset_name, data=normalized_dataset_embeddings)

        # Optionally, you can add metadata as attributes
        dset.attrs['description'] = 'Embeddings for my project'
        dset.attrs['dimension'] = normalized_dataset_embeddings.shape[1]

    print(f"Embeddings successfully saved to {filename} under the dataset '{dataset_name}'.")

except Exception as e:
    print(f"An error occurred: {e}")


# Benchmark section
This used a pregen'd embedded dataset from a file (likely generated from the earlier snippets). Run from this point onwards to see time results.

In [7]:
!pip install h5py

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting h5py
  Downloading h5py-3.15.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (3.0 kB)
Downloading h5py-3.15.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (5.1 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m5.1/5.1 MB[0m [31m1.6 MB/s[0m  [33m0:00:03[0m eta [36m0:00:01[0m0m
[?25hInstalling collected packages: h5py
Successfully installed h5py-3.15.1


In [8]:
# To load the embeddings back
import h5py

filename = 'agnews_embeddings.h5'
dataset_name = 'agnews'

try:
    with h5py.File(filename, 'r') as f:
        # Access the dataset
        loaded_embeddings = f[dataset_name][:]

        print("\nEmbeddings loaded successfully.")
        print("Shape:", loaded_embeddings.shape)
        print("Metadata description:", f[dataset_name].attrs['description'])

except Exception as e:
    print(f"An error occurred during loading: {e}")


Embeddings loaded successfully.
Shape: (120000, 384)
Metadata description: Embeddings for my project


In [9]:
import time

normalized_dataset_embeddings = loaded_embeddings
num_headlines = loaded_embeddings.shape[0]

K_NEIGHBORS = input("\nEnter k (the number of nearest neighbors to find): ")
try:
    K_NEIGHBORS = int(K_NEIGHBORS)
    if K_NEIGHBORS < 1 or K_NEIGHBORS > 100:
        print("Invalid k. Setting to default: 5")
        K_NEIGHBORS = 5
except ValueError:
    print("Invalid k. Setting to default: 5")
    K_NEIGHBORS = 5

REORDER_NEIGHBORS = input("Enter the number of reorder candidates (recommended: 10*k): ")
try:
    REORDER_NEIGHBORS = int(REORDER_NEIGHBORS)
    if REORDER_NEIGHBORS < K_NEIGHBORS:
        print(f"Reorder candidates must be >= k. Setting to {K_NEIGHBORS * 10}")
        REORDER_NEIGHBORS = K_NEIGHBORS * 10
except ValueError:
    print(f"Invalid input. Setting to {K_NEIGHBORS * 10}")
    REORDER_NEIGHBORS = K_NEIGHBORS * 10

# ------ Section: Normalized TEST Embeddings ----

num_tests = int(input("Enter the number of test queries to generate: "))
if (num_tests > 1000) or (num_tests < 1):
  print('Invalid input. num_tests is set back to 100.')
  num_tests = 100

print("Loading public dataset (ag_news) subset...")

try:
    # Load the test split (used for generating test queries for recall calculation)
    ag_news_dataset_test = load_dataset('ag_news', split=f'test[:{num_tests}]')
    test_dataset_text = ag_news_dataset_test['text']

except Exception as e:
    print(f"Error loading ag_news dataset: {e}")

#------- Inner function, declared right before use -------
def generate_and_normalize(data):
    """Generates embeddings and performs L2 normalization."""
    print(f"Generating embeddings for {len(data)} items...")

    embeddings = embedding_model.encode(
        data,
        convert_to_tensor=False,
        show_progress_bar=True
    )

    print("Normalizing embeddings...")
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

    return normalized_embeddings, embeddings.shape[1]
#------- Inner function, declared right before use -------

normalized_test_embeddings, _ = generate_and_normalize(test_dataset_text)

# --- 5. Dynamic ScaNN Parameters Based on Dataset Size ---
print("\n--- Building ScaNN Index with Dynamic Parameters ---")

# Calculate optimal parameters based on dataset size
num_leaves = max(int(np.sqrt(num_headlines)), 100)  # rcm val: sqrt(num_hl);
num_leaves_to_search = max(int(num_leaves * 0.1), 10)  # can't know for sure, requires tuning, will add prompt to enter this number later
training_sample_size = min(int(num_headlines * 0.69), num_headlines - 1)  # 80% of dataset, prevents overfitting, fits for smaller dataset

print(f"Dataset size: {num_headlines}")
print(f"Number of leaves (clusters): {num_leaves}")
print(f"Leaves to search: {num_leaves_to_search}")
print(f"Training sample size: {training_sample_size}")
print(f"K neighbors: {K_NEIGHBORS}")
print(f"Reorder candidates: {REORDER_NEIGHBORS}")

# --- 6. Build ScaNN Index ---
builder = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    K_NEIGHBORS,
    "dot_product"
)

tree_configured = builder.tree(
    num_leaves=num_leaves,
    num_leaves_to_search=num_leaves_to_search,
    training_sample_size=training_sample_size
)

ah_configured = tree_configured.score_ah(
    8,  # Number of dimensions per subvector
    anisotropic_quantization_threshold=0.2
)

reorder_configured = ah_configured.reorder(REORDER_NEIGHBORS)
searcher = reorder_configured.build()

print("ScaNN optimized index built successfully.")

# -----------

print("\n--- Computing Recall (ScaNN vs. Brute Force) ---")

def compute_recall(neighbors, true_neighbors):
    """Computes recall @k."""
    total = 0
    for gt_row, row in zip(true_neighbors, neighbors):
        total += np.intersect1d(gt_row, row).shape[0]
    return total / true_neighbors.size

# Build brute-force searcher
bruteforce_searcher = scann.scann_ops_pybind.builder(
    normalized_dataset_embeddings,
    K_NEIGHBORS,
    "dot_product"
).score_brute_force().build()

test_query_input = input("\nEnter number of test queries for recall evaluation (or 'all' for complete test): ")

if test_query_input.lower() == 'all':
    NUM_RECALL_QUERIES = len(normalized_test_embeddings)
    print(f"Testing on ALL {NUM_RECALL_QUERIES} test queries (most accurate, takes longer)")
else:
    try:
        requested_queries = int(test_query_input)
        NUM_RECALL_QUERIES = min(requested_queries, len(normalized_test_embeddings))
        print(f"Testing on {NUM_RECALL_QUERIES} test queries")
    except ValueError:
        NUM_RECALL_QUERIES = min(1000, len(normalized_test_embeddings))
        print(f"Invalid input. Using default: {NUM_RECALL_QUERIES} test queries")

recall_test_queries = normalized_test_embeddings[:NUM_RECALL_QUERIES]

brute_force_time_start = time.perf_counter()

true_neighbors, _ = bruteforce_searcher.search_batched(recall_test_queries) # Brute-force searches

brute_force_time_end = time.perf_counter()

scann_time_start = time.perf_counter()

print("Running Optimized ScaNN search...")
scann_neighbors, _ = searcher.search_batched(recall_test_queries)

scann_time_end = time.perf_counter()

recall_value = compute_recall(scann_neighbors, true_neighbors)
print(f"\nRecall @{K_NEIGHBORS}: {recall_value * 100:.2f}%")
print("(Percentage of exact nearest neighbors found by ScaNN)")
print(f"Done. Brute-force time: {brute_force_time_end - brute_force_time_start}")
print(f"ScaNN time: {scann_time_end - scann_time_start}")


Enter k (the number of nearest neighbors to find):  10
Enter the number of reorder candidates (recommended: 10*k):  500
Enter the number of test queries to generate:  20


Loading public dataset (ag_news) subset...
Generating embeddings for 20 items...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00,  2.56it/s]


Normalizing embeddings...

--- Building ScaNN Index with Dynamic Parameters ---
Dataset size: 120000
Number of leaves (clusters): 346
Leaves to search: 34
Training sample size: 82800
K neighbors: 10
Reorder candidates: 500


I0000 00:00:1762238254.415163   13532 partitioner_factory_base.cc:58] Size of sampled dataset for training partition: 82800
I0000 00:00:1762238255.120589   13532 kmeans_tree_partitioner_utils.h:90] PartitionerFactory ran in 705.352746ms.


ScaNN optimized index built successfully.

--- Computing Recall (ScaNN vs. Brute Force) ---



Enter number of test queries for recall evaluation (or 'all' for complete test):  all


Testing on ALL 20 test queries (most accurate, takes longer)
Running Optimized ScaNN search...

Recall @10: 92.50%
(Percentage of exact nearest neighbors found by ScaNN)
Done. Brute-force time: 0.037970810997649096
ScaNN time: 0.0024119370063999668


In [13]:
print("\n----------- After benchmarking ---------")

manual_query = input("Search an article...")

# Pass the manual_query as a list to the generate_and_normalize function
# normalized_query = generate_and_normalize([manual_query]) # This function is not meant to return the normalized query directly, it's better to use run_query

def run_query(query, search_index, original_dataset):
    """Embeds a query, normalizes it, and searches the ScaNN index."""
    print(f"\nSearching with query: '{query}'")

    # 6.1 Embed and Normalize the query
    query_embedding = embedding_model.encode([query])[0]
    normalized_query = query_embedding / np.linalg.norm(query_embedding)

    # 6.2 Perform the search
    # The 'k' parameter is configured during the builder step, so we omit it here.
    indices, distances = search_index.search(normalized_query)
    print(f"\nTop {len(indices)} results found:")
    for rank, (idx, distance) in enumerate(zip(indices, distances)):
        print(f"  Rank {rank+1}:")
        print(idx)
        print(f"    Text: {original_dataset[idx.item() ]}")
        # Dot product distance is 1.0 for perfect match, 0.0 for orthogonal
        print(f"    Similarity (Dot Product): {distance:.4f}")
        print(f"    Dataset Index: {idx}")

from datasets import load_dataset
# Load the full training dataset to match the loaded_embeddings size
ag_news_dataset_full = load_dataset('ag_news', split='train')
full_dataset_text = ag_news_dataset_full['text']

run_query(manual_query, searcher, full_dataset_text)


----------- After benchmarking ---------


Search an article... baseball



Searching with query: 'baseball'

Top 10 results found:
  Rank 1:
69481
    Text: Baseball and its fans recover from 1994 strike Ten years after the World Series was canceled and fans left in droves, Major League Baseball will tell you it has never been healthier.
    Similarity (Dot Product): 0.6311
    Dataset Index: 69481
  Rank 2:
11745
    Text: Baseball Today * Abraham Nunez, Royals, hit his second grand slam in two weeks to lead Kansas City over Seattle 7-3. * Aaron Harang, Reds, limited St.
    Similarity (Dot Product): 0.5888
    Dataset Index: 11745
  Rank 3:
65440
    Text: Red Sox-Yanks an instant classic Heroics and heartbreaks, diving catches and basepath blunders, hot bats and bizarre slumps, rainy days and endless nights. Already an instant classic, a short story transformed into a great American League 
    Similarity (Dot Product): 0.5738
    Dataset Index: 65440
  Rank 4:
28122
    Text: BBO Baseball Today (AP) AP - Texas at Oakland (10:05 p.m. EDT). Mark Mulder (17