# Corrective RAG as a Service

<a href="https://colab.research.google.com/github/run-llama/llama-agents/blob/main/examples/corrective_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this guide we show you how to use llama-agents to build [CRAG (Corrective RAG)](https://arxiv.org/abs/2401.15884) by Yan et al. as a service.

![CRAG Diagram](https://github.com/run-llama/llama-agents/blob/main/examples/assets/corrective_rag.png?raw=1)

In [None]:
!pip install llama-index tavily-python llama-index-tools-tavily-research

In [None]:
import nest_asyncio

nest_asyncio.apply()

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "sk-proj-..."
os.environ["LLAMA_CLOUD_API_KEY"] = "llx-..."

## Setup Data, Indexes, and Tools

We do the following:
- Download the Gemini paper as an example document to build RAG over. **NOTE**: We use LlamaParse to parse the document which requires a [LlamaCloud account](https://cloud.llamaindex.ai/). If you choose to use a purely open-source reader, you can do that too.
- Setup a vector index over this paper
- Setup a web search tool (powered by Tavily)

In [None]:
!mkdir -p 'data/'
!curl 'https://arxiv.org/pdf/2312.11805' -o "data/gemini.pdf"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 26.2M  100 26.2M    0     0  62.4M      0 --:--:-- --:--:-- --:--:--     0 --:--:-- --:--:-- 63.3M


In [None]:
# OPTION: use LlamaParse
from llama_parse import LlamaParse

parser = LlamaParse(result_type="text")
docs = parser.load_data("data/gemini.pdf")

# # OPTION: use SimpleDirectoryReader (uses open-source PyPDF)
# from llama_index.core import SimpleDirectoryReader

# reader = SimpleDirectoryReader(input_files=["data/llama2.pdf"])
# docs = reader.load_data()

Started parsing the file under job_id cac11eca-d6ce-4b2f-a83d-4973106715a5


In [None]:
import os
from llama_index.core import (
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage,
)

if not os.path.exists("storage_gemini"):
    index = VectorStoreIndex.from_documents(docs)
    # save index to disk
    index.set_index_id("vector_index")
    index.storage_context.persist("./storage_gemini")
else:
    # rebuild storage context
    storage_context = StorageContext.from_defaults(persist_dir="storage_gemini")
    # load index
    index = load_index_from_storage(storage_context, index_id="vector_index")

retriever = index.as_retriever()

In [None]:
# Setup Tavily web search
from llama_index.tools.tavily_research.base import TavilyToolSpec

# TODO: remove
tavily_api_key = "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"
tavily_tool = TavilyToolSpec(api_key=tavily_api_key)

In [None]:
# Setup LLM
from llama_index.llms.openai import OpenAI

llm = OpenAI(model="gpt-4o")

## Define Agent Services

Here we define three services:
- Another initial RAG service that will return retrieved nodes, as well as how relevant they are to the question.
- A separate web search service that is triggered if there are any irrelevant nodes. Will perform query transformation and web search.
- A final summarization service that takes in a set of documents and returns a final result.

In [None]:
from llama_agents import ComponentService, ServiceComponent, SimpleMessageQueue

message_queue = SimpleMessageQueue()


def to_service_component(component, message_queue, service_name, description):
    server = ComponentService(
        component=component,
        message_queue=message_queue,
        description=description,
        service_name=service_name,
    )
    service_component = ServiceComponent.from_component_service(server)
    return service_component, server

#### Setup Initial RAG Service

Runs retrieval and relevancy check on retrieved nodes.

In [None]:
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_pipeline import QueryPipeline

relevancy_prompt_tmpl = PromptTemplate(
    template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.

    Retrieved Document:
    -------------------
    {context_str}

    User Question:
    --------------
    {query_str}

    Evaluation Criteria:
    - Consider whether the document contains keywords or topics related to the user's question.
    - The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.

    Decision:
    - Assign a binary score to indicate the document's relevance.
    - Use 'yes' if the document is relevant to the question, or 'no' if it is not.

    Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)
relevancy_qp = QueryPipeline(chain=[relevancy_prompt_tmpl, llm])

In [None]:
# define RAG agent
from llama_index.core.query_pipeline import FnComponent
from typing import Dict


def run_retrieval(input_str: str) -> Dict:
    """Run Retrieval."""
    # retrieves a set of nodes
    retrieved_nodes = retriever.retrieve(input_str)

    # runs a relevancy check
    relevancy_results = []
    for node in retrieved_nodes:
        relevancy = relevancy_qp.run(context_str=node.text, query_str=query_str)
        relevancy_results.append(relevancy.message.content.lower().strip())
    contains_irrelevant = "no" in relevancy_results

    # get relevant texts
    relevant_texts = [
        retrieved_nodes[i].text
        for i, result in enumerate(relevancy_results)
        if result == "yes"
    ]
    relevant_text = "\n".join(relevant_texts)

    # returns a dictionary of items
    return {
        "relevant_text": relevant_text,
        "contains_irrelevant": contains_irrelevant,
        "input_str": input_str,
    }


retrieval_component = FnComponent(fn=run_retrieval)
retrieval_component_s, retrieval_server = to_service_component(
    retrieval_component,
    message_queue,
    "Runs a retrieval + relevancy check",
    "retrieval_service",
)

#### Setup Web Search Service

Wrap Tavily into a component that performs query transformation and web search.

In [None]:
from llama_index.core.prompts import PromptTemplate

query_transform_tmpl = PromptTemplate(
    template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
    Analyze the given input to grasp the core semantic intent or meaning. \n
    Original Query:
    \n ------- \n
    {query_str}
    \n ------- \n
    Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
    Respond with the optimized query only:"""
)
query_transform_qp = QueryPipeline(chain=[query_transform_tmpl, llm])

In [None]:
def run_web_search(input_str: str) -> str:
    """Run Web Search."""

    transformed_query_str = query_transform_qp.run(query_str=input_str).message.content
    # Conduct a search with the transformed query string and collect the results.
    search_results = tavily_tool.search(query_str, max_results=5)
    return "\n".join([result.text for result in search_results])


web_search_component = FnComponent(fn=run_web_search)
web_search_component_s, web_server = to_service_component(
    web_search_component, message_queue, "Runs web search", "web_search_service"
)

### Setup Summarization Service

At the end, setup a summarization service that can take in relevant information and join it.

In [None]:
from llama_index.core.base.response.schema import Response
from llama_index.core.schema import Document
from llama_index.core import SummaryIndex
from typing import Optional


def run_summarization(retrieved_text: str, search_text: Optional[str] = None) -> str:
    """Run summarization."""
    # use summary index to perform summarization
    search_text = search_text or ""
    documents = [Document(text=retrieved_text + "\n" + search_text)]
    index = SummaryIndex.from_documents(documents)
    query_engine = index.as_query_engine()
    return str(query_engine.query(query_str))


summary_component = FnComponent(fn=run_summarization)
summary_component_s, summary_server = to_service_component(
    summary_component, message_queue, "Run summarization", "summarization_service"
)

## Launch Agent Services

Now that we've setup the main components, we can orchestrate them via our pipeline orchestrator.

In [None]:
from llama_agents import (
    AgentService,
    ControlPlaneServer,
    SimpleMessageQueue,
    PipelineOrchestrator,
    ServiceComponent,
    ComponentService,
)
from llama_index.core.query_pipeline import Link, InputComponent

pipeline = QueryPipeline(
    module_dict={
        "input": InputComponent(),
        "retrieval_server": retrieval_component_s,
        "web_server": web_search_component_s,
        # TODO: clean this interface up
        "summary_server_no_web": summary_component_s,
        "summary_server_web": summary_component_s,
    }
)
pipeline.add_link("input", "retrieval_server")
pipeline.add_link(
    "retrieval_server",
    "web_server",
    condition_fn=lambda x: x["contains_irrelevant"],
    input_fn=lambda x: x["input_str"],
)
# if web search is called
pipeline.add_link(
    "retrieval_server",
    "summary_server_web",
    dest_key="retrieved_text",
    condition_fn=lambda x: x["contains_irrelevant"],
    input_fn=lambda x: x["relevant_text"],
)
pipeline.add_link("web_server", "summary_server_web", dest_key="search_text")

# if web search is not called
pipeline.add_link(
    "retrieval_server",
    "summary_server_no_web",
    dest_key="retrieved_text",
    condition_fn=lambda x: not x["contains_irrelevant"],
    input_fn=lambda x: x["relevant_text"],
)

pipeline_orchestrator = PipelineOrchestrator(pipeline)

control_plane = ControlPlaneServer(
    message_queue=message_queue,
    orchestrator=pipeline_orchestrator,
)

In [None]:
from llama_agents.launchers import LocalLauncher

## Define Launcher
launcher = LocalLauncher(
    [retrieval_server, web_server, summary_server],
    control_plane,
    message_queue,
)

In [None]:
query_str = "Tell me about the pretraining datasets used."
result = launcher.launch_single(query_str)
print(str(result))

INFO:llama_agents.message_queues.simple - Consumer ComponentService-2671543e-633b-492d-a871-0712fbe4ee91: Runs a retrieval + relevancy check has been registered.
INFO:llama_agents.message_queues.simple - Consumer ComponentService-e68532cc-d06c-4d77-bd2a-d18cf13df618: Runs web search has been registered.
INFO:llama_agents.message_queues.simple - Consumer ComponentService-c68e28ba-73d6-46c1-801b-e60f8ce9fe8e: Run summarization has been registered.
INFO:llama_agents.message_queues.simple - Consumer c541b9d2-92b6-4fe5-81d4-2d99a3cb81fa: human has been registered.
INFO:llama_agents.message_queues.simple - Consumer ControlPlaneServer-b205f278-fa5b-44cc-8150-e9f07b168250: control_plane has been registered.
INFO:llama_agents.services.component - Runs a retrieval + relevancy check launch_local
INFO:llama_agents.services.component - Runs web search launch_local
INFO:llama_agents.services.component - Run summarization launch_local
INFO:llama_agents.message_queues.base - Publishing message to 'con

The pre-training datasets used for Gemini models are multimodal and multilingual, incorporating data from web documents, books, and code. These datasets include image, audio, and video data. The SentencePiece tokenizer is utilized, and training the tokenizer on a large sample of the entire training corpus enhances the inferred vocabulary and subsequently boosts model performance. The models are trained on a diverse range of data sources to improve performance and efficiency across various domains.
