In [None]:
from textwrap import dedent
from langchain import hub
from pprint import pprint
from langchain.schema import Document
from typing import TypedDict
from dotenv import load_dotenv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain_community.chat_models.openai import ChatOpenAI
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langgraph.graph import END, StateGraph

In [None]:
load_dotenv("../.env")

In [None]:
url = "https://lilianweng.github.io/posts/2023-06-23-agent/"
loader = WebBaseLoader(url)
docs = loader.load()

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=100,
)

all_split_docs = text_splitter.split_documents(docs)

embedding = OpenAIEmbeddings()

vectorstore = Chroma.from_documents(
    documents=all_split_docs,
    collection_name="rag-chroma",
    embedding=embedding,
)
retriever = vectorstore.as_retriever()

In [None]:
llm = ChatOpenAI()

prompt = PromptTemplate(
    template=dedent(
        """
    You are a grader assesing relevance of a retrieved document to a user
    question.\n
    Here is the retrieved document:\n\n {context} \n\n
    Here is the user question: {question} \n
    If the document contains keywords related to the user question, grade it as
    relevant. \n
    It does not need to be a stringent test, The goal is to filter out
    errouneous retrievals. \n
    Give a binary score 'yes' or 'no' score to indicate wheter the document is
    relevant to the question. \n
    Provide the binary score as JSOn with a single key 'score' and no premable
    or explanation.
    """
    ),
    input_variables=["question", "context"],
)

retrieval_grader = prompt | llm | JsonOutputParser()
question = "Explain how the different types of agent memory work?"
docs = retriever.get_relevant_documents(question)
score = retrieval_grader.invoke(
    {"question": question, "context": docs[0].page_content}
)
score

In [None]:
prompt = hub.pull("rlm/rag-prompt")
for message in prompt.messages:
    print(message.prompt.template)

In [None]:
rag_chain = prompt | llm | StrOutputParser()
generation = rag_chain.invoke({"context": docs, "question": question})
generation

In [None]:
re_write_prompt = PromptTemplate(
    template=dedent(
        """
        You are a question re-writer that converts an input question to a
        better version that is optmized for vectorstore retrieval. Look at the
        initial question and formulate an improved question.\n
        Here is the initial question: {question} \n. Improved question with no
        premable:\n
        """
    ),
    input_variables=["question"],
)

question_rewritter = re_write_prompt | llm | StrOutputParser()
question_rewritter.invoke({"question": question})

In [None]:
class GraphState(TypedDict):
    question: str
    generation: str
    web_search = str
    documents: list[str]

In [None]:
def retrieve(state):
    print("--- RETRIEVE ---")
    question = state["question"]
    documents = retriever.get_relevant_documents(question)
    return {"documents": documents, "question": question}

In [None]:
def generate(state):
    print("--- GENERATE ---")
    question = state["question"]
    documents = state["documents"]
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {
        "generation": generation,
        "question": question,
        "documents": documents,
    }

In [None]:
def grade_documents(state):
    print("--- GRADE: CHECK DOCUMENT RELEVANCE ---")
    question = state["question"]
    documents = state["documents"]

    filtered_docs = []
    web_search = "No"
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "context": d.page_content}
        )
        grade = score["score"]
        if grade == "yes":
            print("--- GRADE: DOCUMENT RELEVANT ---")
            filtered_docs.append(d)
        else:
            print("--- GRADE: DOCUMENT IRRELEVANT ---")
            web_search = "Yes"
            continue
    return {
        "documents": filtered_docs,
        "question": question,
        "web_search": web_search,
    }

In [None]:
def transform_query(state):
    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state["documents"]

    better_question = question_rewritter.invoke({"question": question})
    return {"documents": documents, "question": better_question}

In [None]:
def web_search(state):
    raise NotImplementedError()
    print("---WEB SEARCH---")
    question = state["question"]
    documents = state["documents"]

    docs = web_search_tool.invoke({"query": question})
    web_results = "\n".join([d[content] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)
    return {"documents": documents, "quesiton": question}

In [None]:
def decide_to_generate(state):
    print("--- ASSESS GRADED DOCUMENTS ---")
    web_search = "No"

    if web_search == "Yes":
        print(
            dedent(
                """
                ---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION
                TRANSFORM QUERY---
                """
            )
        )
        return "transform_query"
    else:
        print(
            dedent(
                """
                ---DECISION: ALL DOCUMENTS ARE RELEVANT TO QUESTION---
                """
            )
        )
        return "generate"

In [None]:
workflow = StateGraph(GraphState)

workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search", web_search)

workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {"transform_query": "transform_query", "generate": "generate"},
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

app = workflow.compile()

In [None]:
inputs = {"question": "what are the types of agent memory?"}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Node '{key}':")
pprint(value["generation"])