In [None]:
import pyterrier as pt
import pyterrier_alpha as pta
import tiktoken
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd
from pyterrier.measures import *
from pyterrier_alpha import RRFusion

dataset = pt.datasets.get_dataset("msmarco_passage")
print(dataset.get_corpus())

# iter_indexer = pt.IterDictIndexer("./wordindex", meta={'docno': 20, 'text': 4096})
# indexref = iter_indexer.index(dataset.get_corpus_iter())

INDEX_DIR = '/Users/manitk/Desktop/GIR/Pyterrier/combined/wordindex/data.properties'
indexref = pt.IndexFactory.of(INDEX_DIR)
print(indexref.getCollectionStatistics())

In [None]:
encoder = "sentence-transformers/msmarco-distilbert-base-v4"
tokenizer = AutoTokenizer.from_pretrained(encoder)
model = AutoModelForMaskedLM.from_pretrained(encoder)
model.eval()

In [None]:
def expand_query_nearest_neighbor(query, top_k=3):
    tokens = query.split()
    expanded_queries = []

    # looping over every token in the query
    for i, word in enumerate(tokens):
        masked = tokens.copy()
        masked[i] = "[MASK]"
        masked_query = " ".join(masked)
        
        # tokenising and passing through the model
        inputs = tokenizer(masked_query, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)

        # finding the mask token index and getting logits
        mask_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1].item()
        logits = outputs.logits[0, mask_index]
        
        # top-k predicted tokens
        topk_indices = torch.topk(logits, top_k).indices
        predicted_tokens = tokenizer.convert_ids_to_tokens(topk_indices)
        
        # replacing the mask with each predicted token and generate new queries
        for token in predicted_tokens:
            new_tokens = masked.copy()
            new_tokens[i] = token
            expanded_queries.append(" ".join(new_tokens))
    
    return [query] + expanded_queries

def nearest_neighbor_expander(row, top_k=3):
    qid = row['qid']
    query = row['query']
    expansions = expand_query_nearest_neighbor(query, top_k=top_k)

    print(f"Original query: {query}")
    print(f"Expanded queries: {expansions}")
    return pd.DataFrame([{"qid": qid, "query": exp} for exp in expansions])

def build_expanded_topics(topics_df, top_k=3, include_original=True):
    new_rows = []
    for _, row in topics_df.iterrows():
        qid = row['qid']
        query = row['query']

        # Include original query
        if include_original:
            new_rows.append({"qid": qid, "query": query})

        # Add k expanded queries
        expansions = expand_query_nearest_neighbor(query, top_k=top_k)
        for exp in expansions:
            new_rows.append({"qid": qid, "query": exp})

    return pd.DataFrame(new_rows)


topics = dataset.get_topics('test-2020')
qrels = dataset.get_qrels('test-2020')
expanded_topics = build_expanded_topics(topics, top_k=3)

# expander = pt.apply.query(lambda topics: nearest_neighbor_expander(topics, top_k=3))
# bm25 = pt.terrier.Retriever(indexref, wmodel="BM25")
# expanded_bm25 = expander >> bm25

In [51]:
for _, row in topics.iterrows():
    print(f"Query ID: {row['qid']}, Original Query: {row['query']}")


Query ID: 1030303, Original Query: who is aziz hashim
Query ID: 1037496, Original Query: who is rep scalise
Query ID: 1043135, Original Query: who killed nicholas ii of russia
Query ID: 1045109, Original Query: who owns barnhart crane
Query ID: 1049519, Original Query: who said no one can make you feel inferior
Query ID: 1051399, Original Query: who sings monk theme song
Query ID: 1056416, Original Query: who was the highest career passer rating in the nfl
Query ID: 1064670, Original Query: why do hunters pattern their shotguns
Query ID: 1065636, Original Query: why do some places on my scalp feel sore
Query ID: 1071750, Original Query: why is pete rose banned from hall of fame
Query ID: 1103153, Original Query: who is thomas m cooley
Query ID: 1103791, Original Query: definition of endorsing
Query ID: 1104501, Original Query: which hormone increases calcium levels in the blood
Query ID: 1105792, Original Query: define geon
Query ID: 1105860, Original Query: where can the amazon rainfo

In [52]:
for _, row in expanded_topics.iterrows():
    print(f"Query ID: {row['qid']}, Expanded Query: {row['query']}")

Query ID: 1030303, Expanded Query: who is aziz hashim
Query ID: 1030303, Expanded Query: who is aziz hashim
Query ID: 1030303, Expanded Query: ul is aziz hashim
Query ID: 1030303, Expanded Query: mark is aziz hashim
Query ID: 1030303, Expanded Query: italics is aziz hashim
Query ID: 1030303, Expanded Query: who ul aziz hashim
Query ID: 1030303, Expanded Query: who sloppy aziz hashim
Query ID: 1030303, Expanded Query: who rail aziz hashim
Query ID: 1030303, Expanded Query: who is lase hashim
Query ID: 1030303, Expanded Query: who is ul hashim
Query ID: 1030303, Expanded Query: who is lius hashim
Query ID: 1030303, Expanded Query: who is aziz lle
Query ID: 1030303, Expanded Query: who is aziz westwood
Query ID: 1030303, Expanded Query: who is aziz names
Query ID: 1037496, Expanded Query: who is rep scalise
Query ID: 1037496, Expanded Query: who is rep scalise
Query ID: 1037496, Expanded Query: ores is rep scalise
Query ID: 1037496, Expanded Query: sov is rep scalise
Query ID: 1037496, Ex

In [None]:
bm25 = pt.terrier.Retriever(indexref, wmodel="BM25")

expanded_topics["query"] = expanded_topics["query"].str.replace("##", "", regex=False)
# Could not parse query qid 1132943 'how long do ##tag cook artichokes for' -- Lexical error at line 1, column 14.  Encountered: "#" (35), after : "" org.terrier.querying.parser.QueryParserException

results_expanded = bm25.transform(expanded_topics)
results_baseline = bm25.transform(topics)

In [40]:
rrf = pta.fusion.rr_fusion(results_baseline, results_expanded, k=60, num_results=1000)

In [47]:
results = pt.Experiment(
    [bm25, rrf],
    topics,
    qrels,
    eval_metrics=[nDCG@10, AP(rel=2), AP(rel=3), 'map'],
    names=["bm25", "rrf"]
)

print(results[["nDCG@10", "AP(rel=2)", "AP(rel=3)", "map"]])

    nDCG@10  AP(rel=2)  AP(rel=3)       map
0  0.493627   0.292988   0.287098  0.358724
1  0.432854   0.192317   0.148353  0.208707


In [49]:
tfidf = pt.terrier.Retriever(indexref, wmodel="TF_IDF")
rm3 = tfidf >> pt.rewrite.RM3(indexref) >> tfidf

results_expanded = tfidf.transform(expanded_topics)
results_baseline = tfidf.transform(topics)

rrf = pta.fusion.rr_fusion(results_baseline, results_expanded, k=60, num_results=1000)
results = pt.Experiment(
    [tfidf, rrf],
    topics,
    qrels,
    eval_metrics=[nDCG@10, AP(rel=2), AP(rel=3), 'map'],
    names=["tfidf", "rrf"]
)
print(results[["nDCG@10", "AP(rel=2)", "AP(rel=3)", "map"]])

results_expanded = rm3.transform(expanded_topics)
results_baseline = rm3.transform(topics)

rrf = pta.fusion.rr_fusion(results_baseline, results_expanded, k=60, num_results=1000)  
results = pt.Experiment(
    [rm3, rrf],
    topics,
    qrels,
    eval_metrics=[nDCG@10, AP(rel=2), AP(rel=3), 'map'],
    names=["rm3", "rrf"]
)
print(results[["nDCG@10", "AP(rel=2)", "AP(rel=3)", "map"]])

    nDCG@10  AP(rel=2)  AP(rel=3)       map
0  0.492575   0.292548   0.285249  0.358072
1  0.435037   0.191622   0.150664  0.208082
    nDCG@10  AP(rel=2)  AP(rel=3)       map
0  0.509225   0.316460   0.305664  0.400533
1  0.509987   0.311512   0.294901  0.392797
