# Corrective RAG (CRAG)
自己修正的 RAG は RAG を強化し、質の低い検索や生成の修正を可能にする。

最近のいくつかの論文がこのテーマに焦点を当てているが、アイデアを実装するのは難しい。

ここでは、LangGraph を使って Corrective RAG (CRAG) [論文](https://arxiv.org/pdf/2401.15884.pdf)のアイデアを実装する方法を紹介する。

ベース Notebook:
https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_crag.ipynb

# CRAGの詳細
Corrective-RAG (CRAG)は、自己修正的 RAG のための興味深いアプローチを紹介した最近の論文である。

このフレームワークは、検索された文書を質問に対して相対的に評価する：

1. 正しい文書
- 少なくとも一つの文書が関連性の閾値を超えた場合、生成に進む。
- 生成の前に、知識洗練を行う
- これは文書を "知識ストリップ "に分割する。
- 各ストリップを評定し、無関係なものをフィルタリングする。

2. あいまいな文書や不正確な文書
- すべての文書が関連性の閾値を下回る場合、あるいは採点者が確信が持てない場合、フレームワークは追加のデータソースを探す。
- 検索を補うためにウェブ検索を使う
- 論文の図を見ると、クエリの書き直しもここで使われているようだ。

### Bing Search API のテスト

In [None]:
import os

os.environ["BING_SUBSCRIPTION_KEY"] = "YOUR_SUBSCRIPTION_KEY"
os.environ["BING_SEARCH_URL"] = "https://api.bing.microsoft.com/v7.0/search"

In [None]:
from langchain_community.utilities import BingSearchAPIWrapper
search = BingSearchAPIWrapper()

search.results("2023年の大河ドラマのタイトルと藤原道長を演じているのは何ですか？", 3)

## LangGraph のインストール

In [None]:
!pip install langchain_community tiktoken langchain-openai langchainhub langchain langgraph

In [None]:
from typing import Dict, TypedDict

from langchain_core.messages import BaseMessage


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        keys: A dictionary where each key is a string.
    """

    keys: Dict[str, any]

## Retrieve from Azure AI Search
以下のコードを参考に Azure AI Search にデータを挿入します。

https://github.com/nohanaga/busho-index

In [None]:
import os
import json
import operator
from typing import Annotated, Sequence, TypedDict

from langchain import hub
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.vectorstores import Chroma
from langchain_core.messages import BaseMessage, FunctionMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_openai import AzureChatOpenAI

from langchain.retrievers import AzureCognitiveSearchRetriever

os.environ["AZURE_OPENAI_API_KEY"] = "YOUR_API_KEY"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://<YOUR_ENDPOINT>.openai.azure.com/"

### Nodes ###
def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    print("retrieve: ", question)
    retriever = AzureCognitiveSearchRetriever(
        service_name="", #Your Azure AI Search service name
        index_name="",
        api_key="YOUR_API_KEY", #Your Azure AI Search API key
        content_key="content",
        top_k=3,
    )

    documents = retriever.get_relevant_documents(question)
    doc_results = "\n".join(["■■■" + d.metadata["sourcepage"] + d.page_content for d in documents])
    print("documents: ", doc_results)
    return {"keys": {"documents": documents, "question": question}}


In [None]:
def generate(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Prompt
    prompt = hub.pull("rlm/rag-prompt")

    # LLM
    # llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True)
    llm = AzureChatOpenAI(
        openai_api_version="2023-05-15",
        azure_deployment="gpt-4-0125-preview", 
    )

    # Post-processing
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    # Chain
    rag_chain = prompt | llm | StrOutputParser()

    # Run
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {
        "keys": {"documents": documents, "question": question, "generation": generation}
    }



In [None]:

def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with relevant documents
    """

    print("---CHECK RELEVANCE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Data model
    class grade(BaseModel):
        """Binary score for relevance check."""

        binary_score: str = Field(description="Relevance score 'yes' or 'no'")

    # LLM
    #model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
    model = AzureChatOpenAI(
        openai_api_version="2023-05-15",
        temperature=0, 
        azure_deployment="gpt-4-0125-preview", 
        streaming=True
    )

    # Tool
    grade_tool_oai = convert_to_openai_tool(grade)

    # LLM with tool and enforce invocation
    llm_with_tool = model.bind(
        tools=[grade_tool_oai],
        tool_choice={"type": "function", "function": {"name": "grade"}},
    )

    # Parser
    parser_tool = PydanticToolsParser(tools=[grade])

    # Prompt
    prompt = PromptTemplate(
        template="""あなたは、検索された文書とユーザーの質問との関連性を評価する採点者です。 \n 
        以下は検索された文書である。: \n\n {context} \n\n
        以下はユーザーからの質問です。: {question} \n
        文書にユーザーの質問に関連するキーワードまたは意味的(semantic)な意味が含まれている場合、関連性があると評価します。 \n
        その文書が質問に関連しているかどうかを示すために、'yes' か 'no' の二値スコアを与える。""",
        input_variables=["context", "question"],
    )

    # Chain
    chain = prompt | llm_with_tool | parser_tool

    # Score
    filtered_docs = []
    search = "No"  # Default do not opt for web search to supplement retrieval
    for d in documents:
        score = chain.invoke({"question": question, "context": d.page_content})
        print("Score: ", score)
        grade = score[0].binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            search = "Yes"  # Perform web search
            continue

    return {
        "keys": {
            "documents": filtered_docs,
            "question": question,
            "run_web_search": search,
        }
    }



In [None]:

def transform_query(state):
    """
    Transform the query to produce a better question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates question key with a re-phrased question
    """

    print("---TRANSFORM QUERY---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Create a prompt template with format instructions and the query
    prompt = PromptTemplate(
        template="""あなたは検索に最適化された質問を生成しています。\n 
        入力を見て、根底にある意味的な意図/意味を推論します。\n 
        これが最初の質問です:
        \n ------- \n
        {question} 
        \n ------- \n
        日本語で出力してください。
        改善された質問: """,
        input_variables=["question"],
    )

    # Grader
    #model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
    model = AzureChatOpenAI(
        openai_api_version="2023-05-15",
        temperature=0, 
        azure_deployment="gpt-4-0125-preview", 
        streaming=True
    )
    # Prompt
    chain = prompt | model | StrOutputParser()
    better_question = chain.invoke({"question": question})

    return {"keys": {"documents": documents, "question": better_question}}



In [None]:
from langchain_community.utilities import BingSearchAPIWrapper

def web_search(state):
    """
    Web search based on the re-phrased question using Tavily API.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with appended web results
    """
    
    print("---WEB SEARCH---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    search = BingSearchAPIWrapper()
    docs =search.results(question, 5)
    print(question)
    web_results = "\n".join([d["snippet"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)
    print("Web Results: ", web_results)
    return {"keys": {"documents": documents, "question": question}}



In [None]:

### Edges
def decide_to_generate(state):
    """
    Determines whether to generate an answer or re-generate a question for web search.

    Args:
        state (dict): The current state of the agent, including all keys.

    Returns:
        str: Next node to call
    """

    print("---DECIDE TO GENERATE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    filtered_documents = state_dict["documents"]
    search = state_dict["run_web_search"]

    if search == "Yes":
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print("---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---")
        return "transform_query"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"

In [None]:
import pprint

from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("grade_documents", grade_documents)  # grade documents
workflow.add_node("generate", generate)  # generatae
workflow.add_node("transform_query", transform_query)  # transform_query
workflow.add_node("web_search", web_search)  # web search

# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

# Compile
app = workflow.compile()

In [None]:
# Run
inputs = {"keys": {"question": "源範頼ゆかりの地にあるカフェを提案してください"}}

for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint.pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint.pprint("\n---\n")

# Final generation
pprint.pprint(value["keys"]["generation"])