## Understanding the RAG workflow in LangGraph : Retrieval AI Agent

In [None]:
import asyncio
from typing import List, TypedDict
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver


In [None]:
#Websites url by adding them to a vector DB
urls = [
    "...."
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=250, chunk_overlap=0)
doc_splits = text_splitter.split_documents(docs_list)
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma"
    embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()

#Prepare the RAG chain
prompt = ChatPromptTemplate.from_template("""
You are an assistant for question-answering tasks.Use the following pieces of retrieved context to answer the question.
if you don't know the answer, just say that you don't know.Use three sentences maximum and keep the answer concise.
Question: {question}
Context: {context} 
Answer:                                                                                                                                                                       
                                          
""")
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
rag_chain = (
    prompt | model | StrOutputParser()
)

#Define the graph
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """
    question: str
    generationn: str
    web_search: str
    documents: List[str]

#Retrieve node
def retrieve(state):
    """
    Retreive the documents
    Args:
        state(dict): The current graph state
    Returns:
        state(dict): New key added to state, documents, that contains retrieved documents
    """

    print("---Retrieve---")
    question = state["question"]

    #Retrieval
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}


#Generate node
def generate(state):
    """
    Generate answer

    Args:
        state(dict):The current graph state
    Returns:
        state(dict): New key added to state, generation, that contains LLM generation    
    """
    print("---Generate---")
    question = state["question"]
    documents = state["documents"]

    #RAG Generation
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

#Define the workflow
def create_workflow():
    workflow = StateGraph(GraphState)

    #Add nodes
    workflow.add_node("retrieve", retrieve)
    workflow.add_node("generate", generate)

    #Add edges
    workflow.add_edge(START,"retrieve")
    workflow.add_edge("retrieve", "generate")
    workflow.add_edge("generate", END)
    return workflow.compile(checkpointer=MemorySaver())

#Run the workflow
async def run_workflow():
    app = create_workflow()
    config = {
        "configurable": {"thread_id": "1"},
        "recursion_limit": 50
    }
    inputs = {"question": f"What are flat indexes?"}

    try:
        async for event in app.astream(inputs, config=config, stream_mode="values"):
            if "error" in event:
                print(f"Error: {event['error']}")
                break
            print(event)
    except Exception as e:
        print(f"Workflow wxecution failed: {str(e)}")


if __name__=="__main__":
    asyncio.run(run_workflow())       
