<a href="https://colab.research.google.com/github/navneetkrc/llm-rag-with-reranker-demo/blob/main/ANCE_Query_Reformulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# samsung_ance_colab.ipynb (Python script version for Colab)

# =========================================
# 📦 1. Install Dependencies
# =========================================
!pip install sentence-transformers faiss-cpu pandas datasets

In [None]:
# =========================================
# 🧪 2. Generate Synthetic Samsung E-commerce Data
# =========================================
# !pip install -q sentence-transformers faiss-cpu pandas datasets

from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
from datasets import Dataset
import faiss
import pandas as pd
import numpy as np
import os


os.makedirs("data", exist_ok=True)

queries = [
    {"query_id": i, "query": q} for i, q in enumerate([
        "galaxy s23 ultra case",
        "65 inch smart tv",
        "wireless earbuds for running",
        "black french door refrigerator",
        "tablet with s pen",
        "bespoke washing machine",
        "ultra hd monitor",
        "gaming laptop with rtx",
        "curved screen monitor",
        "portable air conditioner"
    ])
]

products = [
    {"product_id": i, "title": t, "description": d}
    for i, (t, d) in enumerate([
        ("Galaxy S23 Ultra Protective Case", "Durable case for S23 Ultra, shockproof and slim design."),
        ("Samsung 65-Inch QLED Smart TV", "4K UHD, HDR, Smart Hub, Voice Assistant Enabled."),
        ("Galaxy Buds2 Pro", "Noise canceling wireless earbuds with long battery life."),
        ("Bespoke Black French Door Refrigerator", "Customizable design, energy efficient, large capacity."),
        ("Galaxy Tab S8 with S Pen", "Powerful tablet with stylus, ideal for creators."),
        ("Bespoke Washing Machine AI Control", "Smart washing with AI wash cycles and eco-bubble technology."),
        ("Samsung 32 inch UHD Monitor", "Ultra HD display with vibrant color and fast refresh rate."),
        ("Galaxy Book Pro Gaming Laptop", "Intel i7, RTX 3060, AMOLED display."),
        ("Samsung Curved Monitor 27-Inch", "Immersive curve, eye comfort, full HD."),
        ("Samsung Portable AC 1.5 Ton", "Powerful portable cooling with inverter tech.")
    ])
]

co_purchases = [
    {"query_id": i, "product_id": i} for i in range(len(queries))
]

pd.DataFrame(queries).to_csv("data/queries.csv", index=False)
pd.DataFrame(products).to_csv("data/products.csv", index=False)
pd.DataFrame(co_purchases).to_csv("data/co_purchase_log.csv", index=False)

print("✅ Synthetic data generated!")

✅ Synthetic data generated!


In [None]:
# =========================================
# 🧠 3. ANCE Mining and Bi-Encoder Training
# =========================================
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
import faiss
import numpy as np
from datasets import Dataset


queries = pd.read_csv("data/queries.csv")
products = pd.read_csv("data/products.csv")
co_purchases = pd.read_csv("data/co_purchase_log.csv")

products["text"] = products["title"] + " " + products["description"]
product_texts = products["text"].tolist()

biencoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
product_embeddings = biencoder.encode(product_texts, show_progress_bar=True, convert_to_numpy=True)
faiss.normalize_L2(product_embeddings)
index = faiss.IndexFlatIP(product_embeddings.shape[1])
index.add(product_embeddings)

triplets = []
top_k = 20
query_embeddings = biencoder.encode(queries["query"].tolist(), show_progress_bar=True, convert_to_numpy=True)
faiss.normalize_L2(query_embeddings)
D, I = index.search(query_embeddings, top_k)

for i, (scores, indices) in enumerate(zip(D, I)):
    qid = queries.iloc[i]['query_id']
    qtxt = queries.iloc[i]['query']
    positive_ids = co_purchases[co_purchases['query_id'] == qid]['product_id'].tolist()
    if not positive_ids:
        continue
    positive_text = products[products['product_id'] == positive_ids[0]]['text'].values[0]
    for pid in indices:
        neg_id = products.iloc[pid]['product_id']
        if neg_id not in positive_ids:
            negative_text = products.iloc[pid]['text']
            triplets.append(InputExample(texts=[qtxt, positive_text, negative_text]))
            break

print(f"✅ {len(triplets)} training triplets ready!")

train_dataloader = DataLoader(triplets, shuffle=True, batch_size=4)
train_loss = losses.TripletLoss(model=biencoder)

biencoder.fit(train_objectives=[(train_dataloader, train_loss)],
              epochs=1,
              warmup_steps=10,
              show_progress_bar=True)

biencoder.save("models/biencoder")
print("✅ Bi-encoder model saved to /models/biencoder")

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

✅ 10 training triplets ready!


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnavneetkrch[0m ([33mnavneetkrch-samsung-research[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss


✅ Bi-encoder model saved to /models/biencoder


In [None]:
# =========================================
# 📌 4. Test Retrieval with Trained Bi-Encoder
# =========================================
def test_query_retrieval(user_query, top_k=3):
    q_embed = biencoder.encode([user_query], convert_to_numpy=True)
    faiss.normalize_L2(q_embed)
    D, I = index.search(q_embed, top_k)
    print(f"\n🔍 Results for: '{user_query}'")
    for rank, idx in enumerate(I[0]):
        p = products.iloc[idx]
        print(f"[{rank+1}] {p['title']} - {p['description']}")

# Run test
sample_query = "best smart tv with voice assistant"
test_query_retrieval(sample_query)


🔍 Results for: 'best smart tv with voice assistant'
[1] Samsung 65-Inch QLED Smart TV - 4K UHD, HDR, Smart Hub, Voice Assistant Enabled.
[2] Samsung 32 inch UHD Monitor - Ultra HD display with vibrant color and fast refresh rate.
[3] Galaxy Book Pro Gaming Laptop - Intel i7, RTX 3060, AMOLED display.


In [None]:
# =========================================
# 🧠 5. Fine-tune Cross-Encoder on ANCE Mined Data
# =========================================
from sentence_transformers import CrossEncoder

pairwise_data = [
    InputExample(texts=[triplet.texts[0], triplet.texts[1]], label=1.0) for triplet in triplets
] + [
    InputExample(texts=[triplet.texts[0], triplet.texts[2]], label=0.0) for triplet in triplets
]

train_dataloader = DataLoader(pairwise_data, shuffle=True, batch_size=4)

cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", num_labels=1)

cross_encoder.fit(
    train_dataloader=train_dataloader,
    epochs=1,
    warmup_steps=10,
    show_progress_bar=True
)

cross_encoder.save("models/cross_encoder")
print("✅ Cross-encoder model saved to /models/cross_encoder")

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/5 [00:00<?, ?it/s]

✅ Cross-encoder model saved to /models/cross_encoder


In [None]:
# =========================================
# 🔍 6. Compare Bi-Encoder vs Cross-Encoder Ranking (w/ MRR@3)
# =========================================
def evaluate_ranking(top_k=3):
    mrr_bi, mrr_cross = 0, 0

    for _, row in queries.iterrows():
        query = row['query']
        qid = row['query_id']
        gt_pid = co_purchases[co_purchases['query_id'] == qid]['product_id'].values[0]

        # Bi-encoder Retrieval
        q_embed = biencoder.encode([query], convert_to_numpy=True)
        faiss.normalize_L2(q_embed)
        D, I = index.search(q_embed, top_k)
        retrieved_pids = [products.iloc[i]['product_id'] for i in I[0]]

        if gt_pid in retrieved_pids:
            rank = retrieved_pids.index(gt_pid) + 1
            mrr_bi += 1.0 / rank

        # Cross-encoder Re-ranking
        candidate_texts = [products.iloc[i]['text'] for i in I[0]]
        input_pairs = [[query, p] for p in candidate_texts]
        scores = cross_encoder.predict(input_pairs)
        reranked = sorted(zip(I[0], scores), key=lambda x: x[1], reverse=True)
        reranked_pids = [products.iloc[i]['product_id'] for i, _ in reranked]

        if gt_pid in reranked_pids:
            rank = reranked_pids.index(gt_pid) + 1
            mrr_cross += 1.0 / rank

    mrr_bi /= len(queries)
    mrr_cross /= len(queries)

    print(f"\n📊 MRR@{top_k} Bi-Encoder: {mrr_bi:.3f}")
    print(f"📊 MRR@{top_k} Cross-Encoder: {mrr_cross:.3f}")

evaluate_ranking(top_k=3)


📊 MRR@3 Bi-Encoder: 1.000
📊 MRR@3 Cross-Encoder: 1.000


##on AMazon data

In [None]:
!pip install --upgrade gcsfs fsspec

In [None]:
!pip install -q sentence-transformers faiss-cpu pandas datasets

from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
from datasets import Dataset
import faiss
import pandas as pd
import numpy as np
import os


In [None]:
# Bi-Encoder and Cross-Encoder Evaluation on Amazon ESCI Dataset

# ✅ 1. Load Dataset (Amazon ESCI from HuggingFace)
from datasets import load_dataset

# data = load_dataset("milistu/amazon-esci-data", "products")
# data = load_dataset("milistu/amazon-esci-data", "queries")
# print("Loaded the 'queries' configuration.")
# print(data)

In [None]:
from datasets import load_dataset
data_queries = load_dataset("milistu/amazon-esci-data", name="queries")
print("Loaded the 'queries' configuration.")
print(data_queries)

README.md:   0%|          | 0.00/5.81k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/47.9M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/15.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1983272 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/638016 [00:00<?, ? examples/s]

Loaded the 'queries' configuration.
DatasetDict({
    train: Dataset({
        features: ['example_id', 'query', 'query_id', 'product_id', 'product_locale', 'esci_label', 'small_version', 'large_version', 'split', '__index_level_0__'],
        num_rows: 1983272
    })
    test: Dataset({
        features: ['example_id', 'query', 'query_id', 'product_id', 'product_locale', 'esci_label', 'small_version', 'large_version', 'split', '__index_level_0__'],
        num_rows: 638016
    })
})


In [None]:
from datasets import load_dataset
data_sources = load_dataset("milistu/amazon-esci-data", name="sources")
print("Loaded the 'sources' configuration.")
print(data_sources)

train-00000-of-00001.parquet:   0%|          | 0.00/1.15M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/360k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/99683 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/30969 [00:00<?, ? examples/s]

Loaded the 'sources' configuration.
DatasetDict({
    train: Dataset({
        features: ['query_id', 'source', 'split', '__index_level_0__'],
        num_rows: 99683
    })
    test: Dataset({
        features: ['query_id', 'source', 'split', '__index_level_0__'],
        num_rows: 30969
    })
})


In [None]:
from datasets import load_dataset
data_products = load_dataset("milistu/amazon-esci-data", name="products")
print("Loaded the 'products' configuration.")
print(data_products)

train-00000-of-00004.parquet:   0%|          | 0.00/222M [00:00<?, ?B/s]

train-00001-of-00004.parquet:   0%|          | 0.00/220M [00:00<?, ?B/s]

train-00002-of-00004.parquet:   0%|          | 0.00/203M [00:00<?, ?B/s]

train-00003-of-00004.parquet:   0%|          | 0.00/218M [00:00<?, ?B/s]

test-00000-of-00002.parquet:   0%|          | 0.00/147M [00:00<?, ?B/s]

test-00001-of-00002.parquet:   0%|          | 0.00/139M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1371823 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/443101 [00:00<?, ? examples/s]

Loaded the 'products' configuration.
DatasetDict({
    train: Dataset({
        features: ['product_id', 'product_title', 'product_description', 'product_bullet_point', 'product_brand', 'product_color', 'product_locale', 'split', '__index_level_0__'],
        num_rows: 1371823
    })
    test: Dataset({
        features: ['product_id', 'product_title', 'product_description', 'product_bullet_point', 'product_brand', 'product_color', 'product_locale', 'split', '__index_level_0__'],
        num_rows: 443101
    })
})


In [None]:
## Full dataset takes a lot of time
# from datasets import load_dataset
# import pandas as pd

# # Load both configs
# queries = load_dataset("milistu/amazon-esci-data", "queries")["train"]
# products = load_dataset("milistu/amazon-esci-data", "products")["train"]

# # Convert to pandas for easier merging
# queries_df = queries.to_pandas()
# products_df = products.to_pandas()

# # Merge on product_id
# merged_df = pd.merge(queries_df, products_df, on=["product_id", "product_locale"])

# # Filter relevant ESCI samples (label 'E' for exact match)
# exact_matches = merged_df[merged_df["esci_label"] == "E"]

# # Create triplets for training
# import random
# from sentence_transformers import InputExample

# def generate_triplets(df):
#     triplets = []
#     for _, row in df.iterrows():
#         query = row["query"]
#         pos = row["product_title"]
#         # Sample hard negatives (same query, different product)
#         negs = df[df["query"] == query]["product_title"].tolist()
#         negs = [n for n in negs if n != pos]
#         if negs:
#             neg = random.choice(negs)
#             triplets.append(InputExample(texts=[query, pos, neg]))
#     return triplets

# train_triplets = generate_triplets(exact_matches)
# print(f"✅ {len(train_triplets)} training triplets ready!")


In [9]:
from datasets import load_dataset
import pandas as pd
import random
from sentence_transformers import InputExample

# Load both configs
queries = load_dataset("milistu/amazon-esci-data", "queries")["train"]
products = load_dataset("milistu/amazon-esci-data", "products")["train"]

# Convert to pandas
queries_df = queries.to_pandas()
products_df = products.to_pandas()

# Merge on product_id + locale
merged_df = pd.merge(queries_df, products_df, on=["product_id", "product_locale"])

# Filter to exact matches
exact_matches = merged_df[merged_df["esci_label"] == "E"]

# ✅ Take only 10,000 samples for faster processing
exact_matches_sampled = exact_matches.sample(n=10_000, random_state=42)

# Triplet generation
def generate_triplets(df):
    triplets = []
    grouped = df.groupby("query")

    for query, group in grouped:
        titles = group["product_title"].tolist()
        if len(titles) < 2:
            continue  # skip if not enough for pos/neg
        for i in range(len(titles) - 1):
            pos = titles[i]
            neg_candidates = [t for t in titles if t != pos]
            if neg_candidates:
                neg = random.choice(neg_candidates)
                triplets.append(InputExample(texts=[query, pos, neg]))
    return triplets

train_triplets = generate_triplets(exact_matches_sampled)
print(f"✅ {len(train_triplets)} training triplets ready (from 10K sample)!")


✅ 587 training triplets ready (from 10K sample)!


In [None]:
train_triplets

In [None]:
# # ✅ 2. Prepare Triplets for Bi-Encoder Training
# from sentence_transformers import InputExample
# from torch.utils.data import DataLoader
# import random

# def generate_triplets(data_split):
#     triplets = []
#     for row in data_split:
#         if row['esci_label'] == 'E':
#             query = row['query']
#             pos = row['product_title']
#             # Sample hard negatives
#             negs = [r['product_title'] for r in random.sample(list(data_split), 10)
#                     if r['esci_label'] == 'I']
#             if negs:
#                 triplets.append(InputExample(texts=[query, pos, negs[0]]))
#     return triplets

# train_triplets = generate_triplets(data['train'])
# train_dataloader = DataLoader(train_triplets, shuffle=True, batch_size=16)

In [None]:
# from datasets import load_dataset
# from torch.utils.data import DataLoader # Assuming you need this later

# # Load the 'sources' configuration which contains query-product pairs and labels
# data = load_dataset("milistu/amazon-esci-data", name="sources")
# print("Loaded the 'sources' configuration.")
# print("Dataset structure:", data)
# print("\nFeatures in the 'train' split:", data['train'].features)
# # Print the first row to see the actual column names
# if len(data['train']) > 0:
#     print("\nFirst row example:", data['train'][0])

In [None]:
# # Your generate_triplets function (keep as is for now, but see note below)
# def generate_triplets(data_split):
#     # NOTE: This logic might need revision depending on your goal
#     # It currently doesn't guarantee finding a true negative for the *same* query
#     triplets = []
#     last_query = None
#     last_pos = None
#     for row in data_split:
#         # Check if the required keys exist before accessing
#         if 'esci_label' not in row or 'query' not in row or 'product_title' not in row:
#             print(f"Skipping row due to missing keys: {row}")
#             continue

#         if row['esci_label'] == 'E': # Assuming 'E' means Exact/Positive
#             last_query = row['query']
#             last_pos = row['product_title']
#         # WARNING: This 'else' assumes any non-'E' is a negative for the *last seen* positive.
#         # This is likely incorrect. A better approach groups by query.
#         elif last_query is not None and row['query'] == last_query: # Simple check if it's the same query
#              # You might want to check for 'I' (Irrelevant) specifically
#              # if row['esci_label'] == 'I':
#              neg = row['product_title']
#              triplets.append({'query': last_query, 'positive': last_pos, 'negative': neg})
#              # Reset last_query to avoid generating multiple negatives for one positive using this simple logic
#              last_query = None
#              last_pos = None
#     return triplets

# # --- Make sure you loaded 'sources' before this ---
# # data = load_dataset("milistu/amazon-esci-data", name="sources")

# if 'train' in data:
#      train_triplets = generate_triplets(data['train'])
#      if train_triplets: # Check if any triplets were generated
#          train_dataloader = DataLoader(train_triplets, shuffle=True, batch_size=16)
#          print(f"\nGenerated {len(train_triplets)} triplets.")
#          # print("First few triplets:", train_triplets[:3]) # Optional: print some examples
#      else:
#          print("\nNo triplets were generated. Check the logic in generate_triplets and the data.")
# else:
#     print("\n'train' split not found in the loaded data.")

In [13]:
# ✅ 3. Train Bi-Encoder

from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, losses

# Initialize your bi-encoder model
biencoder = SentenceTransformer("all-MiniLM-L6-v2")

# Create a DataLoader
train_dataloader = DataLoader(train_triplets, shuffle=True, batch_size=16)

# Define a Triplet loss
train_loss = losses.TripletLoss(model=biencoder)

# Train the model (for 1 epoch as a test)
biencoder.fit(train_objectives=[(train_dataloader, train_loss)],
              epochs=1,
              warmup_steps=10,
              show_progress_bar=True)


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnavneetkrch[0m ([33mnavneetkrch-samsung-research[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss


In [17]:
# ✅ 4. Cross-Encoder Fine-tuning
import torch
from torch.utils.data import DataLoader, Dataset
from sentence_transformers import CrossEncoder
from sentence_transformers import InputExample

# Create a PyTorch Dataset
class CrossEncoderDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Subset for training
subset_df = merged_df[merged_df["esci_label"].isin(["E", "I"])].sample(n=10000, random_state=42)



# Build training examples using InputExample
crossencoder_data = [
    InputExample(texts=[row["query"], row["product_title"]], label=1.0 if row["esci_label"] == "E" else 0.0)
    for _, row in subset_df.iterrows()
]

class CrossEncoderDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

train_dataset = CrossEncoderDataset(crossencoder_data)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

crossencoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', num_labels=1)

crossencoder.fit(
    train_dataloader=train_dataloader,
    epochs=1,
    warmup_steps=10,
    show_progress_bar=True
)




Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/625 [00:00<?, ?it/s]

In [19]:
# ✅ 5. Evaluation: MRR@3
from datasets import load_dataset
import pandas as pd

# Load the test splits of queries and products
queries_test = load_dataset("milistu/amazon-esci-data", "queries")["test"]
products_test = load_dataset("milistu/amazon-esci-data", "products")["test"]

# Convert to pandas
queries_df = queries_test.to_pandas()
products_df = products_test.to_pandas()

# Merge on product_id and locale
merged_test_df = pd.merge(queries_df, products_df, on=["product_id", "product_locale"])

# Optionally filter to only certain locales if needed (e.g., 'us')
# merged_test_df = merged_test_df[merged_test_df['product_locale'] == 'us']

# Optional: For faster testing, sample only 500-1000 query-product examples
test_sample = merged_test_df.sample(n=1000, random_state=42)


In [21]:
from sklearn.metrics.pairwise import cosine_similarity

def evaluate_mrr(test_data, biencoder, crossencoder, k=3):
    mrr_scores = []
    queries = defaultdict(list)

    for row in test_data:
        queries[row['query']].append((row['product_title'], 1 if row['esci_label'] == 'E' else 0))

    for query, products in queries.items():
        texts = [p[0] for p in products]
        labels = [p[1] for p in products]

        # Encode query and product titles separately
        query_emb = biencoder.encode(query, convert_to_tensor=True)
        text_embs = biencoder.encode(texts, convert_to_tensor=True)

        bi_scores = cosine_similarity([query_emb.cpu().numpy()], text_embs.cpu().numpy())[0]

        # Cross-encoder predicts directly
        cross_scores = crossencoder.predict(list(zip([query]*len(texts), texts)))

        # Sort by scores (you can switch between bi_scores and cross_scores to compare)
        ranked = sorted(zip(cross_scores, labels), key=lambda x: x[0], reverse=True)

        for rank, (_, label) in enumerate(ranked[:k]):
            if label == 1:
                mrr_scores.append(1 / (rank + 1))
                break
        else:
            mrr_scores.append(0)

    return np.mean(mrr_scores)


In [22]:
mrr_score = evaluate_mrr(test_sample.to_dict("records"), biencoder, crossencoder)
print("✅ MRR@3:", mrr_score)


✅ MRR@3: 0.6127551020408163


In [23]:
# ✅ 6. Save Models
biencoder.save("biencoder-amazon-esci")
crossencoder.save("crossencoder-amazon-esci")

In [35]:
!pip install -q streamlit pyngrok sentence-transformers

In [36]:
%%writefile app.py
import streamlit as st
from sentence_transformers import SentenceTransformer, CrossEncoder
import pandas as pd
import numpy as np

# Load pretrained models
bi_model = SentenceTransformer("all-MiniLM-L6-v2")
cross_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

# Sample queries and product titles
sample_data = pd.DataFrame({
    "query": [
        "wireless bluetooth earbuds",
        "gaming laptop",
        "smartwatch with heart rate monitor"
    ],
    "product_titles": [
        [
            "Bluetooth Earbuds Wireless with Charging Case",
            "Wired In-Ear Headphones",
            "Noise Cancelling Over-Ear Headphones"
        ],
        [
            "High Performance Gaming Laptop - RTX 3060",
            "Budget Laptop with Intel UHD Graphics",
            "Portable Chromebook for Students"
        ],
        [
            "Smartwatch with ECG and Heart Rate Sensor",
            "Digital Watch for Kids",
            "Fitness Tracker with Step Counter"
        ]
    ]
})

st.title("Bi-Encoder vs Cross-Encoder Ranking Comparison")

query_idx = st.selectbox("Choose a sample query:", sample_data.index)
query = sample_data.loc[query_idx, "query"]
candidates = sample_data.loc[query_idx, "product_titles"]

# Bi-encoder ranking
query_embedding = bi_model.encode(query, convert_to_tensor=True)
candidate_embeddings = bi_model.encode(candidates, convert_to_tensor=True)
scores_bi = np.dot(candidate_embeddings.cpu().numpy(), query_embedding.cpu().numpy())
ranked_bi = sorted(zip(candidates, scores_bi), key=lambda x: x[1], reverse=True)

# Cross-encoder ranking
cross_inputs = [[query, c] for c in candidates]
scores_cross = cross_model.predict(cross_inputs)
ranked_cross = sorted(zip(candidates, scores_cross), key=lambda x: x[1], reverse=True)

st.subheader("Query")
st.markdown(f"> {query}")

col1, col2 = st.columns(2)

with col1:
    st.markdown("### Bi-Encoder Ranking")
    for title, score in ranked_bi:
        st.markdown(f"- **{title}** — `{score:.4f}`")

with col2:
    st.markdown("### Cross-Encoder Ranking")
    for title, score in ranked_cross:
        st.markdown(f"- **{title}** — `{score:.4f}`")


Overwriting app.py


In [37]:
import os
from google.colab import userdata
from pyngrok import ngrok

# Retrieve the ngrok token from Google Colab secrets
ngrok_token = userdata.get('NGROK_AUTH_TOKEN')

# Set the ngrok authentication token
ngrok.set_auth_token(ngrok_token)

# Start ngrok tunnel (specify port 8501 for Streamlit)
public_url = ngrok.connect(8501, "http")  # Ensure ngrok is using HTTP and port 8501
print(f"🌍 Your Streamlit app is live at: {public_url}")

# Run Streamlit app
!streamlit run app.py &


🌍 Your Streamlit app is live at: NgrokTunnel: "https://a4c7-34-125-211-77.ngrok-free.app" -> "http://localhost:8501"

Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://34.125.211.77:8501[0m
[0m




[34m  Stopping...[0m


In [None]:
!pip install pyngrok
!pip install streamlit
!pip install sentence-transformers


In [46]:
from datasets import load_dataset
import pandas as pd

# Load the test splits of queries and products
queries_test = load_dataset("milistu/amazon-esci-data", "queries")["test"]
products_test = load_dataset("milistu/amazon-esci-data", "products")["test"]

# Convert to pandas
queries_df = queries_test.to_pandas()
products_df = products_test.to_pandas()

# Merge on product_id and locale
merged_test_df = pd.merge(queries_df, products_df, on=["product_id", "product_locale"])

# Sample 1000 query-product examples for testing
test_sample = merged_test_df.sample(n=1000, random_state=42)


In [84]:
%%writefile app.py
import streamlit as st
from sentence_transformers import SentenceTransformer, CrossEncoder
import pandas as pd
import numpy as np
from datasets import load_dataset

@st.cache_data
def load_data():
    # Load full test splits
    queries = load_dataset("milistu/amazon-esci-data", "queries")["test"].to_pandas()
    products = load_dataset("milistu/amazon-esci-data", "products")["test"].to_pandas()

    # Merge full dataset
    merged = pd.merge(queries, products, on=["product_id", "product_locale"])

    # Group by query to get product lists
    grouped = (
        merged.groupby("query")["product_title"]
        .apply(lambda x: list(set(x)))  # deduplicate
        .reset_index()
    )

    # Keep only queries with at least 5 products
    grouped["product_count"] = grouped["product_title"].apply(len)
    filtered = grouped[grouped["product_count"] >= 5].reset_index(drop=True)

    # For performance, sample only 200 queries
    return filtered.sample(n=200, random_state=42) if len(filtered) > 200 else filtered

@st.cache_resource
def load_models():
    bi = SentenceTransformer("all-MiniLM-L6-v2")
    cross = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
    return bi, cross

# Load everything
valid_queries = load_data()
bi_model, cross_model = load_models()

st.title("🔍 Bi-Encoder vs Cross-Encoder Ranking")

if not valid_queries.empty:
    query_text = st.selectbox("Choose a sample query:", valid_queries["query"].tolist())
    query_row = valid_queries[valid_queries["query"] == query_text]

    if not query_row.empty:
        candidates = query_row.iloc[0]["product_title"][:10]  # get top 10 candidates

        # Bi-encoder ranking
        query_embedding = bi_model.encode(query_text, convert_to_tensor=True)
        candidate_embeddings = bi_model.encode(candidates, convert_to_tensor=True)
        scores_bi = np.dot(candidate_embeddings.cpu().numpy(), query_embedding.cpu().numpy())
        ranked_bi = sorted(zip(candidates, scores_bi), key=lambda x: x[1], reverse=True)[:5]

        # Cross-encoder ranking
        cross_inputs = [[query_text, c] for c in candidates]
        scores_cross = cross_model.predict(cross_inputs)
        ranked_cross = sorted(zip(candidates, scores_cross), key=lambda x: x[1], reverse=True)[:5]

        st.subheader("Query")
        st.markdown(f"> **{query_text}**")

        col1, col2 = st.columns(2)

        with col1:
            st.markdown("### 🔹 Bi-Encoder Top 5")
            for title, score in ranked_bi:
                st.markdown(f"- **{title}** — `{score:.4f}`")

        with col2:
            st.markdown("### 🔸 Cross-Encoder Top 5")
            for title, score in ranked_cross:
                st.markdown(f"- **{title}** — `{score:.4f}`")
    else:
        st.error("No products found for the selected query.")
else:
    st.error("No valid queries with enough product titles were found.")


Overwriting app.py


In [87]:
from pyngrok import ngrok

# Set up ngrok with the required auth token if needed
# Uncomment the next line if you haven't set up ngrok authentication yet
# ngrok.set_auth_token("your_ngrok_token")

# Open a tunnel on port 8501 (default Streamlit port)
public_url = ngrok.connect(8501)
print(f"🌍 Your Streamlit app is live at: {public_url}")


🌍 Your Streamlit app is live at: NgrokTunnel: "https://f1b0-34-125-211-77.ngrok-free.app" -> "http://localhost:8501"


In [88]:
!streamlit run app.py &



Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://34.125.211.77:8501[0m
[0m
2025-04-16 22:01:22.520528: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744840882.778485   30967 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744840882.848665   30967 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-16 22:01:48.334 Examining the path of torch.classes raised:
Traceback (most recent call la