In [2]:
from pyserini.search.lucene import LuceneSearcher
from datasets import load_dataset
from dotenv import load_dotenv
import os
from langchain.chat_models import ChatOpenAI, ChatCohere, ChatAnyscale

from main import get_query_expansion_dataset, run_search

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
# set up LLMs, assumes that API keys are stored in .env file
load_dotenv(override=True)
chats = {}

# OpenAI
try: 
    openai_api_key = os.getenv("OPENAI_API_KEY")
    chat_openai = ChatOpenAI(openai_api_key=openai_api_key, model="gpt-3.5-turbo")
    chats["OpenAI"] = chat_openai
except:
    print("Could not initialize OpenAI chatbot. Please store a valid API key in a .env file.")

# Cohere
try:
    cohere_api_key = os.getenv("COHERE_API_KEY")
    chat_cohere = ChatCohere(cohere_api_key=cohere_api_key)
    chats["Cohere"] = chat_cohere
except:
    print("Could not initialize Cohere chatbot. Please store a valid API key in a .env file.")

# Llama 2 (Anyscale)
try: 
    anyscale_api_key = os.getenv("ANYSCALE_API_KEY")
    chat_llama = ChatAnyscale(model_name="meta-llama/Llama-2-7b-chat-hf", anyscale_api_key=anyscale_api_key)
    chats["Llama 2"] = chat_llama
except:
    print("Could not initialize Anyscale chatbot. Please store a valid API key in a .env file.")

Could not initialize Cohere chatbot. Please store a valid API key in a .env file.


## Experiments Set 1: All prompts, english language, all models

In [4]:
# Load the Miracl dataset (english version)
dataset = load_dataset("Cohere/miracl-en-queries-22-12", split="dev")
dataset = dataset.to_pandas().to_dict(orient='records')

# set up searcher
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-en')

## 1) Baseline

In [5]:
recall_baseline, ndcg_baseline = run_search(searcher, dataset)

print(f'BM25 Recall@100: {recall_baseline:.4f}')
print(f'BM25 nDCG@10: {ndcg_baseline:.4f}')
print("\n")

Searching: 100%|██████████| 799/799 [01:37<00:00,  8.23it/s]

BM25 Recall@100: 0.8190
BM25 nDCG@10: 0.3506







## 2) Answer prompting

In [5]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='answer')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Searching:   0%|          | 0/799 [00:00<?, ?it/s]

Searching: 100%|██████████| 799/799 [06:56<00:00,  1.92it/s]


OpenAI Recall@100: 0.8984
OpenAI nDCG@10: 0.5122




Searching: 100%|██████████| 799/799 [55:57<00:00,  4.20s/it]  

Llama 2 Recall@100: 0.8424
Llama 2 nDCG@10: 0.4219







## 3) Keywords

In [7]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='keywords')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Searching:   0%|          | 0/799 [00:00<?, ?it/s]

Searching: 100%|██████████| 799/799 [01:27<00:00,  9.16it/s]


OpenAI Recall@100: 0.8433
OpenAI nDCG@10: 0.3745




Expanding queries: 100%|██████████| 799/799 [36:00<00:00,  2.70s/it]
Searching: 100%|██████████| 799/799 [19:37<00:00,  1.47s/it]

Llama 2 Recall@100: 0.8668
Llama 2 nDCG@10: 0.4314







## 5) Keywords with pseudo-relevant feedback

In [8]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='keywords-prf')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Searching:   1%|          | 5/799 [00:01<03:27,  3.83it/s]

Searching: 100%|██████████| 799/799 [02:41<00:00,  4.94it/s]


OpenAI Recall@100: 0.8337
OpenAI nDCG@10: 0.3722


Getting prf documents...


Searching: 100%|██████████| 799/799 [00:26<00:00, 29.96it/s]


Done.


Expanding queries: 100%|██████████| 799/799 [45:29<00:00,  3.42s/it]
Searching: 100%|██████████| 799/799 [29:49<00:00,  2.24s/it]  

Llama 2 Recall@100: 0.8035
Llama 2 nDCG@10: 0.3635







## 6) Chain of thought prompting

In [9]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='chain-of-thought')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Searching:   0%|          | 0/799 [00:00<?, ?it/s]

Searching: 100%|██████████| 799/799 [24:42<00:00,  1.86s/it]


OpenAI Recall@100: 0.8866
OpenAI nDCG@10: 0.4877




Searching: 100%|██████████| 799/799 [1:03:32<00:00,  4.77s/it]

Llama 2 Recall@100: 0.8334
Llama 2 nDCG@10: 0.4340







## 7) Chain of thought prompting with psuedo-relevant feedback

In [10]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='chain-of-thought-prf')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

Searching:   0%|          | 0/799 [00:00<?, ?it/s]

Searching: 100%|██████████| 799/799 [09:02<00:00,  1.47it/s]


OpenAI Recall@100: 0.8486
OpenAI nDCG@10: 0.4343




Searching: 100%|██████████| 799/799 [22:02<00:00,  1.66s/it]

Llama 2 Recall@100: 0.8285
Llama 2 nDCG@10: 0.4265







## 8) Chain of thought prompting with short answer

In [5]:
for chat_name, chat in chats.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='en', prompt='chain-of-thought-short')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

# openai time: 2:04:00 h
# llama2 time: 0:27:47 h

Searching: 100%|██████████| 799/799 [08:47<00:00,  1.51it/s]


OpenAI Recall@100: 0.9027
OpenAI nDCG@10: 0.5154




Expanding queries: 100%|██████████| 799/799 [27:47<00:00,  2.09s/it]
Searching: 100%|██████████| 799/799 [10:15<00:00,  1.30it/s]

Llama 2 Recall@100: 0.8890
Llama 2 nDCG@10: 0.4775







# Experiments Set 2: Chain-of-thought prompt, various languages, all models

## 1) French

In [6]:
# Load the Miracl dataset (english version)
dataset = load_dataset("Cohere/miracl-fr-queries-22-12", split="dev")
dataset = dataset.to_pandas().to_dict(orient='records')

# set up searcher
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-fr')
searcher.set_language('fr')

# exclude cohere model as it does not support french
chats_fr = {k: v for k, v in chats.items() if k != "Cohere"}

In [7]:
recall_baseline, ndcg_baseline = run_search(searcher, dataset)

print(f'BM25 Recall@100: {recall_baseline:.4f}')
print(f'BM25 nDCG@10: {ndcg_baseline:.4f}')
print("\n")

Searching: 100%|██████████| 343/343 [00:27<00:00, 12.26it/s]

BM25 Recall@100: 0.6528
BM25 nDCG@10: 0.1832







In [8]:
for chat_name, chat in chats_fr.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='fr', prompt='chain-of-thought-fr')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

# openai time: 18:35 min
# llama 2 time: 16:49 min

Expanding queries:   0%|          | 0/343 [00:00<?, ?it/s]

Expanding queries: 100%|██████████| 343/343 [18:35<00:00,  3.25s/it]
Searching: 100%|██████████| 343/343 [00:45<00:00,  7.57it/s]


OpenAI Recall@100: 0.7643
OpenAI nDCG@10: 0.3016




Expanding queries: 100%|██████████| 343/343 [16:49<00:00,  2.94s/it]
Searching: 100%|██████████| 343/343 [01:12<00:00,  4.74it/s]

Llama 2 Recall@100: 0.6089
Llama 2 nDCG@10: 0.2549







## 2) German

In [3]:
# Load the Miracl dataset (english version)
dataset = load_dataset("Cohere/miracl-de-queries-22-12", split="dev")
dataset = dataset.to_pandas().to_dict(orient='records')

# set up searcher
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-de')
searcher.set_language('de')

# exclude cohere model as it does not support french
chats_de = {k: v for k, v in chats.items() if k != "Cohere"}

In [4]:
recall_baseline, ndcg_baseline = run_search(searcher, dataset)

print(f'BM25 Recall@100: {recall_baseline:.4f}')
print(f'BM25 nDCG@10: {ndcg_baseline:.4f}')
print("\n")

Searching:   0%|          | 0/305 [00:00<?, ?it/s]

Searching: 100%|██████████| 305/305 [00:27<00:00, 11.11it/s]

BM25 Recall@100: 0.5724
BM25 nDCG@10: 0.2262







In [5]:
for chat_name, chat in chats_de.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='de', prompt='chain-of-thought-de')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

# openai time: 15:05 min
# llama 2 time: 11:01 min

Expanding queries:   0%|          | 0/305 [00:00<?, ?it/s]

Expanding queries: 100%|██████████| 305/305 [15:05<00:00,  2.97s/it]
Searching: 100%|██████████| 305/305 [00:39<00:00,  7.64it/s]


OpenAI Recall@100: 0.7225
OpenAI nDCG@10: 0.3442




Expanding queries: 100%|██████████| 305/305 [11:01<00:00,  2.17s/it]
Searching: 100%|██████████| 305/305 [00:53<00:00,  5.74it/s]

Llama 2 Recall@100: 0.6099
Llama 2 nDCG@10: 0.2864







## 3) Chinese

In [3]:
# Load the Miracl dataset (english version)
dataset = load_dataset("Cohere/miracl-zh-queries-22-12", split="dev")
dataset = dataset.to_pandas().to_dict(orient='records')

# set up searcher
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-zh')
searcher.set_language('zh')

# exclude cohere model as it does not support french
chats_zh = {k: v for k, v in chats.items() if k != "Cohere"}

In [4]:
recall_baseline, ndcg_baseline = run_search(searcher, dataset)

print(f'BM25 Recall@100: {recall_baseline:.4f}')
print(f'BM25 nDCG@10: {ndcg_baseline:.4f}')
print("\n")

Searching: 100%|██████████| 393/393 [00:18<00:00, 21.11it/s]

BM25 Recall@100: 0.5599
BM25 nDCG@10: 0.1801







In [5]:
for chat_name, chat in chats_zh.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='zh', prompt='chain-of-thought-zh')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

# openai time: 27:57 min
# llama 2 time: 26:40 min

Expanding queries: 100%|██████████| 393/393 [27:57<00:00,  4.27s/it]
Searching: 100%|██████████| 393/393 [00:46<00:00,  8.36it/s]


OpenAI Recall@100: 0.6929
OpenAI nDCG@10: 0.2872




Expanding queries: 100%|██████████| 393/393 [26:40<00:00,  4.07s/it]  
Searching: 100%|██████████| 393/393 [00:22<00:00, 17.38it/s]

Llama 2 Recall@100: 0.5007
Llama 2 nDCG@10: 0.1451







## 4) Spanish

In [12]:
# Load the Miracl dataset (english version)
dataset = load_dataset("Cohere/miracl-es-queries-22-12", split="dev")
dataset = dataset.to_pandas().to_dict(orient='records')

# set up searcher
searcher = LuceneSearcher.from_prebuilt_index('miracl-v1.0-es')
searcher.set_language('es')

# exclude cohere model as it does not support french
chats_es = {k: v for k, v in chats.items() if k != "Cohere"}

In [13]:
recall_baseline, ndcg_baseline = run_search(searcher, dataset)

print(f'BM25 Recall@100: {recall_baseline:.4f}')
print(f'BM25 nDCG@10: {ndcg_baseline:.4f}')
print("\n")

Searching: 100%|██████████| 648/648 [00:20<00:00, 31.34it/s]

BM25 Recall@100: 0.7018
BM25 nDCG@10: 0.3193







In [14]:
for chat_name, chat in chats_es.items():
    expanded_dataset = get_query_expansion_dataset(dataset, chat, chat_name, lang='es', prompt='chain-of-thought-es')

    recall, ndcg = run_search(searcher, expanded_dataset)
    print(f'{chat_name} Recall@100: {recall:.4f}') 
    print(f'{chat_name} nDCG@10: {ndcg:.4f}')
    print("\n")

# openai time: 27:57 min
# llama 2 time: 26:40 min

Searching: 100%|██████████| 648/648 [00:42<00:00, 15.38it/s]


OpenAI Recall@100: 0.7770
OpenAI nDCG@10: 0.4034




Searching: 100%|██████████| 648/648 [01:10<00:00,  9.25it/s]

Llama 2 Recall@100: 0.7409
Llama 2 nDCG@10: 0.3646





