In [1]:
from pyserini.search.lucene import LuceneSearcher
from datasets import load_dataset
from tqdm import tqdm
from dotenv import load_dotenv

import os
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
import pickle
import copy
from utils.metrics import get_recall_at_100, get_nDCG_at_10

from main import get_query_expansion_dataset, run_search

  return torch._C._cuda_getDeviceCount() > 0


In [8]:
# set up LLM, assumes that openai API key is stored in .env file
load_dotenv(override=True)
openai_api_key = os.getenv("OPENAI_API_KEY")
chat = ChatOpenAI(openai_api_key=openai_api_key, model="gpt-3.5-turbo")

In [12]:
# Load the Miracl dataset (english version)
dataset = load_dataset("Cohere/miracl-en-queries-22-12", split="dev")
dataset = dataset.to_pandas().to_dict(orient='records')

## Experiment 1: BM25 baseline

In [13]:
# BM25
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
recall_baseline, ndcg_baseline = run_search(searcher, dataset)
print(f'BM25 Recall@100: {recall_baseline:.4f}')
print(f'BM25 nDCG@10: {ndcg_baseline:.4f}')
print("\n")

Searching: 100%|██████████| 799/799 [01:24<00:00,  9.47it/s]

BM25 Recall@100: 0.8190
BM25 nDCG@10: 0.3506







## Experiment 2: Zero-shot prompting

In [14]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_zero_shot = get_query_expansion_dataset(dataset, option='zero-shot')
#data_expanded_zero_shot_test = get_query_expansion_dataset(dataset[:10], option='zero-shot')

recall_zs, ndcg_zs = run_search(searcher, data_expanded_zero_shot)
print(f'BM25 Recall@100: {recall_zs:.4f}')
print(f'BM25 nDCG@10: {ndcg_zs:.4f}')
print("\n")

Searching: 100%|██████████| 799/799 [05:35<00:00,  2.38it/s]

BM25 Recall@100: 0.7777
BM25 nDCG@10: 0.3480







## Experiment 3: One-shot prompting

In [16]:
# BM25
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_one_shot = get_query_expansion_dataset(dataset, option='one-shot')

recall_os, ndcg_os = run_search(searcher, data_expanded_one_shot)
print(f'BM25 Recall@100: {recall_os:.4f}')
print(f'BM25 nDCG@10: {ndcg_os:.4f}')
print("\n")

Searching:   0%|          | 0/799 [00:00<?, ?it/s]

Searching: 100%|██████████| 799/799 [02:15<00:00,  5.88it/s]

BM25 Recall@100: 0.7485
BM25 nDCG@10: 0.3198







## Experiment 4: Multi-shot prompting

In [17]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_multi_shot = get_query_expansion_dataset(dataset, option='multi-shot')

recall_ms, ndcg_ms = run_search(searcher, data_expanded_multi_shot)
print(f'BM25 Recall@100: {recall_ms:.4f}')
print(f'BM25 nDCG@10: {ndcg_ms:.4f}')
print("\n")

Searching: 100%|██████████| 799/799 [02:03<00:00,  6.45it/s]

BM25 Recall@100: 0.7825
BM25 nDCG@10: 0.3486







## Experiment 5: Answer prompting

In [18]:
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')
data_expanded_answer = get_query_expansion_dataset(dataset, option='answer')

recall_ans, ndcg_ans = run_search(searcher, data_expanded_answer)
print(f'BM25 Recall@100: {recall_ans:.4f}')
print(f'BM25 nDCG@10: {ndcg_ans:.4f}')
print("\n")

Searching: 100%|██████████| 799/799 [07:07<00:00,  1.87it/s]

BM25 Recall@100: 0.7361
BM25 nDCG@10: 0.3960





