In [None]:
import os
import sys

# Get the current working directory and add the parent directory to the Python path
current_working_directory = os.getcwd()
print(os.path.join(current_working_directory, ".."))
sys.path.append(os.path.join(current_working_directory, ".."))

In [2]:
import pprint

### Defining the Graph state

In [3]:
from typing import TypedDict, Annotated, List, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage, HumanMessage
import operator
from IPython.display import Image, display

In [4]:
class AgentState(TypedDict):
    input: str
    agent_outcome: Union[AgentAction, AgentFinish, None]
    intermediate_step: Annotated[list, operator.add]

In [5]:
from langchain_community.tools.arxiv.tool import ArxivQueryRun
from langchain_community.tools.pubmed.tool import PubmedQueryRun
from langchain_community.tools.tavily_search import TavilySearchResults

pubmed_search = PubmedQueryRun()
arxiv_search = ArxivQueryRun()
tavily_tool = TavilySearchResults(max_results=5)

tools = [arxiv_search, pubmed_search, tavily_tool]
# tools = [arxiv_search, pubmed_search]
# tools = [arxiv_search]

In [6]:
from models.llm import LLM

model = LLM('gpt-4o')
llm = model.load_model()

In [7]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

def research_agent(data):
    print("----research node----")
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a helpful AI research assistant chatbot,"
                " Use the appropriate search tools and chat history to progress towards finding the relevant results."
                " Once you have the relevant search results, summarise them to answer the user query."
                "\nYou have access to the following search tools: {tool_names}."
            ),
            (
                "human",
                "\nUser Query: {input}"
            ),
            
            MessagesPlaceholder(variable_name="intermediate_step"),
        ]
    )
    prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
    agent = prompt | llm.bind_tools(tools)
    result = agent.invoke(data)
    return {'agent_outcome': [result],}

In [None]:
from langgraph.graph import END, StateGraph
workflow = StateGraph(AgentState)

workflow.add_node("research", research_agent)
workflow.set_entry_point("research")

In [None]:
import json
from langchain_core.messages import ToolMessage

class BasicToolNode:
    def __init__(self, tools: list) -> None:
        self.tools_by_name = {tool.name: tool for tool in tools}

    def __call__(self, inputs: dict):
        print("----tool calling----")
        message = inputs["agent_outcome"][-1]

        outputs = []
        for tool_call in message.tool_calls:
            print(f"---- Calling {tool_call['name']} with args: {tool_call['args']} ----")
            tool_result = self.tools_by_name[tool_call["name"]].invoke(
                tool_call["args"]
            )
            outputs.append(
                ToolMessage(
                    content=json.dumps(tool_result),
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )

        return {
                "agent_outcome": outputs,
                "intermediate_step": [str(outputs)]
            }

tool_node = BasicToolNode(tools=tools)
workflow.add_node("tools", tool_node)

In [10]:
def route_tools(
    state: AgentState,
):
    """
    Use in the conditional_edge to route to the ToolNode if the last message
    has tool calls. Otherwise, route to the end.
    """
    print("----router----")
    if isinstance(state, list):
        ai_message = state[-1]
    elif agent_outcome := state.get("agent_outcome", []):
        ai_message = agent_outcome[-1]
    else:
        raise ValueError(f"No messages found in input state to tool_edge: {state}")

    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return END

In [None]:
workflow.add_conditional_edges(
    "research",
    route_tools,
    {"tools": "tools", END: END}
)

In [None]:
workflow.add_edge("tools", "research")

#### Adding Memmory to the chatbot

LangGraph provides persistence of state (or messages) through checkpointing. If you provide a checkpointer when `compiling the graph` and a `thread_id` when calling your graph, LangGraph automatically saves the state after each step. When you invoke the graph again using the same thread_id, the graph loads its saved state, allowing the chatbot to pick up where it left off.

Reference: [MemorySaver Checkpointer](https://langchain-ai.github.io/langgraph/reference/checkpoints/#langgraph.checkpoint.memory.MemorySaver)

In [13]:
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()

In [None]:
app = workflow.compile(checkpointer=memory)
try:
    display(Image(app.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [None]:
inputs = {
    "input": "What are the recent papers on Small Language Models?",
}

config = {
    "configurable": {
        "thread_id": "123",
    }
}

state = AgentState(**inputs)
events = app.stream(input=state, config=config, stream_mode="values")
for event in events:
    try:
        event["agent_outcome"][-1].pretty_print()
    except Exception as e:
        HumanMessage(event["input"]).pretty_print()

##### Starting Conversation 1

In [None]:
inputs = {
    "input": "Summarise the results into a single paragraph.",
}

state = AgentState(**inputs)
events = app.stream(input=state, config=config, stream_mode="values")
for event in events:
    try:
        event["agent_outcome"][-1].pretty_print()
    except Exception as e:
        HumanMessage(event["input"]).pretty_print()
    print("-----"*20)

In [None]:
inputs = {
    "input": "What are the recent papers on LLM Agents?",
}

config_new = {
    "configurable": {
        "thread_id": "456",
    }
}

state = AgentState(**inputs)
events = app.stream(input=state, config=config_new, stream_mode="values")
for event in events:
    try:
        event["agent_outcome"][-1].pretty_print()
    except Exception as e:
        HumanMessage(event["input"]).pretty_print()
    print("-----"*20)

##### Starting Conversation 2

In [None]:
inputs = {
    "input": "Can you find LLM Agents for report or article generation?",
}

state = AgentState(**inputs)
events = app.stream(input=state, config=config_new, stream_mode="values")
for event in events:
    try:
        event["agent_outcome"][-1].pretty_print()
    except Exception as e:
        HumanMessage(event["input"]).pretty_print()
    print("-----"*20)

In [None]:
inputs = {
    "input": "Summarise all the papers like survey report.",
}

state = AgentState(**inputs)
events = app.stream(input=state, config=config_new, stream_mode="values")
for event in events:
    try:
        event["agent_outcome"][-1].pretty_print()
    except Exception as e:
        HumanMessage(event["input"]).pretty_print()
    print("-----"*20)

##### Switching back to conversation 1

In [None]:
inputs = {
    "input": "What is the research topic?",
}

state = AgentState(**inputs)
events = app.stream(input=state, config=config, stream_mode="values")
for event in events:
    try:
        event["agent_outcome"][-1].pretty_print()
    except Exception as e:
        HumanMessage(event["input"]).pretty_print()
    print("-----"*20)

In [None]:
snapshot = app.get_state(config=config)
snapshot

In [None]:
snapshot = app.get_state(config=config_new)
snapshot