Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/langchain_google_memorystore_redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from langchain_google_memorystore_redis.chat_message_history import MemorystoreChatMessageHistory
from .chat_message_history import MemorystoreChatMessageHistory
from .vector_store import FLATConfig, HNSWConfig, RedisVectorStore

__all__ = ["MemorystoreChatMessageHistory"]
15 changes: 7 additions & 8 deletions src/langchain_google_memorystore_redis/chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@
# limitations under the License.

import json
import redis
from typing import List, Optional

import redis
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict


class MemorystoreChatMessageHistory(BaseChatMessageHistory):
Expand All @@ -46,13 +42,16 @@ def __init__(
self._redis = client
self._key = session_id
self._ttl = ttl
self._encoder = client.connection_pool.get_encoder()

@property
def messages(self) -> List[BaseMessage]:
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all messages chronologically stored in this session."""
all_elements = self._redis.lrange(self._key, 0, -1)

assert isinstance(all_elements, list)
messages = messages_from_dict(
[json.loads(e.decode("utf-8")) for e in all_elements]
[json.loads(self._encoder.decode(e)) for e in all_elements]
)
return messages

Expand Down
2 changes: 0 additions & 2 deletions src/langchain_google_memorystore_redis/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def __init__(


class RedisVectorStore(VectorStore):

DEFAULT_CONTENT_FIELD = "page_content"
DEFAULT_VECTOR_FIELD = "vector"
DEFAULT_DATA_TYPE = "float32"
Expand Down Expand Up @@ -503,7 +502,6 @@ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[boo
def _similarity_search_by_vector_with_score_and_embeddings(
self, query_embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float, List[float]]]:

distance_threshold = kwargs.get(
"distance_threshold", kwargs.get("score_threshold")
)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@

import os
import uuid

import redis
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage

from langchain_google_memorystore_redis import MemorystoreChatMessageHistory
import redis


def test_redis_multiple_sessions() -> None:
client = redis.from_url(
get_env_var("REDIS_URL", "URL of the Redis instance")
)
client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance"))

session_id1 = uuid.uuid4().hex
history1 = MemorystoreChatMessageHistory(
Expand Down