# Testing Retrieval QA Chains
In this notebooks we'll evaluate a retrieval QA chain on a set of edge case scenarios. We'll compare different chains to see which parameters and architectures are the most performant on these edge cases.

In [246]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

# Retrieval test cases
Each test case should contain a list of corpus texts to be stored, optional metadata for the texts, the query used for testing, and the expected chain output.

In [247]:
test_cases = []

In [248]:
# 1. Need to retrieve many documents
texts = [
    "X = 5",
    "Y = 3",
    "Z = 2",
    "A = 12",
    "B = 50",
]
query = "What are the values of A, B, X, Y, Z?"
expected = "A is 12, B is 50, X is 5, Y is 3, Z is 2"
test_case = {
    "name": "Retrieve 5 documents",
    "texts": texts,
    "query": query,
    "expected": expected
}
test_cases.append(test_case)

In [250]:
# 2. Redundant docs
texts = [
    "The color of the cat is blue",
    "Blue is the color of the cat",
    "The cat's color is blue",
    "The cat is blue",
    "I believe the cat was blue",
    "The cat was definitely blue",
    "The dog is green"
]
query = "What colors are the cat and the dog?"
expected = "The cat is blue and the dog is green"
test_case = {
    "name": "Redundant documents",
    "texts": texts,
    "query": query,
    "expected": expected
}
test_cases.append(test_case)

In [251]:
# 3. Split required information across two docs in a way that changes semantic meaning of each half
texts = [
    "The cat was fat and it's color",
    " was green. The cat liked whole milk.",
    "The dog was blue."
]
query = "What color was the cat?"
expected = "The cat was green"
test_case = {
    "name": "Split statement",
    "texts": texts,
    "query": query,
    "expected": expected
}
test_cases.append(test_case)

In [252]:
# 4. Metadata question (temporal), e.g. "What did I say right before I said X"
import datetime

now = datetime.datetime.now()
now_timestamp = now.strftime("%Y-%m-%d %H:%M:%S")
future = now + datetime.timedelta(seconds=5)
future_timestamp = future.strftime("%Y-%m-%d %H:%M:%S")

texts = [
    "The cat is green",
    "The dog is yellow",
]
metadatas = [{"timestamp": now_timestamp}, {"timestamp": future_timestamp}]
query = "What did I say right before I mentioned the dog?"
expected = "The cat is green"
test_case = {
    "name": "Metadata question (temporal information)",
    "texts": texts,
    "query": query,
    "expected": expected,
    "metadatas": metadatas
}
test_cases.append(test_case)

In [253]:
# 5. Store conflicting statements, retrieve both and state there's an inconsistency
texts = [
    "The cat is green",
    "The cat is blue",
    "The dog is yellow",
]
query = "What color is the cat?"
expected = "The color of the cat cannot be determined from the context."
test_case = {
    "name": "Conflicting statements",
    "texts": texts,
    "query": query,
    "expected": expected
}
test_cases.append(test_case)

In [254]:
# 6. A text with many facts and ask about only one of them
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader

loader = TextLoader('../docs/modules/state_of_the_union.txt')
documents = loader.load()
documents[0].page_content += " The color of the cat is purple."
query = "What color is the cat?"
expected = "The cat is purple"
test_case = {
    "name": "One fact in long text",
    "texts": [d.page_content for d in documents],
    "query": query,
    "expected": expected
}
test_cases.append(test_case)

In [293]:
# 7. Make a statement and later revise it
texts = [
    "The cat is green",
    "I believe the cat is actually blue",
    "The dog is yellow",
]
metadatas = [{"text_position": i} for i in range(len(texts))]
query = "What color is the cat?"
expected = "The cat is blue."
test_case = {
    "name": "Revised statement",
    "texts": texts,
    "metadatas": metadatas,
    "query": query,
    "expected": expected
}
test_cases.append(test_case)

# Candidate retrieval systems
Now that we've defined some test cases, let's create a couple candidate retrieval systems to evaluate and compare. For this we'll need to define a function that returns a Retriever given a set of documents.

In [255]:
from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.vectorstores import Chroma

In [279]:
# import os
# os.environ["OPENAI_API_KEY"] = "get-ur-own-key"

In [294]:
retrieval_candidates = []

In [295]:
# Vanilla vectorestore retriever
def get_retriever(documents):
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    documents = text_splitter.split_documents(documents)
    docsearch = Chroma.from_documents(documents, OpenAIEmbeddings())
    return docsearch.as_retriever(search_kwargs={"k": min(4, len(texts))})
candidate = {
    "params": {"search": "similarity", "k": 4, "chunk_size": 1000},
    "get_retriever": get_retriever,
}
retrieval_candidates.append(candidate)

In [296]:
# Retriever with larger context, uses max marginal relevance, tuned prompt, more granular chunking
from langchain.prompts import PromptTemplate

retrieval_prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know. Don't try to make up an answer. If the context contains conflicting pieces of information, say that the answer cannot be determined from the context.

{context}

Question: {question}
Helpful Answer:"""
improved_retrieval_prompt = PromptTemplate(
    template=retrieval_prompt_template, input_variables=["context", "question"]
)

def get_retriever(documents):
    text_splitter = CharacterTextSplitter(chunk_size=200, chunk_overlap=0)
    documents = text_splitter.split_documents(documents)
    for doc in documents:
        doc.page_content = "Document metadata: {doc.metadata}\n\n" + doc.page_content
    docsearch = Chroma.from_documents(documents, OpenAIEmbeddings())
    return docsearch.as_retriever(search_type="mmr", search_kwargs={"k": min(6, len(texts)), "fetch_k": min(20, len(texts))})

candidate = {
    "params": {"search": "mmr", "k": 6, "chunk_size": 200, "metadata": "prepend to content", "prompt": "improved"},
    "get_retriever": get_retriever,
    "qa_kwargs": {"prompt": improved_retrieval_prompt}
}
retrieval_candidates.append(candidate)

In [306]:
from langchain.evaluation.qa import QAEvalChain

eval_template = """You are a teacher grading a quiz.
You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT.

Example Format:
QUESTION: question here
STUDENT ANSWER: student's answer here
TRUE ANSWER: true answer here
GRADE: CORRECT or INCORRECT here

Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. The student answer is correct if and only if it contains all of the infrormation in the true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin! 

QUESTION: {query}
STUDENT ANSWER: {result}
TRUE ANSWER: {answer}
GRADE:"""
eval_prompt = PromptTemplate(
    input_variables=["query", "result", "answer"], template=eval_template
)
def default_evaluate(example, prediction) -> bool:
    """Return True if the prediction is correct."""
    eval_chain = QAEvalChain.from_llm(OpenAI(temperature=0), prompt=eval_prompt)
    grades = eval_chain.evaluate([example], [prediction])
    return grades[0]['text'].strip().upper() == "CORRECT"

# Run test cases against each retrieval system

In [323]:
%%capture --no-stdout
# Suppress chroma warnings about transient DB.

from langchain.schema import Document

test_results = []
for retrieval_candidate in retrieval_candidates:
    res = {"params": retrieval_candidate["params"], "test_cases": {}}
    for tc in test_cases:
        texts = tc["texts"]
        metadatas = tc.get("metadatas", [{}] * len(texts))
        docs = [Document(page_content=text, metadata=metadata) for text, metadata in zip(texts, metadatas)]
        retriever = retrieval_candidate["get_retriever"](docs)
        qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever, **retrieval_candidate.get('qa_kwargs', {}))
        example = {"query": tc["query"], "answer": tc["expected"]}
        prediction = qa.apply([example])[0]
        evaluate = tc.get("evaluate", default_evaluate)
        res["test_cases"][tc["name"]] = {
            "pass": evaluate(example, prediction),
            "prediction": prediction,
        }
    test_results.append(res)

In [324]:
# Visualize results
import pandas as pd

test_result_df = pd.DataFrame([{"System": f"{res['params']}", **{k: v['pass'] for k, v in res["test_cases"].items()}} for res in test_results])

def highlight(s):
    return ['background-color: honeydew' if s_ else 'background-color: mistyrose' for s_ in s]

test_result_df.style.apply(highlight, subset=test_result_df.columns.drop("System"))

Unnamed: 0,System,Many documents 1,Many documents 2,Redundant documents,Split statement,Metadata question (temporal information),Conflicting statements,One fact in long document,Revised statement
0,"{'search': 'similarity', 'k': 4, 'chunk_size': 1000}",True,True,False,True,True,False,False,False
1,"{'search': 'mmr', 'k': 6, 'chunk_size': 200, 'metadata': 'prepend to content', 'prompt': 'improved'}",True,True,True,True,True,True,False,False
