In [None]:
from pprint import pprint
from typing import List

from langchain_core.documents import Document
from typing_extensions import TypedDict

from langgraph.graph import END, StateGraph
from langchain_openai import ChatOpenAI

from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate

from langchain_community.document_loaders import WebBaseLoader
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

from tavily import TavilyClient

import os


os.environ['OPENAI_API_KEY'] = ''

llm = ChatOpenAI(model="gpt-4o-mini", temperature = 0)

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = ""

os.environ["TIVILY_API_KEY"] = ""

### State


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """

    question: str
    documents: List[str]
    relevant_documents: List[str]
    source_info : str
    TavilyChecked: bool
    HallucinationChecked: bool
    answer: str


### Nodes

def retrieve(state):
    """
    Retrieve documents from vectorstore

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    question = state["question"]
    # Retrieval
    urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
    ]

    docs = [WebBaseLoader(url).load() for url in urls]
    docs_list = [item for sublist in docs for item in sublist]

    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=250, chunk_overlap=0
    )
    doc_splits = text_splitter.split_documents(docs_list)

    # Add to vectorDB
    vectorstore = Chroma.from_documents(
        documents=doc_splits,
        collection_name="rag-chroma",
        embedding = OpenAIEmbeddings(model="text-embedding-3-small")
    )
    retriever = vectorstore.as_retriever()

    documents = retriever.invoke(question)
    print(question)
    print(documents)
    return {"documents": documents, "relevant_documents": [], "question": question, "TavilyChecked": False, "HallucinationChecked": False}

def relevance_checker(state):
    print("---RELEVANCE CHECKER---")
    question = state["question"]
    
    system = """You are a grader assessing relevance
        of a retrieved document to a user question. If the document contains keywords related to the user question,
        grade it as relevant. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
        Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
        """
    
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "question: {question}\ndocument: {retrieved_chunk}"),
        ]
    )

    question_router = prompt | llm | JsonOutputParser()
    query = state["question"]

    docs = state["documents"]
    relevant_docs = []
    for doc in docs:
        chunk = doc.page_content

        result = question_router.invoke({
            "question": query,
            "retrieved_chunk": chunk
            })

        if result.get("score") == "yes":
            relevant_docs.append(doc)

    if len(relevant_docs) == 0:
        if state["TavilyChecked"]:
            print("---TAVILY CHECKED---")
            print("failed: not relevant")
            return "Failed"
        else:
            print("---NOT RELEVANT. TAVILY SEARCH NEEDED---")
            return "No"
    else:
            state["relevant_documents"] = relevant_docs
            return "Yes"

def tavily_searcher(state):
    """
    Search Tavily based on the question

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Appended web results to documents
    """
    print("---TAVILY SEARCH---")

    tavily = TavilyClient(api_key='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')

    question = state["question"]
    response = tavily.search(query=question, max_results=5)
    documents = [obj["content"] for obj in response["results"]]

    return {
        "documents": documents, "TavilyChecked": True
    }

def generate_answer(state):
    """
    Generate answer using RAG on retrieved documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["relevant_documents"]

    system = """You are an assistant for question-answering tasks.
        Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
        Use three sentences maximum and keep the answer concise"""

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "question: {question}\n\n context: {context} "),
        ]
    )

    # Chain
    rag_chain = prompt | llm | StrOutputParser()
    answer = rag_chain.invoke({"context": documents, "question": question})

    return {
        "relevant_documents": documents,
        "question": question,
        "answer": answer 
    }

### Edges



### Conditional edge


def hallucination_checker(state):
    print("---CHECK HALLUCINATIONS---")
    question = state["question"]
    documents = state["relevant_documents"]
    answer = state["answer"]  # ✅ 추가

    system = """You are a grader assessing whether
    an answer is grounded in / supported by a set of facts. Give a binary 'yes' or 'no' score to indicate
    whether the answer is grounded in / supported by a set of facts. Provide the binary score as a JSON with a
    single key 'score' and no preamble or explanation."""

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "documents: {documents}\n\nanswer: {answer}"),
        ]
    )

    hallucination_grader = prompt | llm | JsonOutputParser()

    score = hallucination_grader.invoke(
        {"documents": documents, "answer": answer}
    )
    grade = score["score"]

    if grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        state["HallucinationChecked"] = True
        return "Yes"
    else:
        if state["HallucinationChecked"]:
            print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
            state["HallucinationChecked"] = True
            return "No"
        else:
            print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, ABORT---")
            return "Failed"


workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("docs_retrieval_node", retrieve)
workflow.add_node("search_tavily_node", tavily_searcher)
workflow.add_node("generate_answer_node", generate_answer)  # generatae
workflow.add_node("hallucination_checker_node", hallucination_checker)  # generatae


# Build graph
workflow.set_entry_point(
    "docs_retrieval_node"
)
workflow.add_conditional_edges(
    "docs_retrieval_node",
    relevance_checker,
    {
        "No": "search_tavily_node",
        "Yes": "generate_answer_node",
    },
)
workflow.add_conditional_edges(
    "search_tavily_node",
    relevance_checker,
    {
        "Failed": END,
        "Yes": "generate_answer_node",
    },
)
workflow.add_conditional_edges(
    "generate_answer_node",
    hallucination_checker,
    {
        "Yes": END,
        "No": "generate_answer_node",
        "Failed": END,
    },
)

# Compile
app = workflow.compile()

# Test

inputs = {"question": "What is prompt?"}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Finished running: {key}:")
pprint(value["answer"])

<langgraph.graph.state.StateGraph at 0x1136c13a0>