# Neurosearch Training on Google Colab A100

This notebook trains the Neurosearch hybrid retrieval system on the Amazon ESCI dataset.

**GPU**: A100 80GB recommended

**Tasks**:
1. Setup & Data Download
2. Build Semantic IDs (Hierarchical K-Means)
3. Fine-tune Dense Retriever (Sentence-Transformers)
4. Train Generative Retriever (T5)
5. Build Indexes & Evaluate

In [1]:
# Check GPU
!nvidia-smi

Mon Dec  1 21:28:25 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   44C    P8             12W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## 1. Setup

In [18]:
# Install dependencies
!pip install -q sentence-transformers faiss-cpu transformers datasets torch pandas pyarrow scikit-learn tqdm rank_bm25 matplotlib seaborn

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.6/23.6 MB[0m [31m69.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [9]:
# Upload neurosearch_training.zip and extract
from google.colab import files
import zipfile
import os

print("Please upload neurosearch_training.zip")
uploaded = files.upload()

# Get the uploaded filename (assuming only one file is uploaded)
if uploaded:
    uploaded_filename = list(uploaded.keys())[0]
    # Extract
    with zipfile.ZipFile(uploaded_filename, 'r') as zip_ref:
        zip_ref.extractall('.')
else:
    print("No file was uploaded.")

# Add to path
import sys
sys.path.insert(0, '/content/src')

Please upload neurosearch_training.zip


Saving neurosearch_training_final.zip to neurosearch_training_final (2).zip


## 2. Download ESCI Dataset

In [5]:
# Clone ESCI data from GitHub
!git clone --depth 1 https://github.com/amazon-science/esci-data.git /content/esci-data
!ls -lh /content/esci-data/shopping_queries_dataset/

Cloning into '/content/esci-data'...
remote: Enumerating objects: 33, done.[K
remote: Counting objects: 100% (33/33), done.[K
remote: Compressing objects: 100% (31/31), done.[K
remote: Total 33 (delta 7), reused 12 (delta 1), pack-reused 0 (from 0)[K
Receiving objects: 100% (33/33), 362.95 KiB | 24.20 MiB/s, done.
Resolving deltas: 100% (7/7), done.
Filtering content: 100% (2/2), 1.08 GiB | 16.71 MiB/s, done.
total 1.1G
-rw-r--r-- 1 root root  49M Dec  1 21:29 shopping_queries_dataset_examples.parquet
-rw-r--r-- 1 root root 1.1G Dec  1 21:30 shopping_queries_dataset_products.parquet
-rw-r--r-- 1 root root 1.7M Dec  1 21:29 shopping_queries_dataset_sources.csv


In [6]:
import pandas as pd
import numpy as np

# Load data
examples_path = "/content/esci-data/shopping_queries_dataset/shopping_queries_dataset_examples.parquet"
products_path = "/content/esci-data/shopping_queries_dataset/shopping_queries_dataset_products.parquet"

df_examples = pd.read_parquet(examples_path)
df_products = pd.read_parquet(products_path)

print(f"Examples: {df_examples.shape}")
print(f"Products: {df_products.shape}")
print(f"\nLabel distribution:\n{df_examples['esci_label'].value_counts()}")

Examples: (2621288, 9)
Products: (1814924, 7)

Label distribution:
esci_label
E    1708158
S     574313
I     263165
C      75652
Name: count, dtype: int64


In [7]:
# Filter for English locale and prepare training data
df_examples_en = df_examples[df_examples['product_locale'] == 'us'].copy()
df_products_en = df_products[df_products['product_locale'] == 'us'].copy()

# Map labels to scores
label_map = {"E": 3, "S": 2, "C": 1, "I": 0}
df_examples_en['esci_score'] = df_examples_en['esci_label'].map(label_map)

# Merge with product metadata
df_train = df_examples_en.merge(df_products_en[['product_id', 'product_title', 'product_description']],
                                  on='product_id', how='left')

# Fill missing titles
df_train['product_title'] = df_train['product_title'].fillna('')

print(f"Training data: {df_train.shape}")
df_train.head()

Training data: (1818825, 12)


Unnamed: 0,example_id,query,query_id,product_id,product_locale,esci_label,small_version,large_version,split,esci_score,product_title,product_description
0,0,revent 80 cfm,0,B000MOO21W,us,I,0,1,train,0,Panasonic FV-20VQ3 WhisperCeiling 190 CFM Ceil...,
1,1,revent 80 cfm,0,B07X3Y6B1V,us,E,0,1,train,3,Homewerks 7141-80 Bathroom Fan Integrated LED ...,
2,2,revent 80 cfm,0,B07WDM7MQQ,us,E,0,1,train,3,Homewerks 7140-80 Bathroom Fan Ceiling Mount E...,
3,3,revent 80 cfm,0,B07RH6Z8KW,us,E,0,1,train,3,Delta Electronics RAD80L BreezRadiance 80 CFM ...,This pre-owned or refurbished product has been...
4,4,revent 80 cfm,0,B07QJ7WYFQ,us,E,0,1,train,3,Panasonic FV-08VRE2 Ventilation Fan with Reces...,


## 3. Build Semantic IDs

In [10]:
from sentence_transformers import SentenceTransformer
from neurosearch.data.semantic_id_builder import SemanticIDBuilder
import torch

# Sample products for semantic ID building (use subset for speed)
unique_products = df_products_en[['product_id', 'product_title']].drop_duplicates().sample(n=50000, random_state=42)

# Encode product titles
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
model = model.to('cuda')

print("Encoding product titles...")
product_embeddings = model.encode(unique_products['product_title'].tolist(),
                                   convert_to_numpy=True,
                                   show_progress_bar=True,
                                   batch_size=256)

print(f"Embeddings shape: {product_embeddings.shape}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Encoding product titles...


Batches:   0%|          | 0/196 [00:00<?, ?it/s]

Embeddings shape: (50000, 384)


In [11]:
# Build Semantic IDs
print("Building Semantic IDs (3 levels, K=10)...")
id_builder = SemanticIDBuilder(n_levels=3, n_clusters=10, random_state=42)
semantic_ids, id_strings = id_builder.fit_transform(product_embeddings)

# Add to dataframe
unique_products['semantic_id'] = id_strings

print(f"Sample Semantic IDs:\n{unique_products[['product_id', 'product_title', 'semantic_id']].head(10)}")

Building Semantic IDs (3 levels, K=10)...
Sample Semantic IDs:
         product_id                                      product_title  \
869100   B0797JRJLG  Samsung Galaxy S9 Rugged Military Grade Protec...   
880359   B009M36Z4G                   NCAA Duke Blue Devils Badge Reel   
1241197  B00LXTORC4  Cygolite Metro– 550 Lumen Bike Light– 4 Night ...   
366008   0894557947  Can You Find Me?: Building Thinking Skills in ...   
1753321  B014WC62X0  Augusta Sportswear Moisture-Wicking Long-Sleev...   
236755   B074D6M7D8  National Tree Company Artificial Giant Christm...   
485194   B085RB2T46  Magic: The Gathering - Angelic Destiny - Myste...   
811687   B086TY57PK  IRONCK Bookshelf, Double Wide 6-Tier Open Book...   
1313832  B0007P5G8Y  Scotch-Mount Indoor Double-Sided Mounting Tape...   
489551   B0882Z4RMP  Chefmaster - Liqua-Gel Food Coloring - 12 Colo...   

        semantic_id  
869100        7 9 9  
880359        9 1 1  
1241197       6 6 6  
366008        2 5 5  
1753321     

## 4. Fine-tune Dense Retriever

In [12]:
from sentence_transformers import InputExample, losses
from torch.utils.data import DataLoader

# Prepare training pairs (query, product_title, score)
train_df = df_train[df_train['split'] == 'train'].sample(n=min(100000, len(df_train)), random_state=42)

train_examples = []
for _, row in train_df.iterrows():
    train_examples.append(InputExample(texts=[row['query'], row['product_title']], label=float(row['esci_score'])/3.0))

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)

print(f"Training examples: {len(train_examples)}")

Training examples: 100000


In [13]:
# Fine-tune
train_loss = losses.CosineSimilarityLoss(model)

model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=1,
          warmup_steps=100,
          output_path='/content/dense_retriever_finetuned',
          show_progress_bar=True)

print("Dense retriever fine-tuned!")

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mfarhadh202[0m ([33mfarhadh202-university-of-texas-at-arlington[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
500,0.0893
1000,0.0725
1500,0.073
2000,0.0739
2500,0.0735
3000,0.0727


Dense retriever fine-tuned!


## 5. Train Generative Retriever (T5)

In [14]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import Dataset

# Merge semantic IDs with training data
train_with_ids = train_df.merge(unique_products[['product_id', 'semantic_id']], on='product_id', how='left')
train_with_ids = train_with_ids.dropna(subset=['semantic_id'])

# Create dataset for T5 (query -> semantic_id)
t5_data = []
for _, row in train_with_ids.iterrows():
    t5_data.append({
        'input_text': f"query: {row['query']}",
        'target_text': row['semantic_id']
    })

t5_dataset = Dataset.from_list(t5_data[:50000])  # Limit for training time
print(f"T5 training dataset: {len(t5_dataset)} examples")

T5 training dataset: 4039 examples


In [15]:
# Initialize T5
tokenizer = T5Tokenizer.from_pretrained('t5-small')
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small')

def preprocess_function(examples):
    inputs = tokenizer(examples['input_text'], max_length=128, truncation=True, padding='max_length')
    targets = tokenizer(examples['target_text'], max_length=32, truncation=True, padding='max_length')
    inputs['labels'] = targets['input_ids']
    return inputs

tokenized_dataset = t5_dataset.map(preprocess_function, batched=True, remove_columns=['input_text', 'target_text'])
print("Dataset tokenized.")

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Map:   0%|          | 0/4039 [00:00<?, ? examples/s]

Dataset tokenized.


In [16]:
# Training arguments
training_args = TrainingArguments(
    output_dir='/content/t5_generative_retriever',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='/content/logs',
    logging_steps=100,
    save_steps=1000,
    fp16=True,
)

trainer = Trainer(
    model=t5_model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

# Train
print("Training T5 Generative Retriever...")
trainer.train()
print("T5 training complete!")

Training T5 Generative Retriever...


Step,Training Loss
100,13.6986
200,3.7345
300,0.6297
400,0.204
500,0.1381
600,0.1068
700,0.0961


T5 training complete!


## 6. Build FAISS Index & Evaluate

In [19]:
import faiss

# Build FAISS index with full product set
print("Building FAISS index...")
d = product_embeddings.shape[1]
index = faiss.IndexFlatIP(d)
index.add(product_embeddings.astype('float32'))

# Save
faiss.write_index(index, '/content/dense_index.faiss')
print(f"FAISS index built with {index.ntotal} vectors")

Building FAISS index...
FAISS index built with 50000 vectors


In [20]:
# Test Dense Retrieval
test_query = "wireless bluetooth headphones"
q_emb = model.encode([test_query], convert_to_numpy=True, normalize_embeddings=True)
distances, indices = index.search(q_emb.astype('float32'), k=5)

print(f"Query: {test_query}\n")
for i, (idx, score) in enumerate(zip(indices[0], distances[0])):
    product = unique_products.iloc[idx]
    print(f"{i+1}. [{score:.3f}] {product['product_title']}")

Query: wireless bluetooth headphones

1. [0.709] Wireless Headphones, Snoky Bluetooth 5.0 Headphones Hi-Fi Stereo Deep Bass 26H Playtime Foldable Over Ear Headphones with Microphone Wireless Headset for Cell Phone Online Class Home Office TV
2. [0.705] Fashion Wireless Earphone, BT 4.1 Stereo Earphone Headset Wireless Magnetic in-Ear Earbuds Headphone Sports Headset (Black)
3. [0.693] Lsmaa Wireless Bluetooth Headset, Stereo Sports Earbuds Hands-Free Magnetic in-Ear Noise Reduction with Microphone, You Can Enjoy Entertainment and Sports Music (Color : Black)
4. [0.677] Bluetooth Headphones,YAMAY M20 Wireless Headphones with Microphone Hands Free Noise Cancelling Headset for iPhone Samsung Android Cell Phones (Lightweight Foldable On Ear Design Multi-Point Connect)
5. [0.674] Bluetooth 5.0 Wireless Earbuds, Wireless Bluetooth Headphones with Deep Bass HiFi Stereo Sound, Built-in Mic Earphones with Portable Charging Case for iOS and Android


In [22]:
# Test Generative Retrieval
input_ids = tokenizer.encode(f"query: {test_query}", return_tensors="pt").to('cuda')
outputs = t5_model.generate(input_ids, max_length=32, num_return_sequences=5, num_beams=5)

print(f"\nGenerated Semantic IDs for '{test_query}':")
for i, output in enumerate(outputs):
    decoded = tokenizer.decode(output, skip_special_tokens=True)
    print(f"{i+1}. {decoded}")


Generated Semantic IDs for 'wireless bluetooth headphones':
1. 3 4 4
2. 4 2 2
3. 6 6 6
4. 2 5 5
5. 5 7 7


## 7. Download Trained Models

In [23]:
# Zip and download
!zip -r /content/neurosearch_trained_models.zip /content/dense_retriever_finetuned /content/t5_generative_retriever /content/dense_index.faiss

from google.colab import files
files.download('/content/neurosearch_trained_models.zip')

  adding: content/dense_retriever_finetuned/ (stored 0%)
  adding: content/dense_retriever_finetuned/tokenizer_config.json (deflated 73%)
  adding: content/dense_retriever_finetuned/sentence_bert_config.json (deflated 9%)
  adding: content/dense_retriever_finetuned/config_sentence_transformers.json (deflated 41%)
  adding: content/dense_retriever_finetuned/config.json (deflated 47%)
  adding: content/dense_retriever_finetuned/model.safetensors (deflated 8%)
  adding: content/dense_retriever_finetuned/tokenizer.json (deflated 71%)
  adding: content/dense_retriever_finetuned/2_Normalize/ (stored 0%)
  adding: content/dense_retriever_finetuned/modules.json (deflated 62%)
  adding: content/dense_retriever_finetuned/special_tokens_map.json (deflated 80%)
  adding: content/dense_retriever_finetuned/vocab.txt (deflated 53%)
  adding: content/dense_retriever_finetuned/1_Pooling/ (stored 0%)
  adding: content/dense_retriever_finetuned/1_Pooling/config.json (deflated 59%)
  adding: content/dense

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>