From c2998bd2fb29f4105dfd9c48d98633e349c90d04 Mon Sep 17 00:00:00 2001 From: Craig Chi Date: Tue, 13 Feb 2024 10:36:03 -0800 Subject: [PATCH 1/2] fix: fix linter and encoding of MemorystoreChatMessageHistory class. --- .../__init__.py | 5 ++++- .../chat_message_history.py | 15 +++++++-------- tests/test_chat_message_history.py | 8 ++++---- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/langchain_google_memorystore_redis/__init__.py b/src/langchain_google_memorystore_redis/__init__.py index da471a5..b9d1db6 100644 --- a/src/langchain_google_memorystore_redis/__init__.py +++ b/src/langchain_google_memorystore_redis/__init__.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from langchain_google_memorystore_redis.chat_message_history import MemorystoreChatMessageHistory +from langchain_google_memorystore_redis.chat_message_history import ( + MemorystoreChatMessageHistory, +) + from .vector_store import FLATConfig, HNSWConfig, RedisVectorStore __all__ = ["MemorystoreChatMessageHistory"] diff --git a/src/langchain_google_memorystore_redis/chat_message_history.py b/src/langchain_google_memorystore_redis/chat_message_history.py index 2662230..b2c80cf 100644 --- a/src/langchain_google_memorystore_redis/chat_message_history.py +++ b/src/langchain_google_memorystore_redis/chat_message_history.py @@ -13,15 +13,11 @@ # limitations under the License. import json -import redis from typing import List, Optional +import redis from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import ( - BaseMessage, - message_to_dict, - messages_from_dict, -) +from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict class MemorystoreChatMessageHistory(BaseChatMessageHistory): @@ -46,13 +42,16 @@ def __init__( self._redis = client self._key = session_id self._ttl = ttl + self._encoder = client.connection_pool.get_encoder() @property - def messages(self) -> List[BaseMessage]: + def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all messages chronologically stored in this session.""" all_elements = self._redis.lrange(self._key, 0, -1) + + assert isinstance(all_elements, list) messages = messages_from_dict( - [json.loads(e.decode("utf-8")) for e in all_elements] + [json.loads(self._encoder.decode(e)) for e in all_elements] ) return messages diff --git a/tests/test_chat_message_history.py b/tests/test_chat_message_history.py index 5094841..7a0448a 100644 --- a/tests/test_chat_message_history.py +++ b/tests/test_chat_message_history.py @@ -15,15 +15,15 @@ import os import uuid + +import redis from langchain_core.messages import AIMessage, BaseMessage, HumanMessage + from langchain_google_memorystore_redis import MemorystoreChatMessageHistory -import redis def test_redis_multiple_sessions() -> None: - client = redis.from_url( - get_env_var("REDIS_URL", "URL of the Redis instance") - ) + client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance")) session_id1 = uuid.uuid4().hex history1 = MemorystoreChatMessageHistory( From 077c65d7a6e05192e6f42db3ada2af2528d5345e Mon Sep 17 00:00:00 2001 From: Craig Chi Date: Tue, 13 Feb 2024 11:03:12 -0800 Subject: [PATCH 2/2] fix: use relative path for __init__ imports --- src/langchain_google_memorystore_redis/__init__.py | 5 +---- src/langchain_google_memorystore_redis/vector_store.py | 2 -- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/langchain_google_memorystore_redis/__init__.py b/src/langchain_google_memorystore_redis/__init__.py index b9d1db6..830a5f1 100644 --- a/src/langchain_google_memorystore_redis/__init__.py +++ b/src/langchain_google_memorystore_redis/__init__.py @@ -11,10 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from langchain_google_memorystore_redis.chat_message_history import ( - MemorystoreChatMessageHistory, -) - +from .chat_message_history import MemorystoreChatMessageHistory from .vector_store import FLATConfig, HNSWConfig, RedisVectorStore __all__ = ["MemorystoreChatMessageHistory"] diff --git a/src/langchain_google_memorystore_redis/vector_store.py b/src/langchain_google_memorystore_redis/vector_store.py index 7e5d35a..468b08e 100644 --- a/src/langchain_google_memorystore_redis/vector_store.py +++ b/src/langchain_google_memorystore_redis/vector_store.py @@ -227,7 +227,6 @@ def __init__( class RedisVectorStore(VectorStore): - DEFAULT_CONTENT_FIELD = "page_content" DEFAULT_VECTOR_FIELD = "vector" DEFAULT_DATA_TYPE = "float32" @@ -503,7 +502,6 @@ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[boo def _similarity_search_by_vector_with_score_and_embeddings( self, query_embedding: List[float], k: int = 4, **kwargs: Any ) -> List[Tuple[Document, float, List[float]]]: - distance_threshold = kwargs.get( "distance_threshold", kwargs.get("score_threshold") )