In [None]:
#Install required dependencies
%pip install -r requirements.txt

In [None]:
#dotenv module
from dotenv import load_dotenv
_ = load_dotenv()

In [None]:
#Adding typing module for typed python
# Documentation for python typing annotation https://docs.python.org/3/library/typing.html
from typing import TypedDict, Annotated
import operator

In [None]:
#langgraph module and tavily search tool
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage, HumanMessage, AIMessage
from langchain_tavily import TavilySearch
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from uuid import uuid4
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()


In [None]:
# Google Gemini used for this example : https://python.langchain.com/docs/integrations/providers/google/
from langchain.chat_models import init_chat_model

tool = TavilySearch(max_results=2)
model = init_chat_model("google_genai:gemini-2.5-pro").bind_tools([tool])

In [None]:
"""
In previous examples we've annotated the `messages` state key 
with the default `operator.add` or `+` reducer, which always 
appends new messages to the end of the existing messages array.

Now, to support replacing existing messages, we annotate the
`messages` key with a customer reducer function, which replaces
messages with the same `id`, and appends them otherwise.
"""
def reduce_messages(left: list[AnyMessage], right: list[AnyMessage]) -> list [AnyMessage]:
    # assign ids to messages that don't have them
    for message in right:
        if not message.id:
            message.id = str(uuid4())
    # merge the new messages with the existing messages
    merged = left.copy()
    for message in right:
        for i, existing in enumerate(merged):
            # replace any existing messages with the same id
            if existing.id == message.id:
                merged[i] = message
                break
        else:
            # append any new messages to the end
            merged.append(message)
    return merged

class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], reduce_messages]

In [None]:
class Agent:

    def __init__(self, model, tools, checkpointer, system=""):
        self.system = system
        graph = StateGraph(AgentState)
        graph.add_node("llm", self.call_openai)
        graph.add_node("action", self.take_action)
        graph.add_conditional_edges(
            "llm",
            self.exists_action,
            {True: "action", False: END}
        )
        graph.add_edge("action", "llm")
        graph.set_entry_point("llm")
        # The interrupt before action makes to stop the node before executing any tool
        self.graph = graph.compile(checkpointer = checkpointer,
                                   interrupt_before=["action"])
        self.tools = {t.name: t for t in tools}
        self.model = model

    def exists_action(self, state: AgentState):
        result = state['messages'][-1]
        return len(result.tool_calls) > 0

    def call_openai(self, state: AgentState):
        messages = state['messages']
        if self.system:
            messages = [SystemMessage(content=self.system)] + messages
        message = self.model.invoke(messages)
        return {'messages': [message]}

    def take_action(self, state: AgentState):
        tool_calls = state['messages'][-1].tool_calls
        results = []
        for t in tool_calls:
            if not t['name'] in self.tools:      # check for bad tool name from LLM
                print("\n ....bad tool name....")
                result = "bad tool name, retry"  # instruct LLM to retry if bad
            else:
                result = self.tools[t['name']].invoke(t['args'])
            results.append(ToolMessage(tool_call_id=t['id'], name=t['name'], content=str(result)))
        return {'messages': results}

In [None]:
prompt = """You are a smart research assistant. Use the search engine to look up information.
You are allowed to make multiple calls (either together or in sequence).
Only look up information when you are sure of what you want.
If you need to look up some information before asking a follow un question, you are allowed to do that!"""

abot = Agent(model, [tool], system=prompt, checkpointer=memory)

In [None]:
# With this, the loop stops just before calling tavily seach api
messages = [HumanMessage(content="How is the weather in SF?")]
thread = {"configurable": {"thread_id": "1"}}
for event in abot.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v)

In [None]:
abot.graph.get_state(thread)

In [None]:
# We can see the next stage of the graph
abot.graph.get_state(thread).next

## Continue after the stop

In [None]:
# It's enough just adding None as a first parameter
for event in abot.graph.stream(None, thread):
    for v in event.values():
        print(v)

In [None]:
abot.graph.get_state(thread)

In [None]:
abot.graph.get_state(thread).next

In [None]:
messages = [HumanMessage("Whats the weather in Los Angeles?")]
thread = {"configurable": {"thread_id": "2"}}
for event in abot.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v)
while abot.graph.get_state(thread).next:
    print("\n", abot.graph.get_state(thread),"\n")
    _input = input("proceed?")
    if _input != "y":
        print("aborting")
        break
    for event in abot.graph.stream(None, thread):
        for v in event.values():
            print(v)

## Modify State

Run until the interrupt and then, modify the state

In [None]:
messages = [HumanMessage("Whats the weather in Los Angeles?")]
thread = {"configurable": {"thread_id": "3"}}
for event in abot.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v)

In [None]:
abot.graph.get_state(thread)

In [None]:
current_values = abot.graph.get_state(thread)

In [None]:
current_values.values['messages'][-1]

In [None]:
current_values.values['messages'][-1].tool_calls

In [None]:
_id = current_values.values['messages'][-1].tool_calls[0]['id']
current_values.values['messages'][-1].tool_calls = [
    {'name': 'tavily_search_results_json',
  'args': {'query': 'current weather in Louisiana'},
  'id': _id}
]

In [None]:
abot.graph.update_state(thread, current_values.values)

In [None]:
abot.graph.get_state(thread)

In [None]:
for event in abot.graph.stream(None, thread):
    for v in event.values():
        print(v)

## Time Travel

In [None]:
states = []
for state in abot.graph.get_state_history(thread):
    print(state)
    print('--')
    states.append(state)

To fetch the same state as was filmed, the offset below is changed to -3 from -1. This accounts for the initial state __start__ and the first state that are now stored to state memory with the latest version of software.

In [None]:
to_replay = states[-3]

In [None]:
to_replay

In [None]:
for event in abot.graph.stream(None, to_replay.config):
    for k, v in event.items():
        print(v)

## Go back in time and edit

In [None]:
to_replay

In [None]:
_id = to_replay.values['messages'][-1].tool_calls[0]['id']
to_replay.values['messages'][-1].tool_calls = [{'name': 'tavily_search_results_json',
  'args': {'query': 'current weather in LA, accuweather'},
  'id': _id}]

In [None]:
branch_state = abot.graph.update_state(to_replay.config, to_replay.values)

In [None]:
for event in abot.graph.stream(None, branch_state):
    for k, v in event.items():
        if k != "__end__":
            print(v)

## Add message state for a given time

In [None]:
to_replay

In [None]:
_id = to_replay.values['messages'][-1].tool_calls[0]['id']

In [None]:
state_update = {"messages": [ToolMessage(
    tool_call_id=_id,
    name="tavily_search_results_json",
    content="54 degree celcius",
)]}

In [None]:
branch_and_add = abot.graph.update_state(
    to_replay.config, 
    state_update, 
    as_node="action")

In [None]:
for event in abot.graph.stream(None, branch_and_add):
    for k, v in event.items():
        print(v)