From f02dc5250e06779d6a6bb7ed856cd630b02f9780 Mon Sep 17 00:00:00 2001 From: Craig Chi Date: Mon, 12 Feb 2024 14:50:55 -0800 Subject: [PATCH 1/5] feat: add integration test for MemorystoreChatMessageHistory (#13) --- integration.cloudbuild.yaml | 6 +++ tests/test_chat_message_history.py | 67 ++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 tests/test_chat_message_history.py diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index c78148e..3edce93 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -22,3 +22,9 @@ steps: name: python:3.11 entrypoint: python args: ["-m", "pytest"] + env: + - 'REDIS_URL=$_REDIS_URL' + +options: + pool: + name: 'projects/$PROJECT_ID/locations/$LOCATION/workerPools/redis' diff --git a/tests/test_chat_message_history.py b/tests/test_chat_message_history.py new file mode 100644 index 0000000..5094841 --- /dev/null +++ b/tests/test_chat_message_history.py @@ -0,0 +1,67 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import uuid +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_google_memorystore_redis import MemorystoreChatMessageHistory +import redis + + +def test_redis_multiple_sessions() -> None: + client = redis.from_url( + get_env_var("REDIS_URL", "URL of the Redis instance") + ) + + session_id1 = uuid.uuid4().hex + history1 = MemorystoreChatMessageHistory( + client=client, + session_id=session_id1, + ) + session_id2 = uuid.uuid4().hex + history2 = MemorystoreChatMessageHistory( + client=client, + session_id=session_id2, + ) + + history1.add_ai_message("Hey! I am AI!") + history1.add_user_message("Hey! I am human!") + history2.add_user_message("Hey! I am human in another session!") + messages1 = history1.messages + messages2 = history2.messages + + assert len(messages1) == 2 + assert len(messages2) == 1 + assert isinstance(messages1[0], AIMessage) + assert messages1[0].content == "Hey! I am AI!" + assert isinstance(messages1[1], HumanMessage) + assert messages1[1].content == "Hey! I am human!" + assert isinstance(messages2[0], HumanMessage) + assert messages2[0].content == "Hey! I am human in another session!" + + history1.clear() + assert len(history1.messages) == 0 + assert len(history2.messages) == 1 + + history2.clear() + assert len(history1.messages) == 0 + assert len(history2.messages) == 0 + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v From bee553b0111cf3b3ee3d38e697d14dbc07ea12da Mon Sep 17 00:00:00 2001 From: Craig Chi Date: Wed, 14 Feb 2024 11:19:27 -0800 Subject: [PATCH 2/5] fix: fix encoding of the Redis client --- .../chat_message_history.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/langchain_google_memorystore_redis/chat_message_history.py b/src/langchain_google_memorystore_redis/chat_message_history.py index 2662230..8f70f37 100644 --- a/src/langchain_google_memorystore_redis/chat_message_history.py +++ b/src/langchain_google_memorystore_redis/chat_message_history.py @@ -46,15 +46,18 @@ def __init__( self._redis = client self._key = session_id self._ttl = ttl + self._encoding = client.get_encoder().encoding @property def messages(self) -> List[BaseMessage]: """Retrieve all messages chronologically stored in this session.""" all_elements = self._redis.lrange(self._key, 0, -1) - messages = messages_from_dict( - [json.loads(e.decode("utf-8")) for e in all_elements] + + assert isinstance(all_elements, list) + loaded_messages = messages_from_dict( + [json.loads(e.decode(self._encoding)) for e in all_elements] ) - return messages + return loaded_messages def add_message(self, message: BaseMessage) -> None: """Append one message to this session.""" From ea4b695e4c610377523e86a5479492567e356908 Mon Sep 17 00:00:00 2001 From: Craig Chi Date: Wed, 14 Feb 2024 11:21:18 -0800 Subject: [PATCH 3/5] fix: use if statement instead of assert --- .../chat_message_history.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/langchain_google_memorystore_redis/chat_message_history.py b/src/langchain_google_memorystore_redis/chat_message_history.py index 8f70f37..0f0c252 100644 --- a/src/langchain_google_memorystore_redis/chat_message_history.py +++ b/src/langchain_google_memorystore_redis/chat_message_history.py @@ -53,10 +53,13 @@ def messages(self) -> List[BaseMessage]: """Retrieve all messages chronologically stored in this session.""" all_elements = self._redis.lrange(self._key, 0, -1) - assert isinstance(all_elements, list) - loaded_messages = messages_from_dict( - [json.loads(e.decode(self._encoding)) for e in all_elements] - ) + loaded_messages = [] + if isinstance(all_elements, list): + loaded_messages = messages_from_dict( + [json.loads(e.decode(self._encoding)) for e in all_elements] + ) + else: + raise RuntimeError("redis-py returns result of unexpected type") return loaded_messages def add_message(self, message: BaseMessage) -> None: From 9e908be35e27f4dc2d87292bcfaac4499d437244 Mon Sep 17 00:00:00 2001 From: Craig Chi Date: Wed, 14 Feb 2024 13:01:45 -0800 Subject: [PATCH 4/5] fix: fix worker pool name for integration --- integration.cloudbuild.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index 3edce93..62d8d22 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -27,4 +27,4 @@ steps: options: pool: - name: 'projects/$PROJECT_ID/locations/$LOCATION/workerPools/redis' + name: '$_WORKER_POOL' From c8b5f0f8ab1d0e894815fa3cef0a0eaacd422aa8 Mon Sep 17 00:00:00 2001 From: Craig Chi Date: Wed, 14 Feb 2024 13:15:14 -0800 Subject: [PATCH 5/5] fix: fix linter --- src/langchain_google_memorystore_redis/__init__.py | 4 +++- .../chat_message_history.py | 10 +++------- tests/test_chat_message_history.py | 8 ++++---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/langchain_google_memorystore_redis/__init__.py b/src/langchain_google_memorystore_redis/__init__.py index 14e510c..5e4d1eb 100644 --- a/src/langchain_google_memorystore_redis/__init__.py +++ b/src/langchain_google_memorystore_redis/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from langchain_google_memorystore_redis.chat_message_history import MemorystoreChatMessageHistory +from langchain_google_memorystore_redis.chat_message_history import ( + MemorystoreChatMessageHistory, +) __all__ = ["MemorystoreChatMessageHistory"] diff --git a/src/langchain_google_memorystore_redis/chat_message_history.py b/src/langchain_google_memorystore_redis/chat_message_history.py index 0f0c252..bc5949c 100644 --- a/src/langchain_google_memorystore_redis/chat_message_history.py +++ b/src/langchain_google_memorystore_redis/chat_message_history.py @@ -13,15 +13,11 @@ # limitations under the License. import json -import redis from typing import List, Optional +import redis from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import ( - BaseMessage, - message_to_dict, - messages_from_dict, -) +from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict class MemorystoreChatMessageHistory(BaseChatMessageHistory): @@ -49,7 +45,7 @@ def __init__( self._encoding = client.get_encoder().encoding @property - def messages(self) -> List[BaseMessage]: + def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all messages chronologically stored in this session.""" all_elements = self._redis.lrange(self._key, 0, -1) diff --git a/tests/test_chat_message_history.py b/tests/test_chat_message_history.py index 5094841..7a0448a 100644 --- a/tests/test_chat_message_history.py +++ b/tests/test_chat_message_history.py @@ -15,15 +15,15 @@ import os import uuid + +import redis from langchain_core.messages import AIMessage, BaseMessage, HumanMessage + from langchain_google_memorystore_redis import MemorystoreChatMessageHistory -import redis def test_redis_multiple_sessions() -> None: - client = redis.from_url( - get_env_var("REDIS_URL", "URL of the Redis instance") - ) + client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance")) session_id1 = uuid.uuid4().hex history1 = MemorystoreChatMessageHistory(