diff --git a/samples/langchain_quick_start.ipynb b/samples/langchain_quick_start.ipynb new file mode 100644 index 0000000..b0e2c99 --- /dev/null +++ b/samples/langchain_quick_start.ipynb @@ -0,0 +1,33 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2024 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/langchain_google_memorystore_redis/doc_loader.py b/src/langchain_google_memorystore_redis/doc_loader.py new file mode 100644 index 0000000..6097978 --- /dev/null +++ b/src/langchain_google_memorystore_redis/doc_loader.py @@ -0,0 +1,99 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Iterator, List, Optional, Sequence, Set, Union + +import redis +from langchain_community.document_loaders.base import BaseLoader +from langchain_core.documents.base import Document + + +class MemorystoreDocumentLoader(BaseLoader): + """Document Loader for Cloud Memorystore for Redis database.""" + + def __init__( + self, + client: redis.Redis, + key_prefix: str, + content_fields: Set[str], + metadata_fields: Optional[Set[str]] = None, + ): + """Initializes the Document Loader for Memorystore for Redis. + + Args: + client: A redis.Redis client object. + key_prefix: A prefix for the keys to store Documents in Redis. + content_fields: The set of fields of the hash that Redis uses to + store the page_content of the Document. If more than one field + are specified, a JSON encoded dict containing the fields as top + level keys will be filled in the page_content of the Documents. + metadata_fields: The metadata fields of the Document that will be + stored in the Redis. If None, Redis stores all metadata fields. + """ + + self._redis = client + self._content_fields = content_fields + self._metadata_fields = metadata_fields + if metadata_fields and len(content_fields & metadata_fields): + raise ValueError( + "Fields {} are specified in both content_fields and" + " metadata_fields.".format(content_fields & metadata_fields) + ) + self._key_prefix = key_prefix if key_prefix else "" + self._encoding = client.get_encoder().encoding + + def lazy_load(self) -> Iterator[Document]: + """Lazy load the Documents and yield them one by one.""" + for key in self._redis.scan_iter(match=f"{self._key_prefix}*", _type="HASH"): + doc = {} + stored_value = self._redis.hgetall(key) + if not isinstance(stored_value, dict): + raise RuntimeError(f"{key} returns unexpected {stored_value}") + decoded_value = { + k.decode(self._encoding): v.decode(self._encoding) + for k, v in stored_value.items() + } + + if len(self._content_fields) == 1: + doc["page_content"] = decoded_value[next(iter(self._content_fields))] + else: + doc["page_content"] = json.dumps( + {k: decoded_value[k] for k in self._content_fields} + ) + + filtered_fields = ( + self._metadata_fields if self._metadata_fields else decoded_value.keys() + ) + filtered_fields = filtered_fields - self._content_fields + doc["metadata"] = { + k: self._decode_if_json_parsable(decoded_value[k]) + for k in filtered_fields + } + + yield Document.construct(**doc) + + def load(self) -> List[Document]: + """Load all Documents at once.""" + return list(self.lazy_load()) + + @staticmethod + def _decode_if_json_parsable(s: str) -> Union[str, dict]: + """Decode a JSON string to a dict if it is JSON.""" + try: + decoded = json.loads(s) + return decoded + except ValueError: + pass + return s diff --git a/src/langchain_google_memorystore_redis/doc_saver.py b/src/langchain_google_memorystore_redis/doc_saver.py new file mode 100644 index 0000000..9b8c31c --- /dev/null +++ b/src/langchain_google_memorystore_redis/doc_saver.py @@ -0,0 +1,113 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import uuid +from typing import Optional, Sequence, Set, Union + +import redis +from langchain_core.documents.base import Document + + +class MemorystoreDocumentSaver: + """Document Saver for Cloud Memorystore for Redis database.""" + + def __init__( + self, + client: redis.Redis, + key_prefix: str, + content_field: str, + metadata_fields: Optional[Set[str]] = None, + ): + """Initializes the Document Saver for Memorystore for Redis. + + Args: + client: A redis.Redis client object. + key_prefix: A prefix for the keys to store Documents in Redis. + content_field: The field of the hash that Redis uses to store the + page_content of the Document. + metadata_fields: The metadata fields of the Document that will be + stored in the Redis. If None, Redis stores all metadata fields. + """ + + self._redis = client + if not key_prefix: + raise ValueError("key_prefix must not be empty") + self._key_prefix = key_prefix + self._content_field = content_field + self._metadata_fields = metadata_fields + + def add_documents( + self, + documents: Sequence[Document], + ids: Optional[Sequence[str]] = None, + batch_size: int = 1000, + ) -> None: + """Save a list of Documents to Redis. + + Args: + documents: A List of Documents. + ids: The list of suffixes for keys that Redis uses to store the + Documents. If specified, the length of the IDs must be the same + as Documents. If not specified, random UUIDs appended after + prefix are used to store each Document. + batch_size: The number of documents to process in a single batch + operation. This parameter helps manage memory and performance + when adding a large number of documents. Defaults to 1000. + """ + if ids and len(documents) != len(ids): + raise ValueError("The length of documents must match the length of the IDs") + if batch_size <= 0: + raise ValueError("batch_size must be greater than 0") + + doc_ids = ids if ids else [str(uuid.uuid4()) for _ in documents] + doc_ids = [self._key_prefix + doc_id for doc_id in doc_ids] + + pipeline = self._redis.pipeline(transaction=False) + for i, doc in enumerate(documents): + mapping = self._filter_metadata_by_fields(doc.metadata) + mapping.update({self._content_field: doc.page_content}) + + # Remove existing key in Redis to avoid reusing the doc ID. + pipeline.delete(doc_ids[i]) + pipeline.hset(doc_ids[i], mapping=mapping) + if (i + 1) % batch_size == 0 or i == len(documents) - 1: + pipeline.execute() + + def _filter_metadata_by_fields(self, metadata: Optional[dict]) -> dict: + """Filter metadata fields to be stored in Redis. + + Args: + metadata: The metadata field of a Document object. + + Returns: + dict: A subset dict of the metadata that only contains the fields + specified in the initialization of the saver. The value of each + metadata key is serialized by JSON if it is a dict. + """ + if not metadata: + return {} + filtered_fields = ( + self._metadata_fields & metadata.keys() + if self._metadata_fields + else metadata.keys() + ) + filtered_metadata = { + k: self._jsonify_if_dict(metadata[k]) for k in filtered_fields + } + return filtered_metadata + + @staticmethod + def _jsonify_if_dict(s: Union[str, dict]) -> str: + return s if isinstance(s, str) else json.dumps(s) diff --git a/tests/test_doc_loader.py b/tests/test_doc_loader.py new file mode 100644 index 0000000..b0431a0 --- /dev/null +++ b/tests/test_doc_loader.py @@ -0,0 +1,123 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os + +import pytest +import redis +from langchain_core.documents.base import Document + +from langchain_google_memorystore_redis.doc_loader import MemorystoreDocumentLoader +from langchain_google_memorystore_redis.doc_saver import MemorystoreDocumentSaver + + +@pytest.mark.parametrize( + "page_content,metadata,content_field,metadata_fields", + [ + ( + '"content1"', + {"key1": "doc1_value1", "key2": "doc1_value2"}, + "page_content", + None, + ), + ( + '"content2"', + {"key1": {'"nested_key"': {'"double_nested"': '"doc2_value1"'}}}, + "special_page_content", + None, + ), + ( + '"content3"', + {"key1": {"k": "not_in_filter"}, "key2": {'"key"': "in_filter"}}, + "page_content", + set(["key2"]), + ), + ], +) +def test_doc_loader_one_doc(page_content, metadata, content_field, metadata_fields): + client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance")) + + prefix = "prefix:" + saver = MemorystoreDocumentSaver( + client=client, + key_prefix=prefix, + content_field=content_field, + metadata_fields=metadata_fields, + ) + doc = Document.construct(page_content=page_content, metadata=metadata) + doc_id = "saved_doc" + saver.add_documents([doc], [doc_id]) + + loader = MemorystoreDocumentLoader( + client=client, + key_prefix=prefix, + content_fields=set([content_field]), + metadata_fields=metadata_fields, + ) + loaded_docs = loader.load() + expected_doc = ( + doc + if not metadata_fields + else Document.construct( + page_content=page_content, + metadata={k: metadata[k] for k in metadata_fields}, + ) + ) + assert loaded_docs == [expected_doc] + client.delete(prefix + doc_id) + + +def test_doc_loader_multiple_docs(): + client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance")) + + prefix = "multidocs:" + # Clean up stored documents with the same prefix + for key in client.keys(f"{prefix}*"): + client.delete(key) + + content_field = "page_content" + saver = MemorystoreDocumentSaver( + client=client, + key_prefix=prefix, + content_field=content_field, + ) + docs = [] + for content in range(10): + docs.append( + Document.construct( + page_content=f"{content}", + metadata={"metadata": f"meta: {content}"}, + ) + ) + + saver.add_documents(docs) + + loader = MemorystoreDocumentLoader( + client=client, + key_prefix=prefix, + content_fields=set([content_field]), + ) + loaded_docs = [] + for doc in loader.lazy_load(): + loaded_docs.append(doc) + assert sorted(loaded_docs, key=lambda d: d.page_content) == docs + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v diff --git a/tests/test_doc_saver.py b/tests/test_doc_saver.py new file mode 100644 index 0000000..0320a28 --- /dev/null +++ b/tests/test_doc_saver.py @@ -0,0 +1,118 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os + +import pytest +import redis +from langchain_core.documents.base import Document + +from langchain_google_memorystore_redis.doc_saver import MemorystoreDocumentSaver + + +@pytest.mark.parametrize( + "page_content,metadata,content_field,metadata_fields", + [ + ( + '"content1"', + {"key1": "doc1_value1", "key2": "doc1_value2"}, + "page_content", + None, + ), + ( + '"content2"', + {"key1": {'"nested_key"': {'"double_nested"': '"doc2_value1"'}}}, + "special_page_content", + None, + ), + ( + '"content3"', + {"key1": {"k": "not_in_filter"}, "key2": {'"key"': "in_filter"}}, + "page_content", + set(["key2"]), + ), + ], +) +def test_doc_saver_add_documents_one_doc( + page_content, metadata, content_field, metadata_fields +): + client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance")) + prefix = "prefix:" + + saver = MemorystoreDocumentSaver( + client=client, + key_prefix=prefix, + content_field=content_field, + metadata_fields=metadata_fields, + ) + + doc = Document.construct(page_content=page_content, metadata=metadata) + doc_id = "doc" + saver.add_documents([doc], [doc_id]) + + # Only verify the metadata keys given in the metadata_fields + metadata_to_verify = {} + for k, v in metadata.items(): + if not metadata_fields or k in metadata_fields: + metadata_to_verify[k] = v + + verify_stored_values( + client, + prefix + doc_id, + page_content, + content_field, + metadata_to_verify, + ) + + client.delete(prefix + doc_id) + + +def verify_stored_values( + client: redis.Redis, + key: str, + page_content: str, + content_field: str, + metadata_to_verify: dict, +): + stored_value = client.hgetall(key) + assert isinstance(stored_value, dict) + assert len(stored_value) == 1 + len(metadata_to_verify) + + for k, v in stored_value.items(): + decoded_value = v.decode() + if k == content_field.encode(): + assert page_content == decoded_value + else: + assert ( + metadata_to_verify[k.decode()] == json.loads(decoded_value) + if is_json_parsable(decoded_value) + else decoded_value + ) + + +def is_json_parsable(s: str) -> bool: + try: + json.loads(s) + return True + except ValueError: + return False + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v