In [None]:
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import pyterrier as pt
from tqdm import tqdm

model_name = 'all-MiniLM-L6-v2'
model = SentenceTransformer(model_name)
tokenizer = model.tokenizer

vocab = tokenizer.get_vocab()
token_list = list(vocab.keys())
print(vocab)



In [None]:
batch_size = 1000
token_embeddings = []

for i in tqdm(range(0, len(token_list), batch_size)):
    batch = token_list[i:i+batch_size]
    embeddings = model.encode(batch, show_progress_bar=True, convert_to_numpy=True)
    token_embeddings.extend(embeddings)

token_embeddings = np.array(token_embeddings)

n_neighbors = 50
nn = NearestNeighbors(n_neighbors=n_neighbors+1, metric='cosine', n_jobs=-1)
nn.fit(token_embeddings)

Batches: 100%|██████████| 32/32 [00:00<00:00, 70.75it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 75.63it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 103.68it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 94.29it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 93.75it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 80.71it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 82.78it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 105.61it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 79.33it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 110.18it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 67.95it/s] 
Batches: 100%|██████████| 32/32 [00:00<00:00, 101.75it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 99.87it/s] 
Batches: 100%|██████████| 32/32 [00:00<00:00, 83.59it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 101.24it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 80.38it/s]
Batches: 100%|██████████| 32/32 [00:00<00:00, 102.07it/s]
Batches: 100%|█████████

In [27]:
import json

distances, indices = nn.kneighbors(token_embeddings)
token_to_neighbors = {}

for i, idx_list in enumerate(indices):
    token = token_list[i]
    neighbors = [token_list[j] for j in idx_list if j != i][:n_neighbors]  # Exclude self, keep top 50
    token_to_neighbors[token] = neighbors

# Save to JSON
output_json_file = "token_nearest_neighbors.json"
with open(output_json_file, 'w', encoding='utf-8') as f:
    json.dump(token_to_neighbors, f, ensure_ascii=False, indent=2)

In [28]:
def expand_query_msmarco(query, nn_model, token_list, tokenizer, n_neighbors=3):
    tokens = tokenizer.tokenize(query.lower())
    expanded_tokens = []
    
    for token in tokens:
        try:
            token_idx = token_list.index(token)
            token_embedding = token_embeddings[token_idx].reshape(1, -1)
            
            _, indices = nn_model.kneighbors(token_embedding)
            
            neighbors = [token_list[i] for i in indices[0][1:n_neighbors+1]]
            
            expanded_tokens.append(token)
            expanded_tokens.extend(neighbors)
            
        except ValueError:
            expanded_tokens.append(token) # not in vocabulary - just keep original
    
    expanded_query = ' '.join(expanded_tokens).replace("##", "")
    return expanded_query

In [29]:
query_file = "/Users/manitk/Desktop/GIR/Pyterrier/combined/msmarco_passage_test2019_queries-sbert.tsv"
queries = pd.read_csv(query_file, sep='\t', names=["qid", "query"], dtype={"qid": str, "query": str})

tqdm.pandas()
expanded_queries = queries.copy()
expanded_queries['query'] = queries['query'].progress_apply(
    lambda x: expand_query_msmarco(x, nn, token_list, tokenizer)
)

expanded_query_file = "/Users/manitk/Desktop/GIR/Pyterrier/combined/msmarco_passage_test2019_queries-sbert-expanded-msmarco.tsv"
expanded_queries.to_csv(expanded_query_file, sep='\t', header=False, index=False)

100%|██████████| 201/201 [00:19<00:00, 10.15it/s]


In [30]:
import re
from tqdm import tqdm

def clean_text(text):
    """
    Comprehensive text cleaning for MS MARCO:
    1. Remove special characters
    2. Handle whitespace
    3. Clean WordPiece tokens
    4. Remove control tokens
    """
    if not isinstance(text, str):
        return ""
    
    # Remove special characters (keep alphanumeric, spaces, and basic punctuation)
    text = re.sub(r'[^\w\s.,?]', ' ', text)
    
    # Normalize whitespace
    text = ' '.join(text.split())
    
    # Lowercase
    text = text.lower()
    
    return text

def msmarco_tokenize_clean(text):
    """Safe tokenization with cleaning"""
    cleaned = clean_text(text)
    tokens = tokenizer.tokenize(cleaned)
    
    # Filter out special tokens and clean WordPiece tokens
    filtered_tokens = [
        token.replace("##", "") 
        for token in tokens
        if token not in ['[CLS]', '[SEP]', '[PAD]', '[UNK]', '[MASK]']
    ]
    
    return ' '.join(filtered_tokens)

In [32]:
from pyterrier.measures import *
import pandas as pd

expanded_query_file = "/Users/manitk/Desktop/GIR/Pyterrier/combined/msmarco_passage_test2019_queries-sbert-expanded-msmarco.tsv"
queries = pd.read_csv(expanded_query_file, sep='\t', names=["qid", "query"], dtype={"qid": str, "query": str})
queries['query'] = queries['query'].progress_apply(msmarco_tokenize_clean)

index_ref = pt.IndexFactory.of("./sbert_index")

bm25 = pt.terrier.Retriever(index_ref, wmodel="BM25", num_results=100)
tfidf = pt.terrier.Retriever(index_ref, wmodel="TF_IDF", num_results=100)
rm3 = tfidf >> pt.rewrite.RM3(index_ref, fb_terms=10, fb_docs=10) >> tfidf

results = pt.Experiment(
    [bm25, tfidf, rm3],
    queries,
    pt.datasets.get_dataset("msmarco_passage").get_qrels("test-2019"),
    eval_metrics=[nDCG@10, AP(rel=2), AP(rel=3), 'map'],
    names=["bm25","tfidf","rm3"]
)

print(results[["nDCG@10", "AP(rel=2)", "AP(rel=3)", "map"]])

100%|██████████| 201/201 [00:00<00:00, 9797.16it/s]


16:03:04.195 [main] WARN org.terrier.structures.BaseCompressingMetaIndex -- Structure meta reading data file directly from disk (SLOW) - try index.meta.data-source=fileinmem in the index properties file. 1.9 GiB of memory would be required.
    nDCG@10  AP(rel=2)  AP(rel=3)       map
0  0.258888   0.127955   0.068067  0.130872
1  0.268339   0.128001   0.072621  0.130665
2  0.268039   0.148951   0.079571  0.161135
