Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: various breaking bugs with local LLM implementation and postgres docker. #1355

Merged
merged 11 commits into from
May 12, 2024
Empty file modified db/run_postgres.sh
100644 → 100755
Empty file.
33 changes: 31 additions & 2 deletions memgpt/agent_store/db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import os
import uuid
from datetime import datetime
from typing import Dict, Iterator, List, Optional

import numpy as np
Expand Down Expand Up @@ -379,29 +380,33 @@ def list_data_sources(self):
unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
return unique_data_sources

def query_date(self, start_date, end_date, offset=0, limit=None):
def query_date(self, start_date, end_date, limit=None, offset=0):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small question - is there any reason why this line is a diff? (args got swapped?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes what I reference in 2.1 above. The order of arguments is different from that of the callers, unless I'm interpreting the intent of what is supposed to be returned. Maybe a previous version used the keywords and order didn't matter, but the current callers of these functions don't so order does matter.

filters = self.get_filters({})
with self.session_maker() as session:
query = (
session.query(self.db_model)
.filter(*filters)
.filter(self.db_model.created_at >= start_date)
.filter(self.db_model.created_at <= end_date)
.filter(self.db_model.role != "system")
.filter(self.db_model.role != "tool")
.offset(offset)
)
if limit:
query = query.limit(limit)
results = query.all()
return [result.to_record() for result in results]

def query_text(self, query, offset=0, limit=None):
def query_text(self, query, limit=None, offset=0):
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
filters = self.get_filters({})
with self.session_maker() as session:
query = (
session.query(self.db_model)
.filter(*filters)
.filter(func.lower(self.db_model.text).contains(func.lower(query)))
.filter(self.db_model.role != "system")
.filter(self.db_model.role != "tool")
.offset(offset)
)
if limit:
Expand Down Expand Up @@ -527,6 +532,30 @@ def update(self, record: RecordType):
# Commit the changes to the database
session.commit()

def str_to_datetime(self, str_date):
val = str_date.split("-")
_datetime = datetime(int(val[0]), int(val[1]), int(val[2]))
return _datetime

def query_date(self, start_date, end_date, limit=None, offset=0):
filters = self.get_filters({})
_start_date = self.str_to_datetime(start_date)
_end_date = self.str_to_datetime(end_date)
with self.session_maker() as session:
query = (
session.query(self.db_model)
.filter(*filters)
.filter(self.db_model.created_at >= _start_date)
.filter(self.db_model.created_at <= _end_date)
.filter(self.db_model.role != "system")
.filter(self.db_model.role != "tool")
.offset(offset)
)
if limit:
query = query.limit(limit)
results = query.all()
return [result.to_record() for result in results]


class SQLLiteStorageConnector(SQLStorageConnector):
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
Expand Down
1 change: 0 additions & 1 deletion memgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def load(cls) -> "MemGPTConfig":
"config_path": config_path,
"memgpt_version": get_field(config, "version", "memgpt_version"),
}

# Don't include null values
config_dict = {k: v for k, v in config_dict.items() if v is not None}

Expand Down
5 changes: 5 additions & 0 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ def dict_to_message(
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
)

def to_openai_dict_search_results(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
result_json = self.to_openai_dict()
search_result_json = {"timestamp": self.created_at, "message": {"content": result_json["content"], "role": result_json["role"]}}
return search_result_json

def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
"""Go from Message class to ChatCompletion message object"""

Expand Down
4 changes: 2 additions & 2 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,12 @@ def get_all(self, start=0, count=None):

def text_search(self, query_string, count=None, start=None):
results = self.storage.query_text(query_string, count, start)
results_json = [message.to_openai_dict() for message in results]
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)

def date_search(self, start_date, end_date, count=None, start=None):
results = self.storage.query_date(start_date, end_date, count, start)
results_json = [message.to_openai_dict() for message in results]
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(
# Update storage URI to match passed in settings
# TODO: very hack, fix in the future
for memory_type in ("archival", "recall", "metadata"):
if settings.memgpt_pg_uri:
if settings.memgpt_pg_uri_no_default:
# override with env
setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri)
self.config.save()
Expand Down
21 changes: 16 additions & 5 deletions memgpt/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,27 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="memgpt_")

server_pass: Optional[str] = None
pg_db: Optional[str] = "memgpt"
pg_user: Optional[str] = "memgpt"
pg_password: Optional[str] = "memgpt"
pg_host: Optional[str] = "localhost"
pg_port: Optional[int] = 5432
pg_db: Optional[str] = None
pg_user: Optional[str] = None
pg_password: Optional[str] = None
pg_host: Optional[str] = None
pg_port: Optional[int] = None
pg_uri: Optional[str] = None # option to specifiy full uri
cors_origins: Optional[list] = ["http://memgpt.localhost", "http://localhost:8283", "http://localhost:8083"]

@property
def memgpt_pg_uri(self) -> str:
if self.pg_uri:
return self.pg_uri
elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port:
return f"postgresql+pg8000://{self.pg_user}:{self.pg_password}@{self.pg_host}:{self.pg_port}/{self.pg_db}"
else:
return f"postgresql+pg8000://memgpt:memgpt@localhost:5432/memgpt"

# add this property to avoid being returned the default
# reference: https://github.com/cpacker/MemGPT/issues/1362
@property
def memgpt_pt_uri_no_default(self) -> str:
if self.pg_uri:
return self.pg_uri
elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port:
Expand Down
Loading