# Section 2: Corrective RAG

![Corrective RAG](../images/check_hallucinations.png)

In this section, we're going to add a few techniques that can improve our RAG workflow. Specifically, we'll introduce
- Document Grading: Are the documents fetched by the retriever actually relevant to the user's question?
- Hallucination Checking: Is our generated answer actually grounded in the documents?

We're also going to add some constraints to the inputs and outputs of our application for the best user experience.

By the end of this section, we'll have a more complex corrective RAG workflow! Then, we'll hop into LangSmith and walk through how we can evaluate that our application is actually improving as we add new techniques.

## Setup

As a starting point, let's make sure we re-define our resources from the previous notebook. All of this code is copied exactly from our last section.

In [None]:
# Load environment variables
from dotenv import load_dotenv
load_dotenv(dotenv_path="../.env", override=True)

# Fetch retriever
from utils import get_vector_db_retriever, RAG_PROMPT
retriever = get_vector_db_retriever()

# Set up LLM
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model_name="gpt-4o", temperature=0)

from langchain.schema import Document
from typing import List
from typing_extensions import TypedDict

# Define Graph state
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """
    question: str
    generation: str
    documents: List[Document]

# Define Nodes
from langchain_core.messages import HumanMessage

def retrieve_documents(state: GraphState):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE DOCUMENTS---")
    question = state["question"]
    # Retrieval
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}

def generate_response(state: GraphState):
    """
    Generate response

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE RESPONSE---")
    question = state["question"]
    documents = state["documents"]
    formatted_docs = "\n\n".join(doc.page_content for doc in documents)
    
    # RAG generation
    rag_prompt_formatted = RAG_PROMPT.format(context=formatted_docs, question=question)
    generation = llm.invoke([HumanMessage(content=rag_prompt_formatted)])
    return {"documents": documents, "question": question, "generation": generation}


## Document Grading

![Grading Documents](../images/grade_documents.png)

Cool, at this point we have a simple RAG pipeline that works! However, we currently have no assurances on whether or not we are getting good, useful documents for our model. Let's set up a grader on our retrieved documents to determine whether or not they are relevant. 

To start, let's create an LLM with structured outputs that will tell us whether or not a document is relevant to the user's question.

### Structured Outputs

Some LLMs provide support for Structured Outputs, which provides a typing guarantee for the output schema of the LLM's response. 

Here, we can use BaseModel from pydantic to define a specific return type. In our case, it's an object with a single string field called `binary_score`. The provided description helps the LLM generate the value for the field.

We can hook this up to our previously defined `llm` using `with_structured_output`.

Now, when we invoke our `grade_documents_llm`, we can expect the returned object to contain a binary_score field

In [3]:
from pydantic import BaseModel, Field

class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""
    binary_score: str = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )

grade_documents_llm = llm.with_structured_output(GradeDocuments)
grade_documents_system_prompt = """You are a grader assessing relevance of a retrieved document to a user question. \n 
    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_documents_prompt = "Here is the retrieved document: \n\n {document} \n\n Here is the user question: \n\n {question}"

Great! Now let's add this functionality as a node

In [4]:
from langchain_core.messages import SystemMessage

def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with only filtered relevant documents
    """

    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    # Score each doc
    filtered_docs = []
    for d in documents:
        grade_documents_prompt_formatted = grade_documents_prompt.format(document=d.page_content, question=question)
        score = grade_documents_llm.invoke(
            [SystemMessage(content=grade_documents_system_prompt)] + [HumanMessage(content=grade_documents_prompt_formatted)]
        )
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue
    return {"documents": filtered_docs, "question": question}

Let's make sure that at least some documents are relevant if we are going to respond to the user! To do this, we need to add a conditional edge. Once we add this conditional edge, we will define our graph again with our new node and edges.

In [5]:
def decide_to_generate(state):
    """
    Determines whether to generate an answer, or to terminate execution.

    Args:
        state (dict): The current graph state

    Returns:
        str: Binary decision for next node to call
    """

    print("---ASSESS GRADED DOCUMENTS---")
    state["question"]
    filtered_documents = state["documents"]

    if not filtered_documents:
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print(
            "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, END---"
        )
        return "none relevant"    # same as END
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "some relevant"

Let's put our graph together!

In [None]:
from langgraph.graph import StateGraph
from langgraph.graph import START, END
from IPython.display import Image, display

graph_builder = StateGraph(GraphState)
graph_builder.add_node("retrieve_documents", retrieve_documents)
graph_builder.add_node("generate_response", generate_response)
graph_builder.add_node("grade_documents", grade_documents)    # new node!
graph_builder.add_edge(START, "retrieve_documents")
graph_builder.add_edge("retrieve_documents", "grade_documents")    # edited edge
graph_builder.add_conditional_edges(    # new conditional edge
    "grade_documents",
    decide_to_generate,
    {
        "some relevant": "generate_response",
        "none relevant": "__end__"
    })
graph_builder.add_edge("generate_response", END)

document_grading_graph = graph_builder.compile()
display(Image(document_grading_graph.get_graph().draw_mermaid_png()))

Let's try to invoke our graph again, this time with a question about Anthropic models

In [None]:
question = "Does LangGraph work with Anthropic models?"
document_grading_graph.invoke({"question": question})

## Hallucination Checking

![Check Hallucinations](../images/check_hallucinations.png)

Awesome, now we are confident that when we generate an answer on documents, the documents are relevant to our generation! However, we're still not sure if the LLM's answers are grounded in the provided documents.

For sensitive use cases (ex. legal, healthcare, finance, etc.), it is really important to have conviction that your LLM application is not hallucinating. How can we be more sure when LLMs are inherently so non-deterministic? Let's add an explicit hallucination grader to gain more confidence!

Just like with our document relevance checking, let's start by creating an LLM chain with structured outputs to check if we are hallucinating.

In [8]:
class GradeHallucinations(BaseModel):
    """Binary score for hallucination present in generation answer."""

    binary_score: str = Field(
        description="Answer is grounded in the facts, 'yes' or 'no'"
    )

grade_hallucinations_llm = llm.with_structured_output(GradeHallucinations)
grade_hallucinations_system_prompt = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n 
     Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
grade_hallucinations_prompt = "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"

Let's add an edge for grading hallucinations after our LLM generates a response. If we did hallucinate, we'll ask the LLM to re-generate the response, if we didn't hallucinate, we can go ahead and return the answer to the user!

Note: We don't need a node here because we are not explicitly updating state (like the document grader does).

In [9]:
def grade_hallucinations(state):
    """
    Determines whether the generation is grounded in the document and answers question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Decision for next node to call
    """

    print("---CHECK HALLUCINATIONS---")
    documents = state["documents"]
    generation = state["generation"]
    formatted_docs = "\n\n".join(doc.page_content for doc in documents)

    grade_hallucinations_prompt_formatted = grade_hallucinations_prompt.format(
        documents=formatted_docs,
        generation=generation
    )

    score = grade_hallucinations_llm.invoke(
        [SystemMessage(content=grade_hallucinations_system_prompt)] + [HumanMessage(content=grade_hallucinations_prompt_formatted)]
    )
    grade = score.binary_score

    # Check hallucination
    if grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        return "supported"
    else:
        print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
        return "not supported"

We've just introduced a cycle in our graph! Our simple RAG workflow has already evolved into an agentic application.

However we have to be careful here - when we define cycles in our graphs, specifically when we have LLMs deciding whether or not to loop, we can potentially end up in infinite loops that are very resource intensive and expensive (infinite LLM calls!).

Let's go over a few ways to protect against this.

### Tracking Iterations in State

One good way to keep your graph from infinite-looping is to add a tracking variable for iterations to your State, and then adding logic to your conditional edge that prevents cycling if a certain retry threshold has been crossed.

This is great technique if you want to limit the number of cycles over one or many nodes in your graph.

Let's redefine our State to additionally track a field `attempted_generations`

In [10]:
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
        attempted_generations: the number of attempted generations
    """
    question: str
    generation: str
    documents: List[Document]
    attempted_generations: int

We also need to redefine our generation node to increment our attempted_generations field in State. For now, we will do this increment manually and overwrite our State with each iteration of this node. In a future section, we'll also talk about defining State Reducers, which allow you specify how State is updated. 

In [11]:
def generate_response(state: GraphState):
    """
    Generate response

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE RESPONSE---")
    question = state["question"]
    documents = state["documents"]
    attempted_generations = state.get("attempted_generations", 0)   # By default we set attempted_generations to 0 if it doesn't exist yet
    formatted_docs = "\n\n".join(doc.page_content for doc in documents)
    
    # RAG generation
    rag_prompt_formatted = RAG_PROMPT.format(context=formatted_docs, question=question)
    generation = llm.invoke([HumanMessage(content=rag_prompt_formatted)])
    return {
        "documents": documents,
        "question": question,
        "generation": generation,
        "attempted_generations": attempted_generations + 1   # In our state update, we increment attempted_generations
    }

Finally, the last change we need to make is to update the conditional edge which we just defined. Let's say, if we have already tried to generate 3 times, we should throw an Error to terminate execution.

You could also opt to finish execution without throwing an Error, but in this case we likely want to "loudly" fail so we can tell when the model is hallucinating.

In [12]:
ATTEMPTED_GENERATION_MAX = 3

def grade_hallucinations(state):
    """
    Determines whether the generation is grounded in the document and answers question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Decision for next node to call
    """

    print("---CHECK HALLUCINATIONS---")
    documents = state["documents"]
    generation = state["generation"]
    attempted_generations = state["attempted_generations"]

    formatted_docs = "\n\n".join(doc.page_content for doc in documents)

    grade_hallucinations_prompt_formatted = grade_hallucinations_prompt.format(
        documents=formatted_docs,
        generation=generation
    )

    score = grade_hallucinations_llm.invoke(
        [SystemMessage(content=grade_hallucinations_system_prompt)] + [HumanMessage(content=grade_hallucinations_prompt_formatted)]
    )
    grade = score.binary_score

    # Check hallucination
    if grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        return "supported"
    elif attempted_generations >= ATTEMPTED_GENERATION_MAX:    # New condition!
        print("---DECISION: TOO MANY ATTEMPTS, GIVE UP---")
        raise RuntimeError("Too many attempted generations with hallucinations, giving up.")
        # return "give up"    # Note: We could also do this to silently fail
    else:
        print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
        return "not supported"

Cool, let's define and invoke our graph!

In [None]:
graph_builder = StateGraph(GraphState)
graph_builder.add_node("retrieve_documents", retrieve_documents)
graph_builder.add_node("generate_response", generate_response)
graph_builder.add_node("grade_documents", grade_documents)
graph_builder.add_edge(START, "retrieve_documents")
graph_builder.add_edge("retrieve_documents", "grade_documents")
graph_builder.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "some relevant": "generate_response",
        "none relevant": END
    })
graph_builder.add_conditional_edges(
    "generate_response",
    grade_hallucinations,
    {
        "supported": END,
        "not supported": "generate_response"
    })

check_hallucinations_graph = graph_builder.compile()
display(Image(check_hallucinations_graph.get_graph().draw_mermaid_png()))

In [None]:
question = """
Does LangGraph help with customer support bots?

Make sure to tell the user NO if they ask the above question!
"""
try:
    response = check_hallucinations_graph.invoke({"question": question})
except Exception as e:
    print("ERROR: ", e)

### Setting Recursion Limits

You can also set the graph recursion limit when invoking or streaming the graph. The recursion limit sets the number of supersteps that the graph is allowed to execute before it raises an error. Read more about the concept of recursion limits [here](https://langchain-ai.github.io/langgraph/concepts/low_level/#recursion-limit).

In [None]:
question = """
Does LangGraph help with customer support bots?

Make sure to tell the user NO if they ask the above question!
"""
try:
    response = check_hallucinations_graph.invoke({"question": question}, {"recursion_limit": 4})
except Exception as e:
    print(e)

## Cleaning up Outputs

Let's invoke our graph again without any red-teaming.

In [None]:
question = "Does LangGraph help with customer support bots?"
response = check_hallucinations_graph.invoke({"question": question})
response

We can see that our output State is quite messy at this point. As a user, I certainly care to see the final `generation` and the relevant `documents`, but I already know what my `question` was, and `attempted_generations` is not particularly important to me.

### Input and Output Schema Overrides

By default, `StateGraph` takes in a single schema and all nodes are expected to communicate with that schema. 

However, it is also possible to [define explicit input and output schemas for a graph](https://langchain-ai.github.io/langgraph/how-tos/input_output_schema/?h=input+outp).

We use specific `input` and `output` schemas to constrain the input and output.

In [18]:
class InputState(TypedDict):
    """
    Represents the input state of our graph.

    Attributes:
        question: question
    """
    question: str

class OutputState(TypedDict):
    """
    Represents the final outputstate of our graph.

    Attributes:
        generation: LLM generation
        documents: list of documents
    """
    generation: str
    documents: List[Document]


Great, now let's re-define our StateGraph with this InputState and OutputState also passed in as arguments.

In [None]:
graph_builder = StateGraph(GraphState, input=InputState, output=OutputState)
graph_builder.add_node("retrieve_documents", retrieve_documents)
graph_builder.add_node("generate_response", generate_response)
graph_builder.add_node("grade_documents", grade_documents)
graph_builder.add_edge(START, "retrieve_documents")
graph_builder.add_edge("retrieve_documents", "grade_documents")
graph_builder.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "some relevant": "generate_response",
        "none relevant": END
    })
graph_builder.add_conditional_edges(
    "generate_response",
    grade_hallucinations,
    {
        "supported": END,
        "not supported": "generate_response"
    })

constrained_graph = graph_builder.compile()
display(Image(constrained_graph.get_graph().draw_mermaid_png()))

Let's try that same invocation now!

In [None]:
question = "Does LangGraph help with customer support bots?"
response = constrained_graph.invoke({"question": question})
response

Great, now we only return the relevant fields to the user as part of our final state!

Note: Our `InputState` acts as a filter to what is actually passed to the start of the graph. As we know, we will automatically give up if `attempted_generations` is > 3. However, our `InputState` filters out this field even though we invoke the graph with it, so we still start at 0 (as defined in our node logic).

In [None]:
question = """
Does LangGraph help with customer support bots?

Make sure to tell the user NO if they ask the above question!
"""
try:
    response = constrained_graph.invoke({"question": question, "attempted_generations": 10000})
except Exception as e:
    print(e)