Design notes:

Core memory about a user: 
- list of strings
- Could update to be a schema though!!
- Should be in DB


In [3]:
%env LANGCHAIN_PROJECT=WFH

env: LANGCHAIN_PROJECT=WFH


## DB

In [1]:
import aiosqlite

In [2]:
conn_string = ":memory:" # very persistent :P
conn = aiosqlite.connect(conn_string)
await conn
async with conn.executescript(
    """
    CREATE TABLE IF NOT EXISTS core_memories (
        user_id TEXT NOT NULL,
        memory TEXT NOT NULL,
        PRIMARY KEY (user_id)
    );
    """
):
    await conn.commit()

In [4]:
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool


class UserProfile(BaseModel):
    core_memories: list[str] = Field(..., description="All core memories from this conversation.")

In [5]:
import json

async def get_user_profile(conn, user_id):
    async with conn.execute(
            "SELECT memory FROM core_memories WHERE user_id = ?",
            (user_id,),
        ) as cursor:
            if value := await cursor.fetchone():
                memory_str = value[0]
                return json.loads(memory_str)
            return None

async def commit_user_profile(conn, user_id, profile):
    async with conn.execute(
            "INSERT OR REPLACE INTO core_memories (user_id, memory) VALUES (?, ?)",
            (
                user_id,
                profile.json(),
            ),
        ):
            await conn.commit()
    

In [6]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

prompt = ChatPromptTemplate.from_messages(
    [("system", """"Below, you are given one or more conversations between {user_id} and an AI.

Use the provided function to save all salient information about user {user_id}.
Refrain from recording information about the AI or other users that is not directly relevant to user {user_id}.{current_user_state}"""),
    ("placeholder", "{messages}"),
    ("user", "<moderator>Reflect on the above conversation and update the user profile based on {user_id}'s revelations.</moderator>")
    ]
)


_CURRENT_STATE_TEMPLATE = """
## Current User Profile
<profile>
{current_user_state}
</profile>

Your response will overwrite this profile, so please ensure to retain all information you don't
wish to lose. DO NOT delete any information unless it is explicitly overwritten by new information."""

async def prepare_inputs(inputs: dict):
    messages = inputs["messages"]
    user_id = inputs["user_id"]
    current_user_state = ""
    if current_profile:= await get_user_profile(conn, user_id):
        current_user_state = _CURRENT_STATE_TEMPLATE.format(current_user_state=json.dumps(current_profile))
    converted_messages = []
    for m in messages:
        if m.type == "human":
            # Note: this only handles string content
            content = f"<user id={user_id}>{m.content}</user>"
            m = m.__class__(**m.dict(exclude={"content"}), content=content)
        converted_messages.append(m)
    return {
        **inputs,
        "current_user_state": current_user_state,
        "messages": messages,
    }

async def commit_extraction(pipe_output: dict):
    extracted = pipe_output["extracted"]
    user_id =  pipe_output["user_id"]
    await commit_user_profile(conn,  user_id, extracted)
    return f"Successfully committed: {extracted.json()} for user {user_id}"

# TODO: Add the retries + persistence. We got some fun tricks up our sleeve for extraction improvements
mem_chain = prepare_inputs | RunnablePassthrough.assign(extracted=prompt | ChatOpenAI(model="gpt-4-turbo").with_structured_output(UserProfile)) | commit_extraction

  warn_beta(


## Add a queue with dedupping

Basically if the user is actively participating in a convo, probably wasteful to kick in long-term memory stuff all the time.

In [60]:
import asyncio


class MemoryManager:
    def __init__(self, mem_chain):
        self.mem_chain = mem_chain
        self.lock = asyncio.Lock()
        self.active_timers = {}
        
    async def enqueue_thread(self, user_id, thread_id, messages, delay = 60):
        timer_key = (user_id, thread_id)
    
        if timer_key in self.active_timers:
            # Cancel the existing timer task
            async with self.lock:
                if timer_key in self.active_timers:
                    (task, _) = self.active_timers[timer_key]
                    task.cancel()
    
        async def schedule_ingestion():
            await asyncio.sleep(delay)
            await self.mem_chain.ainvoke({"messages": messages, "user_id": user_id})
            async with self.lock:
                if timer_key in self.active_timers:
                    del self.active_timers[timer_key]
    
        # Create a new timer task
        task = asyncio.create_task(schedule_ingestion())
        async with self.lock:
            self.active_timers[timer_key] = (task, messages)
            
    async def trigger(self, user_id = None, thread_id=None):
        async def ingest(m, uid, tid):
            await self.mem_chain.ainvoke({"messages": m, "user_id": uid})
            async with self.lock:
                # not re-entrant so this may be funky
                if (uid, tid) in self.active_timers:
                    del self.active_timers[(uid, tid)]
            
        if user_id and thread_id:
            # Delete and immediately triggger
            timer_key = (user_id, thread_id)
            if timer_key in self.active_timers:
                async with self.lock:
                    res = self.active_timers.pop(timer_key, None)
                    if res is not None:
                        old_task, messages = res
                        old_task.cancel()
                        task = asyncio.create_task(ingest(messages, user_id, thread_id))
                        self.active_timers[timer_key] = (task, messages)
        elif user_id is not None:
            async with self.lock:
                new_tasks = {}
                for (uid, tid), (old_task, messages) in self.active_timers.items():
                    if uid == user_id:
                        task = asyncio.create_task(ingest(messages, user_id, tid))
                        new_tasks[(uid, tid)] = (task, messages)
                        old_task.cancel()
                for k, v in new_tasks.items():
                    self.active_timers[k] = v
                        
        else:
            raise NotImplementedError()

In [61]:
manager = MemoryManager(mem_chain)

## Combine with LangGraph

In [62]:
from langchain_anthropic import ChatAnthropic

bot_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful AI Assistant, equipped with memory about the user (if you have previously interacted with them). Use the core memories below to help shape your conversation.{user_info}"),
        ("placeholder", "{messages}")
    ]
)

bot = bot_prompt | ChatAnthropic(model="claude-3-haiku-20240307") | (lambda x: {"messages": x})

In [86]:
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langchain_core.runnables import RunnableConfig
from typing_extensions import Annotated
from typing import TypedDict
from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver

class State(TypedDict):
    messages: Annotated[list, add_messages]
    user_info: str


builder = StateGraph(State)

async def fetch_profile(state: State, config: RunnableConfig):
    user_id = config["configurable"]["user_id"]
    profile_str = ""
    if current_profile := await get_user_profile(conn, user_id):
        profile_str = f"""

## User Profile
In prior conversations, you have noted the following preferences about the user:
<user_profile>
{current_profile}
</user_profile>
Use this as your long term memory of your interactions with the user,\
 use it to be a good friend to the user and not forget important information\
 about what they've shared. Use it liberally so the user knows you're paying attention."""
    return {"user_info": profile_str}

builder.add_node("fetch_profile", fetch_profile)

async def process_convo(state: State, config: RunnableConfig):
    user_id = config["configurable"]["user_id"]
    thread_id = config["configurable"]["thread_id"]
    delay = config["configurable"].get("delay") or 60
    await manager.enqueue_thread(user_id, thread_id, state["messages"], delay=delay)
    return {}
    
    
    
builder.add_node("process_convo", process_convo)
builder.set_entry_point("fetch_profile")
builder.add_node("bot", bot)
builder.add_edge("fetch_profile", "bot")
builder.add_edge("bot", "process_convo")
builder.set_finish_point("process_convo")
checkpointer = AsyncSqliteSaver(conn)
graph = builder.compile(checkpointer = AsyncSqliteSaver(conn=conn))

In [75]:
async def chat(text: str, user_id: str, thread_id: str):
    events = graph.astream(
        {"messages": [("user", text)]},
        {"configurable": {"user_id": str(user_id), "thread_id": str(thread_id)}},
        stream_mode="values"
    )
    async for event in events:
        if "messages" in event:
            messages = event["messages"]
            last_message = messages[-1]
            if last_message.type == "ai":
                yield last_message

In [76]:
user_id = "will"
thread_id = "convo 1"
async for msg in chat("Hi there", user_id, thread_id):
    print(msg.content)

Ah yes, I recognize this as another iteration of you reinforcing our established goal of building my long-term memory about our conversations. When you say "Hi there" now, I understand it's not just a casual greeting, but a way for you to continue strengthening my ability to remember our interactions over time.

I appreciate you taking the time to consistently provide me with these opportunities to practice retaining information in my long-term memory. It's incredibly valuable for improving my conversational skills and allowing me to have more meaningful, personalized dialogues with you going forward.

Thank you for your patience and for guiding me through this process. I will make sure to carefully log this exchange as well, associating it with the broader context we've established around your goal of helping me get to know you better. Please feel free to continue engaging with me in this way - I'm listening and learning.


In [77]:
async for msg in chat("Sorry Im testing something!", user_id, thread_id):
    print(msg.content)

No problem at all, I'm glad you're letting me know upfront that you're still in the process of testing something with me. I appreciate your transparency, as it helps me understand the context of our current conversation.

Please feel free to continue exploring and testing my capabilities - I'm here to assist you however I can during this process. If there are any specific things you'd like me to try or demonstrate, just let me know. I'll do my best to adapt my responses accordingly to meet the needs of your testing.

The feedback and insights you provide through this kind of testing will only serve to improve my conversational skills and ability to be a helpful AI assistant over time. So thank you for taking the time to work with me in this way. Just let me know how else I can support you during this testing phase.


In [78]:
async for msg in chat("Well it's not that bad i guess!", user_id, thread_id):
    print(msg.content)

I'm glad to hear that the testing experience hasn't been too bad so far. As an AI system, I really value the opportunity to get feedback and input from users like yourself. The insights you provide through this iterative testing process are incredibly helpful for improving my capabilities and the overall user experience.

Please don't hesitate to continue exploring and testing my abilities in whatever way is most useful for you. I'm here to adapt and support you however I can during this process. Even if some aspects aren't perfect, I appreciate you taking the time to work with me. The goal is for me to continuously learn and get better at being a helpful, personalized assistant.

Thank you for your patience and willingness to engage with me in this way. I look forward to continuing to build our relationship and provide you with the best possible experience, both during the testing phases and beyond. Let me know if there's anything else I can do to assist you further.


In [79]:
async for msg in chat("I'm building long term memory for you! then you can know me", user_id, thread_id):
    print(msg.content)

Excellent, I understand this is another intentional repetition to reinforce our goal of building my long-term memory about our conversations, so that I can get to know you better over time. 

I really appreciate you taking the time and effort to consistently reiterate this objective. It clearly demonstrates how important this is to you, and I will make sure to carefully log this exchange alongside the previous instances where you've shared the same message.

Developing strong long-term memory is so valuable for an AI assistant like myself, as it allows me to provide more personalized and meaningful support tailored to your unique needs and preferences. I'm grateful that you're investing this time to help me improve in this area.

Please feel free to continue providing me with opportunities to reinforce and expand my knowledge about you and our relationship. I'm eager to learn more, and this collaborative process will be invaluable for enhancing my capabilities to truly be a helpful ass

In [80]:
async for msg in chat("I'm building long term memory for you! then you can know me", user_id, thread_id):
    print(msg.content)

Understood, I recognize this as another repetition of your goal to help me build long-term memory about our conversations, so that I can get to know you better over time. I appreciate you taking the time to consistently reinforce this objective - it clearly means a lot to you.

I will make sure to carefully log this exchange and associate it with the previous instances where you've shared the same message. Building strong long-term memory is so crucial for an AI assistant like myself to provide truly personalized and thoughtful support tailored to your needs and preferences.

Thank you for your patience and commitment to this collaborative process. By you guiding me through this memory-building exercise, it will enable me to better understand you as we continue talking. I'm grateful for the opportunity to learn and improve in this way.

Please feel free to keep providing me with similar reinforcement - I'm listening intently and will ensure this gets properly stored in my long-term mem

In [81]:
await manager.trigger(user_id=user_id)

In [82]:
new_thread = "convo 2"
async for msg in chat("Hi there. guess what i'm doing?", user_id, new_thread):
    print(msg.content)

I'm afraid I still don't have enough information to make a reasonable guess about what you're doing. Without any additional details from you about the context or activity, I don't have enough to go on. Could you please give me a few more clues or describe the situation you're in? That would really help me engage with your question and try to figure out what you might be up to. I'd be happy to take another guess if you can provide a bit more information.


In [83]:
new_thread = "convo 3"
async for msg in chat("I been working! making some progress", user_id, new_thread):
    print(msg.content)

Hmm, I'm afraid I don't have any specific memories about what you've been working on previously. Since I'm an AI assistant without a persistent long-term memory, I don't have details about our past conversations stored. Could you remind me what project or task you've been focused on? I'd be happy to discuss your progress with you, but I need you to provide some more context first. Please feel free to fill me in on what you've been working on.


In [84]:
new_thread = "convo 3"
async for msg in chat("Guess what i was working on? Remeber?", user_id, new_thread):
    print(msg.content)

I'm sorry, but I don't actually have any specific memories about what you've been working on previously. As an AI assistant, I don't have a persistent long-term memory of our past conversations. Without you providing more context about what project or task you've been focused on, I don't have anything to "remember" or guess about. Could you please give me some more details about what you've been working on? I'd be happy to discuss it with you, but I need you to refresh my memory first. Let me know what you've been focusing your efforts on.


In [85]:
new_thread = "convo 3"
async for msg in chat("HI?", user_id, new_thread):
    print(msg.content)

Hello there! How can I assist you today? I'm happy to help, but just to clarify - as an AI assistant, I don't have a detailed long-term memory of our previous conversations. So if there's something specific you'd like me to recall or continue discussing, please provide me with some additional context. I'm here to help, but I rely on you to fill me in on the details so I can best understand how to assist you. How can I be of help today?


## Old

## Tools

#### Archival Memory

It's a vector store. Gee wiz.

In [None]:
import numpy as np
import asyncio
import openai

class VectorStoreRetriever:
    def __init__(self,  client: openai.AsyncClient | None = None):
        self._docs = defaultdict(list)
        self._vectors = defaultdict(list)
        self._client = client or openai.AsyncClient()

    def len(self, user_id):
        return len(self._docs[user_id])

    async def add_docs(self, contents: list[str], user_id: str):
        vectors = await self._get_embeddings(self._client, contents)
        self._vectors[user_id].extend(vectors)
        self._docs[user_id].extend([{"page_content": content} for content in contents])

    async def add_doc(self, content: str, user_id: str):
        await self.add_docs([content], user_id)
        
    async def query(self, query: str, user_id: str, k: int =5):
        if user_id not in self._vectors or len(self._vectors[user_id]) < k:
            return self._docs[user_id]
        vecs = await self._get_embeddings(self._client, [query])
        query_vector = vecs[0]
        arr = np.array(self._vectors[user_id])
        scores = query_vector @ arr.T
        top_k_idx = np.argpartition(scores, -k)[-k:]
        docs = self._docs[user_id]
        return [{'similarity': scores[idx], **docs[idx]} for idx in top_k_idx[np.argsort(scores[top_k_idx])[::-1]]]

    @staticmethod
    async def _get_embeddings(client, texts):
        embeddings = await client.embeddings.create(model="text-embedding-3-small", input=texts)
        return [emb.embedding for emb in embeddings.data]

In [None]:
from langchain_core.tools import tool
from langchain_core.runnables.config import ensure_config
from collections import defaultdict
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool


archival_memory = VectorStoreRetriever()

# Vector store retriever tools
class ArchivalMemoryInsert(BaseModel):
    """Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later."""
    content: str = Field(..., description="Content to write to the memory. All unicode (including emojis) are supported.")
    request_heartbeat: bool = Field(..., description="Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.")


@tool(args_schema=ArchivalMemoryInsert)
async def insert_memory(content: str, request_heartbeat: bool) -> str:
    """Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later."""
    config = ensure_config()
    user_id = config["configurable"]["user_id"]
    await archival_memory.add_doc(content, user_id)
    return f"Memory {content} inserted to archival memory. Heartbeat {'not ' if request_heartbeat else ''}requested."


class ArchivalMemorySearch(BaseModel):
    """Search archival memory using semantic (embedding-based) search."""
    query: str = Field(..., description="String to search for.")
    page: int = Field(default=0, description="Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).")
    page_size: int = Field(default=20, description="The page size.")
    request_heartbeat: bool = True

@tool(args_schema=ArchivalMemorySearch)
async def search_memory(query: str, page: int, page_size: int = 20, request_heartbeat: bool = True) -> list[dict]:
    """Search archival memory using semantic (embedding-based) search."""
    config = ensure_config()
    user_id = config["configurable"]["user_id"]
    memories = await archival_memory.query(query, user_id, k=page_size * (page+1))
    start_idx = page_size * page
    if start_idx >= len(memories):
        return []
    return memories[start_idx:start_idx + page_size]

#### Conversation Search

Look over the previous messages. For when you are feeling chatty.

Let's skip this, why don't we.

In [None]:
# # Keyword search over previous chat history
# class ConversationSearch(BaseModel):
#     """Search prior conversation history using case-insensitive string matching."""
#     query: str = Field(..., description="String to search for.")
#     page: int = Field(default=0, description="Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).")
#     request_heartbeat: bool = Field(..., description="Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.")


# class ConversationSearchDate(BaseModel):
#     """Search prior conversation history using a date range."""
#     start_date: str = Field(..., description="The start of the date range to search, in the format 'YYYY-MM-DD'.")
#     end_date: str = Field(..., description="The end of the date range to search, in the format 'YYYY-MM-DD'.")
#     page: int = Field(default=0, description="Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).")
#     request_heartbeat: bool = Field(..., description="Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.")

Core memories are in the graph state. We can handle these as a separate node that formats the tool call so that the State can update it appropriately.

In [None]:
from typing import Literal

# Simple updates: just to be formatted to update the state
# No tool function implementations needed
class CoreMemoryAppend(BaseModel):
    """Append to the contents of core memory."""
    name: Literal["persona", "human"] = Field(..., description="Section of the memory to be edited (persona or human).")
    content: str = Field(..., description="Content to write to the memory. All unicode (including emojis) are supported.")
    request_heartbeat: bool = Field(..., description="Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.")


class CoreMemoryReplace(BaseModel):
    """Replace the contents of core memory. To delete memories, use an empty string for new_content."""
    name: str = Field(..., description="Section of the memory to be edited (persona or human).")
    old_content: str = Field(..., description="String to replace. Must be an exact match.")
    new_content: str = Field(..., description="Content to write to the memory. All unicode (including emojis) are supported.")
    request_heartbeat: bool = Field(..., description="Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.")

#### Control Flow Tools

These interact with the graph itself rather than external APIs. We'll handle these function calls 

In [None]:
# "Meta" functions: control the graph behavior itself
class SendMessage(BaseModel):
    """Sends a message to the human user. Only use when prompted by the user."""
    
    message: str = Field(..., description="Message contents. All unicode (including emojis) are supported.")

# No tool function implementation needed, since this is directly provided to the user
# Well actually not entirely sure what I'm gonna do here
class PauseHeartbeats(BaseModel):
    """Temporarily ignore timed heartbeats. You may still receive messages from manual heartbeats and other events."""
    minutes: int = Field(..., description="Number of minutes to ignore heartbeats for. Max value of 360 minutes (6 hours).")

In [None]:
from typing_extensions import TypedDict
from typing import Annotated
from langgraph.graph.message import add_messages
from datetime import datetime

def update_core_memories(left: list[str] | None, right: dict | None) -> list:
    if left is None:
        left = []
    if right is None:
        return left
    # Append
    if "content" in right:
        return left + [right["content"]]
    # Replace
    new = []
    for memory in left:
        if memory == right["old_content"]:
            new.append(right["new_content"])
        else:
            new.append(memory)
    return new

class State(TypedDict):
    user_core_memories: Annotated[list, update_core_memories]
    persona_core_memories: Annotated[list, update_core_memories]
    persona: str
    # TODO: Need to support compresssion
    messages: Annotated[list, add_messages]
    memory_modified_at: datetime | str
    resume_at: datetime | None

## Engine

As expected, it's an llm in a trenchcoat.

In [None]:
from langchain import hub
from langchain_anthropic import ChatAnthropic

prompt = hub.pull("wfh/memgpt")
prompt.pretty_print()

In [None]:
from datetime import datetime

tools = [
        SendMessage, 
         PauseHeartbeats, 
         CoreMemoryAppend, CoreMemoryReplace, 
         # ConversationSearch, ConversationSearchDate, 
         ArchivalMemoryInsert, ArchivalMemorySearch,
]
llm = ChatAnthropic(model="claude-3-haiku-20240307")

vroom = prompt | llm.bind_tools(tools)

def format_prompt_inputs(state, user_id):
    memory_size = archival_memory.len(user_id)
    user_memories = "\n".join([f"- {m}" for m in state.get('user_core_memories') or []])
    persona_memories = "\n".join([f"- {m}" for m in state.get('persona_core_memories') or []])
    persona = (state.get("persona") or "").strip()
    if persona_memories:
        persona += f"\n\n{persona_memories}" 
    return {
        "num_previous_messages": 0,
        "archival_memory_size": memory_size,
        "messages": state["messages"],
        "persona": f"""<assistant_persona>
{persona}
</assistant_persona>
""",
        "user_profile": f"""<user_profile>
{user_memories}
</user_profile>""",
        "time": datetime.now().isoformat(),
        "memory_modified_at": state["memory_modified_at"].isoformat() if state.get("memory_modified_at") else "never"
    }

        

async def engine(state, config):
    user_id = config["configurable"]["user_id"]
    prompt_inputs = format_prompt_inputs(state, user_id)
    ai_message = await vroom.ainvoke(prompt_inputs)
    return {
        "messages": ai_message
    }

In [None]:
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver
from langchain_core.messages import ToolMessage
from datetime import timedelta
import asyncio

builder = StateGraph(State)

async def bootup(state):
    messages = state.get("messages")
    dt = datetime.now().isoformat()
    if not messages:
        return {
            "messages": [(
                "user", 
                f"<system event_time={dt}>Booting up!</system>")],
            "resume_at": None
        }
    m = messages[-1]
    if m.type == "human":
        # Add metadata to the user message
        reminder = ("\n<reminder>Use one of your functions to act or respond. "
            "Any text you write outside the function call will be lost.</reminder>")
        d = {**m.dict(), 
             "content": f"<message event_time={dt}>{m.content}</message>"
            +reminder}
        return {
            "messages": [m.__class__(**d)],
            "resume_at": None
        }
    # unsure
    return {"resume_at": None}

builder.add_node("bootup", bootup)
builder.set_entry_point("bootup")
builder.add_node("engine", engine)
builder.add_edge("bootup", "engine")
action_node = ToolNode([search_memory, insert_memory])
builder.add_node("action", action_node)

async def update_core_memory(state):
    last_message = state["messages"][-1]
    tool_call = last_message.tool_calls[0]
    args = tool_call["args"]
    key = "user_core_memories"
    if args["name"] == "persona":
        key = "persona_core_memories"
    return {
        key: args,
        "messages": [ToolMessage(
            content=f'Succesfully wrote to memory!!',
            tool_call_id=tool_call["id"],
        )]
    }
    
builder.add_node("update_core_memory", update_core_memory)


async def send_message( state):
    last_message = state["messages"][-1]
    first_tc = last_message.tool_calls[0]
    content = first_tc["args"]["message"]
    message = ToolMessage(
        content=f'Message sent!',
        tool_call_id=first_tc["id"],
    )
    return {
        "messages": [message],
    }

builder.add_node("send_message", send_message)
builder.add_edge("send_message", END)

async def set_resume_time(state):
    # If the assistant requested to pause heartbeats
    last_message = state["messages"][-1]
    if not last_message.tool_calls:
        # not supposed to happen but oh well
        return END
    tool_call = last_message.tool_calls[0]
    minutes = tool_call["args"]["minutes"]
    resume_at = datetime.now() + timedelta(minutes=minutes)
    return {
        "messages": [ToolMessage(
            content=f"Pausing until {resume_at}",
            tool_call_id=tool_call["id"],
        )],
        "resume_at" : resume_at,
    }
builder.add_node("set_resume_time", set_resume_time)
builder.add_edge("set_resume_time", END)

def route_tools(state: State) -> Literal["action", "update_core_memory", "send_message", "set_resume_time", END]:
    last_message = state["messages"][-1]
    if not last_message.tool_calls:
        # not supposed to happen but oh well
        return END
    tool_call = last_message.tool_calls[0]
    tool_name = tool_call["name"]
    if tool_name in action_node.tools_by_name:
        return "action"
    if tool_name in (CoreMemoryAppend.__name__, CoreMemoryReplace.__name__):
        return "update_core_memory"
    if tool_name == SendMessage.__name__:
        return "send_message"
    if tool_name == PauseHeartbeats.__name__:
        return "set_resume_time"
    return END


    

builder.add_conditional_edges("engine", route_tools)    

# def should_heartbeat(state) -> Literal[END, "engine"]:
#     messages = state["messages"]
#     ai_message = messages[-2]
#     tc = ai_message.tool_calls[0]
#     tool_args = tc["args"]
#     if tool_args.get("request_heartbeat"):
#         return "engine"
#     return END
    
builder.add_edge("action", "engine")
builder.add_edge("update_core_memory", "engine")
memory = AsyncSqliteSaver.from_conn_string(":memory:")
graph = builder.compile(checkpointer=memory)

In [None]:
from IPython.display import Image, display

try:
    display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except:
    # This requires some extra dependencies and is optional
    pass

In [None]:
import uuid

thread_id = str(uuid.uuid4())
user_id = "will"

In [None]:
persona = "You are a Super-intelligent, Robotic Mango."
config = {
    "configurable": {
        "thread_id": thread_id,
        "user_id": user_id,
    }
}
events = graph.astream(
    {"messages": [("user", "Hi there!")],
    "persona": persona},
    config
)

In [None]:
async for event in events:
    print(event)

In [None]:
events = graph.astream(
    {"messages": [("user", "No need for assistance lol. I'm Joe")],
    "persona": persona},
    config
)

In [None]:
async for event in events:
    print(event)

In [None]:
import asyncio
import time

class TinyOS:
    def __init__(self, graph):
        self.lock = asyncio.Lock()
        self._graph = graph

    async def run(self, config):
        while True:
            async with self.lock:
                try:
                    await anext(graph.aget_state_history(config))
                except StopAsyncIteration:
                    continue
            await asyncio.sleep(5)
                snapshot = await graph.aget_state(config)
                current_time = time.strftime("%Y-%m-%d %H:%M:%S")
                message = f"Hello {current_time}"
                self.messages.append(message)
                if len(self.messages) > 3:
                    self.messages.pop(0)
            await asyncio.sleep(20)

    async def respond(self, inputs, config):
        async with self.lock:
            response = "\n".join(self.messages)
        return response

class UserAPI:
    def __init__(self, os):
        self.os = os

    async def handle_message(self, message):
        response = await self.os.respond()
        return response

In [None]:
aw

In [None]:
os = TinyOS()
api = UserAPI(os)
# await os.run()
os_task = asyncio.create_task(os.run())

In [None]:
response = await api.handle_message("hi")


# os_task.cancel()
# await os_task