diff --git a/src/backend/base/langflow/base/vectorstores/model.py b/src/backend/base/langflow/base/vectorstores/model.py index e05be744e698..a00b56e99c4e 100644 --- a/src/backend/base/langflow/base/vectorstores/model.py +++ b/src/backend/base/langflow/base/vectorstores/model.py @@ -79,7 +79,7 @@ def build_base_retriever(self) -> Retriever: # type: ignore[type-var] """ vector_store = self.build_vector_store() if hasattr(vector_store, "as_retriever"): - retriever = vector_store.as_retriever() + retriever = vector_store.as_retriever(**self.get_retriever_kwargs()) if self.status is None: self.status = "Retriever built successfully." return retriever @@ -106,3 +106,9 @@ def search_documents(self) -> List[Data]: ) self.status = search_results return search_results + + def get_retriever_kwargs(self): + """ + Get the retriever kwargs. Implementations can override this method to provide custom retriever kwargs. + """ + return {} diff --git a/src/backend/base/langflow/components/vectorstores/AstraDB.py b/src/backend/base/langflow/components/vectorstores/AstraDB.py index d633953b13c4..1b55ebe24220 100644 --- a/src/backend/base/langflow/components/vectorstores/AstraDB.py +++ b/src/backend/base/langflow/components/vectorstores/AstraDB.py @@ -1,6 +1,9 @@ from loguru import logger +from langchain_core.vectorstores import VectorStore from langflow.base.vectorstores.model import LCVectorStoreComponent +from langflow.helpers import docs_to_data +from langflow.inputs import FloatInput, DictInput from langflow.io import ( BoolInput, DataInput, @@ -20,6 +23,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): documentation: str = "https://python.langchain.com/docs/integrations/vectorstores/astradb" icon: str = "AstraDB" + _cached_vectorstore: VectorStore = None + inputs = [ StrInput( name="collection_name", @@ -124,23 +129,40 @@ class AstraVectorStoreComponent(LCVectorStoreComponent): info="Optional dictionary defining the indexing policy for the collection.", advanced=True, ), + IntInput( + name="number_of_results", + display_name="Number of Results", + info="Number of results to return.", + advanced=True, + value=4, + ), DropdownInput( name="search_type", display_name="Search Type", - options=["Similarity", "MMR"], + info="Search type to use", + options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"], value="Similarity", advanced=True, ), - IntInput( - name="number_of_results", - display_name="Number of Results", - info="Number of results to return.", + FloatInput( + name="search_score_threshold", + display_name="Search Score Threshold", + info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')", + value=0, advanced=True, - value=4, + ), + DictInput( + name="search_filter", + display_name="Search Metadata Filter", + info="Optional dictionary of filters to apply to the search query.", + advanced=True, + is_list=True, ), ] def _build_vector_store_no_ingest(self): + if self._cached_vectorstore: + return self._cached_vectorstore try: from langchain_astradb import AstraDBVectorStore from langchain_astradb.utils.astradb import SetupMode @@ -199,13 +221,13 @@ def _build_vector_store_no_ingest(self): except Exception as e: raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e + self._cached_vectorstore = vector_store + return vector_store def build_vector_store(self): vector_store = self._build_vector_store_no_ingest() - if hasattr(self, "ingest_data") and self.ingest_data: - logger.debug("Ingesting data into the Vector Store.") - self._add_documents_to_vector_store(vector_store) + self._add_documents_to_vector_store(vector_store) return vector_store def _add_documents_to_vector_store(self, vector_store): @@ -216,7 +238,7 @@ def _add_documents_to_vector_store(self, vector_store): else: raise ValueError("Vector Store Inputs must be Data objects.") - if documents and self.embedding is not None: + if documents: logger.debug(f"Adding {len(documents)} documents to the Vector Store.") try: vector_store.add_documents(documents) @@ -225,8 +247,17 @@ def _add_documents_to_vector_store(self, vector_store): else: logger.debug("No documents to add to the Vector Store.") + def _map_search_type(self): + if self.search_type == "Similarity with score threshold": + return "similarity_score_threshold" + elif self.search_type == "MMR (Max Marginal Relevance)": + return "mmr" + else: + return "similarity" + def search_documents(self) -> list[Data]: vector_store = self._build_vector_store_no_ingest() + self._add_documents_to_vector_store(vector_store) logger.debug(f"Search input: {self.search_input}") logger.debug(f"Search type: {self.search_type}") @@ -234,27 +265,38 @@ def search_documents(self) -> list[Data]: if self.search_input and isinstance(self.search_input, str) and self.search_input.strip(): try: - if self.search_type == "Similarity": - docs = vector_store.similarity_search( - query=self.search_input, - k=self.number_of_results, - ) - elif self.search_type == "MMR": - docs = vector_store.max_marginal_relevance_search( - query=self.search_input, - k=self.number_of_results, - ) - else: - raise ValueError(f"Invalid search type: {self.search_type}") + search_type = self._map_search_type() + search_args = self._build_search_args() + + docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args) except Exception as e: raise ValueError(f"Error performing search in AstraDBVectorStore: {str(e)}") from e logger.debug(f"Retrieved documents: {len(docs)}") - data = [Data.from_document(doc) for doc in docs] + data = docs_to_data(docs) logger.debug(f"Converted documents to data: {len(data)}") self.status = data return data else: logger.debug("No search input provided. Skipping search.") return [] + + def _build_search_args(self): + args = { + "k": self.number_of_results, + "score_threshold": self.search_score_threshold, + } + + if self.search_filter: + clean_filter = {k: v for k, v in self.search_filter.items() if k and v} + if len(clean_filter) > 0: + args["filter"] = clean_filter + return args + + def get_retriever_kwargs(self): + search_args = self._build_search_args() + return { + "search_type": self._map_search_type(), + "search_kwargs": search_args, + } diff --git a/src/backend/base/langflow/components/vectorstores/Cassandra.py b/src/backend/base/langflow/components/vectorstores/Cassandra.py index 90800e70b15b..81d56765cf44 100644 --- a/src/backend/base/langflow/components/vectorstores/Cassandra.py +++ b/src/backend/base/langflow/components/vectorstores/Cassandra.py @@ -1,10 +1,10 @@ -from typing import List, Optional +from typing import List from langchain_community.vectorstores import Cassandra from langflow.base.vectorstores.model import LCVectorStoreComponent from langflow.helpers.data import docs_to_data -from langflow.inputs import DictInput +from langflow.inputs import DictInput, FloatInput, BoolInput from langflow.io import ( DataInput, DropdownInput, @@ -15,6 +15,7 @@ SecretStrInput, ) from langflow.schema import Data +from loguru import logger class CassandraVectorStoreComponent(LCVectorStoreComponent): @@ -23,6 +24,8 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent): documentation = "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/cassandra" icon = "Cassandra" + _cached_vectorstore: Cassandra = None + inputs = [ MessageTextInput( name="database_ref", @@ -64,12 +67,6 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent): value=16, advanced=True, ), - MessageTextInput( - name="body_index_options", - display_name="Body Index Options", - info="Optional options used to create the body index.", - advanced=True, - ), DropdownInput( name="setup_mode", display_name="Setup Mode", @@ -99,14 +96,52 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent): value=4, advanced=True, ), + DropdownInput( + name="search_type", + display_name="Search Type", + info="Search type to use", + options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"], + value="Similarity", + advanced=True, + ), + FloatInput( + name="search_score_threshold", + display_name="Search Score Threshold", + info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')", + value=0, + advanced=True, + ), + DictInput( + name="search_filter", + display_name="Search Metadata Filter", + info="Optional dictionary of filters to apply to the search query.", + advanced=True, + is_list=True, + ), + MessageTextInput( + name="body_search", + display_name="Search Body", + info="Document textual search terms to apply to the search query.", + advanced=True, + ), + BoolInput( + name="enable_body_search", + display_name="Enable Body Search", + info="Flag to enable body search. This must be enabled BEFORE the table is created.", + value=False, + advanced=True, + ), ] def build_vector_store(self) -> Cassandra: - return self._build_cassandra(ingest=True) + return self._build_cassandra() - def _build_cassandra(self, ingest: bool) -> Cassandra: + def _build_cassandra(self) -> Cassandra: + if self._cached_vectorstore: + return self._cached_vectorstore try: import cassio + from langchain_community.utilities.cassandra import SetupMode except ImportError: raise ImportError( "Could not import cassio integration package. " "Please install it with `pip install cassio`." @@ -138,49 +173,73 @@ def _build_cassandra(self, ingest: bool) -> Cassandra: password=self.token, cluster_kwargs=self.cluster_kwargs, ) - ttl_seconds: Optional[int] = self.ttl_seconds - documents = [] - if ingest: - for _input in self.ingest_data or []: - if isinstance(_input, Data): - documents.append(_input.to_lc_document()) - else: - documents.append(_input) + for _input in self.ingest_data or []: + if isinstance(_input, Data): + documents.append(_input.to_lc_document()) + else: + documents.append(_input) + + if self.enable_body_search: + body_index_options = [("index_analyzer", "STANDARD")] + else: + body_index_options = None + + if self.setup_mode == "Off": + setup_mode = SetupMode.OFF + elif self.setup_mode == "Sync": + setup_mode = SetupMode.SYNC + else: + setup_mode = SetupMode.ASYNC if documents: + logger.debug(f"Adding {len(documents)} documents to the Vector Store.") table = Cassandra.from_documents( documents=documents, embedding=self.embedding, table_name=self.table_name, keyspace=self.keyspace, - ttl_seconds=ttl_seconds, + ttl_seconds=self.ttl_seconds or None, batch_size=self.batch_size, - body_index_options=self.body_index_options, + body_index_options=body_index_options, ) - else: + logger.debug("No documents to add to the Vector Store.") table = Cassandra( embedding=self.embedding, table_name=self.table_name, keyspace=self.keyspace, - ttl_seconds=ttl_seconds, - body_index_options=self.body_index_options, - setup_mode=self.setup_mode, + ttl_seconds=self.ttl_seconds or None, + body_index_options=body_index_options, + setup_mode=setup_mode, ) - + self._cached_vectorstore = table return table + def _map_search_type(self): + if self.search_type == "Similarity with score threshold": + return "similarity_score_threshold" + elif self.search_type == "MMR (Max Marginal Relevance)": + return "mmr" + else: + return "similarity" + def search_documents(self) -> List[Data]: - vector_store = self._build_cassandra(ingest=False) + vector_store = self._build_cassandra() + + logger.debug(f"Search input: {self.search_query}") + logger.debug(f"Search type: {self.search_type}") + logger.debug(f"Number of results: {self.number_of_results}") if self.search_query and isinstance(self.search_query, str) and self.search_query.strip(): try: - docs = vector_store.similarity_search( - query=self.search_query, - k=self.number_of_results, - ) + search_type = self._map_search_type() + search_args = self._build_search_args() + + logger.debug(f"Search args: {str(search_args)}") + + docs = vector_store.search(query=self.search_query, search_type=search_type, **search_args) except KeyError as e: if "content" in str(e): raise ValueError( @@ -189,8 +248,33 @@ def search_documents(self) -> List[Data]: else: raise e + logger.debug(f"Retrieved documents: {len(docs)}") + data = docs_to_data(docs) self.status = data return data else: return [] + + def _build_search_args(self): + args = { + "k": self.number_of_results, + "score_threshold": self.search_score_threshold, + } + + if self.search_filter: + clean_filter = {k: v for k, v in self.search_filter.items() if k and v} + if len(clean_filter) > 0: + args["filter"] = clean_filter + if self.body_search: + if not self.enable_body_search: + raise ValueError("You should enable body search when creating the table to search the body field.") + args["body_search"] = self.body_search + return args + + def get_retriever_kwargs(self): + search_args = self._build_search_args() + return { + "search_type": self._map_search_type(), + "search_kwargs": search_args, + }