In [9]:
import getpass
import os
from dotenv import load_dotenv

os.environ["LANGCHAIN_TRACING_V2"] = "true"

load_dotenv()
if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")
if "LANGCHAIN_API_KEY" not in os.environ:
    os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your LangChain API key: ")


In [10]:
from langchain_core.messages.utils import (
    trim_messages,
    count_tokens_approximately
)
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, START, MessagesState
from langgraph.checkpoint.memory import InMemorySaver

model = init_chat_model(
    "openai:gpt-4o",
    temperature=0.7
)
summarization_model = model.bind(max_tokens=128)

def call_model(state: MessagesState):
    messages = trim_messages(
        state["messages"],
        strategy="last",
        token_counter=count_tokens_approximately,
        max_tokens=128,
        start_on="human",
        end_on=("human", "tool"),
    )
    response = model.invoke(messages)
    return {"messages": [response]}


checkpointer = InMemorySaver()
builder = StateGraph(MessagesState)
builder.add_node(call_model)
builder.add_edge(START, "call_model")
graph = builder.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": "hi, my name is bob"}, config)
graph.invoke({"messages": "write a short poem about cats"}, config)
graph.invoke({"messages": "now do the same but for dogs"}, config)
final_response = graph.invoke({"messages": "what's my name?"}, config)

final_response["messages"][-1].pretty_print()


I'm sorry, but I can't determine your name based on the information provided.


In [11]:
# Let's examine the MessagesState and see how state is managed
print("=== Understanding MessagesState ===")
print()

# 1. What is MessagesState?
from langgraph.graph import MessagesState
print("1. MessagesState definition:")
print(f"   MessagesState is a TypedDict with: {MessagesState.__annotations__}")
print()

# 2. How does the checkpointer work?
print("2. Checkpointer state management:")
config = {"configurable": {"thread_id": "demo"}}

# Let's create a new graph to demonstrate
demo_builder = StateGraph(MessagesState)
demo_builder.add_node("demo_node", lambda state: {"messages": [f"Response to: {state['messages'][-1].content}"]})
demo_builder.add_edge(START, "demo_node")
demo_graph = demo_builder.compile(checkpointer=InMemorySaver())

# First interaction
print("   First call:")
result1 = demo_graph.invoke({"messages": "Hello"}, config)
print(f"   Messages after call 1: {len(result1['messages'])} messages")

# Second interaction - state persists!
print("   Second call (same thread_id):")
result2 = demo_graph.invoke({"messages": "How are you?"}, config)
print(f"   Messages after call 2: {len(result2['messages'])} messages")
print("   → State automatically accumulates because of the checkpointer!")
print()

# 3. Show the actual message history
print("3. Accumulated message history:")
for i, msg in enumerate(result2['messages']):
    print(f"   Message {i+1}: {msg.content[:50]}...")
print()

print("Key insight: The 'messages' are stored in the checkpointer, not declared in your code!")

=== Understanding MessagesState ===

1. MessagesState definition:
   MessagesState is a TypedDict with: {'messages': ForwardRef('Annotated[list[AnyMessage], add_messages]', module='langgraph.graph.message')}

2. Checkpointer state management:
   First call:
   Messages after call 1: 2 messages
   Second call (same thread_id):
   Messages after call 2: 4 messages
   → State automatically accumulates because of the checkpointer!

3. Accumulated message history:
   Message 1: Hello...
   Message 2: Response to: Hello...
   Message 3: How are you?...
   Message 4: Response to: How are you?...

Key insight: The 'messages' are stored in the checkpointer, not declared in your code!


In [12]:
# Now let's understand how trim_messages works
print("=== Understanding trim_messages() ===")
print()

# Let's examine the trim_messages function in detail
from langchain_core.messages import HumanMessage, AIMessage

# Create a sample conversation
sample_messages = [
    HumanMessage("Hi, my name is Bob"),
    AIMessage("Hello Bob! Nice to meet you."),
    HumanMessage("Write a short poem about cats"),
    AIMessage("Cats are graceful, cats are sweet..."),
    HumanMessage("Now do the same but for dogs"),
    AIMessage("Dogs are loyal, dogs are true..."),
    HumanMessage("What's my name?"),
]

print("1. Original message count:", len(sample_messages))
print()

# Test trim_messages with different parameters
print("2. Trim with max_tokens=128, strategy='last':")
trimmed = trim_messages(
    sample_messages,
    strategy="last",  # Keep the most recent messages
    token_counter=count_tokens_approximately,
    max_tokens=128,
    start_on="human",  # Always start with a human message
    end_on=("human", "tool"),  # End on human or tool message
)
print(f"   Trimmed to: {len(trimmed)} messages")
for i, msg in enumerate(trimmed):
    print(f"   {i+1}. {type(msg).__name__}: {msg.content[:50]}...")
print()

print("3. What each parameter does:")
print("   - strategy='last': Keep the most recent messages within token limit")
print("   - max_tokens=128: Maximum tokens to keep")
print("   - start_on='human': Always start the trimmed conversation with a human message")
print("   - end_on=('human', 'tool'): End on human or tool message (ensures complete exchange)")
print("   - token_counter: Function to count tokens approximately")
print()

print("4. Why trimming matters in your example:")
print("   - Without trimming: All 4 exchanges would be sent to the model")
print("   - With trimming: Only recent exchanges that fit in 128 tokens")
print("   - Result: The model forgets 'Bob' because early messages get trimmed!")

=== Understanding trim_messages() ===

1. Original message count: 7

2. Trim with max_tokens=128, strategy='last':
   Trimmed to: 7 messages
   1. HumanMessage: Hi, my name is Bob...
   2. AIMessage: Hello Bob! Nice to meet you....
   3. HumanMessage: Write a short poem about cats...
   4. AIMessage: Cats are graceful, cats are sweet......
   5. HumanMessage: Now do the same but for dogs...
   6. AIMessage: Dogs are loyal, dogs are true......
   7. HumanMessage: What's my name?...

3. What each parameter does:
   - strategy='last': Keep the most recent messages within token limit
   - max_tokens=128: Maximum tokens to keep
   - start_on='human': Always start the trimmed conversation with a human message
   - end_on=('human', 'tool'): End on human or tool message (ensures complete exchange)
   - token_counter: Function to count tokens approximately

4. Why trimming matters in your example:
   - Without trimming: All 4 exchanges would be sent to the model
   - With trimming: Only recent 

# Summary: Answering Your Questions

## Where is the state getting declared as messages for this thread?

**Answer**: The state is NOT explicitly declared in your code. Here's the flow:

1. **`MessagesState`** - Pre-built TypedDict from LangGraph with `messages: Annotated[list[BaseMessage], add_messages]`
2. **`InMemorySaver()`** - Checkpointer that automatically saves/loads state by `thread_id`
3. **`config = {"configurable": {"thread_id": "1"}}`** - This tells LangGraph which conversation thread to use
4. **Each `invoke()`** - Automatically loads existing messages, adds new ones, saves back to checkpointer

## How the full flow works:

```
Call 1: graph.invoke({"messages": "hi, my name is bob"}, config)
  → Checkpointer loads: [] (empty)
  → Adds user message: [HumanMessage("hi, my name is bob")]
  → trim_messages() processes: [HumanMessage("hi, my name is bob")]
  → Model responds: [HumanMessage("hi..."), AIMessage("Hello Bob...")]
  → Checkpointer saves state for thread "1"

Call 2: graph.invoke({"messages": "write a short poem about cats"}, config)
  → Checkpointer loads: [HumanMessage("hi..."), AIMessage("Hello Bob...")]
  → Adds user message: [HumanMessage("hi..."), AIMessage("Hello Bob..."), HumanMessage("write a short poem...")]
  → trim_messages() processes and trims if needed
  → Model responds, checkpointer saves updated state

...and so on for each call
```

## Why trimming causes the name to be forgotten:

With `max_tokens=128`, the early messages (including "my name is Bob") get trimmed out, so the model loses that context!

In [13]:
# Let's see exactly what happened - why the name was forgotten
print("=== Debugging the Memory Loss ===")
print()

# Let's trace what messages were actually sent to the model in the final call
config_debug = {"configurable": {"thread_id": "debug"}}

# Create a debug version that shows us what trim_messages is doing
def debug_call_model(state: MessagesState):
    print(f"Full conversation has {len(state['messages'])} messages:")
    for i, msg in enumerate(state['messages']):
        print(f"  {i+1}. {type(msg).__name__}: {msg.content[:60]}...")
    print()
    
    # Apply the same trimming
    trimmed = trim_messages(
        state["messages"],
        strategy="last",
        token_counter=count_tokens_approximately,
        max_tokens=128,
        start_on="human",
        end_on=("human", "tool"),
    )
    
    print(f"After trimming to 128 tokens, only {len(trimmed)} messages remain:")
    for i, msg in enumerate(trimmed):
        print(f"  {i+1}. {type(msg).__name__}: {msg.content[:60]}...")
    print()
    
    # Count tokens for transparency
    token_count = count_tokens_approximately(trimmed)
    print(f"Token count of trimmed messages: ~{token_count}")
    print()
    
    response = model.invoke(trimmed)
    return {"messages": [response]}

# Build debug graph
debug_builder = StateGraph(MessagesState)
debug_builder.add_node("debug_call_model", debug_call_model)
debug_builder.add_edge(START, "debug_call_model")
debug_graph = debug_builder.compile(checkpointer=InMemorySaver())

# Recreate the same conversation
print("Simulating the same conversation with debug output...")
print()
debug_graph.invoke({"messages": "hi, my name is bob"}, config_debug)
debug_graph.invoke({"messages": "write a short poem about cats"}, config_debug)
debug_graph.invoke({"messages": "now do the same but for dogs"}, config_debug)
print("Final question - this will show what the model actually sees:")
final_debug = debug_graph.invoke({"messages": "what's my name?"}, config_debug)

final_debug["messages"][-1].pretty_print()

=== Debugging the Memory Loss ===

Simulating the same conversation with debug output...

Full conversation has 1 messages:
  1. HumanMessage: hi, my name is bob...

After trimming to 128 tokens, only 1 messages remain:
  1. HumanMessage: hi, my name is bob...

Token count of trimmed messages: ~9

Full conversation has 3 messages:
  1. HumanMessage: hi, my name is bob...
  2. AIMessage: Hello Bob! How can I assist you today?...
  3. HumanMessage: write a short poem about cats...

After trimming to 128 tokens, only 3 messages remain:
  1. HumanMessage: hi, my name is bob...
  2. AIMessage: Hello Bob! How can I assist you today?...
  3. HumanMessage: write a short poem about cats...

Token count of trimmed messages: ~36

Full conversation has 3 messages:
  1. HumanMessage: hi, my name is bob...
  2. AIMessage: Hello Bob! How can I assist you today?...
  3. HumanMessage: write a short poem about cats...

After trimming to 128 tokens, only 3 messages remain:
  1. HumanMessage: hi, my name 

In [16]:
# Let's fix the problem - Demo with larger max_tokens
print("=== FIXED VERSION - Larger Token Limit ===")
print()

def fixed_call_model(state: MessagesState):
    # Increase max_tokens to preserve more context
    messages = trim_messages(
        state["messages"],
        strategy="last",
        token_counter=count_tokens_approximately,
        max_tokens=300,  # Much larger token limit
        start_on="human",
        end_on=("human", "tool"),
    )
    response = model.invoke(messages)
    return {"messages": [response]}

# Create fixed graph
fixed_builder = StateGraph(MessagesState)
fixed_builder.add_node("fixed_call_model", fixed_call_model)
fixed_builder.add_edge(START, "fixed_call_model")
fixed_graph = fixed_builder.compile(checkpointer=InMemorySaver())

# Test the fixed version
config_fixed = {"configurable": {"thread_id": "fixed"}}
print("Testing fixed version with max_tokens=1000...")

fixed_graph.invoke({"messages": "hi, my name is bob"}, config_fixed)
fixed_graph.invoke({"messages": "write a short poem about cats"}, config_fixed)
fixed_graph.invoke({"messages": "now do the same but for dogs"}, config_fixed)
fixed_result = fixed_graph.invoke({"messages": "what's my name?"}, config_fixed)

print("Fixed result:")
fixed_result["messages"][-1].pretty_print()

print()
print("Key takeaway: The 128 token limit was too aggressive and trimmed away")
print("the important context about the user's name!")

=== FIXED VERSION - Larger Token Limit ===

Testing fixed version with max_tokens=1000...
Fixed result:

Your name is Bob.

Key takeaway: The 128 token limit was too aggressive and trimmed away
the important context about the user's name!
Fixed result:

Your name is Bob.

Key takeaway: The 128 token limit was too aggressive and trimmed away
the important context about the user's name!
