diff --git a/pyproject.toml b/pyproject.toml index cd4e1ac..2017123 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ license = {file = "LICENSE"} requires-python = ">=3.8" dependencies = [ "langchain==0.1.1", + "redis>=5.0.0", ] [project.urls] @@ -37,4 +38,4 @@ profile = "black" [tool.mypy] python_version = 3.8 -warn_unused_configs = true \ No newline at end of file +warn_unused_configs = true diff --git a/src/langchain_google_memorystore_redis/__init__.py b/src/langchain_google_memorystore_redis/__init__.py index 6d5e14b..14e510c 100644 --- a/src/langchain_google_memorystore_redis/__init__.py +++ b/src/langchain_google_memorystore_redis/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from langchain_google_memorystore_redis.chat_message_history import MemorystoreChatMessageHistory + +__all__ = ["MemorystoreChatMessageHistory"] diff --git a/src/langchain_google_memorystore_redis/chat_message_history.py b/src/langchain_google_memorystore_redis/chat_message_history.py new file mode 100644 index 0000000..2662230 --- /dev/null +++ b/src/langchain_google_memorystore_redis/chat_message_history.py @@ -0,0 +1,67 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import redis +from typing import List, Optional + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import ( + BaseMessage, + message_to_dict, + messages_from_dict, +) + + +class MemorystoreChatMessageHistory(BaseChatMessageHistory): + """Chat message history stored in a Cloud Memorystore for Redis database.""" + + def __init__( + self, + client: redis.Redis, + session_id: str, + ttl: Optional[int] = None, + ): + """Initializes the chat message history for Memorystore for Redis. + + Args: + client: A redis.Redis client object. + session_id: A string that uniquely identifies the chat history. + ttl: Specifies the time in seconds after which the session will + expire and be eliminated from the Redis instance since the most + recent message is added. + """ + + self._redis = client + self._key = session_id + self._ttl = ttl + + @property + def messages(self) -> List[BaseMessage]: + """Retrieve all messages chronologically stored in this session.""" + all_elements = self._redis.lrange(self._key, 0, -1) + messages = messages_from_dict( + [json.loads(e.decode("utf-8")) for e in all_elements] + ) + return messages + + def add_message(self, message: BaseMessage) -> None: + """Append one message to this session.""" + self._redis.rpush(self._key, json.dumps(message_to_dict(message))) + if self._ttl: + self._redis.expire(self._key, self._ttl) + + def clear(self) -> None: + """Clear all messages in this session.""" + self._redis.delete(self._key)