# Context and State

In [None]:
from typing import Literal

from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langchain.messages import HumanMessage, ToolMessage
from langchain.tools import ToolRuntime, tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import MessagesState
from langgraph.types import Command
from pydantic import BaseModel, Field

from chain_reaction.config import APIKeys, ModelBehavior, ModelName
from chain_reaction.utils import get_last_message

api_keys = APIKeys()

chat_model = init_chat_model(
    model=ModelName.CLAUDE_HAIKU,
    timeout=None,
    max_retries=2,
    api_key=api_keys.anthropic,
    **ModelBehavior.deterministic().model_dump(),
)

## Context Schema

`context_schema` allows you to define additional parameters that get injected into a tool call at runtime, separate from the tool's parameters.

These context parameters:

- Provide runtime context like user IDs, session data, or configuration
- Are automatically injected when the tool is invoked via the `ToolRuntime`
  - Are NOT passed to the LLM directly
  - Don't appear in the tool's description sent to the LLM
  - Allow you to reduce context LLM has to deal with


### Why Use `context_schema`?

Without `context_schema`, you'd need to either:

- Include sensitive data in prompts (security risk)
- Use global variables (not thread-safe)
- Create tool instances per request (inefficient)

Using a `context_schema` and tools, the agent is given the tools to access the appropriate fields per context instance, but it does not have access directly to these fields and it does not see these fields when invoked. They are injected only when the tool is called.


### Common Use Cases

- Authentication context: User IDs, permissions, API keys
- Environment config: Database connections, feature flags
- Session data: Shopping cart, conversation history


In [None]:
# Define a context schema for storing user preferences
type TimeOfDay = Literal["morning", "afternoon", "evening"]


class UserPreferences(BaseModel):
    """Context schema for user preferences.

    Attributes:
        favorite_drink_by_time (dict[TimeOfDay, str]): A mapping of time of day to the user's favorite drink.
        preferred_language (str): The user's preferred language.
    """

    favorite_drink_by_time: dict[TimeOfDay, str] = Field(
        description="The user's favorite drink by time of day",
        default_factory=lambda: {
            "morning": "coffee",
            "afternoon": "tea",
            "evening": "beer",
        },
    )
    preferred_language: str = Field(default="English", description="The user's preferred language")


# Create tools for accessing user preferences
@tool
def get_preferred_drink_from_context(time_of_day: TimeOfDay, runtime: ToolRuntime[UserPreferences]) -> str:
    """Get the user's favorite drink by time of day."""
    return runtime.context.favorite_drink_by_time[time_of_day]


@tool
def get_preferred_language_from_context(runtime: ToolRuntime[UserPreferences]) -> str:
    """Get the user's preferred language."""
    return runtime.context.preferred_language


# Create the agent with context schema and tools
agent = create_agent(
    model=chat_model,
    system_prompt="Please respond to the user's queries based on their preferences.",
    tools=[get_preferred_drink_from_context, get_preferred_language_from_context],
    context_schema=UserPreferences,
)

In [None]:
# Create a default context instance for a user
user_context = UserPreferences()

# Invoke the agent
response = agent.invoke(
    input={"messages": [HumanMessage(content="Good morning! What's my favorite drink?")]},
    context=user_context,
)

# Extract and print the last message from the agent's response
last_message = get_last_message(response)
if last_message:
    print("Agent Response:", last_message.content)

In [None]:
# Create a custom context instance for a user
user_context = UserPreferences(
    favorite_drink_by_time={
        "morning": "grog",
        "afternoon": "ale",
        "evening": "rum",
    },
    preferred_language="Pirate English",
)

# Invoke the agent
response = agent.invoke(
    input={"messages": [HumanMessage(content="Arr! The Moon is high! What's me favorite drink?")]},
    context=user_context,
)

# Extract and print the last message from the agent's response
last_message = get_last_message(response)
if last_message:
    print("Agent Response:", last_message.content)

# Custom State

- The custom context we just defined is immutable, which means our agent cannot update it.
- If we want the agent to store information it learns in a structured format we can use its state
- Checkpointers kep track of historical messages, but we can add custom structured fields to an `AgentState` class for the agent to keep track of

In [None]:
# Define an agent state schema for learning user preferences over time
class UserPreferencesState(MessagesState):
    """Agent state to learn and store user preferences over time.

    Attributes:
        favorite_drink_by_time (dict[TimeOfDay, str]): A mapping of time of day to the user's favorite drink.
        preferred_language (str): The user's preferred language.
    """

    favorite_drink_by_time: dict[TimeOfDay, str]
    preferred_language: str


# Create tools to update and retrieve user preferences from the agent state
@tool
def update_favorite_drink_state(
    time_of_day: TimeOfDay,
    drink: str,
    runtime: ToolRuntime[UserPreferencesState],
) -> Command:
    """Update the user's favorite drink by time of day.

    Args:
        time_of_day (TimeOfDay): The time of day to update.
        drink (str): The new favorite drink.
        runtime (ToolRuntime[UserPreferencesState]): The tool runtime with access to agent state.

    Returns:
        Command: A command to update the agent state.
    """
    # Get current favorites from state
    favorite_drink_by_time: dict[TimeOfDay, str] = runtime.state.get("favorite_drink_by_time", {})

    # Update current favorites
    favorite_drink_by_time.update({time_of_day: drink})

    # Generate command to update current favorites in state
    return Command(
        update={
            "favorite_drink_by_time": favorite_drink_by_time,
            "messages": [
                ToolMessage(
                    content=f"Updated favorite drink for {time_of_day} to {drink}.",
                    tool_call_id=runtime.tool_call_id,
                )
            ],
        }
    )


@tool
def get_preferred_drink_from_state(time_of_day: TimeOfDay, runtime: ToolRuntime[UserPreferencesState]) -> str | None:
    """Get the user's favorite drink by time of day.

    Args:
        time_of_day (TimeOfDay): The time of day to retrieve.
        runtime (ToolRuntime[UserPreferencesState]): The tool runtime with access to agent state.

    Returns:
        str | None: The user's favorite drink for the specified time of day, or None if not set.
    """
    # Get current favorites from state
    favorite_drink_by_time: dict[TimeOfDay, str] = runtime.state.get("favorite_drink_by_time", {})

    return favorite_drink_by_time.get(time_of_day)


# Create the agent with state, memory, and tools
agent = create_agent(
    model=chat_model,
    system_prompt="Please respond to the user's queries based on their learned preferences.",
    tools=[get_preferred_drink_from_state, update_favorite_drink_state],
    checkpointer=InMemorySaver(),
    state_schema=UserPreferencesState,
)

In [None]:
# Invoke the agent to learn a new preference
user_config = {"configurable": {"thread_id": "1"}}
response = agent.invoke(
    input={"messages": [HumanMessage(content="My favorite drink in the morning is coffee.")]},
    config=user_config,
)

# Extract and print the last message from the agent's response
last_message = get_last_message(response)
if last_message:
    print("Agent Response:", last_message.content)

In [None]:
# Invoke the agent again to retrieve the learned preference
response = agent.invoke(
    input={"messages": [HumanMessage(content="It's 7am what should I drink?")]},
    config=user_config,
)

# Extract and print the last message from the agent's response
last_message = get_last_message(response)
if last_message:
    print("Agent Response:", last_message.content)

In [None]:
# Invoke the agent trying to retrieve a new preference that hasn't been set yet
response = agent.invoke(
    input={"messages": [HumanMessage(content="I'm headed home after work, what should I drink?")]},
    config=user_config,
)

# Extract and print the last message from the agent's response
last_message = get_last_message(response)
if last_message:
    print("Agent Response:", last_message.content)

In [None]:
# Invoke the agent to teach it a new preference
response = agent.invoke(
    input={"messages": [HumanMessage(content="Oh right, I like water in the afternoon")]},
    config=user_config,
)

# Extract and print the last message from the agent's response
last_message = get_last_message(response)
if last_message:
    print("Agent Response:", last_message.content)