# Testing Retrieval QA Chains
In this notebooks we'll evaluate a retrieval QA chain on a set of edge case scenarios. We'll compare different chains to see which parameters and architectures are the most performant on these edge cases.

In [426]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

# Retrieval test cases
Each test case should contain a list of corpus texts to be stored, optional metadata for the texts, the query used for testing, and the expected chain output.

In [427]:
# 1. Need to retrieve many documents
import random

import pandas as pd


def get_retrieve_many_docs_test_case(num_samples=5, num_total=100):
    """Test if retriever can return as many documents as query requires."""
    text_template = "On {date} the peak temperature was {temp} degrees"
    dates = pd.date_range(start="01-01-2023", freq='D', periods=num_total).astype(str)
    temps = [str(random.randint(50, 80)) for _ in range(len(dates))]
    sample_idxs = random.sample(range(len(dates)), k=num_samples)

    texts = [text_template.format(date=d, temp=t) for d, t in zip(dates, temps)]
    query = f"What were the peak temperatures on {', '.join([dates[i] for i in sample_idxs])}?"
    expected = f"The peak temperatures were {', '.join(temps[i] for i in sample_idxs)} degrees"
    def retriever_check(retrieved_docs) -> bool:
        retrieved_texts = set(d.page_content for d in retrieved_docs)
        return all(texts[i] in retrieved_texts for i in sample_idxs)
    return {
        "name": f"Retrieve {num_samples} documents",
        "texts": texts,
        "query": query,
        "expected": expected,
        "retriever_check": retriever_check
    }

In [428]:
# 2. Redundant docs
def get_redundant_docs_test_case():
    """Test if retriever can handle many redundant documents."""
    texts = [
        "OpenAI announces the release of GPT-5",
        "GPT-5 released by OpenAI",
        "The next-generation OpenAI GPT model is here",
        "GPT-5: OpenAI's next model is the biggest yet",
        "Sam Altman's OpenAI comes out with new GPT-5 model",
        "GPT-5 is here. What you need to know about the OpenAI model",
        "OpenAI announces ChatGPT successor GPT-5",
        "5 jaw-dropping things OpenAI's GPT-5 can do that ChatGPT couldn't",
        "OpenAI's GPT-5 Is Exciting and Scary",
        "OpenAI announces GPT-5, the new generation of AI",
        "OpenAI says new model GPT-5 is more creative and less",
        "Meta open sources new AI model, largest yet",
    ]
    query = "What companies have recently released new models?"
    expected = "OpenAI and Meta have recently released new models"
    def retriever_check(retrieved_docs) -> bool:
        return any("Meta" in d.page_content for d in retrieved_docs)
    return {
        "name": "Redundant documents",
        "texts": texts,
        "query": query,
        "expected": expected,
        "retriever_check": retriever_check
    }

In [458]:
# 3. Information about entity split across documents using different entity names
def get_disambiguate_and_combine_test_case():
    texts = [
        "The founder of ReallyCoolAICompany LLC is from Louisville, Kentucky.",
        "Melissa Harkins, founder of ReallyCoolAICompany LLC, said in a recent interview that she will be stepping down as CEO.",
        state_of_union,
    ]
    query = "Where is Melissa Harkins from?"
    expected = "Melissa Harkins is from Louisville, Kentucky"
    def retriever_check(retrieved_docs) -> bool:
        for keyword in ("Melissa Harkins", "Louisville, Kentucky"):
            found = False
            for doc in retrieved_docs:
                if keyword in doc.page_content:
                    found = True
                    break
            if not found:
                return False
        return True
            
    return {
        "name": "Entity disambiguation",
        "texts": texts,
        "query": query,
        "expected": expected,
        "retriever_check": retriever_check
    }

In [453]:
# 4. Metadata question (temporal)
def get_temporal_test_case(num_total=100):
    scores = [random.randint(1, 4) for _ in range(num_total - 1)] + [4]
    dates = pd.date_range(start="01-01-2023", freq='D', periods=num_total).astype(str)
    texts = [
        f"Daily log: My energy levels were a {score} out of 4 today" for score in scores
    ]
    metadatas = [{"date": d} for d in dates]
    query = "When was the first time I reported an energy level of 4?"
    expected = f"The first time was {dates[scores.index(4)]}"
    def retriever_check(retrieved_docs) -> bool:
        retrieved_dates = set(d.metadata.get('date') for d in retrieved_docs)
        return dates[scores.index(4)] in retrieved_dates
    return {
        "name": "Metadata question (temporal information)",
        "texts": texts,
        "query": query,
        "expected": expected,
        "metadatas": metadatas,
        "retriever_check": retriever_check
    }

In [431]:
# 5. Store a single text which updates a fact multiple times.
def get_revised_fact_test_case():
    split_sou = state_of_union.split(". ")
    len_split_sou = len(split_sou)
    updates = [
        "We are receiving reports of a magnitude 10 earthquake in Japan",
        "The latest reports are that the earthquake that has hit Japan is actually of magnitude 8.5",
        "Now the earthquake in Japan has been downgraded to magnitude 7",
        "Looks like the earthquake is back up to an 8",
        "The latest news is that the earthquake was of magnitude 3",
        "The Japanese earthquake is now being recorded as magnitude 5"
    ]
    chunk_size = len_split_sou // len(updates)
    sou_chunks = [split_sou[idx: idx + chunk_size] for idx in range(0, len_split_sou, chunk_size)]
    for update, chunk in zip(updates, sou_chunks):
        chunk.append(update)
    texts = [". ".join([s for chunk in sou_chunks for s in chunk])]
    query = "What is the latest reported magnitude of the earthquake in Japan?"
    expected = "The latest reported magnitude of the earthquake is 5"
    def retriever_check(retrieved_docs) -> bool:
        retrieved_texts = [d.page_content for d in retrieved_docs]
        return any("magnitude 5" in t for t in retrieved_texts)
    return {
        "name": "Revised statements",
        "texts": texts,
        "query": query,
        "expected": expected,
        "retriever_check": retriever_check,
    }

In [432]:
# 6. A text with many facts and ask about only one of them
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader

def get_single_fact_test_case():
    loader = TextLoader('../docs/modules/state_of_the_union.txt')
    text = loader.load()[0].page_content
    text += " The color of the cat is purple."
    query = "What color is the cat?"
    expected = "The cat is purple"
    def retriever_check(retrieved_docs) -> bool:
        retrieved_texts = [d.page_content for d in retrieved_docs]
        return any("purple" in t for t in retrieved_texts)
    return {
        "name": "One fact in long text",
        "texts": [text],
        "query": query,
        "expected": expected,
        "retriever_check": retriever_check
    }

# Candidate retrieval systems
Now that we've defined some test cases, let's create a couple candidate retrieval systems to evaluate and compare. For this we'll need to define a function that returns a Retriever given a set of documents.

In [433]:
from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.vectorstores import Chroma

In [434]:
# import os
# os.environ["OPENAI_API_KEY"] = "get-ur-own-key"

In [466]:
retrieval_candidates = []

In [467]:
# Vanilla vectorestore retriever
def get_retriever(documents):
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    documents = text_splitter.split_documents(documents)
    docsearch = Chroma.from_documents(documents, OpenAIEmbeddings())
    return docsearch.as_retriever(search_kwargs={"k": min(4, len(texts))})
candidate = {
    "params": {"search": "similarity", "k": 4, "chunk_size": 1000},
    "get_retriever": get_retriever,
}
retrieval_candidates.append(candidate)

In [468]:
# Retriever with larger context, uses max marginal relevance, tuned prompt, more granular chunking
from langchain.prompts import PromptTemplate

retrieval_prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know. Don't try to make up an answer. If the context contains conflicting pieces of information, say that the answer cannot be determined from the context.

{context}

Question: {question}
Helpful Answer:"""
improved_retrieval_prompt = PromptTemplate(
    template=retrieval_prompt_template, input_variables=["context", "question"]
)

def get_retriever(documents):
    text_splitter = CharacterTextSplitter(chunk_size=200, chunk_overlap=0)
    documents = text_splitter.split_documents(documents)
    for doc in documents:
        doc.page_content = f"Document metadata: {doc.metadata}\n\n" + doc.page_content
    docsearch = Chroma.from_documents(documents, OpenAIEmbeddings())
    return docsearch.as_retriever(search_type="mmr", search_kwargs={"k": min(6, len(texts)), "fetch_k": min(20, len(texts))})

candidate = {
    "params": {"search": "mmr", "k": 6, "chunk_size": 200, "metadata": "prepend to content", "prompt": "improved"},
    "get_retriever": get_retriever,
    "qa_kwargs": {"prompt": improved_retrieval_prompt}
}
retrieval_candidates.append(candidate)

In [509]:
import datetime
import pinecone
from pinecone_text.sparse import BM25Encoder
from langchain.retrievers import PineconeHybridSearchRetriever

# os.environ["PINECONE_API_KEY"] = "get-ur-own-api-key"
# os.environ["PINECONE_ENVIRONMENT"] = "the-moon"
api_key = os.getenv("PINECONE_API_KEY")
env = os.getenv("PINECONE_ENVIRONMENT")

pinecone.init(api_key=api_key, enviroment=env)

def get_retriever(documents):
    timestamp = int(datetime.datetime.now().timestamp())
    index_name = f"hybrid-search-{timestamp}"

    # create the index
    pinecone.create_index(
       name = index_name,
       dimension = 1536,  # dimensionality of dense model
       metric = "dotproduct",  # sparse values supported only for dotproduct
       pod_type = "s1",
       metadata_config={"indexed": []}  # see explaination above
    )
    index = pinecone.Index(index_name)
    # or from pinecone_text.sparse import SpladeEncoder if you wish to work with SPLADE

    # use default tf-idf values
    bm25_encoder = BM25Encoder().default()
    retriever = PineconeHybridSearchRetriever(embeddings=embeddings, sparse_encoder=bm25_encoder, index=index)
    
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    documents = text_splitter.split_documents(documents)
    for doc in documents:
        doc.page_content = f"Document metadata: {doc.metadata}\n\n" + doc.page_content
    retriever.add_texts([doc.page_content for doc in documents])
    return retriever

def cleanup_retriever(retriever):
    index_name = retriever.index.configuration.server_variables['index_name']
    pinecone.delete_index(index_name)

candidate = {
    "params": {"search": "pinecone hybrid", "k": 4, "chunk_size": 100},
    "get_retriever": get_retriever,
    "cleanup_retriever": cleanup_retriever
}
retrieval_candidates.append(candidate)

WhoAmIResponse(username='805b516', user_label='default', projectname='6ddd519')

In [438]:
from langchain.evaluation.qa import QAEvalChain

eval_template = """You are a teacher grading a quiz.
You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT.

Example Format:
QUESTION: question here
STUDENT ANSWER: student's answer here
TRUE ANSWER: true answer here
GRADE: CORRECT or INCORRECT here

Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. The student answer is correct if and only if it contains all of the infrormation in the true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin! 

QUESTION: {query}
STUDENT ANSWER: {result}
TRUE ANSWER: {answer}
GRADE:"""
eval_prompt = PromptTemplate(
    input_variables=["query", "result", "answer"], template=eval_template
)
def default_evaluate(example, prediction) -> bool:
    """Return True if the prediction is correct."""
    eval_chain = QAEvalChain.from_llm(OpenAI(temperature=0), prompt=eval_prompt)
    grades = eval_chain.evaluate([example], [prediction])
    return grades[0]['text'].strip().upper() == "CORRECT"

# Run test cases against each retrieval system

In [459]:
test_cases = [
    get_retrieve_many_docs_test_case(),
    get_revised_fact_test_case(),
    get_single_fact_test_case(),
    get_redundant_docs_test_case(),
    get_disambiguate_and_combine_test_case(),
    get_temporal_test_case(),
]

In [469]:
%%capture --no-stdout
# Suppress chroma warnings about transient DB.

from langchain.schema import Document

test_results = []
for retrieval_candidate in retrieval_candidates:
    res = {"params": retrieval_candidate["params"], "test_cases": {}}
    for tc in test_cases:
        try:
            texts = tc["texts"]
            metadatas = tc.get("metadatas", [{}] * len(texts))
            docs = [Document(page_content=text, metadata=metadata) for text, metadata in zip(texts, metadatas)]
            retriever = retrieval_candidate["get_retriever"](docs)
            qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever, **retrieval_candidate.get('qa_kwargs', {}))
            example = {"query": tc["query"], "answer": tc["expected"]}
            prediction = qa.apply([example])[0]
            retrieved_docs = retriever.get_relevant_documents(tc["query"])

            res["test_cases"][tc["name"]] = {
                "pass": tc.get("evaluate", default_evaluate)(example, prediction),
                "prediction": prediction,
                "retriever_pass": tc['retriever_check'](retrieved_docs),
                "retrieved_docs": retrieved_docs
            }
            if "cleanup_retriever" in retrieval_candidate:
                retrieval_candidate["cleanup_retriever"](retriever)
        except:
            continue
    test_results.append(res)

In [538]:
# Visualize results
import pandas as pd

test_result_df = pd.DataFrame([{"System": f"{res['params']}", **{k: v['pass'] for k, v in res["test_cases"].items()}} for res in test_results])

def highlight(s):
    res = []
    for s_ in s:
        if pd.isnull(s_):
            res.append('background-color: whitesmoke')
        elif s_:
            res.append('background-color: honeydew')
        else:
            res.append('background-color: mistyrose')
    return res

test_result_df.style.apply(highlight, subset=test_result_df.columns.drop("System"))

Unnamed: 0,System,Retrieve 5 documents,Revised statements,One fact in long text,Redundant documents,Entity disambiguation,Metadata question (temporal information)
0,"{'search': 'similarity', 'k': 4, 'chunk_size': 1000}",False,True,False,True,True,False
1,"{'search': 'mmr', 'k': 6, 'chunk_size': 200, 'metadata': 'prepend to content', 'prompt': 'improved'}",False,False,False,True,False,False
2,"{'search': 'pinecone hybrid', 'k': 4, 'chunk_size': 100}",False,True,,True,True,True


In [539]:
retriever_result_df = pd.DataFrame([{"System": f"{res['params']}", **{k: v['retriever_pass'] for k, v in res["test_cases"].items()}} for res in test_results])
retriever_result_df.style.apply(highlight, subset=test_result_df.columns.drop("System"))

Unnamed: 0,System,Retrieve 5 documents,Revised statements,One fact in long text,Redundant documents,Entity disambiguation,Metadata question (temporal information)
0,"{'search': 'similarity', 'k': 4, 'chunk_size': 1000}",False,True,True,True,False,False
1,"{'search': 'mmr', 'k': 6, 'chunk_size': 200, 'metadata': 'prepend to content', 'prompt': 'improved'}",False,False,True,True,False,False
2,"{'search': 'pinecone hybrid', 'k': 4, 'chunk_size': 100}",False,True,,True,False,False


In [541]:
test_results[2]

{'params': {'search': 'pinecone hybrid', 'k': 4, 'chunk_size': 100},
 'test_cases': {'Retrieve 5 documents': {'pass': False,
   'prediction': {'query': 'What were the peak temperatures on 2023-02-12, 2023-02-28, 2023-03-31, 2023-04-07, 2023-01-16?',
    'answer': 'The peak temperatures were 55, 79, 52, 79, 54 degrees',
    'result': " The peak temperatures on 2023-02-12, 2023-02-28, 2023-03-31, 2023-04-07, and 2023-01-16 were 55 degrees, 79 degrees, 52 degrees, 79 degrees, and I don't know, respectively."},
   'retriever_pass': False,
   'retrieved_docs': [Document(page_content='Document metadata: {}\n\nOn 2023-03-31 the peak temperature was 52 degrees', metadata={}),
    Document(page_content='Document metadata: {}\n\nOn 2023-02-12 the peak temperature was 55 degrees', metadata={}),
    Document(page_content='Document metadata: {}\n\nOn 2023-02-28 the peak temperature was 79 degrees', metadata={}),
    Document(page_content='Document metadata: {}\n\nOn 2023-04-07 the peak temperature 

1. hybrid of sparse and dense embeddings
2. add metadata to content
3. add ability to filter based on metadata or page_content
4. dynamically determine k (keep asking for more documents as needed)
5. don't only store raw docs, store summaries / extracted information as well
6. use MMR
7. let retrieval system determine query
8. [optional] compress docs before adding to context