diff --git a/docs/vector_store.ipynb b/docs/vector_store.ipynb index 20b3173..564e338 100644 --- a/docs/vector_store.ipynb +++ b/docs/vector_store.ipynb @@ -174,7 +174,7 @@ "outputs": [], "source": [ "rvs = RedisVectorStore(\n", - " client=redis_client, index_name=\"my_vector_index\", embedding_service=embeddings\n", + " client=redis_client, index_name=\"my_vector_index\", embeddings=embeddings\n", ")\n", "ids = rvs.add_texts(\n", " texts=[d.page_content for d in docs], metadatas=[d.metadata for d in docs]\n", diff --git a/src/langchain_google_memorystore_redis/__init__.py b/src/langchain_google_memorystore_redis/__init__.py index 94cb66f..84b472d 100644 --- a/src/langchain_google_memorystore_redis/__init__.py +++ b/src/langchain_google_memorystore_redis/__init__.py @@ -13,12 +13,19 @@ # limitations under the License. from .chat_message_history import MemorystoreChatMessageHistory from .doc_saver import MemorystoreDocumentSaver -from .vector_store import DistanceStrategy, FLATConfig, HNSWConfig, RedisVectorStore +from .vector_store import ( + DistanceStrategy, + FLATConfig, + HNSWConfig, + RedisVectorStore, + VectorIndexConfig, +) __all__ = [ "MemorystoreChatMessageHistory", "MemorystoreDocumentSaver", "DistanceStrategy", + "VectorIndexConfig", "FLATConfig", "HNSWConfig", "RedisVectorStore", diff --git a/src/langchain_google_memorystore_redis/chat_message_history.py b/src/langchain_google_memorystore_redis/chat_message_history.py index 90c32fe..bad05a3 100644 --- a/src/langchain_google_memorystore_redis/chat_message_history.py +++ b/src/langchain_google_memorystore_redis/chat_message_history.py @@ -49,7 +49,9 @@ def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all messages chronologically stored in this session.""" all_elements = self._redis.lrange(self._key, 0, -1) - assert isinstance(all_elements, list) + if not isinstance(all_elements, list): + raise TypeError("Expected a list from `lrange` but got a different type.") + loaded_messages = messages_from_dict( [json.loads(e.decode(self._encoding)) for e in all_elements] ) diff --git a/src/langchain_google_memorystore_redis/vector_store.py b/src/langchain_google_memorystore_redis/vector_store.py index 0fc9ac9..e40cac0 100644 --- a/src/langchain_google_memorystore_redis/vector_store.py +++ b/src/langchain_google_memorystore_redis/vector_store.py @@ -15,11 +15,9 @@ import json import logging import operator -import pprint import re import uuid from abc import ABC -from enum import Enum, auto from itertools import zip_longest from typing import Any, Iterable, List, Optional, Tuple @@ -29,12 +27,9 @@ DistanceStrategy, maximal_marginal_relevance, ) -from langchain_core._api import deprecated -from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.utils import get_from_dict_or_env -from langchain_core.vectorstores import VectorStore, VectorStoreRetriever +from langchain_core.vectorstores import VectorStore # Setting up a basic logger logger = logging.getLogger(__name__) @@ -46,7 +41,7 @@ DEFAULT_CONTENT_FIELD = "page_content" DEFAULT_VECTOR_FIELD = "vector" -DEFAULT_DATA_TYPE = "float32" +DEFAULT_DATA_TYPE = "FLOAT32" DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE @@ -131,6 +126,12 @@ def __init__( f"Supported strategies are: {supported_strategies}." ) + if data_type.upper() != DEFAULT_DATA_TYPE: + raise ValueError(f"Unsupported data type: {data_type}") + + if vector_size < 0: + raise ValueError(f"Unsupported vector size: {vector_size}") + super().__init__(name, field_name, type) self.distance_strategy = distance_strategy self.vector_size = vector_size @@ -148,6 +149,12 @@ def distance_metric(self): class HNSWConfig(VectorIndexConfig): + DEFAULT_VECTOR_SIZE = 128 + DEFAULT_INITIAL_CAP = 10000 + DEFAULT_M = 16 + DEFAULT_EF_CONSTRUCTION = 200 + DEFAULT_EF_RUNTIME = 10 + """ Configuration class for HNSW (Hierarchical Navigable Small World) vector indexes. """ @@ -155,13 +162,13 @@ class HNSWConfig(VectorIndexConfig): def __init__( self, name: str, - field_name=None, - vector_size: int = 128, + field_name: Optional[str] = None, + vector_size: int = DEFAULT_VECTOR_SIZE, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - initial_cap: int = 10000, - m: int = 16, - ef_construction: int = 200, - ef_runtime: int = 10, + initial_cap: int = DEFAULT_INITIAL_CAP, + m: int = DEFAULT_M, + ef_construction: int = DEFAULT_EF_CONSTRUCTION, + ef_runtime: int = DEFAULT_EF_RUNTIME, ): """ Initializes the HNSWConfig object. @@ -179,21 +186,19 @@ def __init__( are ranked. initial_cap (int): Specifies the initial capacity of the index in terms of the number of vectors it can hold, impacting the initial memory allocation. - Defaults to 10000. m (int): Determines the maximum number of outgoing edges each node in the index graph can have, directly affecting the graph's connectivity and - search performance. Defaults to 16. + search performance. ef_construction (int): Controls the size of the dynamic candidate list during the construction of the index, influencing the index build time and quality. - Defaults to 200. ef_runtime (int): Sets the size of the dynamic candidate list during search - queries, balancing between search speed and accuracy. Defaults to 10. + queries, balancing between search speed and accuracy. """ if field_name is None: field_name = DEFAULT_VECTOR_FIELD super().__init__( - name, field_name, "HNSW", distance_strategy, vector_size, "FLOAT32" + name, field_name, "HNSW", distance_strategy, vector_size, DEFAULT_DATA_TYPE ) self.initial_cap = initial_cap self.m = m @@ -206,11 +211,13 @@ class FLATConfig(VectorIndexConfig): Configuration class for FLAT vector indexes, utilizing brute-force search. """ + DEFAULT_VECTOR_SIZE = 128 + def __init__( self, name: str, - field_name=None, - vector_size: int = 128, + field_name: Optional[str] = None, + vector_size: int = DEFAULT_VECTOR_SIZE, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ): """ @@ -231,7 +238,7 @@ def __init__( if field_name is None: field_name = DEFAULT_VECTOR_FIELD super().__init__( - name, field_name, "FLAT", distance_strategy, vector_size, "FLOAT32" + name, field_name, "FLAT", distance_strategy, vector_size, DEFAULT_DATA_TYPE ) @@ -240,7 +247,7 @@ def __init__( self, client: redis.Redis, index_name: str, - embedding_service: Embeddings, + embeddings: Embeddings, content_field: str = DEFAULT_CONTENT_FIELD, vector_field: str = DEFAULT_VECTOR_FIELD, ): @@ -254,7 +261,7 @@ def __init__( index_name (str): The name assigned to the vector index within Redis. This name is used to identify the index for operations such as searching and indexing. - embedding_service (Embeddings): An instance of an embedding service or model + embeddings (Embeddings): An instance of an embedding service or model capable of generating vector embeddings from document content. This service is utilized to convert text documents into vector representations for storage and search. @@ -267,24 +274,24 @@ def __init__( when adding new documents to the store and when retrieving or searching documents based on their vector embeddings. Defaults to 'vector'. """ - if client == None: + if not isinstance(client, redis.Redis): raise ValueError( "A Redis 'client' must be provided to initialize RedisVectorStore" ) - if index_name == None: + if not isinstance(index_name, str): raise ValueError( "A 'index_name' must be provided to initialize RedisVectorStore" ) - if embedding_service == None: + if not isinstance(embeddings, Embeddings): raise ValueError( - "An 'embedding_service' must be provided to initialize RedisVectorStore" + "An 'embeddings' must be provided to initialize RedisVectorStore" ) self._client = client self.index_name = index_name - self.embedding_service = embedding_service + self.embeddings_service = embeddings self.key_prefix = self.get_key_prefix(index_name) self.content_field = content_field self.vector_field = vector_field @@ -362,7 +369,8 @@ def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, - batch_size: int = 1000, + ids: Optional[List[str]] = None, + batch_size: Optional[int] = 1000, **kwargs: Any, ) -> List[str]: """ @@ -375,49 +383,57 @@ def add_texts( where each dictionary corresponds to a text document in the same order as the `texts` iterable. Each metadata dictionary should contain key-value pairs representing the metadata attributes for the associated text document. + ids (Optional[List[str]], optional): An optional list of unique identifiers for each text document. + If not provided, the system will generate unique identifiers for each document. If provided, + the length of this list should match the length of `texts`. batch_size (int, optional): The number of documents to process in a single batch operation. This parameter helps manage memory and performance when adding a large number of documents. Defaults to 1000. - **kwargs (Any): Additional keyword arguments for extended functionality. This includes: - - 'keys' or 'ids' (List[str], optional): Custom identifiers for each document. If provided, - the length of this list should match the length of `texts`. If not provided, the system - will generate unique identifiers. Returns: List[str]: A list containing the unique keys or identifiers for each added document. These keys can be used to retrieve or reference the documents within the vector store. Note: - If both 'keys' (or 'ids') and 'metadatas' are provided, they must be of the same length as the - `texts` iterable to ensure each document is correctly associated with its metadata and identifier. + If both 'ids' and 'metadatas' are provided, they must be of the same length as the `texts` + iterable to ensure each document is correctly associated with its metadata and identifier. """ - # Generate or extend keys/IDs for the documents - keys_or_ids = kwargs.get("keys", kwargs.get("ids", [])) - if keys_or_ids and len(keys_or_ids) != len(list(texts)): - raise ValueError( - "The length of keys or ids must match the length of the texts" - ) - if not keys_or_ids: - keys_or_ids = [str(uuid.uuid4()) for _ in texts] - # Ensure there's a unique ID for each text document - # Fallback for empty metadata - metadatas = metadatas if metadatas is not None else [{} for _ in texts] + + # Generate ids if not provided + if ids is None: + ids = [str(uuid.uuid4()) for _ in texts] + + # Check if both ids and metadatas are provided and have the same length + if ids is not None: + if len(ids) != len(list(texts)): + raise ValueError("The length of 'ids' and 'texts' must be the same.") + + if not metadatas: + metadatas = [{} for _ in texts] + + if metadatas is not None: + if len(metadatas) != len(list(texts)): + raise ValueError( + "The length of 'metadatas' and 'texts' must be the same." + ) + + if not batch_size or batch_size <= 0: + raise ValueError("batch_size must be greater than 0.") + # Generate embeddings for all documents - embeddings = self.embedding_service.embed_documents(list(texts)) + embeddings = self.embeddings_service.embed_documents(list(texts)) - ids = [] + new_ids = [] pipeline = self._client.pipeline(transaction=False) - for i, bundle in enumerate( - zip_longest(keys_or_ids, texts, embeddings, metadatas), start=1 - ): - key, text, embedding, metadata = bundle - key = self.key_prefix + key + for i, bundle in enumerate(zip(ids, texts, embeddings, metadatas), start=1): + id, text, embedding, metadata = bundle + new_id = self.key_prefix + id # Initialize the mapping with content and vector fields mapping = { self.content_field: text, self.vector_field: np.array(embedding) - .astype(DEFAULT_DATA_TYPE) + .astype(DEFAULT_DATA_TYPE.lower()) .tobytes(), } @@ -431,8 +447,8 @@ def add_texts( mapping[meta_key] = str(meta_value) # Add the document to the Redis hash - pipeline.hset(key, mapping=mapping) - ids.append(key) + pipeline.hset(new_id, mapping=mapping) + new_ids.append(new_id) # Ensure to execute any remaining commands in the pipeline after the loop if i % batch_size == 0: @@ -443,7 +459,7 @@ def add_texts( logger.info(f"{len(ids)} documents ingested into Redis.") - return ids + return new_ids @classmethod def from_texts( @@ -451,8 +467,9 @@ def from_texts( texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, - client=None, - index_name=None, + ids: Optional[List[str]] = None, + client: Optional[redis.Redis] = None, + index_name: Optional[str] = None, **kwargs: Any, ) -> "RedisVectorStore": """ @@ -465,26 +482,29 @@ def from_texts( metadatas (Optional[List[dict]]): A list of dictionaries where each dictionary contains metadata corresponding to each text document in `texts`. If provided, the length of `metadatas` must match the length of `texts`. - **kwargs (Any): Additional keyword arguments that can include: - - 'client': A Redis client instance to be used by the RedisVectorStore. - - 'index_name': The name of the index to be created or used in Redis. - If not provided, a default name may be used. + ids (Optional[List[str]], optional): An optional list of unique identifiers for + each text document. If not provided, the system will generate unique identifiers + for each document. If provided, the length of this list should match the length + of `texts`. + client (redis.Redis): The Redis client instance to be used for database + operations, providing connectivity and command execution against the + Redis instance. + index_name (str): The name assigned to the vector index within Redis. This + name is used to identify the index for operations such as searching and + indexing. + **kwargs (Any): Additional keyword arguments Returns: RedisVectorStore: An instance of RedisVectorStore that has been populated with the embeddings of the provided texts, along with their associated metadata. - - Raises: - ValueError: If a Redis client instance is not provided in `kwargs`, indicating - that the method cannot proceed without a connection to a Redis database. """ - if "client" == None: + if not isinstance(client, redis.Redis): raise ValueError( "A 'client' must be provided to initialize RedisVectorStore" ) - if "index_name" == None: + if not isinstance(index_name, str): raise ValueError( "A 'index_name' must be provided to initialize RedisVectorStore" ) @@ -498,7 +518,7 @@ def from_texts( ) # Add texts and their corresponding metadata to the instance - instance.add_texts(texts, metadatas) + instance.add_texts(texts, metadatas, ids) return instance @@ -558,7 +578,7 @@ def _similarity_search_by_vector_with_score_and_embeddings( "PARAMS", 2, "query_vector", - np.array([query_embedding]).astype(DEFAULT_DATA_TYPE).tobytes(), + np.array([query_embedding]).astype(DEFAULT_DATA_TYPE.lower()).tobytes(), "DIALECT", 2, ] @@ -584,7 +604,9 @@ def _similarity_search_by_vector_with_score_and_embeddings( if key == self.content_field: page_content = value.decode(self.encoding) elif key == self.vector_field: - embedding = np.frombuffer(value, dtype=DEFAULT_DATA_TYPE).tolist() + embedding = np.frombuffer( + value, dtype=DEFAULT_DATA_TYPE.lower() + ).tolist() elif key == "distance": distance = float(value.decode(self.encoding)) else: @@ -652,7 +674,7 @@ def similarity_search_with_score( documents most relevant to the query according to the similarity scores. """ # Embed the query using the embedding function - query_embedding = self.embedding_service.embed_query(query) + query_embedding = self.embeddings_service.embed_query(query) return self._similarity_search_by_vector_with_score( query_embedding, k, **kwargs ) @@ -710,7 +732,7 @@ def similarity_search( the search backend. """ # Embed the query using the embedding function - query_embedding = self.embedding_service.embed_query(query) + query_embedding = self.embeddings_service.embed_query(query) return self.similarity_search_by_vector(query_embedding, k, **kwargs) def max_marginal_relevance_search( @@ -748,7 +770,7 @@ def max_marginal_relevance_search( raise ValueError("lambda_mult must be between 0 and 1.") # Embed the query using a hypothetical method to convert text to vector. - query_embedding = self.embedding_service.embed_query(query) + query_embedding = self.embeddings_service.embed_query(query) # Fetch initial documents based on query embedding. initial_results = self._similarity_search_by_vector_with_score_and_embeddings( diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py new file mode 100644 index 0000000..3a84272 --- /dev/null +++ b/tests/test_vector_store.py @@ -0,0 +1,253 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import uuid + +import numpy +import pytest +import redis +from langchain_community.embeddings.fake import FakeEmbeddings +from langchain_core.documents.base import Document + +from langchain_google_memorystore_redis import ( + DistanceStrategy, + HNSWConfig, + RedisVectorStore, + VectorIndexConfig, +) + + +def test_vector_store_init_index(): + client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance")) + index_name = str(uuid.uuid4()) + + index_config = HNSWConfig( + name=index_name, distance_strategy=DistanceStrategy.COSINE, vector_size=128 + ) + + assert not check_index_exists(client, index_name, index_config) + RedisVectorStore.init_index(client=client, index_config=index_config) + assert check_index_exists(client, index_name, index_config) + RedisVectorStore.drop_index(client=client, index_name=index_name) + assert not check_index_exists(client, index_name, index_config) + client.flushall() + + +@pytest.mark.parametrize( + "texts,metadatas,ids", + [ + # Test case 1: Basic scenario with texts only + (["text1", "text2"], None, None), + # Test case 2: Texts with metadatas + (["text1", "text2"], [{"meta1": "data1"}, {"meta2": "data2"}], None), + # Test case 3: Texts with metadatas and ids + (["text1", "text2"], [{"meta1": "data1"}, {"meta2": "data2"}], ["id1", "id2"]), + # Test case 4: Texts with ids only + (["text1", "text2"], None, ["id1", "id2"]), + # Additional test cases can be added as needed + ], +) +def test_vector_store_add_texts(texts, metadatas, ids): + client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance")) + + # Initialize the vector index + index_name = str(uuid.uuid4()) + index_config = HNSWConfig( + name=index_name, distance_strategy=DistanceStrategy.COSINE, vector_size=128 + ) + RedisVectorStore.init_index(client=client, index_config=index_config) + + # Insert the documents + rvs = RedisVectorStore( + client=client, index_name=index_name, embeddings=FakeEmbeddings(size=128) + ) + returned_ids = rvs.add_texts(texts=texts, metadatas=metadatas, ids=ids) + + original_metadatas = metadatas if metadatas is not None else [None] * len(texts) + original_ids = ids if ids is not None else [""] * len(texts) + + # Validate the results + for original_id, text, original_metadata, returned_id in zip( + original_ids, texts, original_metadatas, returned_ids + ): + expected_id = f"{index_name}{original_id}" + # Check if original_id is empty and adjust assertion accordingly + if original_id == "": + assert returned_id.startswith( + expected_id + ), f"Returned ID {returned_id} does not start with expected prefix {expected_id}" + else: + assert ( + returned_id == expected_id + ), f"Returned ID {returned_id} does not match expected {expected_id}" + + # Fetch the record from Redis + hash_record = client.hgetall(returned_id) + + # Validate page_content + fetched_page_content = hash_record[b"page_content"].decode("utf-8") + assert fetched_page_content == text, "Page content does not match" + + # Validate vector embedding + vector = numpy.frombuffer(hash_record[b"vector"], dtype=numpy.float32) + assert ( + len(vector) == 128 + ), f"Decoded 'vector' length is {len(vector)}, expected 128" + + # Iterate over each key-value pair in the hash_record + fetched_metadata = {} + for key, value in hash_record.items(): + # Decode the key from bytes to string + key_decoded = key.decode("utf-8") + + # Skip 'page_content' and 'vector' keys, include all others in fetched_metadata + if key_decoded not in ["page_content", "vector"]: + # Decode the value from bytes to string or JSON as needed + try: + # Attempt to load JSON content if applicable + value_decoded = json.loads(value.decode("utf-8")) + except json.JSONDecodeError: + # Fallback to simple string decoding if it's not JSON + value_decoded = value.decode("utf-8") + + # Add the decoded key-value pair to fetched_metadata + fetched_metadata[key_decoded] = value_decoded + + if original_metadata is None: + original_metadata = {} + + assert fetched_metadata == original_metadata, "Metadata does not match" + + # Verify no extra keys are present + all_keys = [key.decode("utf-8") for key in client.keys(f"{index_name}*")] + # Currently RedisQuery stores the index schema as a key using the index_name + assert len(all_keys) == len(returned_ids) + 1, "Found unexpected keys in Redis" + + # Clena up + RedisVectorStore.drop_index(client=client, index_name=index_name) + client.flushall() + + +def test_vector_store_knn_query(): + texts = [ + "The quick brown fox jumps over the lazy dog", + "A clever fox outwitted the guard dog to sneak into the farmyard at night", + "Exploring the mysteries of deep space and black holes", + "Delicious recipes for homemade pasta and pizza", + "Advanced techniques in machine learning and artificial intelligence", + "Sustainable living: Tips for reducing your carbon footprint", + ] + + client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance")) + + # Initialize the vector index + index_name = str(uuid.uuid4()) + index_config = HNSWConfig( + name=index_name, distance_strategy=DistanceStrategy.COSINE, vector_size=128 + ) + RedisVectorStore.init_index(client=client, index_config=index_config) + + # Insert the documents + rvs = RedisVectorStore( + client=client, index_name=index_name, embeddings=FakeEmbeddings(size=128) + ) + rvs.add_texts(texts=texts) + + # Validate knn query + query_result = rvs.similarity_search(query="fox dog", k=2) + assert len(query_result) == 2, "Expected 2 documents to be returned" + + # Clean up + RedisVectorStore.drop_index(client=client, index_name=index_name) + client.flushall() + + +@pytest.mark.parametrize( + "distance_strategy, distance_threshold", + [ + (DistanceStrategy.COSINE, 0.8), + (DistanceStrategy.MAX_INNER_PRODUCT, 1.0), + (DistanceStrategy.EUCLIDEAN_DISTANCE, 2.0), + ], +) +def test_vector_store_range_query(distance_strategy, distance_threshold): + texts = [ + "The quick brown fox jumps over the lazy dog", + "A clever fox outwitted the guard dog to sneak into the farmyard at night", + "Exploring the mysteries of deep space and black holes", + "Delicious recipes for homemade pasta and pizza", + "Advanced techniques in machine learning and artificial intelligence", + "Sustainable living: Tips for reducing your carbon footprint", + ] + + client = redis.from_url(get_env_var("REDIS_URL", "URL of the Redis instance")) + + # Initialize the vector index + index_name = str(uuid.uuid4()) + index_config = HNSWConfig( + name=index_name, distance_strategy=distance_strategy, vector_size=128 + ) + RedisVectorStore.init_index(client=client, index_config=index_config) + + # Insert the documents + rvs = RedisVectorStore( + client=client, index_name=index_name, embeddings=FakeEmbeddings(size=128) + ) + rvs.add_texts(texts=texts) + + # Validate range query + query_result = rvs.similarity_search_with_score( + query="dog", + k=3, + distance_strategy=distance_strategy, + distance_threshold=distance_threshold, + ) + assert len(query_result) <= 3, "Expected less than 3 documents to be returned" + for _, score in query_result: + if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + assert ( + score > distance_threshold + ), f"Score {score} is not greater than {distance_threshold} for {distance_strategy}" + else: + assert ( + score < distance_threshold + ), f"Score {score} is not less than {distance_threshold} for {distance_strategy}" + + # Clean up + RedisVectorStore.drop_index(client=client, index_name=index_name) + client.flushall() + + +def check_index_exists( + client: redis.Redis, index_name: str, index_config: VectorIndexConfig +) -> bool: + try: + index_info = client.ft(index_name).info() + except: + return False + + return ( + index_info["index_name"] == index_name + and index_info["index_definition"][1] == b"HASH" + ) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v