In [None]:
import os
from langchain_community.retrievers import WikipediaRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
import torch
import random
import numpy as np
import pandas as pd
import ast
from datasets import load_dataset
import random
import sys
import time

os.environ["OPENAI_API_KEY"] = "your api key"




In [2]:
# This part is for normal RAG pipeline

'''
from langchain.embeddings import OpenAIEmbeddings
from scipy.spatial.distance import cosine
import random

class CustomWikipediaRetriever:
    def __init__(self, k=3, top_m=15, embedding_model=None, threshold=0.7):
        self.retriever = WikipediaRetriever(top_k_results=top_m, doc_content_chars_max=2500)
        self.embedding_model = embedding_model or OpenAIEmbeddings()
        self.k = k # Number of documents to sample 
        self.docs = None
        self.doc_embeddings = None
        self.threshold = threshold

    def retrieve(self, query):
        docs = self.retriever.get_relevant_documents(query)
        doc_embeddings = [self.embedding_model.embed_documents([doc.page_content])[0] for doc in docs]
        doc_embeddings = [emb / np.linalg.norm(emb) for emb in doc_embeddings]
        
        selected_idx = [random.choice(range(len(docs)))]
        selected_embeddings = [doc_embeddings[selected_idx[0]]]

        remaining_indices = list(set(range(len(docs))) - {selected_idx[0]})
        while len(selected_idx) < self.k and remaining_indices:
            candidate_idx = random.choice(remaining_indices)
            candidate_embedding = doc_embeddings[candidate_idx]

            similarity_ok = all((1+cosine(candidate_embedding, emb))/2.0 < self.threshold for emb in selected_embeddings)

            if similarity_ok:
                selected_idx.append(candidate_idx)
                selected_embeddings.append(candidate_embedding)

            remaining_indices.remove(candidate_idx)

        selected_docs = [docs[idx] for idx in selected_idx]
        return selected_docs


class QAChain:
    def __init__(self, k: int = 3, top_m: int = 15, threshold: float = 0.7, embedding_model=None):
        self.retriever = CustomWikipediaRetriever(top_m=top_m, k=k, threshold=threshold, embedding_model=embedding_model)
        self.prompt = ChatPromptTemplate.from_template(
            """Answer the question based only on the context provided as short as possible.

            Context: {context}

            Question: {question}"""
        )
        self.llm = ChatOpenAI(model="gpt-3.5-turbo")
        
        self.chain = (
            {"context": self.retrieve_docs, "question": RunnablePassthrough()}
            | self.prompt
            | self.llm
            | StrOutputParser()
        ) 

    def retrieve_docs(self, query):
        
        # for normal RAG pipeline
        docs = self.retriever.retrieve(query)
        return "\n\n".join(doc.page_content for doc in docs)
    
    
    def answer(self, question: str):
        return self.chain.invoke(question)

qa_chain = QAChain(k=3)
query = "What is the capital of France?"
answer = qa_chain.answer(query)
print(answer)
'''


'\nfrom langchain.embeddings import OpenAIEmbeddings\nfrom scipy.spatial.distance import cosine\nimport random\n\nclass CustomWikipediaRetriever:\n    def __init__(self, k=3, top_m=15, embedding_model=None, threshold=0.7):\n        self.retriever = WikipediaRetriever(top_k_results=top_m, doc_content_chars_max=2500)\n        self.embedding_model = embedding_model or OpenAIEmbeddings()\n        self.k = k # Number of documents to sample \n        self.docs = None\n        self.doc_embeddings = None\n        self.threshold = threshold\n\n    def retrieve(self, query):\n        docs = self.retriever.get_relevant_documents(query)\n        doc_embeddings = [self.embedding_model.embed_documents([doc.page_content])[0] for doc in docs]\n        doc_embeddings = [emb / np.linalg.norm(emb) for emb in doc_embeddings]\n        \n        selected_idx = [random.choice(range(len(docs)))]\n        selected_embeddings = [doc_embeddings[selected_idx[0]]]\n\n        remaining_indices = list(set(range(len(

In [3]:
# For normal RAG pipeline
'''

qa_chain.answer("who is the founder of quantum physics")
'''

'\n\nqa_chain.answer("who is the founder of quantum physics")\n'

Using similarity threshold as parameter:
similarity_ok = all((1+cosine(candidate_embedding, emb))/2.0 < self.similarity_threshold for emb in selected_embeddings) 

Using distance threshold as parameter:
distance_ok = all((1-cosine(candidate_embedding, emb))/2.0 >= self.distance_threshold for emb in selected_embeddings) 

The above two definitions are equivalent:
distance_threshold =  1 - similarity_threshold

In [None]:
# This part is for evaluation only

from langchain.embeddings import OpenAIEmbeddings
from scipy.spatial.distance import cosine
import random

class CustomWikipediaRetriever:
    def __init__(self, k=3, top_m=15, embedding_model=None):
        self.retriever = WikipediaRetriever(top_k_results=top_m, doc_content_chars_max=2500)
        self.embedding_model = embedding_model or OpenAIEmbeddings()
        self.top_m = top_m
        self.k = k # Number of documents to sample 
        self.docs = None
        self.doc_embeddings = None
    
    def retrieve_with_embeddings(self, query):
        docs = self.retriever.get_relevant_documents(query)
        
        doc_embeddings = [self.embedding_model.embed_documents([doc.page_content])[0] for doc in docs]
        doc_embeddings = [emb/np.linalg.norm(emb) for emb in doc_embeddings]
        
        self.docs = docs
        self.doc_embeddings = doc_embeddings
        if not self.docs or len(self.docs)<self.top_m:
            return False ## To deal with no sufficient articles in corpus
        return True

    def retrieve(self, threshold):
        if not self.docs:
            return []
        selected_idx = [random.choice(range(len(self.docs)))]
        selected_embeddings = [self.doc_embeddings[selected_idx[0]]]
        
        remaining_indices = list(set(range(len(self.docs))) - {selected_idx[0]})
        while len(selected_idx) < self.k and remaining_indices:
            candidate_idx = random.choice(remaining_indices)
            candidate_embedding = self.doc_embeddings[candidate_idx]

            similarity_ok = all((1-cosine(candidate_embedding, emb)/2.0) < threshold for emb in selected_embeddings)
            
            if similarity_ok:
                selected_idx.append(candidate_idx)
                selected_embeddings.append(candidate_embedding)

            remaining_indices.remove(candidate_idx)

        selected_docs = [self.docs[idx] for idx in selected_idx]
        return selected_docs


class QAChain:
    def __init__(self, docs=None):
        self.prompt = ChatPromptTemplate.from_template(
            """Answer the question based only on the context provided as short as possible.

            Context: {context}

            Question: {question}"""
        )
        self.llm = ChatOpenAI(model="gpt-3.5-turbo")
        self.docs = docs # for evaluation only
        
        self.chain = (
            {"context": self.retrieve_docs, "question": RunnablePassthrough()}
            | self.prompt
            | self.llm
            | StrOutputParser()
        ) 


    def retrieve_docs(self, query):
        if not self.docs:
            return ''
        return "\n\n".join(doc.page_content for doc in self.docs)

    
    def answer(self, question: str):
        return self.chain.invoke(question)



In [5]:
th_test = [0, 0.3, 0.5, 0.6, 0.7, 0.75, 0.8, 0.83, 0.85, 0.88, 0.9, 0.92, 0.95, 0.98, 1]
num_sample = 1000
num_times = 3
rng = 42

In [6]:
from get_dataset import get_nq, get_tqa, get_squad, get_asqa

nq = get_nq()
tqa = get_tqa()
squad = get_squad()
asqa = get_asqa()

datasets = ["NQ", "TriviaQA", "SQuAD", "ASQA"]

name_to_ds = {"NQ": nq, "TriviaQA": tqa, "SQuAD": squad, "ASQA": asqa}


In [7]:
indices = {}
random.seed(rng)
    
for name in datasets:
    ds = name_to_ds[name]
    print(name)
    selected_idx = random.sample(range(ds.shape[0]), num_sample)
    indices[name] = selected_idx
    print("selected indices are: ", selected_idx)
    print("\n")
        

NQ
selected indices are:  [1309, 228, 51, 1518, 563, 501, 457, 285, 1508, 209, 1385, 1516, 1116, 178, 1209, 864, 65, 61, 191, 447, 476, 1034, 1232, 54, 1149, 407, 1466, 1330, 1436, 1787, 859, 451, 919, 1206, 569, 1657, 13, 1554, 1650, 326, 1429, 865, 696, 1765, 318, 440, 1563, 689, 1790, 189, 778, 198, 735, 1735, 704, 1236, 541, 1652, 88, 1494, 940, 1098, 255, 775, 161, 1130, 600, 1698, 1287, 1266, 740, 1182, 393, 1442, 142, 93, 1354, 466, 1583, 592, 163, 1779, 206, 1749, 1756, 928, 1301, 1708, 747, 333, 758, 727, 429, 1372, 546, 1437, 1399, 1327, 146, 1247, 1300, 350, 1093, 1493, 1794, 334, 946, 777, 552, 1310, 1409, 1140, 449, 1402, 664, 1573, 1589, 114, 469, 1783, 1648, 646, 821, 548, 135, 432, 1161, 1470, 644, 435, 1342, 1022, 810, 1316, 939, 292, 542, 1792, 505, 1525, 1775, 1103, 538, 1529, 1197, 877, 1195, 817, 741, 1687, 283, 1043, 1010, 186, 1547, 96, 224, 313, 1285, 327, 1622, 1393, 1784, 1221, 130, 788, 781, 1220, 958, 1083, 514, 1133, 23, 1638, 1476, 234, 1396, 1099, 1537, 1

In [None]:

def evaluate(name, selected_idx):
    
    ds = name_to_ds[name]
    os.makedirs(f"threshold_results/{name}", exist_ok=True)

    print(name)
    print("evaluation progress: ", end="")
    
    n = len(th_test)

    candidates = [[] for _ in range(n)]
    references = []

    with open(f'threshold_results/{name}/references.jsonl', 'a') as ref_file:
        
        # Open separate candidate files for each j
        cand_files = [open(f'threshold_results/{name}/cand_{j}.jsonl', 'a') for j in range(n)]
        
        for idx in selected_idx:
            print(idx, end=", ")
            q = ds.loc[idx, "question"]
            a = ds.loc[idx, "answer"]
            try:
                retriever = CustomWikipediaRetriever()
                if not retriever.retrieve_with_embeddings(q):
                    continue
                for j in range(n):
                    th = th_test[j]
                    c = []
                    for _ in range(num_times):
                        docs = retriever.retrieve(th)
                        qa_chain = QAChain(docs=docs)
                        answer = qa_chain.answer(q)
                        c.append(answer)
                        time.sleep(0.03)
                        
                    candidates[j].append(c) 
                    cand_files[j].write(f"{c}\n")
                    cand_files[j].flush()
                    
                references.append(a)
                ref_file.write(f"{a}\n")
                ref_file.flush()
            except: ## To deal with fetching API too frequently
                time.sleep(5)
                try:
                    retriever = CustomWikipediaRetriever()
                    if not retriever.retrieve_with_embeddings(q): 
                        continue
                    for j in range(n):
                        th = th_test[j]
                        c = []
                        for _ in range(num_times):
                            docs = retriever.retrieve(th)
                            qa_chain = QAChain(docs=docs)
                            answer = qa_chain.answer(q)
                            c.append(answer)
                            time.sleep(0.03)

                        candidates[j].append(c) 
                        cand_files[j].write(f"{c}\n")
                        cand_files[j].flush()
                        
                    references.append(a)
                    ref_file.write(f"{a}\n")
                    ref_file.flush()
                except:
                    continue
        
        for file in cand_files:
            file.close()
        ref_file.close()      

    

In [None]:
for ds in datasets:
    evaluate(ds, indices[ds])
    print("\n")
    

NQ
evaluation progress: 1309, 

  self.embedding_model = embedding_model or OpenAIEmbeddings()
  docs = self.retriever.get_relevant_documents(query)


228, 51, 1518, 563, 501, 457, 285, 1508, 209, 1385, 1516, 1116, 178, 1209, 864, 65, 61, 191, 447, 476, 1034, 1232, 54, 1149, 407, 1466, 1330, 1436, 1787, 859, 451, 919, 1206, 569, 1657, 13, 1554, 1650, 326, 1429, 865, 696, 1765, 318, 440, 1563, 689, 1790, 189, 778, 198, 735, 1735, 704, 1236, 541, 1652, 88, 1494, 940, 



  lis = BeautifulSoup(html).find_all('li')


1098, 255, 775, 161, 1130, 600, 1698, 1287, 1266, 740, 1182, 393, 1442, 142, 93, 1354, 466, 1583, 592, 163, 1779, 206, 1749, 1756, 928, 1301, 1708, 747, 333, 758, 727, 429, 1372, 546, 1437, 1399, 1327, 146, 1247, 1300, 350, 1093, 1493, 1794, 334, 946, 777, 



  lis = BeautifulSoup(html).find_all('li')


552, 1310, 1409, 1140, 449, 1402, 664, 1573, 1589, 114, 469, 1783, 1648, 646, 821, 548, 135, 432, 1161, 1470, 644, 435, 1342, 1022, 810, 1316, 939, 292, 542, 1792, 505, 1525, 1775, 1103, 538, 1529, 1197, 877, 1195, 817, 741, 1687, 283, 1043, 1010, 186, 1547, 96, 224, 313, 1285, 327, 1622, 1393, 1784, 1221, 130, 788, 781, 1220, 958, 1083, 514, 1133, 23, 1638, 1476, 234, 1396, 1099, 1537, 1705, 1574, 1312, 1757, 1798, 601, 890, 323, 929, 6, 1478, 1473, 539, 1025, 1560, 365, 1039, 217, 1280, 611, 1308, 

