<a href="https://colab.research.google.com/github/jaycrossler/ai-training/blob/main/Using%20AI%20Agents%20to%20build%20a%20Knowledge%20Graph%20and%20verify%20results.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Agentic RAG
Agentic Retrieval-Augmented Generation (RAG) uses agents to verify the source of knowledge, hopefully increasing accuracy.

Note, you will need an **OPENAI_API_KEY** loaded as a key (on the left if in Colab) and shared as secrets.

This will also **build a hosted web app** (using Gradio) that others can use.

## Blog

For a detailed explanation of agentic rag, check out  [blog post on Medium](https://aksdesai1998.medium.com/662bac582da9) and original code from https://github.com/lancedb/vectordb-recipes/tree/main/tutorials/Agentic_RAG.


In [1]:
# install the required dependencies
%%capture --no-stderr
%pip install -U --quiet langchain-community tiktoken langchain-openai langchainhub lancedb  langchain langgraph langchain-text-splitters langchain_openai gradio

In [2]:
import os
import getpass
import gradio as gr
from typing import Annotated, Literal, Sequence, TypedDict
from langchain import hub
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import LanceDB
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.tools.retriever import create_retriever_tool
from langgraph.graph.message import add_messages
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolExecutor, ToolNode, tools_condition


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


In [3]:
#Loads the URLs and Splits text into chunks - takes about 1 minute

# Function to set environment variables securely
def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")


from google.colab import userdata
os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')

# (Optional) For tracing
os.environ["LANGCHAIN_TRACING_V2"] = "False"
#_set_env("LANGCHAIN_API_KEY")


# upload urls to sources based on your use case
urls = [
    "https://en.wikipedia.org/wiki/United_States_Air_Force", # Do we need this? The model most likely knows this
    "https://www.doctrine.af.mil/Portals/61/documents/AFDP_1/AFDP-1.pdf",
    "https://www.doctrine.af.mil/Portals/61/documents/AFDP_3-12/3-12-AFDP-CYBERSPACE-OPS.pdf",
]


docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=100, chunk_overlap=50
)
doc_splits = text_splitter.split_documents(docs_list)

In [4]:
# Add the split content to lancedb as knowledge vectors - takes about 3.5 minutes

vectorstore = LanceDB.from_documents(
    documents=doc_splits,
    embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()


In [26]:
# Check that the knowledge store has content

table_vectors = vectorstore.get_table()
print(table_vectors.count_rows())
print(table_vectors.to_pandas())

36087
                                              vector  \
0  [-0.0070624077, -0.013837619, -0.00607354, -0....   
1  [0.011888713, 0.005399457, -0.005541461, -0.02...   
2  [0.015158966, 0.00059414015, 0.0043268176, -0....   
3  [0.0219406, -0.0118596, -0.0015316671, -0.0137...   
4  [0.002864139, -0.012553102, -0.015122923, -0.0...   
5  [-0.010749912, -0.0048866277, -0.0021539961, -...   
6  [-0.0022778655, -0.013432504, 0.014647366, -0....   
7  [-0.0070134318, -0.008831483, 0.008698778, -0....   
8  [-0.010623703, -0.005616577, -0.014177768, -0....   
9  [-0.015056208, -0.013588921, -0.021031104, -0....   

                                     id  \
0  e2e9bf92-9ab9-47d6-adf5-514b2e1d615e   
1  f00910d1-3f1e-4474-9b15-d4450c7fcd3f   
2  f78c5df5-f374-4c6f-a66c-e82997e97397   
3  ffa19483-13bc-4917-bb83-cd77585da608   
4  0dd067d4-28d8-4b3d-a624-df9330e33986   
5  8e8bb5cc-78b1-4631-9b68-6e41bbed83d8   
6  fa7d77df-bb17-4bd2-a52c-ab28387cadad   
7  8cc31d1c-05dd-4845-8877-debf67

In [27]:
# create the tools
retriever_tool = create_retriever_tool(
    retriever,
    "retrieve_blog_posts",
    "Search and return information about Air Force Cybersecurity doctrine",
)

tools = [retriever_tool]
tool_executor = ToolExecutor(tools)


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]


def grade_documents(state) -> Literal["generate", "rewrite"]:
    class grade(BaseModel):
        binary_score: str = Field(description="Relevance score 'yes' or 'no'")

    model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
    llm_with_tool = model.with_structured_output(grade)
    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
        input_variables=["context", "question"],
    )
    chain = prompt | llm_with_tool

    messages = state["messages"]
    last_message = messages[-1]
    question = messages[0].content
    docs = last_message.content

    scored_result = chain.invoke({"question": question, "context": docs})
    score = scored_result.binary_score

    return "generate" if score == "yes" else "rewrite"


def agent(state):
    messages = state["messages"]
    model = ChatOpenAI(temperature=0, streaming=True, model="gpt-4-turbo")
    model = model.bind_tools(tools)
    response = model.invoke(messages)
    return {"messages": [response]}


def rewrite(state):
    messages = state["messages"]
    question = messages[0].content
    msg = [
        HumanMessage(
            content=f""" \n
            Look at the input and try to reason about the underlying semantic intent / meaning. \n
            Here is the initial question:
            \n ------- \n
            {question}
            \n ------- \n
            Formulate an improved question: """,
        )
    ]
    model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
    response = model.invoke(msg)
    return {"messages": [response]}


def generate(state):
    messages = state["messages"]
    question = messages[0].content
    last_message = messages[-1]
    docs = last_message.content

    prompt = hub.pull("rlm/rag-prompt")
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True)

    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    rag_chain = prompt | llm | StrOutputParser()
    response = rag_chain.invoke({"context": docs, "question": question})
    return {"messages": [response]}


workflow = StateGraph(AgentState)
workflow.add_node("agent", agent)
retrieve = ToolNode([retriever_tool])
workflow.add_node("retrieve", retrieve)
workflow.add_node("rewrite", rewrite)
workflow.add_node("generate", generate)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
    "agent", tools_condition, {"tools": "retrieve", END: END}
)
workflow.add_conditional_edges("retrieve", grade_documents)
workflow.add_edge("generate", END)
workflow.add_edge("rewrite", "agent")
graph = workflow.compile()


def process_message(user_message):
    inputs = {"messages": [("user", user_message)]}
    content_output = None
    for output in graph.stream(inputs):
        print(f"Debug output: {output}")  # Debugging line to print the output
        if "agent" in output and "messages" in output["agent"]:
            messages = output["agent"]["messages"]
            if messages and hasattr(messages[0], "content"):
                content_output = messages[0].content  # Accessing attribute directly
                print(f"Extracted content: {content_output}")  # Print extracted content
    return content_output if content_output else "No relevant output found."


# Define example questions to guide the user
example_questions = [
    "Summarize what the rules are for defending Air Force cyberspace and attacks on AI systems?",
    "What MAJCOM is at Wright-Patterson?",
    "Which Operation did Joint Task Force Ares conduct against ISIS?",
    "What does the 616 OC do?"
]

  tool_executor = ToolExecutor(tools)


In [None]:
# Create a Gradio interface
iface = gr.Interface(
    fn=process_message,
    inputs="text",
    outputs="text",
    title="Searching PDFs and rewriting results using agents",
    description="Enter a message to query related to Air Force cyberspace and attacks on AI systems.",
    examples=example_questions,
)

# Launch the Gradio app
iface.launch(debug=True)

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://200ba03cd1d6d337dc.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
