In [1]:
from pyserini.search.lucene import LuceneSearcher
from datasets import load_dataset
from tqdm import tqdm
from dotenv import load_dotenv

import os
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
import pickle
import copy
from utils.metrics import get_recall_at_100, get_nDCG_at_10

from main import get_query_expansion_dataset, run_search

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
# set up LLM, assumes that openai API key is stored in .env file
load_dotenv(override=True)
openai_api_key = os.getenv("OPENAI_API_KEY")
chat = ChatOpenAI(openai_api_key=openai_api_key, model="gpt-3.5-turbo")

In [3]:
# Load the Miracl dataset (english version)
dataset = load_dataset("Cohere/miracl-en-queries-22-12", split="dev")
dataset = dataset.to_pandas().to_dict(orient='records')

## Experiment 1: BM25 baseline

In [13]:
# BM25
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
recall_baseline, ndcg_baseline = run_search(searcher, dataset)
print(f'BM25 Recall@100: {recall_baseline:.4f}')
print(f'BM25 nDCG@10: {ndcg_baseline:.4f}')
print("\n")

Searching: 100%|██████████| 799/799 [01:24<00:00,  9.47it/s]

BM25 Recall@100: 0.8190
BM25 nDCG@10: 0.3506







## Experiment 2: Zero-shot prompting

In [5]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_zero_shot = get_query_expansion_dataset(dataset, chat, option='zero-shot')
#data_expanded_zero_shot_test = get_query_expansion_dataset(dataset[:10], option='zero-shot')

recall_zs, ndcg_zs = run_search(searcher, data_expanded_zero_shot)
print(f'BM25 Recall@100: {recall_zs:.4f}')
print(f'BM25 nDCG@10: {ndcg_zs:.4f}')
print("\n")

Expanding queries:   0%|          | 1/799 [00:10<2:20:22, 10.55s/it]


KeyboardInterrupt: 

## Experiment 3: One-shot prompting

In [16]:
# BM25
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_one_shot = get_query_expansion_dataset(dataset, chat, option='one-shot')

recall_os, ndcg_os = run_search(searcher, data_expanded_one_shot)
print(f'BM25 Recall@100: {recall_os:.4f}')
print(f'BM25 nDCG@10: {ndcg_os:.4f}')
print("\n")

Searching:   0%|          | 0/799 [00:00<?, ?it/s]

Searching: 100%|██████████| 799/799 [02:15<00:00,  5.88it/s]

BM25 Recall@100: 0.7485
BM25 nDCG@10: 0.3198







## Experiment 4: Multi-shot prompting

In [17]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_multi_shot = get_query_expansion_dataset(dataset, chat, option='multi-shot')

recall_ms, ndcg_ms = run_search(searcher, data_expanded_multi_shot)
print(f'BM25 Recall@100: {recall_ms:.4f}')
print(f'BM25 nDCG@10: {ndcg_ms:.4f}')
print("\n")

Searching: 100%|██████████| 799/799 [02:03<00:00,  6.45it/s]

BM25 Recall@100: 0.7825
BM25 nDCG@10: 0.3486







## Experiment 5: Answer prompting

In [18]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_answer = get_query_expansion_dataset(dataset, chat, option='answer')

recall_ans, ndcg_ans = run_search(searcher, data_expanded_answer)
print(f'BM25 Recall@100: {recall_ans:.4f}')
print(f'BM25 nDCG@10: {ndcg_ans:.4f}')
print("\n")

Searching: 100%|██████████| 799/799 [07:07<00:00,  1.87it/s]

BM25 Recall@100: 0.7361
BM25 nDCG@10: 0.3960







## Query2Doc Zero Shot Prompting

In [10]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_answer = get_query_expansion_dataset(dataset, chat, option='q2d-zs')

recall, ndcg = run_search(searcher, data_expanded_answer)
print(f'BM25 Recall@100: {recall:.4f}')
print(f'BM25 nDCG@10: {ndcg:.4f}')
print("\n")

Expanding queries:  68%|██████▊   | 542/799 [5:21:36<2:32:29, 35.60s/it]


KeyboardInterrupt: 

## Q2D zero shot prompting with pseudo-relevant feedback

In [None]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_answer = get_query_expansion_dataset(dataset, chat, option='q2d-zs-prf')

recall, ndcg = run_search(searcher, data_expanded_answer)
print(f'BM25 Recall@100: {recall:.4f}')
print(f'BM25 nDCG@10: {ndcg:.4f}')
print("\n")

Getting prf documents...


Searching: 100%|██████████| 5/5 [00:00<00:00, 72.71it/s]


Done.


Expanding queries: 100%|██████████| 5/5 [00:52<00:00, 10.54s/it]
Searching: 100%|██████████| 5/5 [00:08<00:00,  1.60s/it]

BM25 Recall@100: 0.9000
BM25 nDCG@10: 0.3924







## Keywords with zero-shot prompting

In [8]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_answer = get_query_expansion_dataset(dataset, chat, option='q2e-zs')

recall, ndcg = run_search(searcher, data_expanded_answer)
print(f'BM25 Recall@100: {recall:.4f}')
print(f'BM25 nDCG@10: {ndcg:.4f}')
print("\n")

Expanding queries:   0%|          | 0/799 [00:00<?, ?it/s]

Expanding queries: 100%|██████████| 799/799 [54:50<00:00,  4.12s/it]  
Searching: 100%|██████████| 799/799 [01:25<00:00,  9.38it/s]

BM25 Recall@100: 0.8433
BM25 nDCG@10: 0.3745







## Keywords with pseudo-relevant feedback

In [9]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_answer = get_query_expansion_dataset(dataset, chat, option='q2e-zs-prf')

recall, ndcg = run_search(searcher, data_expanded_answer)
print(f'BM25 Recall@100: {recall:.4f}')
print(f'BM25 nDCG@10: {ndcg:.4f}')
print("\n")

Getting prf documents...


Searching:  11%|█         | 89/799 [00:02<00:28, 25.03it/s]

Searching: 100%|██████████| 799/799 [00:26<00:00, 30.04it/s]


Done.


Expanding queries: 100%|██████████| 799/799 [1:02:36<00:00,  4.70s/it]
Searching: 100%|██████████| 799/799 [02:56<00:00,  4.53it/s]

BM25 Recall@100: 0.8337
BM25 nDCG@10: 0.3722







## Chain of thought prompting

In [11]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_answer = get_query_expansion_dataset(dataset, chat, option='chain-of-thought')

recall, ndcg = run_search(searcher, data_expanded_answer)
print(f'BM25 Recall@100: {recall:.4f}')
print(f'BM25 nDCG@10: {ndcg:.4f}')
print("\n")

Searching:   0%|          | 0/799 [00:00<?, ?it/s]

Searching: 100%|██████████| 799/799 [30:35<00:00,  2.30s/it]

BM25 Recall@100: 0.8866
BM25 nDCG@10: 0.4877







## Chain of thought prompting with psuedo-relevant feedback

In [7]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_answer = get_query_expansion_dataset(dataset, chat, option='chain-of-thought-prf')

recall, ndcg = run_search(searcher, data_expanded_answer)
print(f'BM25 Recall@100: {recall:.4f}')
print(f'BM25 nDCG@10: {ndcg:.4f}')
print("\n")

Getting prf documents...


Searching: 100%|██████████| 799/799 [00:30<00:00, 25.81it/s]


Done.


Expanding queries: 100%|██████████| 799/799 [1:41:54<00:00,  7.65s/it]  
Searching: 100%|██████████| 799/799 [09:45<00:00,  1.36it/s]

BM25 Recall@100: 0.8486
BM25 nDCG@10: 0.4343





