In [1]:
import getpass
import os
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from dotenv import load_dotenv

load_dotenv()
from langchain_pinecone import PineconeVectorStore
from langgraph.graph import MessagesState, StateGraph
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import END
from langgraph.prebuilt import ToolNode, tools_condition
from typing import Literal
from pydantic import BaseModel, Field
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

from langgraph.checkpoint.memory import MemorySaver


  from tqdm.autonotebook import tqdm


In [2]:
# Set up environment variables
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
os.environ["PINECONE_API_KEY"] = os.getenv('PINECONE_API_KEY')
os.environ["LANGSMITH_API_KEY"] = os.getenv('LANGSMITH_API_KEY')
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'

llm = ChatOpenAI(model="gpt-4o", streaming = True)
embeddings = OpenAIEmbeddings(model = "text-embedding-3-small")

MAX_RETRIEVAL_ATTEMPTS = 4

class ExtState(MessagesState):
    retrieval_attempts: int


In [3]:
# Set up vector store
vs = PineconeVectorStore(
    index_name = 'langchain-index',
    embedding = embeddings)


In [4]:
from langchain_core.messages import AIMessage, HumanMessage
chat_history = [HumanMessage(content="Tell me something about Joe Biden's foreign policy in Ukraine"), HumanMessage(content="How has that impacted Russia?")]

from typing import Annotated, TypedDict
from langgraph.graph import StateGraph, add_messages

In [13]:
# loads chats up to x tokens for context for each inquiry, 
def history(state):
    # Prompt for history contextualizing
    print("---ADDING HISTORY---")
    print(f"{chat_history}")
    prompt = PromptTemplate(
        template="""Given a chat history and the latest user question 
        which might reference context in the chat history,
        formulate a standalone question which can be understood 
        without the chat history. Do NOT answer the question,
        just reformulate it if needed and otherwise return it as is.
        Here is the chat history: \n\n {chat_history} \n\n
        Here is the latest user question: {question} \n\n
        """,
        input_variables=["chat_history", "query"],
    )
    chain = prompt | llm

    messages = state["messages"]
    last_message = messages[-1]

    question = messages[0].content
    docs = last_message.content

    response = chain.invoke({"chat_history": chat_history, "question": question})
    return {"messages": [response]}
    
# Node 1 - Retriever Tool 
# possible switch to vector_store.as_retriever
@tool(response_format ="content_and_artifact")
def retrieve(query: str):
    """Retrieve information related to a query."""
    
    print("---RETRIEVE 6 RELEVANT DOCS---")
    retrieved_docs = vs.similarity_search(query, k=6)
    
    serialized = "\n\n".join(
        (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
        for doc in retrieved_docs
    )
    return "", ""

# Node 2 - grader
def grade_documents(state) -> Literal["rewrite", "generate"]:
    print("---CHECK RELEVANCE---")
    state["retrieval_attempts"] += 1
    print(type(state))
    print(f"Retrieval attempts: {state.get("retrieval_attempts")}")

    # Data model
    class grade(BaseModel):
        """Binary score for relevance check."""
        binary_score: str = Field(description="Relevance score 'yes' or 'no'")

    model = llm
    
    # LLM with tool and validation
    llm_with_tool = model.with_structured_output(grade)

    # Prompt for binary grading
    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n 
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
        input_variables=["context", "question"],
    )

    # Chain
    chain = prompt | llm_with_tool

    messages = state["messages"]
    last_message = messages[-1]

    question = messages[1].content
    docs = last_message.content

    scored_result = chain.invoke({"question": question, "context": docs})

    score = scored_result.binary_score
    

    if score == "yes":
        print("---DECISION: DOCS RELEVANT---")
        return state, "generate"
    # max attempts
    elif score == "no" and state["retrieval_attempts"] >= MAX_RETRIEVAL_ATTEMPTS:
        print("---DECISION: DOCS NOT RELEVANT & MAX RETRIEVAL ATTEMPTS REACHED---")
        return state, "direct_response"
    else:
        print("---DECISION: DOCS NOT RELEVANT---")
        print(score)
        return state, "rewrite"

# Node 2 - Query or Respond agent --> decides whether to retrieve or respond
# Step 1: Generate an AIMessage that may include a tool-call to be sent.
def agent(state):
    """Generate tool call for retrieval or respond."""
    #print(f"Test original message: {state["messages"][1]}")
    print("---CALL AGENT---")
    llm_with_tools = llm.bind_tools([retrieve])
    response = llm_with_tools.invoke(state["messages"])
    return {"messages": [response]}

def rewrite(state):

    print("---TRANSFORM QUERY---")
    messages = state["messages"]
    question = messages[1].content

    msg = [
        HumanMessage(
            content=f""" \n 
    Look at the input and try to reason about the underlying semantic intent / meaning. \n 
    Here is the initial question:
    \n ------- \n
    {question} 
    \n ------- \n
    Formulate a new question that is specific and deatiled to help with document retrieval. 
    Focus on key terms and concepts that might appear in relevant documents: """,
        )
    ]

    # Grader
    response = llm.invoke(msg)
    return {"messages": [response]}

# Node 3 - Generate Answer
def generate(state):
    """Generate answer."""
    print("---GENERATE ANSWER---")
    # Get generated ToolMessages
    messages = state["messages"]
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]

    # Format into prompt --> adds former messages from graph
    docs_content = "\n\n".join(doc.content for doc in tool_messages)
    question = messages[1].content 
    
    prompt = PromptTemplate(template="""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, say that you don't know. Use three sentences maximum and keep the answer concise. \n 
        Here is the original question: \n\n {question} \n\n
        Here is the context: {context} \n\n
        Give a one line summary of the context metadata used at the end of your answer.""",
        input_variables=["context", "question"],
        )
    gen_chain = prompt | llm | StrOutputParser()

    response = gen_chain.invoke({"context": docs_content, "question": question})
    return {"messages": [response]}


def direct_response(state):
    """Generate a highly constrained response when not using tools."""
    print("---DIRECT RESPONSE---")
    
    prompt = PromptTemplate(
        template="""You are a focused assistant that only provides brief, targeted responses.
        Provide a two sentence response that politely explains you can only answer questions about the specific content in our knowledge base which provides information about American policy .
        Do not provide any other information or engage in general conversation.
        
        User question: {question}
        
        One sentence response:""",
        input_variables=["question"]
    )
    
    messages = state["messages"]
    question = messages[0].content
    
    chain = prompt | llm | StrOutputParser()
    response = chain.invoke({"question": question})
    
    return {"messages": [response]}

In [14]:
tools = ToolNode([retrieve])


In [16]:
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode

# Define a new graph
workflow = StateGraph(ExtState)

# Define the nodes we will cycle between
workflow.add_node("history", history)
workflow.add_node("agent", agent)  # agent
workflow.add_node("retrieve", tools)  # retrieval
workflow.add_node("rewrite", rewrite)  # Re-writing the question
workflow.add_node("generate", generate) # generate answer
workflow.add_node("direct_response", direct_response) # direct response for irrelevant question
# Call agent node to decide to retrieve or not
workflow.add_edge(START, "history")
workflow.add_edge("history", "agent")

# Decide whether to retrieve
workflow.add_conditional_edges(
    "agent",
    # Assess agent decision
    tools_condition,
    {
        # Translate the condition outputs to nodes in our graph
        "tools": "retrieve",
        END: "direct_response",
    },
)
workflow.add_conditional_edges(
    "retrieve",
    lambda x: grade_documents(x)[1],
    {
        "rewrite": "rewrite",
        "generate": "generate",
        "direct_response": "direct_response",
    }
)


workflow.add_edge("rewrite", "agent")
workflow.add_edge("generate", END)
workflow.add_edge("direct_response", END)


# Compile
#memory = MemorySaver()
graph = workflow.compile()

In [None]:
from langchain_core.messages import HumanMessage

inputs = {"messages": [HumanMessage(content="Tell me more about the first thing I asked about")], "retrieval_attempts": 0}
for msg, metadata in graph.stream(inputs, stream_mode="messages"):
    #print(f"Node: {metadata['langgraph_node']}")
    if (
        msg.content
        #and not isinstance(msg, HumanMessage)
        and metadata["langgraph_node"] == "generate"
    ):
        print(msg.content, flush=True)

---ADDING HISTORY---
[HumanMessage(content="Tell me something about Joe Biden's foreign policy in Ukraine"), HumanMessage(content='How has that impacted Russia?')]
---CALL AGENT---
---RETRIEVE 6 RELEVANT DOCS---
---CHECK RELEVANCE---
<class 'dict'>
Retrieval attempts: 1
---DECISION: DOCS NOT RELEVANT---
no
---TRANSFORM QUERY---
---CALL AGENT---
---RETRIEVE 6 RELEVANT DOCS------RETRIEVE 6 RELEVANT DOCS---

---RETRIEVE 6 RELEVANT DOCS---
---CHECK RELEVANCE---
<class 'dict'>
Retrieval attempts: 1
---DECISION: DOCS NOT RELEVANT---
no
---TRANSFORM QUERY---
---CALL AGENT---
---RETRIEVE 6 RELEVANT DOCS------RETRIEVE 6 RELEVANT DOCS---

---RETRIEVE 6 RELEVANT DOCS---
---CHECK RELEVANCE---
<class 'dict'>
Retrieval attempts: 1
---DECISION: DOCS NOT RELEVANT---
no
---TRANSFORM QUERY---
---CALL AGENT---


In [None]:
from langchain_core.messages import HumanMessage

inputs = {"messages": [HumanMessage(content="Tell me about the meaning of life")], "retrieval_attempts": 0}
for msg, metadata in graph.stream(inputs, stream_mode="messages"):
    #print(f"Node: {metadata['langgraph_node']}")
    if (
        msg.content
        #and not isinstance(msg, HumanMessage)
        and metadata["langgraph_node"] == "generate" or metadata["langgraph_node"] == "direct_response"
    ):
        print(msg.content, flush=True)

In [None]:
from IPython.display import Image, display

try:
    display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass
