Skip to content

Commit

Permalink
Merge branch 'cpacker:main' into key_roller
Browse files Browse the repository at this point in the history
  • Loading branch information
palandovalex committed Mar 10, 2024
2 parents 06d32f6 + ff986ad commit 49cf259
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 8 deletions.
11 changes: 9 additions & 2 deletions memgpt/functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import os
import sys
from types import ModuleType


from memgpt.functions.schema_generator import generate_schema
Expand All @@ -12,7 +13,7 @@
sys.path.append(USER_FUNCTIONS_DIR)


def load_function_set(module):
def load_function_set(module: ModuleType) -> dict:
"""Load the functions and generate schema for them, given a module object"""
function_dict = {}

Expand All @@ -36,7 +37,7 @@ def load_function_set(module):
return function_dict


def load_all_function_sets(merge=True):
def load_all_function_sets(merge: bool = True) -> dict:
# functions/examples/*.py
scripts_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script
function_sets_dir = os.path.join(scripts_dir, "function_sets") # Path to the function_sets directory
Expand All @@ -59,6 +60,7 @@ def load_all_function_sets(merge=True):
schemas_and_functions = {}
for dir_path, module_files in [(function_sets_dir, example_module_files), (USER_FUNCTIONS_DIR, user_module_files)]:
for file in module_files:
tags = []
module_name = file[:-3] # Remove '.py' from filename
if dir_path == USER_FUNCTIONS_DIR:
# For user scripts, adjust the module name appropriately
Expand Down Expand Up @@ -86,6 +88,7 @@ def load_all_function_sets(merge=True):
else:
# For built-in scripts, use the existing method
full_module_name = f"memgpt.functions.function_sets.{module_name}"
tags.append(f"memgpt-{module_name}")
try:
module = importlib.import_module(full_module_name)
except Exception as e:
Expand All @@ -96,6 +99,10 @@ def load_all_function_sets(merge=True):
try:
# Load the function set
function_set = load_function_set(module)
# Add the metadata tags
for k, v in function_set.items():
# print(function_set)
v["tags"] = tags
schemas_and_functions[module_name] = function_set
except ValueError as e:
print(f"Error loading function set '{module_name}': {e}")
Expand Down
15 changes: 12 additions & 3 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
import datetime
from typing import Optional, List, Tuple
import uuid
from typing import Optional, List, Tuple, Union

from memgpt.constants import MESSAGE_SUMMARY_WARNING_FRAC
from memgpt.utils import get_local_time, printd, count_tokens, validate_date_format, extract_date_from_timestamp
Expand Down Expand Up @@ -388,7 +389,7 @@ def save(self):
"""Save the index to disk"""
self.storage.save()

def insert(self, memory_string):
def insert(self, memory_string, return_ids=False) -> Union[bool, List[uuid.UUID]]:
"""Embed and save memory string"""

if not isinstance(memory_string, str):
Expand All @@ -412,9 +413,17 @@ def insert(self, memory_string):
)
passages.append(self.create_passage(text, embedding))

# grab the return IDs before the list gets modified
ids = [str(p.id) for p in passages]

# insert passages
self.storage.insert_many(passages)
return True

if return_ids:
return ids
else:
return True

except Exception as e:
print("Archival insert error", e)
raise e
Expand Down
1 change: 1 addition & 0 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
ToolModel(
name=k,
json_schema=v["json_schema"],
tags=v["tags"],
source_type="python",
source_code=python_inspect.getsource(v["python_function"]),
)
Expand Down
6 changes: 5 additions & 1 deletion memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Optional, Dict, Literal
from pydantic import BaseModel, Field, Json
from pydantic import BaseModel, Field, Json, ConfigDict
import uuid
from datetime import datetime
from sqlmodel import Field, SQLModel
Expand All @@ -15,6 +15,9 @@ class LLMConfigModel(BaseModel):
model_wrapper: Optional[str] = None
context_window: Optional[int] = None

# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())


class EmbeddingConfigModel(BaseModel):
embedding_endpoint_type: Optional[str] = "openai"
Expand All @@ -40,6 +43,7 @@ class ToolModel(BaseModel):
# TODO move into database
name: str = Field(..., description="The name of the function.")
json_schema: dict = Field(..., description="The JSON schema of the function.")
tags: List[str] = Field(..., description="Metadata tags.")
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
source_code: Optional[str] = Field(..., description="The source code of the function.")

Expand Down
98 changes: 97 additions & 1 deletion memgpt/server/rest_api/agents/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import uuid
from functools import partial
from typing import List, Optional

from fastapi import APIRouter, Depends, Body
from fastapi import APIRouter, Depends, Body, HTTPException, status, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field

from memgpt.server.rest_api.auth_token import get_current_user
Expand Down Expand Up @@ -34,6 +36,30 @@ class UpdateAgentMemoryResponse(BaseModel):
new_core_memory: CoreMemory = Field(..., description="The updated state of the agent's core memory.")


class ArchivalMemoryObject(BaseModel):
# TODO move to models/pydantic_models, or inherent from data_types Record
id: str = Field(..., description="Unique identifier for the memory object inside the archival memory store.")
contents: str = Field(..., description="The memory contents.")


class GetAgentArchivalMemoryResponse(BaseModel):
archival_memory: List[ArchivalMemoryObject] = Field(..., description="A list of all memory objects in archival memory.")


class InsertAgentArchivalMemoryRequest(BaseModel):
content: str = Field(None, description="The memory contents to insert into archival memory.")


class InsertAgentArchivalMemoryResponse(BaseModel):
ids: List[str] = Field(
..., description="Unique identifier for the new archival memory object. May return multiple ids if insert contents are chunked."
)


class DeleteAgentArchivalMemoryRequest(BaseModel):
id: str = Field(..., description="Unique identifier for the new archival memory object.")


def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

Expand Down Expand Up @@ -70,4 +96,74 @@ def update_agent_memory(
response = server.update_agent_core_memory(user_id=user_id, agent_id=agent_id, new_memory_contents=new_memory_contents)
return UpdateAgentMemoryResponse(**response)

@router.get("/agents/{agent_id}/archival/all", tags=["agents"], response_model=GetAgentArchivalMemoryResponse)
def get_agent_archival_memory_all(
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
"""
interface.clear()
archival_memories = server.get_all_archival_memories(user_id=user_id, agent_id=agent_id)
print("archival_memories:", archival_memories)
return GetAgentArchivalMemoryResponse(archival_memory=archival_memories)

@router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=GetAgentArchivalMemoryResponse)
def get_agent_archival_memory(
agent_id: uuid.UUID,
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
interface.clear()
# TODO need to add support for non-postgres here
# chroma will throw:
# raise ValueError("Cannot run get_all_cursor with chroma")
_, archival_json_records = server.get_agent_archival_cursor(
user_id=user_id,
agent_id=agent_id,
after=after,
before=before,
limit=limit,
)
print(archival_json_records)
return GetAgentArchivalMemoryResponse(archival_json_records)

@router.post("/agents/{agent_id}/archival", tags=["agents"], response_model=InsertAgentArchivalMemoryResponse)
def insert_agent_archival_memory(
agent_id: uuid.UUID,
request: InsertAgentArchivalMemoryRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Insert a memory into an agent's archival memory store.
"""
interface.clear()
memory_ids = server.insert_archival_memory(user_id=user_id, agent_id=agent_id, memory_contents=request.content)
return InsertAgentArchivalMemoryResponse(ids=memory_ids)

@router.delete("/agents/{agent_id}/archival", tags=["agents"])
def delete_agent_archival_memory(
agent_id: uuid.UUID,
id: str = Query(..., description="Unique ID of the memory to be deleted."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete a memory from an agent's archival memory store.
"""
interface.clear()
try:
memory_id = uuid.UUID(id)
server.delete_archival_memory(user_id=user_id, agent_id=agent_id, memory_id=memory_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")

return router
2 changes: 1 addition & 1 deletion memgpt/server/rest_api/tools/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_tools(
async def list_all_tools(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Expand Down
44 changes: 44 additions & 0 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,50 @@ def get_agent_archival_cursor(
json_records = [vars(record) for record in records]
return cursor, json_records

def get_all_archival_memories(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> list:
# TODO deprecate (not safe to be returning an unbounded list)
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")

# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)

# Assume passages
records = memgpt_agent.persistence_manager.archival_memory.storage.get_all()
print("records:", records)

return [dict(id=str(r.id), contents=r.text) for r in records]

def insert_archival_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, memory_contents: str) -> uuid.UUID:
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")

# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)

# Insert into archival memory
# memory_id = uuid.uuid4()
passage_ids = memgpt_agent.persistence_manager.archival_memory.insert(memory_string=memory_contents, return_ids=True)

return [str(p_id) for p_id in passage_ids]

def delete_archival_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, memory_id: uuid.UUID):
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")

# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)

# Delete by ID
# TODO check if it exists first, and throw error if not
memgpt_agent.persistence_manager.archival_memory.storage.delete({"id": memory_id})

def get_agent_recall_cursor(
self,
user_id: uuid.UUID,
Expand Down

0 comments on commit 49cf259

Please sign in to comment.