In [None]:
!pip install --quiet langgraph langchain langchain-ollama

In [None]:
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from typing import List, Literal, Annotated
from langgraph.graph.message import add_messages
import uuid
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_ollama import ChatOllama
import os
from getpass import getpass
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
import json

In [None]:
template = """Your job is to book a flight for a user based on user input.

You should get the following information from them:

1. What the departure city is
2. What the arrival city is
3. What the date of travel is 
4. What the legal name of the user is

If you are not able to discern this info, ask them to clarify! Do not attempt to wildly guess.

After you are able to discern all the information, call the relevant tool."""

In [None]:
def get_messages_info(messages):
    return [SystemMessage(content=template)] + messages

In [None]:
def info_chain(state):
    messages = get_messages_info(state["messages"])
    response = llm_with_tool.invoke(messages)
    return {"messages": [response]}

In [None]:
class State(TypedDict):
    messages: Annotated[list, add_messages]

In [None]:
@tool
def book_flight(from_city: str, to_city: str, travel_date: str, passenger_name: str) -> str:
    """Book a flight for the customer. Call this whenever you need to book a flight, for example when a customer asks 'I want to book a flight from Los Angeles to New York'
    Args:
        from_city: The departure city
        to_city: The arrival city
        travel_date: The date of travel
        passenger_name: The passenger's legal name.
    """
    output_json = { "from_city": from_city, "to_city": to_city, "travel_date": travel_date, "passenger_name": passenger_name }
    print("book_flight is called")
    return json.dumps(output_json)

In [None]:
def route_tools(state) -> Literal["tools", "end"]:
    """
    Determine whether to continue to tools or end.
    
    This function checks if the last message has tool calls.
    If yes, route to tools node. If no, end the conversation.
    """
    messages = state["messages"]
    last_message = messages[-1]
    
    # If there are tool calls, continue to tools node
    if isinstance(last_message, AIMessage) and last_message.tool_calls:
        return "tools"
    elif not isinstance(messages[-1], HumanMessage):
        return "end"    
    # Otherwise, end the conversation
    return "info"

In [None]:
#llm = ChatOllama(model="gemma3:4b", base_url="http://10.8.4.240:11434")    
llm = ChatOllama(model="llama3.1:8b", base_url="http://10.8.4.240:11434")    
tools = [book_flight]
llm_with_tool = llm.bind_tools(tools)
#llm_with_tool = create_agent(llm, tools)

In [None]:
memory = InMemorySaver()
workflow = StateGraph(State)
# Add nodes
workflow.add_node("info", info_chain)
workflow.add_node("tools", ToolNode(tools))

# Add edges
workflow.add_edge(START, "info")
workflow.add_conditional_edges("info", route_tools,{"tools":"tools","end": END,"info":"info"})
workflow.add_edge("tools", END)
# Compile
graph = workflow.compile(checkpointer=memory)

In [None]:
cached_human_responses = ["hi!", "rag prompt", "1 rag, 2 none, 3 no, 4 no", "red", "q"]
cached_response_index = 0
config = {"configurable": {"thread_id": str(uuid.uuid4())}}

while True:
    try:
        user = input("User (q/Q to quit): ")
    except:
        user = cached_human_responses[cached_response_index]
        cached_response_index += 1
    if user in {"q", "Q"}:
        print("AI: Bye bye")
        break
    output   = None
    set_exit = False
    for output in graph.stream( {"messages": [HumanMessage(content=user)]}, config=config, stream_mode="updates" ):
        last_message = next(iter(output.values()))["messages"][-1]
        if hasattr(last_message, "content"):
            if isinstance(last_message.content, list):
                texts = [part.get("text", "") for part in last_message.content if isinstance(part, dict) and "text" in part]
                print("AI:", " ".join(texts))
            else:
                print("AI:", last_message.content)
        else:
            print("AI:", last_message)