In [1]:
import os
from dotenv import find_dotenv, load_dotenv
_ = load_dotenv(find_dotenv())

TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")

# Designing the Workflow

# Events

In [2]:
from llama_index.core.workflow import Event
from llama_index.core.schema import NodeWithScore


class PrepEvent(Event):
    """Prep event (prepares for retrieval)."""

    pass


class RetrieveEvent(Event):
    """Retrieve event (gets retrieved nodes)."""

    retrieved_nodes: list[NodeWithScore]


class RelevanceEvalEvent(Event):
    """Relevance evaluation event (gets results of relevance evaluation)."""

    relevant_results: list[str]


class TextExtractEvent(Event):
    """Text extract event. Extracts relevant text and concatenates."""

    relevant_text: str


class QueryEvent(Event):
    """Query event. Queries given relevant text and search text."""

    relevant_text: str
    search_text: str

In [3]:
from llama_index.core.workflow import (
    Workflow,
    step,
    Context,
    StartEvent,
    StopEvent,
)
from llama_index.core import (
    VectorStoreIndex,
    Document,
    PromptTemplate,
    SummaryIndex,
)
from llama_index.llms.openai import OpenAI
from llama_index.tools.tavily_research.base import TavilyToolSpec
from llama_index.core.base.base_retriever import BaseRetriever

In [None]:
DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
    template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.

    Retrieved Document:
    -------------------
    {context_str}

    User Question:
    --------------
    {query_str}

    Evaluation Criteria:
    - Consider whether the document contains keywords or topics related to the user's question.
    - The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.

    Decision:
    - Assign a binary score to indicate the document's relevance.
    - Use 'yes' if the document is relevant to the question, or 'no' if it is not.

    Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)

DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
    template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
    Analyze the given input to grasp the core semantic intent or meaning. \n
    Original Query:
    \n ------- \n
    {query_str}
    \n ------- \n
    Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
    Respond with the optimized query only:"""
)


In [None]:
class CorrectiveRAGWorkflow(Workflow):
    @step
    async def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent | None:
        """document_dir, index_dir"""
        document_dir = ev.get("document_dir")
        index_dir = ev.get("index_dir")
        if not index_dir:
            # 如果沒有 index_dir 這條路就會直接走不通
            # 所以 query 的時候預設是不用給 index_dir 的
            return None
        if not os.path.isdir(index_dir):
            print(f'create index_dir: {index_dir}')
            os.makedirs(index_dir)
        if document_dir:

            documents = SimpleDirectoryReader(document_dir).load_data()
            d = 1536
            faiss_index = faiss.IndexFlatL2(d)
            vector_store = FaissVectorStore(faiss_index=faiss_index)
            storage_context = StorageContext.from_defaults(
                vector_store=vector_store
            )
            index = VectorStoreIndex.from_documents(
                documents=documents,
                embed_model=OpenAIEmbedding(model_name="text-embedding-3-small"),
                storage_context=storage_context,
            )
            index.storage_context.persist(persist_dir=index_dir)
            print("Index built and persisted to:", index_dir)
        else:
            print(f"load index from: {index_dir}")
            # load index from disk
            vector_store = FaissVectorStore.from_persist_dir(index_dir)
            storage_context = StorageContext.from_defaults(
                vector_store=vector_store, persist_dir=index_dir
            )
            index = load_index_from_storage(storage_context=storage_context)
        return StopEvent(result=index)

    @step
    async def prepare_for_retrieval(
        self, ctx: Context, ev: StartEvent
    ) -> PrepEvent | None:
        """Prepare for retrieval."""

        query_str: str | None = ev.get("query_str")
        retriever_kwargs: dict | None = ev.get("retriever_kwargs", {})

        if query_str is None:
            return None

#        tavily_ai_apikey: str | None = ev.get("tavily_ai_apikey")
        index = ev.get("index")

        llm = OpenAI(model="gpt5-mini")

        await ctx.store.set("llm", llm)
        await ctx.store.set("index", index)
        await ctx.store.set(
            "tavily_tool", TavilyToolSpec(api_key=tavily_ai_apikey)
        )

        await ctx.store.set("query_str", query_str)
        await ctx.store.set("retriever_kwargs", retriever_kwargs)

        return PrepEvent()

    @step
    async def retrieve(
        self, ctx: Context, ev: PrepEvent
    ) -> RetrieveEvent | None:
        """Retrieve the relevant nodes for the query."""
        query_str = await ctx.store.get("query_str")
        retriever_kwargs = await ctx.store.get("retriever_kwargs")

        if query_str is None:
            return None

        index = await ctx.store.get("index", default=None)
        tavily_tool = await ctx.store.get("tavily_tool", default=None)
        if not (index or tavily_tool):
            raise ValueError(
                "Index and tavily tool must be constructed. Run with 'documents' and 'tavily_ai_apikey' params first."
            )

        retriever: BaseRetriever = index.as_retriever(**retriever_kwargs)
        result = retriever.retrieve(query_str)
        await ctx.store.set("retrieved_nodes", result)
        await ctx.store.set("query_str", query_str)
        return RetrieveEvent(retrieved_nodes=result)

    @step
    async def eval_relevance(
        self, ctx: Context, ev: RetrieveEvent
    ) -> RelevanceEvalEvent:
        """Evaluate relevancy of retrieved documents with the query."""
        retrieved_nodes = ev.retrieved_nodes
        query_str = await ctx.store.get("query_str")

        relevancy_results = []
        for node in retrieved_nodes:
            llm = await ctx.store.get("llm")
            resp = await llm.acomplete(
                DEFAULT_RELEVANCY_PROMPT_TEMPLATE.format(
                    context_str=node.text, query_str=query_str
                )
            )
            relevancy_results.append(resp.text.lower().strip())

        await ctx.store.set("relevancy_results", relevancy_results)
        return RelevanceEvalEvent(relevant_results=relevancy_results)

    @step
    async def extract_relevant_texts(
        self, ctx: Context, ev: RelevanceEvalEvent
    ) -> TextExtractEvent:
        """Extract relevant texts from retrieved documents."""
        retrieved_nodes = await ctx.store.get("retrieved_nodes")
        relevancy_results = ev.relevant_results

        relevant_texts = [
            retrieved_nodes[i].text
            for i, result in enumerate(relevancy_results)
            if result == "yes"
        ]

        result = "\n".join(relevant_texts)
        return TextExtractEvent(relevant_text=result)

    @step
    async def transform_query(
        self, ctx: Context, ev: TextExtractEvent
    ) -> QueryEvent:
        """Search the transformed query with Tavily API."""
        relevant_text = ev.relevant_text
        relevancy_results = await ctx.store.get("relevancy_results")
        query_str = await ctx.store.get("query_str")

        # If any document is found irrelevant, transform the query string for better search results.
        if "no" in relevancy_results:
            llm = await ctx.store.get("llm")
            resp = await llm.acomplete(
                DEFAULT_TRANSFORM_QUERY_TEMPLATE.format(query_str=query_str)
            )
            transformed_query_str = resp.text
            # Conduct a search with the transformed query string and collect the results.
            tavily_tool = await ctx.store.get("tavily_tool")
            search_results = tavily_tool.search(
                transformed_query_str, max_results=5
            )
            search_text = "\n".join([result.text for result in search_results])
        else:
            search_text = ""

        return QueryEvent(relevant_text=relevant_text, search_text=search_text)

    @step
    async def query_result(self, ctx: Context, ev: QueryEvent) -> StopEvent:
        """Get result with relevant text."""
        relevant_text = ev.relevant_text
        search_text = ev.search_text
        query_str = await ctx.store.get("query_str")

        documents = [Document(text=relevant_text + "\n" + search_text)]
        index = SummaryIndex.from_documents(documents)
        query_engine = index.as_query_engine()
        result = query_engine.query(query_str)
        return StopEvent(result=result)