In [1]:
from pyserini.search.lucene import LuceneSearcher
from datasets import load_dataset
from dotenv import load_dotenv
import os
from langchain.chat_models import ChatOpenAI, ChatCohere, ChatAnyscale

from main import get_query_expansion_dataset, run_search

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set up LLMs, assumes that API keys are stored in .env file
load_dotenv(override=True)
chats = {}

# OpenAI
try: 
    openai_api_key = os.getenv("OPENAI_API_KEY")
    chat_openai = ChatOpenAI(openai_api_key=openai_api_key, model="gpt-3.5-turbo")
    chats["OpenAI"] = chat_openai
except:
    print("Could not initialize OpenAI chatbot. Please store a valid API key in a .env file.")

# Cohere
try:
    cohere_api_key = os.getenv("COHERE_API_KEY")
    chat_cohere = ChatCohere(cohere_api_key=cohere_api_key)
    chats["Cohere"] = chat_cohere
except:
    print("Could not initialize Cohere chatbot. Please store a valid API key in a .env file.")

# Llama 2 (Anyscale)
try: 
    anyscale_api_key = os.getenv("ANYSCALE_API_KEY")
    chat_llama = ChatAnyscale(model_name="meta-llama/Llama-2-7b-chat-hf", anyscale_api_key=anyscale_api_key)
    chats["Llama 2"] = chat_llama
except:
    print("Could not initialize Anyscale chatbot. Please store a valid API key in a .env file.")

## Experiments Set 1: All prompts, english language, all models

In [3]:
# Load the Miracl dataset (english version)
dataset = load_dataset("Cohere/miracl-en-queries-22-12", split="dev")
dataset = dataset.to_pandas().to_dict(orient='records')

# set up searcher
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')

## 1) Baseline

In [4]:
recall_baseline, ndcg_baseline = run_search(searcher, dataset)

print(f'BM25 Recall@100: {recall_baseline:.4f}')
print(f'BM25 nDCG@10: {ndcg_baseline:.4f}')
print("\n")

## 2) Query2Doc Zero Shot Prompting

In [6]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='q2d-zs')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Expanding queries:   0%|          | 0/5 [00:00<?, ?it/s]

Expanding queries: 100%|██████████| 5/5 [01:10<00:00, 14.11s/it]
Searching: 100%|██████████| 5/5 [00:14<00:00,  2.83s/it]


OpenAI Recall@100: 0.8333
OpenAI nDCG@10: 0.2921




Expanding queries: 100%|██████████| 5/5 [01:06<00:00, 13.30s/it]
Searching: 100%|██████████| 5/5 [00:17<00:00,  3.51s/it]


Cohere Recall@100: 0.7667
Cohere nDCG@10: 0.3049




Expanding queries: 100%|██████████| 5/5 [00:27<00:00,  5.52s/it]
Searching: 100%|██████████| 5/5 [00:08<00:00,  1.77s/it]

Llama 2 Recall@100: 0.7667
Llama 2 nDCG@10: 0.3563







## 3) Q2D zero shot prompting with pseudo-relevant feedback

In [None]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='q2d-zs-prf')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Getting prf documents...


Searching: 100%|██████████| 5/5 [00:00<00:00, 72.71it/s]


Done.


Expanding queries: 100%|██████████| 5/5 [00:52<00:00, 10.54s/it]
Searching: 100%|██████████| 5/5 [00:08<00:00,  1.60s/it]

BM25 Recall@100: 0.9000
BM25 nDCG@10: 0.3924







## 4) Keywords with zero-shot prompting

In [8]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='q2e-zs')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Expanding queries:   0%|          | 0/799 [00:00<?, ?it/s]

Expanding queries: 100%|██████████| 799/799 [54:50<00:00,  4.12s/it]  
Searching: 100%|██████████| 799/799 [01:25<00:00,  9.38it/s]

BM25 Recall@100: 0.8433
BM25 nDCG@10: 0.3745







## 5) Keywords with pseudo-relevant feedback

In [9]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='q2e-zs-prf')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Getting prf documents...


Searching:  11%|█         | 89/799 [00:02<00:28, 25.03it/s]

Searching: 100%|██████████| 799/799 [00:26<00:00, 30.04it/s]


Done.


Expanding queries: 100%|██████████| 799/799 [1:02:36<00:00,  4.70s/it]
Searching: 100%|██████████| 799/799 [02:56<00:00,  4.53it/s]

BM25 Recall@100: 0.8337
BM25 nDCG@10: 0.3722







## 6) Chain of thought prompting

In [11]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='chain-of-thought')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Searching:   0%|          | 0/799 [00:00<?, ?it/s]

Searching: 100%|██████████| 799/799 [30:35<00:00,  2.30s/it]

BM25 Recall@100: 0.8866
BM25 nDCG@10: 0.4877







## 7) Chain of thought prompting with psuedo-relevant feedback

In [7]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='chain-of-thought-prf')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Getting prf documents...


Searching: 100%|██████████| 799/799 [00:30<00:00, 25.81it/s]


Done.


Expanding queries: 100%|██████████| 799/799 [1:41:54<00:00,  7.65s/it]  
Searching: 100%|██████████| 799/799 [09:45<00:00,  1.36it/s]

BM25 Recall@100: 0.8486
BM25 nDCG@10: 0.4343







# Experiments Set 2: Chain-of-thought prompt, various languages, all models

## 1) French

In [3]:
# Load the Miracl dataset (english version)
dataset = load_dataset("Cohere/miracl-fr-queries-22-12", split="dev")
dataset = dataset.to_pandas().to_dict(orient='records')

# set up searcher
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-fr')

In [5]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset[:10], chat, chat_name, lang='fr', prompt='chain-of-thought-fr')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Expanding queries: 100%|██████████| 10/10 [05:21<00:00, 32.13s/it]
Searching: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]


OpenAI Recall@100: 0.4667
OpenAI nDCG@10: 0.1103




Expanding queries: 100%|██████████| 10/10 [02:48<00:00, 16.85s/it]
Searching: 100%|██████████| 10/10 [00:06<00:00,  1.49it/s]


Cohere Recall@100: 0.2967
Cohere nDCG@10: 0.0219




Expanding queries: 100%|██████████| 10/10 [01:25<00:00,  8.51s/it]
Searching: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]

Llama 2 Recall@100: 0.2583
Llama 2 nDCG@10: 0.0429





