In [None]:
# Import LangGraph and LangChain components
from langchain.chat_models import init_chat_model
from langchain.tools import tool
from langchain.agents import create_agent

In [None]:
# Import the AgentCoreMemorySaver that we will use as a checkpointer
import os
import logging

from langgraph_checkpoint_aws import AgentCoreMemorySaver
from bedrock_agentcore.memory import MemoryClient

from dotenv import load_dotenv
load_dotenv()

region = os.getenv('AWS_REGION', 'us-west-2')
logging.getLogger("math-agent").setLevel(logging.DEBUG)

# Create or get the memory resource
memory_name = "MathLanggraphAgent"
client = MemoryClient(region_name=region)
memory = client.create_or_get_memory(name=memory_name)
memory_id = memory['id'] # Keep this memory ID for later use


In [None]:
MODEL_ID = "anthropic.claude-3-haiku-20240307-v1:0"

# Initialize checkpointer for state persistence
checkpointer = AgentCoreMemorySaver(memory_id, region_name=region)

# Initialize LLM
llm = init_chat_model(MODEL_ID, model_provider="bedrock_converse", region_name=region)

In [None]:
@tool
def add(a: int, b: int):
    """Add two integers and return the result"""
    return a + b


@tool
def multiply(a: int, b: int):
    """Multiply two integers and return the result"""
    return a * b


tools = [add, multiply]

In [None]:
graph = create_agent(
    model=llm,
    tools=tools,
    system_prompt="You are a helpful assistant",
    checkpointer=checkpointer,
)


graph

In [None]:
import uuid

thread_id = str(uuid.uuid4())
actor_id = "user"

config = {
    "configurable": {
        "thread_id": thread_id, # REQUIRED: This maps to Bedrock AgentCore session_id under the hood
        "actor_id": actor_id, # REQUIRED: This maps to Bedrock AgentCore actor_id under the hood
    }
}

inputs = {"messages": [{"role": "user", "content": "What is 1337 times 515321? Then add 412 and return the value to me."}]}

In [None]:
for chunk in graph.stream(inputs, stream_mode="updates", config=config):
    print(chunk)

In [None]:
for message in graph.get_state(config).values.get("messages"):
    print(f"{message.type}: {message.text}")
    print("=========================================")

In [None]:
for checkpoint in graph.get_state_history(config):
    print(
        f"(Checkpoint ID: {checkpoint.config['configurable']['checkpoint_id']}) # of messages in state: {len(checkpoint.values.get('messages'))}"
    )

In [None]:
inputs = {"messages": [{"role": "user", "content": "What were the first calculations I asked you to do?"}]}

for chunk in graph.stream(inputs, stream_mode="updates", config=config):
    print(chunk)

In [None]:
thread_id = str(uuid.uuid4())
actor_id = "user2"

config = {
    "configurable": {
        "thread_id": thread_id, # New session ID
        "actor_id": actor_id, # Same Actor ID
    }
}

inputs = {"messages": [{"role": "user", "content": "What values did I ask you to multiply and add?"}]}
for chunk in graph.stream(inputs, stream_mode="updates", config=config):
    print(chunk)