Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: various breaking bugs with local LLM implementation and postgres docker. #1355

Merged
merged 11 commits into from
May 12, 2024
Empty file modified db/run_postgres.sh
100644 → 100755
Empty file.
33 changes: 31 additions & 2 deletions memgpt/agent_store/db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import os
import uuid
from datetime import datetime
from typing import Dict, Iterator, List, Optional

import numpy as np
Expand Down Expand Up @@ -379,29 +380,33 @@ def list_data_sources(self):
unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
return unique_data_sources

def query_date(self, start_date, end_date, offset=0, limit=None):
def query_date(self, start_date, end_date, limit=None, offset=0):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small question - is there any reason why this line is a diff? (args got swapped?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes what I reference in 2.1 above. The order of arguments is different from that of the callers, unless I'm interpreting the intent of what is supposed to be returned. Maybe a previous version used the keywords and order didn't matter, but the current callers of these functions don't so order does matter.

filters = self.get_filters({})
with self.session_maker() as session:
query = (
session.query(self.db_model)
.filter(*filters)
.filter(self.db_model.created_at >= start_date)
.filter(self.db_model.created_at <= end_date)
.filter(self.db_model.role != "system")
.filter(self.db_model.role != "tool")
.offset(offset)
)
if limit:
query = query.limit(limit)
results = query.all()
return [result.to_record() for result in results]

def query_text(self, query, offset=0, limit=None):
def query_text(self, query, limit=None, offset=0):
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
filters = self.get_filters({})
with self.session_maker() as session:
query = (
session.query(self.db_model)
.filter(*filters)
.filter(func.lower(self.db_model.text).contains(func.lower(query)))
.filter(self.db_model.role != "system")
.filter(self.db_model.role != "tool")
.offset(offset)
)
if limit:
Expand Down Expand Up @@ -527,6 +532,30 @@ def update(self, record: RecordType):
# Commit the changes to the database
session.commit()

def str_to_datetime(self, str_date):
val = str_date.split("-")
_datetime = datetime(int(val[0]), int(val[1]), int(val[2]))
return _datetime

def query_date(self, start_date, end_date, limit=None, offset=0):
filters = self.get_filters({})
_start_date = self.str_to_datetime(start_date)
_end_date = self.str_to_datetime(end_date)
with self.session_maker() as session:
query = (
session.query(self.db_model)
.filter(*filters)
.filter(self.db_model.created_at >= _start_date)
.filter(self.db_model.created_at <= _end_date)
.filter(self.db_model.role != "system")
.filter(self.db_model.role != "tool")
.offset(offset)
)
if limit:
query = query.limit(limit)
results = query.all()
return [result.to_record() for result in results]


class SQLLiteStorageConnector(SQLStorageConnector):
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
Expand Down
5 changes: 5 additions & 0 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ def dict_to_message(
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
)

def to_openai_dict_search_results(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
result_json = self.to_openai_dict()
search_result_json = {"timestamp": self.created_at, "message": {"content": result_json["content"], "role": result_json["role"]}}
return search_result_json

def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
"""Go from Message class to ChatCompletion message object"""

Expand Down
4 changes: 2 additions & 2 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,12 @@ def get_all(self, start=0, count=None):

def text_search(self, query_string, count=None, start=None):
results = self.storage.query_text(query_string, count, start)
results_json = [message.to_openai_dict() for message in results]
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)

def date_search(self, start_date, end_date, count=None, start=None):
results = self.storage.query_date(start_date, end_date, count, start)
results_json = [message.to_openai_dict() for message in results]
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)

def __repr__(self) -> str:
Expand Down