# Customize State

- Add additional fields to the state
- The chatbot can access tools and forward the results to a human for review.
- This requires a relatively large LLM.

In [1]:
# Built-in library
import asyncio
import json
import logging
import re
import warnings
from pathlib import Path
from pprint import pprint
from typing import (
    Annotated,
    Any,
    Generator,
    Iterable,
    Literal,
    Optional,
    TypedDict,
    Union,
)

# Standard imports
import nest_asyncio
import numpy as np
import numpy.typing as npt
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "white": "#FFFFFF",  # Bright white
        "info": "#00FF00",  # Bright green
        "warning": "#FFD700",  # Bright gold
        "error": "#FF1493",  # Deep pink
        "success": "#00FFFF",  # Cyan
        "highlight": "#FF4500",  # Orange-red
    }
)
console = Console(theme=custom_theme)

# Visualization
# import matplotlib.pyplot as pltife

# NumPy settings
np.set_printoptions(precision=4)

# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

# Polars settings
pl.Config.set_fmt_str_lengths(1_000)
pl.Config.set_tbl_cols(n=1_000)

warnings.filterwarnings("ignore")

# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [2]:
def go_up_from_current_directory(*, go_up: int = 1) -> None:
    """This is used to up a number of directories.

    Params:
    -------
    go_up: int, default=1
        This indicates the number of times to go back up from the current directory.

    Returns:
    --------
    None
    """
    import os
    import sys

    CONST: str = "../"
    NUM: str = CONST * go_up

    # Goto the previous directory
    prev_directory = os.path.join(os.path.dirname(__name__), NUM)
    # Get the 'absolute path' of the previous directory
    abs_path_prev_directory = os.path.abspath(prev_directory)

    # Add the path to the System paths
    sys.path.insert(0, abs_path_prev_directory)
    print(abs_path_prev_directory)

In [3]:
go_up_from_current_directory(go_up=2)


from schemas import ModelEnum  # noqa: E402
from settings import refresh_settings  # noqa: E402
from utilities.client_utils import check_rate_limit  # noqa: E402

settings = refresh_settings()

/Users/neidu/Desktop/Projects/Personal/My_Projects/AI-Tutorials


In [4]:
from langgraph.graph.message import add_messages


class State(TypedDict):
    messages: Annotated[list, add_messages]
    name: str
    birthday: str

### Update The States Inside The Tool

In [None]:
from IPython.display import Image, display
from langchain.chat_models import init_chat_model
from langchain_core.messages import ToolMessage
from langchain_core.tools import InjectedToolCallId, tool
from langchain_tavily import TavilySearch
from langfuse.callback import CallbackHandler
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langchain_litellm import ChatLiteLLM
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.types import Command, interrupt


# Note: Because we're generating a ToolMessage for a state update, we generally require
# the ID of the corresponding too call. We can use LangChain's InjectedToolCallId to
# signal that this argument should not be revealed to the model in the tool's schema.
@tool
def human_assistance(
    name: str, birthday: str, tool_call_id: Annotated[str, InjectedToolCallId]
) -> str:
    """Request assistance from a human."""
    human_response = interrupt(
        {
            "question": "Is this correct?",
            "name": name,
            "birthday": birthday,
        }
    )
    # If the info is correct, update the state
    if human_response.get("correct", "").lower().startswith("y"):
        verified_name = name
        verified_birthday = birthday
        response = "Correct"

    else:
        # Get the corrected information from the human
        verified_name = human_response.get("name", "")
        verified_birthday = human_response.get("birthday", "")
        response = f"Made a correction: {human_response}"
    # Explicitly update the state with a ToolMessage
    state_update = {
        "name": verified_name,
        "birthday": verified_birthday,
        "messages": [ToolMessage(response, tool_call_id=tool_call_id)],
    }
    # Return a Command object
    return Command(update=state_update)

In [None]:
# from langchain_litellm import ChatLiteLLM

# LLM = ChatLiteLLM(
#     model=f"openrouter/{ModelEnum.QWEN_3p0_32B_REMOTE_FREE.value}",
#     temperature=0.1,
#     openrouter_api_key=settings.OPENROUTER_API_KEY.get_secret_value(),
# )

# LLM.invoke("Who are you?").content

"Hello! I am Qwen, a large-scale language model developed by Alibaba Cloud's Tongyi Lab. I can assist you with a wide range of tasks, including answering questions, creating text (like writing stories, official documents, emails, and scripts), performing logical reasoning, coding, and more. I support multiple languages and am designed to be helpful, harmless, and honest. How can I assist you today?"

In [None]:
async def chatbot(state: State) -> dict[str, Any]:
    """Process chat messages through LLM with tools and return response.

    Parameters
    ----------
    state : State
        Current state containing message history.

    Returns
    -------
    dict[str, Any]
        Dictionary containing LLM response message.
        Contains key 'messages' with list of one message.

    Notes
    -----
    Disables parallel tool calling to prevent duplicate tool invocations
    when restarting the graph flow. Asserts at most one tool call per message.
    """
    message = await llm_with_tools.ainvoke(state["messages"])
    # Disable parallel tool calling because we'll be interrupting (human-in-the-loop)
    # to prevent repeating any tool invocations when we restart the graph
    assert len(message.tool_calls) <= 1
    return {"messages": [message]}


# A simple memory saver for this tutorial. In production,
# it's recommennded to use SqliteSaver or PostgresSaver
memory = MemorySaver()

model_str: str = "mistralai:mistral-large-latest"  # "mistralai:ministral-8b-latest"
llm = init_chat_model(model_str)
tavily_search = TavilySearch(max_results=2)
tools = [tavily_search, human_assistance]
llm_with_tools = llm.bind_tools(tools)

# Add nodes
graph_builder = StateGraph(State)
tool_node = ToolNode(tools=tools)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("tools", tool_node)

# Connect nodes
graph_builder.add_edge(START, "chatbot")
graph_builder.add_conditional_edges("chatbot", tools_condition)
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_edge("chatbot", END)

# Add memory, observability and compile the graph
memory = MemorySaver()
langfuse_handler = CallbackHandler()
graph = graph_builder.compile(checkpointer=memory).with_config({"callbacks": [langfuse_handler]})

### Prompt The Chatbot

- Prompt the chatbot to lookup when LangGraph was created (`birthday`) and direct the chatbot to reach out to the human_assistance tool once it has the reqired info.
- By setting the `name` and `birthday` in the arguments for the tool, we force the chatbot to generate proposals for these fields.

In [10]:
user_input: str = (
    "Can you look up when LangGraph was released? When you "
    "have the answer, use the human_assistance tool for review."
)
config = {"configurable": {"thread_id": "1"}}

events = graph.astream(
    {"messages": [{"role": "user", "content": user_input}]},
    config=config,
    stream_mode="values",
)
async for event in events:
    if "messages" in event:
        event["messages"][-1].pretty_print()


Can you look up when LangGraph was released? When you have the answer, use the human_assistance tool for review.
Tool Calls:
  tavily_search (4W56yVM4L)
 Call ID: 4W56yVM4L
  Args:
    query: LangGraph release date
Name: tavily_search

{"query": "LangGraph release date", "follow_up_questions": null, "answer": null, "images": [], "results": [{"title": "Announcing LangGraph v0.1 & LangGraph Cloud: Running agents at scale ...", "url": "https://blog.langchain.dev/langgraph-cloud/", "content": "Announcing LangGraph v0.1 & LangGraph Cloud: Running agents at scale, reliably Our new infrastructure for running agents at scale, LangGraph Cloud, is available in beta. Separate from the langchain package, LangGraph’s core design philosophy is to help developers add better precision and control into agentic workflows, suitable for the complexity of real-world systems. LangGraph Cloud, currently in closed beta, is infrastructure for deploying your LangGraph agents in a scalable, fault tolerant way. 

### Add Human Assistance

- The chatbot failed to identify the correct date.
- We'll supply the correct information to it.

In [11]:
human_command = Command(
    resume={
        "name": "LangGraph",
        "birthday": "Jan 17, 2024",
    }
)

events = graph.astream(human_command, config, stream_mode="values")
async for event in events:
    if "messages" in event:
        event["messages"][-1].pretty_print()

Tool Calls:
  human_assistance (FUCAd4Bf5)
 Call ID: FUCAd4Bf5
  Args:
    name: Joshua
    birthday: 1990-01-01
Name: human_assistance

Made a correction: {'name': 'LangGraph', 'birthday': 'Jan 17, 2024'}

LangGraph was released on January 17, 2024.


In [12]:
# The state has been updated to reflect the human's input
snapshot = graph.get_state(config)
{k: v for k, v in snapshot.values.items() if k in ("name", "birthday")}

{'name': 'LangGraph', 'birthday': 'Jan 17, 2024'}

## Manually Update The State

In [13]:
graph.update_state(config, {"name": "LangGraph (library)"})

{'configurable': {'thread_id': '1',
  'checkpoint_ns': '',
  'checkpoint_id': '1f038a71-4a83-6d0c-8006-c92bc4c6881b'}}

### View The New State

In [14]:
snapshot = graph.get_state(config)
{k: v for k, v in snapshot.values.items() if k in ("name", "birthday")}

{'name': 'LangGraph (library)', 'birthday': 'Jan 17, 2024'}

<br>

## Time Travel

- Chatbot workflows typically involve users interacting with a bot to complete a task.
- Memory and human-in-the-loop features help manage the bot's state and responses.
- LangGraph's time travel functionality allows users to restart from previous points, explore different outcomes, or correct mistakes.

- This is a continuation from the previous section (customizing the state).

In [None]:
memory = MemorySaver()

model_str: str = "mistralai:mistral-large-latest"  # "mistralai:ministral-8b-latest"
llm = init_chat_model(model_str)
tavily_search = TavilySearch(max_results=2)
tools = [tavily_search, human_assistance]
llm_with_tools = llm.bind_tools(tools)

# Add nodes
graph_builder = StateGraph(State)
tool_node = ToolNode(tools=tools)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("tools", tool_node)

# Connect nodes
graph_builder.add_edge(START, "chatbot")
graph_builder.add_conditional_edges("chatbot", tools_condition)
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_edge("chatbot", END)

# Add memory, observability and compile the graph
memory = MemorySaver()
langfuse_handler = CallbackHandler()
graph = graph_builder.compile(checkpointer=memory).with_config({"callbacks": [langfuse_handler]})

### Add Steps

In [None]:
user_input: str = "I'm learning LangGraph. Could you do some research on it for me?"
config = {"configurable": {"thread_id": "1"}}

events = graph.astream(
    {"messages": [{"role": "user", "content": user_input}]},
    config=config,
    stream_mode="values",
)
async for event in events:
    if "messages" in event:
        # event["messages"][-1].pretty_print()
        console.log(event["messages"][-1])

In [None]:
user_input: str = "Yes, that's very helpful. Maybe I'll build an autonomous agent with it."
config = {"configurable": {"thread_id": "1"}}

events = graph.astream(
    {"messages": [{"role": "user", "content": user_input}]},
    config=config,
    stream_mode="values",
)
async for event in events:
    if "messages" in event:
        # event["messages"][-1].pretty_print()
        console.log(event["messages"][-1])

<br>

### Replay The Full State History

- You can `replay` the full state history to see everything that occurred in the graph.

In [31]:
to_replay = None

for state in graph.get_state_history(config):
    print(f"Num Messages: {len(state.values['messages'])}, Next: {state.next}")
    print("-" * 80)
    # Select an arbitrary state to replay
    if len(state.values["messages"]) == 6:
        to_replay = state

Num Messages: 4, Next: ()
--------------------------------------------------------------------------------
Num Messages: 3, Next: ('chatbot',)
--------------------------------------------------------------------------------
Num Messages: 6, Next: ()
--------------------------------------------------------------------------------
Num Messages: 5, Next: ('chatbot',)
--------------------------------------------------------------------------------
Num Messages: 4, Next: ('__start__',)
--------------------------------------------------------------------------------
Num Messages: 4, Next: ()
--------------------------------------------------------------------------------
Num Messages: 3, Next: ('chatbot',)
--------------------------------------------------------------------------------
Num Messages: 2, Next: ('tools',)
--------------------------------------------------------------------------------
Num Messages: 1, Next: ('chatbot',)
----------------------------------------------------------

<br>

### Resume From A Checkpoint

- Resuming from a checkpoint will call the next `action` node.

In [32]:
print(to_replay.next)
print(to_replay.config)

()
{'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1f038a8d-1d4c-66e2-8006-24cbf3f0d6c9'}}


<br>

### Load A State From A Moment In Time

- The checkpoint's to_replay.config contains a checkpoint_id timestamp.
- This timestamp tells LangGraph's checkpointer to load the state from that moment in time.

In [33]:
async for event in graph.astream(None, to_replay.config, stream_mode="values"):
    if "messages" in event:
        # event["messages"][-1].pretty_print()
        console.log(event["messages"][-1].content)