In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import nest_asyncio

nest_asyncio.apply()

# Imports

In [None]:
from dotenv import load_dotenv
from IPython.display import Image, display
from langchain_core.messages import (
    AIMessage,
    HumanMessage,
    RemoveMessage,
    SystemMessage,
)
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, MessagesState, StateGraph

load_dotenv()

# Memory

In [None]:
memory = MemorySaver()
model = ChatOpenAI(model="gpt-4.1-mini", temperature=1)


def chat_model_node(state: MessagesState):
    return {"messages": [model.invoke(state["messages"])]}


# Build graph
graph = StateGraph(
    MessagesState,
)
graph.add_node("chat_model", chat_model_node)
graph.add_edge(START, "chat_model")
graph.add_edge("chat_model", END)
agent = graph.compile(checkpointer=memory)

# View
display(Image(agent.get_graph().draw_mermaid_png()))

In [None]:
config = {"configurable": {"thread_id": "1"}}
messages = [HumanMessage(content="Get 3 random numbers")]
messages = agent.invoke({"messages": messages}, config=config)

for m in messages["messages"]:
    m.pretty_print()

In [None]:
agent.get_state(config=config)

## Remove messages

In [None]:
model = ChatOpenAI(model="gpt-4.1-mini", temperature=0)


# Nodes
def filter_messages(state: MessagesState):
    delete_messages = [RemoveMessage(id=m.id) for m in state["messages"][:-2]]
    return {"messages": delete_messages}


def chat_model_node(state: MessagesState):
    return {"messages": [model.invoke(state["messages"])]}


# Build graph
graph = StateGraph(MessagesState)
graph.add_node("filter", filter_messages)
graph.add_node("chat_model", chat_model_node)
graph.add_edge(START, "filter")
graph.add_edge("filter", "chat_model")
graph.add_edge("chat_model", END)
agent = graph.compile()

# View
display(Image(agent.get_graph().draw_mermaid_png()))

In [None]:
# Message list with a preamble
messages = [
    HumanMessage("Hi.", name="Dylan", id="2"),
    AIMessage("Hi. How can I help you today?", name="Bot", id="1"),
    HumanMessage(
        "Can you write a detailed explanation of the theory of relativity?",
        name="Dylan",
        id="4",
    ),
]

output = agent.invoke({"messages": messages})
for m in output["messages"]:
    m.pretty_print()

## Trim messages

In [None]:
from langchain_core.messages import trim_messages


# Node
def chat_model_node(state: MessagesState):
    messages = trim_messages(
        state["messages"],
        max_tokens=25,
        strategy="last",
        token_counter=ChatOpenAI(model="gpt-4o"),
        allow_partial=True,
    )
    return {"messages": [model.invoke(messages)]}


# Build graph
graph = StateGraph(MessagesState)
graph.add_node("chat_model", chat_model_node)
graph.add_edge(START, "chat_model")
graph.add_edge("chat_model", END)
agent = graph.compile()

# View
display(Image(agent.get_graph().draw_mermaid_png()))

In [None]:
trim_messages(
    messages,
    max_tokens=25,
    strategy="last",
    token_counter=ChatOpenAI(model="gpt-4o"),
    allow_partial=False,
)

In [None]:
messages_out_trim = agent.invoke({"messages": messages})

for m in messages_out_trim["messages"]:
    m.pretty_print()

## Summarize messages

In [None]:
class State(MessagesState):
    summary: str


def call_model(state: State):
    summary = state.get("summary", "")
    if summary:
        system_message = f"Summary of conversation earlier: {summary}"
        messages = [SystemMessage(content=system_message)] + state["messages"]
    else:
        messages = state["messages"]
    response = model.invoke(messages)
    return {"messages": response}


def summarize_conversation(state: State):
    summary = state.get("summary", "")
    if summary:
        summary_message = (
            f"This is summary of the conversation to date: {summary}\n\n"
            "Extend the summary by taking into account the new messages above:"
        )
    else:
        summary_message = "Create a summary of the conversation above:"

    messages = state["messages"] + [HumanMessage(content=summary_message)]
    response = model.invoke(messages)

    delete_messages = [RemoveMessage(id=m.id) for m in state["messages"][:-2]]
    return {"summary": response.content, "messages": delete_messages}


def should_continue(state: State):
    """Return the next node to execute."""
    messages = state["messages"]

    if len(messages) > 6:
        return "summarize_conversation"

    return END

In [None]:
graph = StateGraph(State)
graph.add_node("conversation", call_model)
graph.add_node(summarize_conversation)

graph.add_edge(START, "conversation")
graph.add_conditional_edges(
    "conversation",
    should_continue,
    {
        "summarize_conversation": "summarize_conversation",
        END: END,
    },
)
graph.add_edge("summarize_conversation", END)

# Compile
memory = MemorySaver()
agent = graph.compile(checkpointer=memory)
display(Image(agent.get_graph().draw_mermaid_png()))

In [None]:
# Create a thread
config = {"configurable": {"thread_id": "1"}}

# Start conversation
input_message = HumanMessage(content="hi! I'm Dylan")
output = agent.invoke({"messages": [input_message]}, config)
for m in output["messages"]:
    m.pretty_print()

input_message = HumanMessage(content="what's my name?")
output = agent.invoke({"messages": [input_message]}, config)
for m in output["messages"][2:]:
    m.pretty_print()

input_message = HumanMessage(content="i like the 49ers!")
output = agent.invoke({"messages": [input_message]}, config)
for m in output["messages"][4:]:
    m.pretty_print()

In [None]:
agent.get_state(config).values.get("summary", "")

In [None]:
input_message = HumanMessage(
    content="i like Nick Bosa, isn't he the highest paid defensive player?"
)
output = agent.invoke({"messages": [input_message]}, config)
for m in output["messages"]:
    m.pretty_print()

In [None]:
agent.get_state(config).values.get("summary", "")

In [None]:
input_message = HumanMessage(
    content="i like the 49ers, isn't it true that they are the best team in the NFL?"
)
output = agent.invoke({"messages": [input_message]}, config)
for m in output["messages"]:
    m.pretty_print()