diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index c78148e..62d8d22 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -22,3 +22,9 @@ steps: name: python:3.11 entrypoint: python args: ["-m", "pytest"] + env: + - 'REDIS_URL=$_REDIS_URL' + +options: + pool: + name: '$_WORKER_POOL' diff --git a/src/langchain_google_memorystore_redis/__init__.py b/src/langchain_google_memorystore_redis/__init__.py index 14e510c..5e4d1eb 100644 --- a/src/langchain_google_memorystore_redis/__init__.py +++ b/src/langchain_google_memorystore_redis/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from langchain_google_memorystore_redis.chat_message_history import MemorystoreChatMessageHistory +from langchain_google_memorystore_redis.chat_message_history import ( + MemorystoreChatMessageHistory, +) __all__ = ["MemorystoreChatMessageHistory"] diff --git a/src/langchain_google_memorystore_redis/chat_message_history.py b/src/langchain_google_memorystore_redis/chat_message_history.py index 2662230..bc5949c 100644 --- a/src/langchain_google_memorystore_redis/chat_message_history.py +++ b/src/langchain_google_memorystore_redis/chat_message_history.py @@ -13,15 +13,11 @@ # limitations under the License. import json -import redis from typing import List, Optional +import redis from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import ( - BaseMessage, - message_to_dict, - messages_from_dict, -) +from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict class MemorystoreChatMessageHistory(BaseChatMessageHistory): @@ -46,15 +42,21 @@ def __init__( self._redis = client self._key = session_id self._ttl = ttl + self._encoding = client.get_encoder().encoding @property - def messages(self) -> List[BaseMessage]: + def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all messages chronologically stored in this session.""" all_elements = self._redis.lrange(self._key, 0, -1) - messages = messages_from_dict( - [json.loads(e.decode("utf-8")) for e in all_elements] - ) - return messages + + loaded_messages = [] + if isinstance(all_elements, list): + loaded_messages = messages_from_dict( + [json.loads(e.decode(self._encoding)) for e in all_elements] + ) + else: + raise RuntimeError("redis-py returns result of unexpected type") + return loaded_messages def add_message(self, message: BaseMessage) -> None: """Append one message to this session.""" diff --git a/tests/test_chat_message_history.py b/tests/test_chat_message_history.py new file mode 100644 index 0000000..7a0448a --- /dev/null +++ b/tests/test_chat_message_history.py @@ -0,0 +1,67 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import uuid + +import redis +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage + +from langchain_google_memorystore_redis import MemorystoreChatMessageHistory + + +def test_redis_multiple_sessions() -> None: + client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance")) + + session_id1 = uuid.uuid4().hex + history1 = MemorystoreChatMessageHistory( + client=client, + session_id=session_id1, + ) + session_id2 = uuid.uuid4().hex + history2 = MemorystoreChatMessageHistory( + client=client, + session_id=session_id2, + ) + + history1.add_ai_message("Hey! I am AI!") + history1.add_user_message("Hey! I am human!") + history2.add_user_message("Hey! I am human in another session!") + messages1 = history1.messages + messages2 = history2.messages + + assert len(messages1) == 2 + assert len(messages2) == 1 + assert isinstance(messages1[0], AIMessage) + assert messages1[0].content == "Hey! I am AI!" + assert isinstance(messages1[1], HumanMessage) + assert messages1[1].content == "Hey! I am human!" + assert isinstance(messages2[0], HumanMessage) + assert messages2[0].content == "Hey! I am human in another session!" + + history1.clear() + assert len(history1.messages) == 0 + assert len(history2.messages) == 1 + + history2.clear() + assert len(history1.messages) == 0 + assert len(history2.messages) == 0 + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v