# Context and State

In [None]:
from typing import Literal

from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langchain.messages import HumanMessage
from langchain.tools import ToolRuntime, tool
from pydantic import Field
from pydantic.dataclasses import dataclass as pydantic_dataclass

from chain_reaction.config import APIKeys, ModelBehavior, ModelName
from chain_reaction.utils import get_last_message

api_keys = APIKeys()

chat_model = init_chat_model(
    model=ModelName.CLAUDE_HAIKU,
    timeout=None,
    max_retries=2,
    api_key=api_keys.anthropic,
    **ModelBehavior.deterministic().model_dump(),
)

## Context Schema

`context_schema` allows you to define additional parameters that get injected into a tool call at runtime, separate from the tool's parameters.

These context parameters:

- Provide runtime context like user IDs, session data, or configuration
- Are automatically injected when the tool is invoked via the `ToolRuntime`
  - Are NOT passed to the LLM directly
  - Don't appear in the tool's description sent to the LLM
  - Allow you to reduce context LLM has to deal with


### Why Use `context_schema`?

Without `context_schema`, you'd need to either:

- Include sensitive data in prompts (security risk)
- Use global variables (not thread-safe)
- Create tool instances per request (inefficient)

Using a `context_schema` and tools, the agent is given the tools to access the appropriate fields per context instance, but it does not have access directly to these fields and it does not see these fields when invoked. They are injected only when the tool is called.


### Common Use Cases

- Authentication context: User IDs, permissions, API keys
- Environment config: Database connections, feature flags
- Session data: Shopping cart, conversation history


In [None]:
# Define a context schema for storing user preferences
type TimeOfDay = Literal["morning", "afternoon", "evening"]


@pydantic_dataclass
class UserPreferences:
    """Context schema for user preferences.

    Attributes:
        favorite_drink_by_time (dict[TimeOfDay, str]): A mapping of time of day to the user's favorite drink.
        preferred_language (str): The user's preferred language.
    """

    favorite_drink_by_time: dict[TimeOfDay, str] = Field(
        description="The user's favorite drink by time of day",
        default_factory=lambda: {
            "morning": "coffee",
            "afternoon": "tea",
            "evening": "beer",
        },
    )
    preferred_language: str = Field(default="English", description="The user's preferred language")


# Create tools for accessing user preferences
@tool
def get_preferred_drink(time_of_day: TimeOfDay, runtime: ToolRuntime[UserPreferences]) -> str:
    """Get the user's favorite drink by time of day."""
    return runtime.context.favorite_drink_by_time[time_of_day]


@tool
def get_preferred_language(runtime: ToolRuntime[UserPreferences]) -> str:
    """Get the user's preferred language."""
    return runtime.context.preferred_language


# Create the agent with context schema and tools
agent = create_agent(
    model=chat_model,
    system_prompt="Please respond to the user's queries based on their preferences.",
    tools=[get_preferred_drink, get_preferred_language],
    context_schema=UserPreferences,
)

In [None]:
# Create a default context instance for a user
user_context = UserPreferences()

# Invoke the agent
response = agent.invoke(
    input={"messages": [HumanMessage(content="Good morning! What's my favorite drink?")]},
    context=user_context,
)

# Extract and print the last message from the agent's response
last_message = get_last_message(response)
if last_message:
    print("Agent Response:", last_message.content)

In [None]:
# Create a custom context instance for a user
user_context = UserPreferences(
    favorite_drink_by_time={
        "morning": "grog",
        "afternoon": "ale",
        "evening": "rum",
    },
    preferred_language="Pirate English",
)

# Invoke the agent
response = agent.invoke(
    input={"messages": [HumanMessage(content="Arr! The Moon is high! What's me favorite drink?")]},
    context=user_context,
)

# Extract and print the last message from the agent's response
last_message = get_last_message(response)
if last_message:
    print("Agent Response:", last_message.content)