In [1]:
import os
import sys 

while "src" not in os.listdir(): os.chdir("..")

if "./src" not in sys.path: sys.path.append("./src")

In [2]:
from os import getenv

from uuid import uuid4

from langchain_openai import ChatOpenAI

from dotenv import load_dotenv

load_dotenv()

llm = ChatOpenAI(
  api_key=getenv("OPENAI_API_KEY"),
  base_url="https://openrouter.ai/api/v1",
  model="gpt-4o-mini",
)

In [3]:
import json

from uuid import uuid4

from typing import Annotated, Literal, List, Dict, Optional, Union, TypedDict

from langchain.tools import InjectedToolCallId, tool, ToolRuntime

from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage, HumanMessage, SystemMessage
from langchain_core.messages.tool import tool_call

from langgraph.checkpoint.memory import InMemorySaver

from langgraph.types import Command, interrupt, Send, Interrupt
from langgraph.graph import add_messages, StateGraph, START, END
from langgraph.prebuilt import ToolNode

In [4]:
def create_ai_message(msg:str) -> AIMessage:
    try:
        content = json.loads(msg)
        return AIMessage(content="calling a tool ... ", tool_calls=[tool_call(**tc) for tc in content])
    except:
        return AIMessage(content=msg)

class FakeMessagesListChatModelWithTools(FakeMessagesListChatModel):
    def bind_tools(self, tools:list):
        return self

In [7]:
ca_responses = [AIMessage(content=f"{i} Yay yay, nay nay!") for i in list(range(100))]
ca_verbose = True 

class ChatAgentState(TypedDict):
    messages: Annotated[List[AnyMessage], add_messages]
    human_wants_out: bool

ca_llm = FakeMessagesListChatModelWithTools(
    responses=ca_responses
)

def ca_llm_node(state: ChatAgentState):
    result = ca_llm.invoke(input=state["messages"])
    if ca_verbose: result.pretty_print()
    return {"messages": [result]}

def ca_interrupt_node(state: ChatAgentState):
    result = interrupt({"prompt": str(state["messages"][-1].content)})
    human_input_as_str = str(result["human_input"])
    human_message = HumanMessage(human_input_as_str)
    human_wants_out = human_input_as_str.strip().lower() in ("exit", "quit")
    if ca_verbose: human_message.pretty_print()
    return {
            "messages": [human_message], 
            "human_wants_out": human_wants_out
        }

def ca_loop_breaker(state: ChatAgentState):
    return END if state["human_wants_out"] else "ca_llm_node"

ca_builder = StateGraph(ChatAgentState)
ca_builder.add_node(ca_llm_node, "ca_llm_node")
ca_builder.add_node(ca_interrupt_node, "ca_interrupt_node")
ca_builder.add_edge(START, "ca_llm_node")
ca_builder.add_edge("ca_llm_node", "ca_interrupt_node")
ca_builder.add_conditional_edges("ca_interrupt_node", ca_loop_breaker)

ca = ca_builder.compile(InMemorySaver())

ca_config = {"configurable": {"thread_id": str(uuid4())}}
result = {"messages": [SystemMessage("You say yay or nay.")], "human_wants_out": False}
input_iter = iter(["foo", "bar", "exit"])
while True:
    result = ca.invoke(input=result, config=ca_config)
    if "__interrupt__" not in result: break 
    human_input = next(input_iter) # input(result["__interrupt__"][0].value["prompt"])
    result = Command(resume={"human_input": human_input})
        


0 Yay yay, nay nay!

foo

1 Yay yay, nay nay!

bar

2 Yay yay, nay nay!

exit


In [11]:
from copy import deepcopy
ca_verbose = False 

class MasterAgentState(TypedDict):
    messages: Annotated[List[AnyMessage], add_messages]
    ca_says_over: bool

def chat_llm_node(state: MasterAgentState):
    tc = tool_call(id=str(uuid4()), name="chat_agent_tool", args={"user_message": str(state["messages"][-1].content)})
    ai_message = AIMessage(content="calling tool ... ", tool_calls=[tc])
    return {"messages": [ai_message], "ca_says_over": False}

@tool
def chat_agent_tool(user_message:str="", resume_data: Union[Dict, str]="", runtime: ToolRuntime=None):
    """This call the specialized chat agent with user's query. 
    Put an empty string as resume_data."""
    if resume_data:
        # this resume_data must be superset of resume expected by chat agent
        # for fields like "config"
        human_or_system_message = HumanMessage(resume_data["human_input"])
        result = Command(resume={k: v for k, v in resume_data.items() if k != "config"}) # subsequent ones
        _config = resume_data["config"]
    else: 
        result = {"messages": runtime.state["messages"], "human_wants_out": False} # first one
        _config = {"configurable": {"thread_id": str(uuid4())}}
        human_or_system_message = result["messages"][-1]

    result = ca.invoke(result, config=_config)
    
    human_or_system_message.pretty_print()
    tool_message = ToolMessage("Success", tool_call_id=runtime.tool_call_id)
    messages = [tool_message, human_or_system_message]
    
    # this indicates the chat agent is done with chatting (no AIMessage)
    if result["human_wants_out"]: return Command(update={"messages": messages, "ca_says_over": True})

    # two possibilities here, interrupt or no interrupt
    # since chat agent may not done yet, we need to interrupt regardless
    if "__interrupt__" in result.keys():
        # in this case, there will be no AIMessage as well
        interrupt_data = {"__interrupt__": result["__interrupt__"][0]}
    else:
        ai_message = result["messages"][-1]
        ai_message.pretty_print()
        messages.append(ai_message)
        interrupt_data = {"__interrupt__": Interrupt({"prompt": str(ai_message.content)}), "id": str(uuid4())}
    
    update = {"messages": messages}
    interrupt_data["config"] = _config

    return Command(goto=Send("chat_agent_interrupt_node", arg={"interrupt_data": interrupt_data} | update))

chat_agent_tool_node = ToolNode(tools=[chat_agent_tool])

def chat_agent_interrupt_node(state: MasterAgentState):
    interrupt_data = state["interrupt_data"]
    human_input = interrupt(interrupt_data["__interrupt__"])
    resume_data = {
            "human_input": human_input,
            "config": interrupt_data["config"]
        }
    tc = tool_call(id = str(uuid4()), name="chat_agent_tool", args={"user_message": "", "resume_data": resume_data})
    ai_message = AIMessage(content="calling tool ... ", tool_calls=[tc]) 
    # return Command(goto=Send("chat_agent_tool_node", arg={"state": {"messages": [ai_message]}}), update={"messages": [ai_message]})
    return Command(goto="chat_agent_tool_node", update={"messages": [ai_message]})

ma_builder = StateGraph(MasterAgentState)
ma_builder.add_node("chat_llm_node", chat_llm_node)
ma_builder.add_node("chat_agent_tool_node", chat_agent_tool_node)
ma_builder.add_node("chat_agent_interrupt_node", chat_agent_interrupt_node)
ma_builder.add_edge(START, "chat_llm_node")
ma_builder.add_edge("chat_llm_node", "chat_agent_tool_node")
ma_builder.add_edge("chat_agent_tool_node", END)
ma = ma_builder.compile(InMemorySaver())


ma_config = {"configurable": {"thread_id": str(uuid4())}}
result = {"messages": [SystemMessage("MA System.")], "ca_says_over": False}
input_iter = iter(["foo foo", "bar bar", "exit"])
while True:
    result = ma.invoke(input=result, config=ma_config)
    if result["ca_says_over"]: break 
    human_input = next(input_iter) # input(result["__interrupt__"][0].value["prompt"])
    result = Command(resume=human_input)
        


calling tool ...
Tool Calls:
  chat_agent_tool (38f7a22b-a996-4b3e-ab18-56fff1b3e190)
 Call ID: 38f7a22b-a996-4b3e-ab18-56fff1b3e190
  Args:
    user_message: MA System.

foo foo

bar bar

exit


In [None]:
ma_config = {"configurable": {"thread_id": str(uuid4())}}
result = ma.invoke(input={"messages": [SystemMessage("User want to chat.")]}, config=ma_config)
result


calling tool ...
Tool Calls:
  chat_agent_tool (21de8d0e-0457-40ef-b9f8-0d15ad54c111)
 Call ID: 21de8d0e-0457-40ef-b9f8-0d15ad54c111
  Args:
    user_message: User want to chat.


{'messages': [SystemMessage(content='User want to chat.', additional_kwargs={}, response_metadata={}, id='557ae29a-b835-4923-b5ce-c10667f5c96a'),
  AIMessage(content='calling tool ... ', additional_kwargs={}, response_metadata={}, id='892b6dcd-8d8a-4e1d-bcbf-eb32c1b1ef18', tool_calls=[{'name': 'chat_agent_tool', 'args': {'user_message': 'User want to chat.'}, 'id': '21de8d0e-0457-40ef-b9f8-0d15ad54c111', 'type': 'tool_call'}]),
  ToolMessage(content='Success', name='chat_agent_tool', id='a74d6f77-837e-44e3-8350-43f4dd332c99', tool_call_id='21de8d0e-0457-40ef-b9f8-0d15ad54c111')],
 'ca_says_over': False,
 '__interrupt__': [Interrupt(value=Interrupt(value={'prompt': '5 Yay yay, nay nay!'}, id='0dbd6d4fe7e089550a7db8dfae7de88f'), id='c732f8ae3966c968ea963dfebeb2a272')]}

In [None]:
result["__interrupt__"][0].value

Interrupt(value={'prompt': '11 Yay yay, nay nay!'}, id='1c42bff0584b89abc348a13f5ddd323c')

In [None]:
example_conversation = """
A: Hi, I am a ticketing agent. How can I help you?
Q: Hello, I want to go to Bali. Can you help me book a flight?
A: [{"id": "tkt_1", "name": "ask_for_help_tkt", "args": {"question": "From where, Human? Round-trip or one-way?"}}]
Q: From Sydney, Australia. Round-trip. 
A: [{"id": "tkt_2", "name": "check_ticket_price", "args": {"origin": "Sydney", "destination": "Bali", "round_trip": true}}, {"id": "tkt_3", "name": "ask_for_help_tkt", "args": {"question": "How many heads, Human?"}}]
Q: Just one. How much would it be?
A: [{"id": "tkt_4", "name": "calculate_total", "args": {"quantity": 1, "unit_price": 1000.99}}]
A: It will be $1000.99, ok to proceed?
Q: Yes, please book the flight.
A: Great! Your flight from Sydney to Bali has been booked. Safe travels!
Q: exit
"""

human_questions = [qim[3:] for qim in example_conversation.split("\n") if qim.startswith("Q: ") or len(qim.strip())==0]
ai_responses = [aim[3:] for aim in example_conversation.split("\n") if aim.startswith("A: ")]

responses=[create_ai_message(msg) for msg in ai_responses]

fake_model = FakeMessagesListChatModelWithTools(
    responses=responses
)

# ticketing_agent = create_ticketing_agent(fake_model, InMemorySaver(), agent_tools=[])

# input_iter = iter(human_questions)
# config={"configurable": {"thread_id": str(uuid4()) }}
# while True:    
#     content = next(input_iter) # input("Your response (or 'exit' to quit): ")
#     human_message = HumanMessage(content=content)
#     human_message.pretty_print()
#     if content.lower() in ['exit', 'quit']:
#         AIMessage(content="Goodbye!").pretty_print()
#         break
#     response = ticketing_agent.invoke({"messages": [human_message], "loop_counter": 0}, config=config)
#     while "__interrupt__" in response:
#         print ("in interrupt")
#         response["messages"][-1].pretty_print()
#         human_response = next(input_iter) # input (str(response["messages"][-1].content) + "\nYour response: ")
#         HumanMessage(content=human_response).pretty_print()
#         response = ticketing_agent.invoke(Command(resume=response), config=config)

#     response["messages"][-1].pretty_print()

In [None]:
import json 
from uuid import uuid4
from typing import Annotated
from langchain.tools import InjectedToolCallId
# from langchain_core.tools import tool
from langchain.tools import ToolRuntime, tool
from langchain_core.messages import HumanMessage, AnyMessage, SystemMessage, ToolMessage
from langgraph.types import Command, interrupt, Send
from langgraph.graph import add_messages
from typing import Literal

from langgraph.prebuilt import ToolNode