Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions samples/langchain_quick_start.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
99 changes: 99 additions & 0 deletions src/langchain_google_memorystore_redis/doc_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import Iterator, List, Optional, Sequence, Set, Union

import redis
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents.base import Document


class MemorystoreDocumentLoader(BaseLoader):
"""Document Loader for Cloud Memorystore for Redis database."""

def __init__(
self,
client: redis.Redis,
key_prefix: str,
content_fields: Set[str],
metadata_fields: Optional[Set[str]] = None,
):
"""Initializes the Document Loader for Memorystore for Redis.

Args:
client: A redis.Redis client object.
key_prefix: A prefix for the keys to store Documents in Redis.
content_fields: The set of fields of the hash that Redis uses to
store the page_content of the Document. If more than one field
are specified, a JSON encoded dict containing the fields as top
level keys will be filled in the page_content of the Documents.
metadata_fields: The metadata fields of the Document that will be
stored in the Redis. If None, Redis stores all metadata fields.
"""

self._redis = client
self._content_fields = content_fields
self._metadata_fields = metadata_fields
if metadata_fields and len(content_fields & metadata_fields):
raise ValueError(
"Fields {} are specified in both content_fields and"
" metadata_fields.".format(content_fields & metadata_fields)
)
self._key_prefix = key_prefix if key_prefix else ""
self._encoding = client.get_encoder().encoding

def lazy_load(self) -> Iterator[Document]:
"""Lazy load the Documents and yield them one by one."""
for key in self._redis.scan_iter(match=f"{self._key_prefix}*", _type="HASH"):
doc = {}
stored_value = self._redis.hgetall(key)
if not isinstance(stored_value, dict):
raise RuntimeError(f"{key} returns unexpected {stored_value}")
decoded_value = {
k.decode(self._encoding): v.decode(self._encoding)
for k, v in stored_value.items()
}

if len(self._content_fields) == 1:
doc["page_content"] = decoded_value[next(iter(self._content_fields))]
else:
doc["page_content"] = json.dumps(
{k: decoded_value[k] for k in self._content_fields}
)

filtered_fields = (
self._metadata_fields if self._metadata_fields else decoded_value.keys()
)
filtered_fields = filtered_fields - self._content_fields
doc["metadata"] = {
k: self._decode_if_json_parsable(decoded_value[k])
for k in filtered_fields
}

yield Document.construct(**doc)

def load(self) -> List[Document]:
"""Load all Documents at once."""
return list(self.lazy_load())

@staticmethod
def _decode_if_json_parsable(s: str) -> Union[str, dict]:
"""Decode a JSON string to a dict if it is JSON."""
try:
decoded = json.loads(s)
return decoded
except ValueError:
pass
return s
113 changes: 113 additions & 0 deletions src/langchain_google_memorystore_redis/doc_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import uuid
from typing import Optional, Sequence, Set, Union

import redis
from langchain_core.documents.base import Document


class MemorystoreDocumentSaver:
"""Document Saver for Cloud Memorystore for Redis database."""

def __init__(
self,
client: redis.Redis,
key_prefix: str,
content_field: str,
metadata_fields: Optional[Set[str]] = None,
):
"""Initializes the Document Saver for Memorystore for Redis.

Args:
client: A redis.Redis client object.
key_prefix: A prefix for the keys to store Documents in Redis.
content_field: The field of the hash that Redis uses to store the
page_content of the Document.
metadata_fields: The metadata fields of the Document that will be
stored in the Redis. If None, Redis stores all metadata fields.
"""

self._redis = client
if not key_prefix:
raise ValueError("key_prefix must not be empty")
self._key_prefix = key_prefix
self._content_field = content_field
self._metadata_fields = metadata_fields

def add_documents(
self,
documents: Sequence[Document],
ids: Optional[Sequence[str]] = None,
batch_size: int = 1000,
) -> None:
"""Save a list of Documents to Redis.

Args:
documents: A List of Documents.
ids: The list of suffixes for keys that Redis uses to store the
Documents. If specified, the length of the IDs must be the same
as Documents. If not specified, random UUIDs appended after
prefix are used to store each Document.
batch_size: The number of documents to process in a single batch
operation. This parameter helps manage memory and performance
when adding a large number of documents. Defaults to 1000.
"""
if ids and len(documents) != len(ids):
raise ValueError("The length of documents must match the length of the IDs")
if batch_size <= 0:
raise ValueError("batch_size must be greater than 0")

doc_ids = ids if ids else [str(uuid.uuid4()) for _ in documents]
doc_ids = [self._key_prefix + doc_id for doc_id in doc_ids]

pipeline = self._redis.pipeline(transaction=False)
for i, doc in enumerate(documents):
mapping = self._filter_metadata_by_fields(doc.metadata)
mapping.update({self._content_field: doc.page_content})

# Remove existing key in Redis to avoid reusing the doc ID.
pipeline.delete(doc_ids[i])
pipeline.hset(doc_ids[i], mapping=mapping)
if (i + 1) % batch_size == 0 or i == len(documents) - 1:
pipeline.execute()

def _filter_metadata_by_fields(self, metadata: Optional[dict]) -> dict:
"""Filter metadata fields to be stored in Redis.

Args:
metadata: The metadata field of a Document object.

Returns:
dict: A subset dict of the metadata that only contains the fields
specified in the initialization of the saver. The value of each
metadata key is serialized by JSON if it is a dict.
"""
if not metadata:
return {}
filtered_fields = (
self._metadata_fields & metadata.keys()
if self._metadata_fields
else metadata.keys()
)
filtered_metadata = {
k: self._jsonify_if_dict(metadata[k]) for k in filtered_fields
}
return filtered_metadata

@staticmethod
def _jsonify_if_dict(s: Union[str, dict]) -> str:
return s if isinstance(s, str) else json.dumps(s)
123 changes: 123 additions & 0 deletions tests/test_doc_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os

import pytest
import redis
from langchain_core.documents.base import Document

from langchain_google_memorystore_redis.doc_loader import MemorystoreDocumentLoader
from langchain_google_memorystore_redis.doc_saver import MemorystoreDocumentSaver


@pytest.mark.parametrize(
"page_content,metadata,content_field,metadata_fields",
[
(
'"content1"',
{"key1": "doc1_value1", "key2": "doc1_value2"},
"page_content",
None,
),
(
'"content2"',
{"key1": {'"nested_key"': {'"double_nested"': '"doc2_value1"'}}},
"special_page_content",
None,
),
(
'"content3"',
{"key1": {"k": "not_in_filter"}, "key2": {'"key"': "in_filter"}},
"page_content",
set(["key2"]),
),
],
)
def test_doc_loader_one_doc(page_content, metadata, content_field, metadata_fields):
client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance"))

prefix = "prefix:"
saver = MemorystoreDocumentSaver(
client=client,
key_prefix=prefix,
content_field=content_field,
metadata_fields=metadata_fields,
)
doc = Document.construct(page_content=page_content, metadata=metadata)
doc_id = "saved_doc"
saver.add_documents([doc], [doc_id])

loader = MemorystoreDocumentLoader(
client=client,
key_prefix=prefix,
content_fields=set([content_field]),
metadata_fields=metadata_fields,
)
loaded_docs = loader.load()
expected_doc = (
doc
if not metadata_fields
else Document.construct(
page_content=page_content,
metadata={k: metadata[k] for k in metadata_fields},
)
)
assert loaded_docs == [expected_doc]
client.delete(prefix + doc_id)


def test_doc_loader_multiple_docs():
client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance"))

prefix = "multidocs:"
# Clean up stored documents with the same prefix
for key in client.keys(f"{prefix}*"):
client.delete(key)

content_field = "page_content"
saver = MemorystoreDocumentSaver(
client=client,
key_prefix=prefix,
content_field=content_field,
)
docs = []
for content in range(10):
docs.append(
Document.construct(
page_content=f"{content}",
metadata={"metadata": f"meta: {content}"},
)
)

saver.add_documents(docs)

loader = MemorystoreDocumentLoader(
client=client,
key_prefix=prefix,
content_fields=set([content_field]),
)
loaded_docs = []
for doc in loader.lazy_load():
loaded_docs.append(doc)
assert sorted(loaded_docs, key=lambda d: d.page_content) == docs


def get_env_var(key: str, desc: str) -> str:
v = os.environ.get(key)
if v is None:
raise ValueError(f"Must set env var {key} to: {desc}")
return v
Loading