# LangGraph - RAG Exercise! 🚀

<a target="_blank" href="https://colab.research.google.com/github/IT-HUSET/ai-workshop-250121/blob/main/lab/5-langgraph-rag-exercise.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a><br/>

Let's apply what we've learned and build a RAG application with LangGraph and LangChain!

Below you will find a partially completed notebook with some code snippets. Your task is to complete the code snippets and run the graph to retrieve relevant documents based on a user question. More specifically...

### ...your task is to:
* Load a set of documents
* Split the documents into smaller chunks
* Ingest the chunks into a vector database
* Build a simple graph that retrieves relevant documents based on a user question

### Additional tasks:
* Add more documents, refine the prompts, play with different settings for chunking and retrieval etc
* Add initial routing based on question type
* Add grading of retrieved documents for relevance (Corrective RAG)


## Setup

### Install dependencies

In [None]:
%pip install httpx~=0.28.1 openai~=1.57 --upgrade --quiet
%pip install python-dotenv~=1.0 docarray~=0.40.0 pypdf~=5.1 --upgrade --quiet
%pip install chromadb~=0.5.18 lark~=1.2 --upgrade --quiet
%pip install langchain~=0.3.10 langchain_openai~=0.2.11 langchain_community~=0.3.10 langchain-chroma~=0.1.4 --upgrade --quiet
%pip install langgraph~=0.2.56 --upgrade --quiet

# If running locally, you can do this instead:
#%pip install -r ../requirements.txt

### Load environment variables

In [None]:
import os
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

# If running in Google Colab, you can use this code instead:
# from google.colab import userdata
# os.environ["AZURE_OPENAI_API_KEY"] = userdata.get("AZURE_OPENAI_API_KEY")
# os.environ["AZURE_OPENAI_ENDPOINT"] = userdata.get("AZURE_OPENAI_ENDPOINT")
# os.environ["ANTHROPIC_API_KEY"] = userdata.get("ANTHROPIC_API_KEY")
# os.environ["LANGCHAIN_API_KEY"] = userdata.get("LANGCHAIN_API_KEY")

### Optional - Setup LangSmith tracing for this notebook

In [None]:
#import os

# API key etc is in the .env file
# my_name = "Totoro"
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_PROJECT"] = f"tokyo24-langgraph-{my_name}"
# os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"

### Setup Chat Model

In [None]:
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
api_version = "2024-10-01-preview"
llm = AzureChatOpenAI(deployment_name="gpt-4o-mini", temperature=0.0, openai_api_version=api_version)
embedding_model = AzureOpenAIEmbeddings(model="text-embedding-3-large", openai_api_version=api_version)

## Setup ingestion / retrieval pipeline

### Setup vector DB (Chroma)

In [None]:
from langchain_chroma import Chroma
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever

persist_directory = './db/rag_exercise/'

# Optionally remove the directory and all files in it recursively if it exists
# import shutil
# import os
# if os.path.exists(persist_directory):
#     shutil.rmtree(persist_directory)

vectordb: Chroma = Chroma(
    collection_name="rag_exercise",
    embedding_function=embedding_model,
    persist_directory=persist_directory # Optionally persist the database
)

### Setup a text splitter

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 1000,
    chunk_overlap = 80
)

### Setup documents to load

In [None]:
# Documents to load

# TODO: Add your own documents here (or begin with these)
documents_to_load = [
    "https://data.riksdagen.se/fil/CDA05163-DE71-448D-807D-747C997E8F3A", # AI:s betydelse för framtidens arbetsmarknad och skola
    "https://data.riksdagen.se/fil/61B7540B-EEDD-4922-B61B-FC0A9F3AE4E2", # 2024/25:263 AI, annan ny teknik och de mänskliga rättigheterna
    "https://data.riksdagen.se/fil/0D43150B-5B31-43A4-89CD-4FE0478EC6C7" # 2024/25:263 AI, annan ny teknik och de mänskliga rättigheterna (svar)
]

### Ingest - split and add to vector index

In [None]:
from langchain_community.document_loaders import PyPDFLoader
import time

def ingest_document(doc: str):
    '''Helper function to ingest a document into the vector database'''

    # Load document
    print(f"Loading document {doc}...")
    loader = PyPDFLoader(doc)
    pages = loader.load()

    # Split
    doc_splits = text_splitter.split_documents(pages)

    # Add to index
    print(f"Adding document {doc} to index...")

    # Add to index in batches, with delay, to avoid rate limiting
    batch_size = 10
    for i in range(0, len(doc_splits), batch_size):
        batch = doc_splits[i:i + batch_size]
        vectordb.add_documents(documents=batch)
        print(f"Added splits {i} to {i + batch_size}")
        time.sleep(0.1)

    print(f"Added document {doc} ({len(pages)} pages) - {len(doc_splits)} splits")


for doc in documents_to_load:
    ingest_document(doc)

## Setup query graph / pipeline

### Graph state

In [None]:
from typing import  List

from langchain_core.documents import Document
from langgraph.graph import MessagesState


class GraphState(MessagesState):
    question: str
    documents: List[Document]
    answer: str


### Nodes

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable
from langchain_core.output_parsers import StrOutputParser

#### Retrieval (Vector Store similarity search)

In [None]:
class RetrievalNode:
    retriever: VectorStoreRetriever

    def __init__(self):
        self.retriever = vectordb.as_retriever(search_kwargs={"k": 3})

    def __call__(self, state: GraphState):
        print("---RETRIEVE---")
        question = state["question"]

        # Retrieval
        documents = self.retriever.invoke(question)

        print(f"---RETRIEVED {len(documents)} DOCS---")
        #print(f"{documents}")

        return {"documents": documents}

#### RAG Generation (LLM call with factual/grounded context)

In [None]:
class RAGNode:
    system_template = """You are an helpful assistant, expert in answering questions based on provided sources (snippets from documents) and citing the sources used to generate the answer. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible.
    ALWAYS respond in the SAME language as the original question.

    ** Context (snippets from documents): **

    {context}
    """

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_template),
            ("human", "{question}"),
        ]
    )

    chain: Runnable

    def __init__(self):
        self.chain = self.prompt | llm | StrOutputParser()

    def __call__(self, state: GraphState):
        print("---GENERATE---")
        question = state["question"]
        documents = state["documents"]

        # RAG generation - setup context (i.e. relevant documents snippets)
        context = "\n\n".join(doc.page_content for doc in documents)

        # RAG generation - generate answer
        answer = self.chain.invoke({"question": question, "context": context})
        #print(f"---GENERATE - ANSWER: \n{answer}")

        return {"documents": documents, "answer": answer}

### Build Graph

In [None]:
#### Graph ####
from langgraph.graph import END, StateGraph, START
from IPython.display import Image, display

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", RetrievalNode())  # retrieve
workflow.add_node("generate", RAGNode())  # generate

workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)

# Compile
graph = workflow.compile()

# View
display(Image(graph.get_graph().draw_mermaid_png()))


## Use Graph

In [None]:
# Run
inputs = {
    "question": "Har det i riksdagen diskuterats något om risker kring användningen av artificiell intelligens (AI)?"
    #"question": "Vilka är nobelpristagarna 2024?" # Should result in web search
    #"question": "Vad innebär vitboken om artificiell intelligens?" # Should NOT result in web search
}

# Execute graph
result = graph.invoke(inputs)

print(f"--- ANSWER: ---\n{result['answer']}")

<br/>

-----

## Refinement - add initial routing based on question type

See [here](https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_adaptive_rag) for more inspiration and guidance.

### Data model for routing (structured output)

In [None]:
from typing import Literal

from langchain_core.prompts import ChatPromptTemplate

from pydantic import BaseModel, Field

# Data model
class RouteQuery(BaseModel):
    """Route a user query to the most relevant datasource."""

    datasource: Literal["vectorstore", "web_search"] = Field(
        ...,
        description="Given a user question choose to route it to web search or a vectorstore.",
    )

### Setup a chain for structured LLM output

For simplicity, we'll make the LLM call directly in the conditional edge below. We could have introduced a separate node for this, but it would also mean we'd have to add a state field for the data source. Feel free to refactor with this improvement if you'd like.

In [None]:
structured_llm_router = llm.with_structured_output(RouteQuery)

# Prompt
routing_system_prompt = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to the swedish government system, the swedish riksdag and politics.
Use the vectorstore for questions on these topics. Otherwise, use web-search."""
route_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", routing_system_prompt),
        ("human", "{question}"),
    ]
)

routing_chain = route_prompt | llm.with_structured_output(RouteQuery)

### Conditional edge for routing

In [None]:
def route_question(state):
    print("---ROUTE QUESTION---")
    datasource: str = routing_chain.invoke(state["question"]).datasource

    if datasource == "web_search":
        print(
            "---DECISION: WEB SEARCH---"
        )
        return "web_search"
    else:
        print("---DECISION: RAG---")
        return "vectorstore"

### Create a fake web search node

In [None]:
class FakeWebSearchNode:
    system_template = """You are a helpful and cheerful assistant."""

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_template),
            ("human", "{question}"),
        ]
    )

    chain: Runnable

    def __init__(self):
        self.chain = self.prompt | llm.bind(temperature=1.0) | StrOutputParser()

    def __call__(self, state: GraphState):
        print("---FAKE WEB SEARCH---")
        question = state["question"]

        web_results = self.chain.invoke({"question": question})

        print(f"---FAKE WEB SEARCH RESULT: \n{web_results}")

        web_results = [Document(page_content=web_results)]

        return {"documents": web_results, "question": question}

### Build Graph with routing

In [None]:
#### Graph ####
from langgraph.graph import END, StateGraph, START
from IPython.display import Image, display

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", RetrievalNode())  # retrieve
workflow.add_node("generate", RAGNode())  # generate
workflow.add_node("web_search", FakeWebSearchNode())  # web search

workflow.add_conditional_edges(
    START,
    route_question,
    {
        "web_search": "web_search",
        "vectorstore": "retrieve",
    },
)
workflow.add_edge("retrieve", "generate")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

# Compile
graph = workflow.compile()

# View
display(Image(graph.get_graph().draw_mermaid_png()))


## Use Graph

In [None]:
# Run
inputs = {
    "question": "Har det i riksdagen diskuterats något om risker kring användningen av artificiell intelligens (AI)?"
    #"question": "Vilka var nobelpristagarna 2023?" # Should result in web search
    #"question": "Vad innebär vitboken om artificiell intelligens?" # Should NOT result in web search
}

# Execute graph
result = graph.invoke(inputs)

print(f"--- ANSWER: ---\n{result['answer']}")

<br/>

-----

## Going even further - adding grading of retrieved documents for relevance (Corrective RAG)

#### Look at **`simple-rag-agent-demo.ipynb`** for inspiration - and try to implement a similar setup here.