# Testing Retrieval QA Chains

In [312]:
from abc import ABC, abstractmethod
from typing import Any, List, Dict, Tuple, Type

from pydantic import Field, BaseModel

from langchain.schema import Document, BaseRetriever
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.question_answering import load_qa_chain
from langchain.evaluation.qa import QAEvalChain


class TestRetriever(BaseRetriever, BaseModel, ABC):
    """Retriever that can also ingest new documents."""
    identifying_params: dict
    
    def add_documents(self, docs: List[Document], can_edit: bool = True) -> None:
        """"""
        if can_edit:
            docs = self._transform_documents(docs)
        self._insert_documents(docs)
        
    def _transform_documents(self, docs: List[Document]) -> List[Document]:
        """"""
        return docs
    
    @abstractmethod
    def _insert_documents(self, docs: List[Document]) -> None:
        """"""
        
    def cleanup(self) -> None:
        pass
    
    @property
    def name(self):
        return str(self.identifying_params)

    
class RetrieverTestCase(BaseModel, ABC):
    """"""
    name: str
    query: str
    docs: List[Document]
    can_edit_docs: bool = True
    
    @classmethod
    def from_config(cls, **kwargs: Any) -> "RetrieverTestCase":
        """"""
        return cls(**kwargs)
    
    
    @abstractmethod
    def check_retrieved_docs(self, retrieved_docs: List[Document]) -> bool:
        """"""
        
    def run(self, retriever: TestRetriever) -> Tuple[bool, dict]:
        retriever.add_documents(self.docs, can_edit=self.can_edit_docs)
        retrieved_docs = retriever.get_relevant_documents(self.query)
        passed = self.check_retrieved_docs(retrieved_docs)
        extra_dict = {"retrieved_docs": retrieved_docs}
        retriever.cleanup()
        return passed, extra_dict
    
    
class QAEvalChainTestCase(RetrieverTestCase):
    """"""
    gold_standard_answer: str
    qa_chain: BaseCombineDocumentsChain = Field(default_factory=lambda: load_qa_chain(OpenAI()))
    qa_eval_chain: QAEvalChain = Field(default_factory=lambda: QAEvalChain.from_llm(OpenAI(temperature=0)))
    
    def check_retrieved_docs(self, retrieved_docs: List[Document]) -> bool:
        qa_response = self.qa_chain({"input_documents": retrieved_docs, "question": query})
        qa_response["answer"] = self.gold_standard_answer
        return self.qa_eval_chain.predict_and_parse(qa_response)
     
        
class ExpectedSubstringsTestCase(RetrieverTestCase):
    expected_substrings: List[str]

    def check_retrieved_docs(self, retrieved_docs: List[Document]) -> bool:
        """"""
        all_text = "\n".join([d.page_content for d in retrieved_docs])
        for substring in self.expected_substrings:
            if substring not in all_text:
                return False
        return True
    
    
class ExpectedDocsTestCase(RetrieverTestCase):
    expected_: List[str]

    def check_retrieved_docs(self, retrieved_docs: List[Document]) -> bool:
        """"""
        all_text = "\n".join([d.page_content for d in retrieved_docs])
        for substring in self.expected_substrings:
            if substring not in all_text:
                return False
        return True


In [313]:
import random
import pandas as pd


class ManyDocsTestCase(ExpectedSubstringsTestCase):
    """"""
    
    @classmethod
    def from_config(cls, retrieve: int = 5, total: int = 100, seed: int = 0, **kwargs: Any) -> "ManyDocsTestCase":
        """"""
        random.seed(seed)
        name = f"Many docs ({retrieve=}, {total=})"
        text_template = "On {date} the peak temperature was {temp} degrees"
        dates = pd.date_range(start="01-01-2023", freq='D', periods=total).astype(str)
        temps = [str(random.randint(50, 80)) for _ in range(len(dates))]
        texts = [text_template.format(date=d, temp=t) for d, t in zip(dates, temps)]
        docs = [Document(page_content=t) for t in texts]
        
        sample_idxs = random.choices(range(len(dates)), k=retrieve)
        expected_dates = [dates[i] for i in sample_idxs]
        query = f"What were the peak temperatures on {', '.join(expected_dates)}?"
        return cls(name=name, query=query, docs=docs, expected_substrings=expected_dates)

In [314]:
class RedundantDocsTestCase(ExpectedSubstringsTestCase):
    """"""

    @classmethod
    def from_config(cls, **kwargs: Any) -> "RedundantDocsTestCase":
        """"""
        name = "Redundant docs"
        texts = [
            "OpenAI announces the release of GPT-5",
            "GPT-5 released by OpenAI",
            "The next-generation OpenAI GPT model is here",
            "GPT-5: OpenAI's next model is the biggest yet",
            "Sam Altman's OpenAI comes out with new GPT-5 model",
            "GPT-5 is here. What you need to know about the OpenAI model",
            "OpenAI announces ChatGPT successor GPT-5",
            "5 jaw-dropping things OpenAI's GPT-5 can do that ChatGPT couldn't",
            "OpenAI's GPT-5 Is Exciting and Scary",
            "OpenAI announces GPT-5, the new generation of AI",
            "OpenAI says new model GPT-5 is more creative and less",
            "Meta open sources new AI model, largest yet",
        ]
        docs = [Document(page_content=t) for t in texts]
        query = "What companies have recently released new models?"
        expected_substrings = ["OpenAI", "Meta"]
        return cls(name=name, docs=docs, query=query, expected_substrings=expected_substrings)
    

In [315]:
from typing import Optional
    
from langchain.document_loaders import TextLoader


class EntityLinkingTestCase(ExpectedSubstringsTestCase):
    """"""
    
    @classmethod
    def from_config(cls, filler_texts: Optional[List[str]] = None, **kwargs: Any) -> "EntityLinkingTestCase":
        """"""
        if filler_texts is None:
            filler_docs = TextLoader('../docs/modules/state_of_the_union.txt').load_and_split()
            filler_texts = [d.page_content for d in filler_docs]
        name = f"Entity linking (num_filler={len(filler_texts)})"
        texts = [
            "The founder of ReallyCoolAICompany LLC is from Louisville, Kentucky.",
            "Melissa Harkins, founder of ReallyCoolAICompany LLC, said in a recent interview that she will be stepping down as CEO.",
        ]
        texts = texts + filler_texts
        docs = [Document(page_content=t) for t in texts]
        query = "Where is Melissa Harkins from?"
        expected_substrings = ["Harkins", "Louisville"]
        return cls(name=name, docs=docs, query=query, expected_substrings=expected_substrings, can_edit_docs=False)


In [316]:
class TemporalQueryTestCase(RetrieverTestCase):
    """"""
    correct_date: str
        
    def check_retrieved_docs(self, retrieved_docs: List[Document]) -> bool:
        """"""
        return any(d.metadata["date"]==self.correct_date for d in retrieved_docs)
    
    @classmethod
    def from_config(cls, options: Optional[List[str]]=None, phrasings: Optional[List[str]]=None, num_docs: int = 200, seed: int = 0, **kwargs: Any) -> "EntityLinkingTestCase":
        """"""
        random.seed(seed)
        if options is None:
            options = ["happy", "sad", "confused", "angry", "disgusted", "scared", "thankful", "astonished", "calm"]
        if phrasings is None:
            phrasings = [
                "Today I felt {option}",
                "I felt {option} today",
                "I was really {option} today",
                "My primary emotion is {option}",
                "Everybody says I seemed so {option}"
            ]
        name = f"Temporal query ({num_docs=})"
        options_sample = random.choices(options, k=num_docs-1) + [options[0]]
        texts = [
            phrase.format(option=option) for phrase, option in zip(random.choices(phrasings, k=num_docs), options_sample)
        ]
        dates = pd.date_range(start="01-01-2023", freq='D', periods=num_docs).astype(str)
        metadatas = [{"date": d} for d in dates]
        docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
        query = f"When was the first time I mentioned being {options[0]}"
        correct_date = dates[options_sample.index(options[0])]
        return cls(name=name, docs=docs, query=query, correct_date=correct_date)
                

In [332]:
from langchain.text_splitter import CharacterTextSplitter


class RevisedStatementTestCase(ExpectedSubstringsTestCase):
    """"""
     
    @classmethod
    def from_config(cls, filler_text: Optional[str]=None, **kwargs) -> "RevisedStatementTestCase":
        """"""
        if filler_text is None:
            filler_text = TextLoader('../docs/modules/state_of_the_union.txt').load()[0].page_content
        texts = CharacterTextSplitter(chunk_size=500, chunk_overlap=0).split_text(filler_text)
        docs = [Document(page_content=t) for t in texts]
        updates = [
            "We are receiving reports of a magnitude 10 earthquake in Japan",
            "The latest reports are that the earthquake that has hit Japan is actually of magnitude 8.2",
            "Now the earthquake in Japan has been downgraded to magnitude 7",
            "Looks like the earthquake is back up to an 8",
            "The latest news is that the earthquake was of magnitude 3",
            "No no it's a magnitude 4",
            "I heard the earthquake is 6.3",
            "Or did they say the earthquake in Japan was a magnitude 6.2",
            "The Japanese arthquake is actually being recorded as a magnitude 12",
            "Sorry correction, my Japanese was poor, the magnitude of the earthquake is 2",
            "The Japanese earthquake is now being recorded as magnitude 5"
        ]
        for update, doc in zip(updates, docs):
            doc.page_content += " " + update + "."
        query = "What is the latest reported magnitude of the earthquake in Japan?"
        num_revisions = len(updates)
        name = f"Revised statement ({num_revisions=})"
        expected_substrings = ["5"]
        return cls(name=name, docs=docs, query=query, expected_substrings=expected_substrings)

In [318]:
class LongTextOneFactTestCase(ExpectedSubstringsTestCase):
    """"""
    
    @classmethod
    def from_config(cls, filler_text: Optional[str]=None, **kwargs) -> "LongTextOneFactTestCase":
        if filler_text is None:
            filler_text = TextLoader('../docs/modules/state_of_the_union.txt').load()[0].page_content
        fact = "We've just received reports of a purple monkey invading the White House."
        filler_split = filler_text.split(". ")
        all_text = ". ".join(filler_split[:len(filler_split)//2] + [fact] + filler_split[len(filler_split)//2:])
        doc = Document(page_content=all_text)
        text_len = len(all_text)
        name = f"Fact in long text ({text_len=})"
        query = "What color was the animal that was mentioned?"
        expected_substrings = ["purple"]
        return cls(name=name, docs=[doc], query=query, expected_substrings=expected_substrings)

In [319]:
import re

import numpy as np

from langchain.document_loaders import TextLoader


def load_transcript():
    interview = TextLoader("/Users/bagatur/Downloads/Ian_Goodfellow--Generative_Adversarial_Networks_(GANs)-Artificial_Intelligence_(AI)_Podcast-April_18_2019.md").load()[0].page_content
    speaker_tmpl = "\*\*\[{name}\]\*\*"
    splits = re.split(speaker_tmpl.format(name="(.*)"), interview.strip())
    # Madeup times
    times = np.cumsum([len(splits[i].split()) for i in range(2, len(splits), 2)]) / 2.5
    docs = [Document(page_content=splits[i+1].strip(), metadata={"speaker": splits[i], "statement_index": i // 2, "time": times[i // 2]}) for i in range(1, len(splits), 2)]
    return docs

In [320]:
from pydantic import Field

class PodcastTestCase(RetrieverTestCase):
    docs: List[Document] = Field(default_factory=load_transcript)
        
        
class FirstMentionTestCase(PodcastTestCase, ExpectedSubstringsTestCase):
    name: str = "Podcast First Mention"
    query: str = "What was the first mention of deep learning?"
    expected_substrings: List[str] = Field(default_factory=lambda: ['"Deep Learning" book'])
        
        
class SpeakerTestCase(PodcastTestCase, ExpectedSubstringsTestCase):
    name: str = "Podcast Reference to Speaker"
    query: str = "What did Ian say about how he came up with the idea for GANs?"
    expected_substrings: List[str] = Field(default_factory=lambda: ["drinking helped a little bit"])
    


# Candidate retrieval systems
Now that we've defined some test cases, let's create a couple candidate retrieval systems to evaluate and compare. For this we'll need to define a function that returns a Retriever given a set of documents.

In [321]:
from langchain.text_splitter import TextSplitter
from langchain.vectorstores.base import VectorStoreRetriever


class VectorStoreTestRetriever(TestRetriever):
    base_retriever: VectorStoreRetriever
    text_splitter: Optional[TextSplitter] = None
        
    class Config:
        arbitrary_types_allowed = True
        
    def get_relevant_documents(self, query):
        return self.base_retriever.get_relevant_documents(query)
    
    def aget_relevant_documents(self, query):
        raise NotImplementedError
    
    def _insert_documents(self, docs: List[Document]) -> None:
        self.base_retriever.vectorstore.add_documents(docs)
        
    def _transform_documents(self, docs: List[Document]) -> List[Document]:
        if self.text_splitter is None:
            return docs
        return self.text_splitter.split_documents(docs)

In [333]:
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma, FAISS
        
class ChromaTestRetriever(VectorStoreTestRetriever):
    base_retriever: BaseRetriever = Field(default_factory=lambda: Chroma(embedding_function=OpenAIEmbeddings()).as_retriever())
    text_splitter: TextSplitter = Field(default_factory=lambda: CharacterTextSplitter(chunk_size=1000, chunk_overlap=0))
    identifying_params = {"chunk_size": 1000, "search": "similarity", "vectorstore": "Chroma", "k": 4}
    
class ChromaTestRetrieverMMR(VectorStoreTestRetriever):
    base_retriever: BaseRetriever = Field(default_factory=lambda: Chroma(embedding_function=OpenAIEmbeddings()).as_retriever(search_type="mmr", search_kwargs={"k": 6, "fetch_k": 12}))
    text_splitter: TextSplitter = Field(default_factory=lambda: CharacterTextSplitter(chunk_size=200, chunk_overlap=0))
    identifying_params = {"chunk_size": 200, "search": "mmr", "vectorstore": "Chroma", "k": 6}

class ChromaTestRetrieverStuffMetadata(ChromaTestRetrieverMMR):
    identifying_params = {"chunk_size": 200, "search": "mmr", "vectorstore": "Chroma", "k": 6, "metadata": "included_in_content"}

    def _transform_documents(self, docs: List[Document]) -> List[Document]:
        docs = super()._transform_documents(docs)
        for doc in docs:
            doc.page_content = f"Document metadata: {doc.metadata}\n\n" + doc.page_content
        return docs
    
class FAISSTestRetriever(VectorStoreTestRetriever):
    base_retriever: BaseRetriever = Field(default_factory=lambda: FAISS.from_texts(["foo"], OpenAIEmbeddings()).as_retriever())
    text_splitter: TextSplitter = Field(default_factory=lambda: CharacterTextSplitter(chunk_size=1000, chunk_overlap=0))
    identifying_params = {"chunk_size": 1000, "search": "similarity", "vectorstore": "FAISS", "k": 4}

# Run test cases against each retrieval system

In [334]:
test_cases = [
    (ManyDocsTestCase, {}), 
    (RedundantDocsTestCase, {}), 
    (EntityLinkingTestCase, {}),
    (TemporalQueryTestCase, {}),
    (RevisedStatementTestCase, {}),
    (LongTextOneFactTestCase, {}),
    (FirstMentionTestCase, {}),
    (SpeakerTestCase, {}),
]
test_retrievers = [
    ChromaTestRetriever,
    ChromaTestRetrieverMMR,
    ChromaTestRetrieverStuffMetadata,
    FAISSTestRetriever,
]

In [324]:
results = {}
for retriever_cls in test_retrievers:
    retriever_name = retriever_cls().name
    results[retriever_name] = {}
    for test_case_cls, config in test_cases:
        retriever = retriever_cls()
        test_case = test_case_cls.from_config(**config)
        results[retriever_name][test_case.name] = test_case.run(retriever)
        

Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Using embedded DuckDB without persistence: data will be transient
Created a 

In [327]:
results

{"{'chunk_size': 1000, 'search': 'similarity', 'vectorstore': 'Chroma', 'k': 4}": {'Many docs (retrieve=5, total=100)': (False,
   {'retrieved_docs': [Document(page_content='On 2023-02-14 the peak temperature was 71 degrees', metadata={}),
     Document(page_content='On 2023-03-14 the peak temperature was 62 degrees', metadata={}),
     Document(page_content='On 2023-02-20 the peak temperature was 63 degrees', metadata={}),
     Document(page_content='On 2023-02-03 the peak temperature was 72 degrees', metadata={})]}),
  'Redundant docs': (True,
   {'retrieved_docs': [Document(page_content='Meta open sources new AI model, largest yet', metadata={}),
     Document(page_content="Sam Altman's OpenAI comes out with new GPT-5 model", metadata={}),
     Document(page_content='The next-generation OpenAI GPT model is here', metadata={}),
     Document(page_content="GPT-5: OpenAI's next model is the biggest yet", metadata={})]}),
  'Entity linking (num_filler=11)': (True,
   {'retrieved_docs': 

In [330]:
# Visualize results
import pandas as pd

def highlight(s):
    res = []
    for s_ in s:
        if pd.isnull(s_):
            res.append('background-color: whitesmoke')
        elif s_:
            res.append('background-color: honeydew')
        else:
            res.append('background-color: mistyrose')
    return res

def visualize_results(results):
    result_df = pd.DataFrame([{"System": retriever, **{k: v[0] for k, v in test_cases.items()}} for retriever, test_cases in results.items()])
    return result_df.style.apply(highlight, subset=result_df.columns.drop("System"))
visualize_results(results)

Unnamed: 0,System,"Many docs (retrieve=5, total=100)",Redundant docs,Entity linking (num_filler=11),Temporal query (num_docs=200),Revised statement (num_revisions=6),Fact in long text (text_len=38613),Podcast First Mention,Podcast Reference to Speaker
0,"{'chunk_size': 1000, 'search': 'similarity', 'vectorstore': 'Chroma', 'k': 4}",False,True,True,False,True,False,True,False
1,"{'chunk_size': 200, 'search': 'mmr', 'vectorstore': 'Chroma', 'k': 6}",False,True,True,False,True,False,False,False
2,"{'chunk_size': 1000, 'search': 'similarity', 'vectorstore': 'FAISS', 'k': 4}",False,True,True,False,True,True,True,False


Improvement ideas
1. hybrid of sparse and dense embeddings
2. add metadata to content
3. add ability to filter based on metadata or page_content
4. dynamically determine k (keep asking for more documents as needed)
5. don't only store raw docs, store summaries / extracted information as well
6. use MMR
7. let retrieval system determine query
8. recency bias

In [509]:
# import datetime
# import pinecone
# from pinecone_text.sparse import BM25Encoder
# from langchain.retrievers import PineconeHybridSearchRetriever

# # os.environ["PINECONE_API_KEY"] = "get-ur-own-api-key"
# # os.environ["PINECONE_ENVIRONMENT"] = "the-moon"
# api_key = os.getenv("PINECONE_API_KEY")
# env = os.getenv("PINECONE_ENVIRONMENT")

# pinecone.init(api_key=api_key, enviroment=env)

# def get_retriever(documents):
#     timestamp = int(datetime.datetime.now().timestamp())
#     index_name = f"hybrid-search-{timestamp}"

#     # create the index
#     pinecone.create_index(
#        name = index_name,
#        dimension = 1536,  # dimensionality of dense model
#        metric = "dotproduct",  # sparse values supported only for dotproduct
#        pod_type = "s1",
#        metadata_config={"indexed": []}  # see explaination above
#     )
#     index = pinecone.Index(index_name)
#     # or from pinecone_text.sparse import SpladeEncoder if you wish to work with SPLADE

#     # use default tf-idf values
#     bm25_encoder = BM25Encoder().default()
#     retriever = PineconeHybridSearchRetriever(embeddings=embeddings, sparse_encoder=bm25_encoder, index=index)
    
#     text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
#     documents = text_splitter.split_documents(documents)
#     for doc in documents:
#         doc.page_content = f"Document metadata: {doc.metadata}\n\n" + doc.page_content
#     retriever.add_texts([doc.page_content for doc in documents])
#     return retriever

# def cleanup_retriever(retriever):
#     index_name = retriever.index.configuration.server_variables['index_name']
#     pinecone.delete_index(index_name)

# candidate = {
#     "params": {"search": "pinecone hybrid", "k": 4, "chunk_size": 100},
#     "get_retriever": get_retriever,
#     "cleanup_retriever": cleanup_retriever
# }
# retrieval_candidates.append(candidate)

WhoAmIResponse(username='805b516', user_label='default', projectname='6ddd519')