In [1]:
import pandas as pd
import pyterrier as pt
import numpy as np
import string
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from pyterrier.measures import *
import re

In [2]:
BASE_IDX = "indexes/stopwords_removed"
QUERIES = "data/train_queries.csv"
QRELS   = "data/train_qrels.csv"

# Load queries and qrels
qs = pd.read_csv(QUERIES, sep="\t", names=["qid", "query"], header=0)
qrels = pd.read_csv(QRELS, sep="\t")

# Strip out all punctuation
qs['query'] = qs['query'].str.replace(rf"[{re.escape(string.punctuation)}]", " ", regex=True)

# # Make the qid an str
qs['qid'] = qs['qid'].astype(str)
qrels['qid'] = qrels['qid'].astype(str)


train_qs, val_qs = train_test_split(
    qs,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

# now get the corresponding qrels
train_qids = set(train_qs['qid'])
val_qids   = set(val_qs  ['qid'])

train_qrels = qrels[qrels['qid'].isin(train_qids)]
val_qrels   = qrels[qrels['qid'].isin(val_qids)]

In [3]:
stopwords_idx = pt.IndexFactory.of("./indexes/stopwords_removed")
bm25 = pt.terrier.Retriever(stopwords_idx, wmodel="BM25", controls={"bm25.k_1": 3.5, "bm25.b": 0.75})

stopwords_stemmed_idx = pt.IndexFactory.of("./indexes/stopwords_and_stemming")
lmd = pt.terrier.Retriever(stopwords_stemmed_idx, wmodel="DirichletLM", controls={"dirichletlm.mu": 100})

Java started (triggered by IndexFactory.of) and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]


In [107]:
from abc import abstractmethod

class LLMExpander(pt.transformer.Transformer):
    def __init__(self, model, tokenizer, device: str | None = None):
        self.tokenizer = tokenizer
        self.model = model
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        self.model.to(self.device)
        self.max_new_tokens = 128

    @abstractmethod
    def build_prompt(self, query: str) -> str:
        pass

    def transform(self, topics: pd.DataFrame) -> pd.DataFrame:
        expanded_records = []
        for _, row in topics.iterrows():
            qid = row["qid"]
            original_query = row["query"]
            prompt = self.build_prompt(original_query)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            outputs = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens)
            llm_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            llm_output = re.sub(rf"[{re.escape(string.punctuation)}]", " ", llm_output)
            expanded_query = " ".join([original_query] * 5) + " " + llm_output
            expanded_records.append({"qid": qid, "query": expanded_query})
        return pd.DataFrame(expanded_records)
    

class Q2DZSExpander(LLMExpander):
    def build_prompt(self, query: str) -> str:
        return f"Write a passage that answers the following query: {query}"
    
class Q2EZSExpander(LLMExpander):
    def build_prompt(self, query: str) -> str:
        return f"Write a list of keywords for the following query: {query}"
    
class Q2EFSExpander(LLMExpander):
    def build_prompt(self, query: str) -> str:
        return f"""
Write a list of additional keywords for a search query
Write the list as one string with spaces between. Don't include duplicates of keywords. Don't include the query itself.

Here are some examples:
Query: where was jaws filmed amity
Keywords: martha's vineyard filming locations movie set shooting site jaws location beach massachusetts island movie scene harbor ocean town

Query: what is the mass of a beta
Keywords: beta particle electron mass subatomic particle physics neutron decay radiation energy charge

Query: what is beef burgundy
Keywords: beef bourguignon recipe wine stew french dish ingredients cooking red wine braised meat

Query: difference between affiliate and subsidiary
Keywords: business structure ownership control corporation company legal entity parent company partnership relationship

Query: does candida cause anxiety
Keywords: candida overgrowth gut brain axis mental health yeast infection microbiome symptoms mood depression

Now it is your turn:
Query: {query}
Keywords: 
"""

class CoTExpander(LLMExpander):
    def build_prompt(self, query: str) -> str:
        return f"""
Answer the following query: {query}
Give the rationale before answering
"""

expander = CoTExpander(model, tokenizer)

expander.transform(val_qs[:5])

Unnamed: 0,qid,query
0,10270,why are noble gases unreactive or inert why a...
1,5180,what is raid bbu what is raid bbu what is raid...
2,8388,define crying define crying define crying d...
3,3625,what percent of salary is good to save for ret...
4,1707,salaries of texas appraiser salaries of texas ...


In [108]:
q2dzs_expander = Q2DZSExpander(model, tokenizer)
cot_expander = CoTExpander(model, tokenizer)
q2ezs_expander = Q2EZSExpander(model, tokenizer)
q2efs_expander = Q2EFSExpander(model, tokenizer)

In [59]:
expanders = {
    "Q2DZSExpander": q2dzs_expander,
    "Q2EZSExpander": q2ezs_expander,
    "Q2EFSExpander": q2efs_expander
}

mean_timings = []

for name, expander in tqdm(expanders.items()):
    times = []
    for _, row in qs[:1000].iterrows():
        single_df = pd.DataFrame([row])
        start = time.time()
        expander.transform(single_df)
        elapsed = time.time() - start
        times.append(elapsed)
    mean_time = sum(times) / len(times)
    mean_timings.append({"expander": name, "mean_time_sec": mean_time})

timings_df = pd.DataFrame(mean_timings)
timings_df

  0%|          | 0/3 [00:00<?, ?it/s]

Unnamed: 0,expander,mean_time_sec
0,Q2DZSExpander,0.204803
1,Q2EZSExpander,0.088314
2,CustomLLMExpander,0.206493


In [110]:
pt.Experiment(
    [bm25, cot_expander >> bm25], val_qs, val_qrels,
    eval_metrics=["recip_rank", "ndcg_cut.20", "P.20", "recall.20"],
    filter_by_qrels=True,
    verbose=True,
    names=["BM25", "BM25 + CoT"]
)

pt.Experiment: 100%|██████████| 2/2 [10:00<00:00, 300.20s/system]


Unnamed: 0,name,recip_rank,ndcg_cut.20,P.20,recall.20
0,BM25,0.488983,0.595117,0.047012,0.940248
1,BM25 + CoT,0.443124,0.545671,0.04425,0.885006


In [109]:
pt.Experiment(
    [bm25, q2ezs_expander >> bm25, q2efs_expander >> bm25], val_qs, val_qrels,
    eval_metrics=["recip_rank", "ndcg_cut.20", "P.20", "recall.20"],
    filter_by_qrels=True,
    verbose=True,
    names=["BM25", "BM25 + Q2E/ZS", "BM25 + Q2E/FS"]
)

pt.Experiment: 100%|██████████| 3/3 [04:02<00:00, 80.75s/system]


Unnamed: 0,name,recip_rank,ndcg_cut.20,P.20,recall.20
0,BM25,0.488983,0.595117,0.047012,0.940248
1,BM25 + Q2E/ZS,0.49161,0.59702,0.046956,0.939121
2,BM25 + Q2E/FS,0.489807,0.595482,0.046956,0.939121


In [111]:
pt.Experiment(
    [q2ezs_expander >> lmd], val_qs, val_qrels,
    eval_metrics=["recip_rank", "ndcg_cut.20", "P.20", "recall.20"],
    filter_by_qrels=True,
    verbose=True,
    names=["LMD + Q2E/ZS"]
)

pt.Experiment: 100%|██████████| 1/1 [01:54<00:00, 114.10s/system]


Unnamed: 0,name,recip_rank,ndcg_cut.20,P.20,recall.20
0,LMD + Q2E/ZS,0.458897,0.572721,0.047351,0.947012
