Skip to content

Commit

Permalink
feat(cassandra/astradb): hybrid search support (langflow-ai#2396)
Browse files Browse the repository at this point in the history
* cassandra/astradb: hybrid search support

* fix

* fix

(cherry picked from commit 30c369f)
  • Loading branch information
nicoloboschi committed Jul 2, 2024
1 parent 9aa8799 commit 4da8be3
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 54 deletions.
8 changes: 7 additions & 1 deletion src/backend/base/langflow/base/vectorstores/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
"""
vector_store = self.build_vector_store()
if hasattr(vector_store, "as_retriever"):
retriever = vector_store.as_retriever()
retriever = vector_store.as_retriever(**self.get_retriever_kwargs())
if self.status is None:
self.status = "Retriever built successfully."
return retriever
Expand All @@ -106,3 +106,9 @@ def search_documents(self) -> List[Data]:
)
self.status = search_results
return search_results

def get_retriever_kwargs(self):
"""
Get the retriever kwargs. Implementations can override this method to provide custom retriever kwargs.
"""
return {}
88 changes: 65 additions & 23 deletions src/backend/base/langflow/components/vectorstores/AstraDB.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from loguru import logger

from langchain_core.vectorstores import VectorStore
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.helpers import docs_to_data
from langflow.inputs import FloatInput, DictInput
from langflow.io import (
BoolInput,
DataInput,
Expand All @@ -20,6 +23,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
documentation: str = "https://python.langchain.com/docs/integrations/vectorstores/astradb"
icon: str = "AstraDB"

_cached_vectorstore: VectorStore = None

inputs = [
StrInput(
name="collection_name",
Expand Down Expand Up @@ -124,23 +129,40 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
info="Optional dictionary defining the indexing policy for the collection.",
advanced=True,
),
IntInput(
name="number_of_results",
display_name="Number of Results",
info="Number of results to return.",
advanced=True,
value=4,
),
DropdownInput(
name="search_type",
display_name="Search Type",
options=["Similarity", "MMR"],
info="Search type to use",
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"],
value="Similarity",
advanced=True,
),
IntInput(
name="number_of_results",
display_name="Number of Results",
info="Number of results to return.",
FloatInput(
name="search_score_threshold",
display_name="Search Score Threshold",
info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')",
value=0,
advanced=True,
value=4,
),
DictInput(
name="search_filter",
display_name="Search Metadata Filter",
info="Optional dictionary of filters to apply to the search query.",
advanced=True,
is_list=True,
),
]

def _build_vector_store_no_ingest(self):
if self._cached_vectorstore:
return self._cached_vectorstore
try:
from langchain_astradb import AstraDBVectorStore
from langchain_astradb.utils.astradb import SetupMode
Expand Down Expand Up @@ -199,13 +221,13 @@ def _build_vector_store_no_ingest(self):
except Exception as e:
raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e

self._cached_vectorstore = vector_store

return vector_store

def build_vector_store(self):
vector_store = self._build_vector_store_no_ingest()
if hasattr(self, "ingest_data") and self.ingest_data:
logger.debug("Ingesting data into the Vector Store.")
self._add_documents_to_vector_store(vector_store)
self._add_documents_to_vector_store(vector_store)
return vector_store

def _add_documents_to_vector_store(self, vector_store):
Expand All @@ -216,7 +238,7 @@ def _add_documents_to_vector_store(self, vector_store):
else:
raise ValueError("Vector Store Inputs must be Data objects.")

if documents and self.embedding is not None:
if documents:
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
try:
vector_store.add_documents(documents)
Expand All @@ -225,36 +247,56 @@ def _add_documents_to_vector_store(self, vector_store):
else:
logger.debug("No documents to add to the Vector Store.")

def _map_search_type(self):
if self.search_type == "Similarity with score threshold":
return "similarity_score_threshold"
elif self.search_type == "MMR (Max Marginal Relevance)":
return "mmr"
else:
return "similarity"

def search_documents(self) -> list[Data]:
vector_store = self._build_vector_store_no_ingest()
self._add_documents_to_vector_store(vector_store)

logger.debug(f"Search input: {self.search_input}")
logger.debug(f"Search type: {self.search_type}")
logger.debug(f"Number of results: {self.number_of_results}")

if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
try:
if self.search_type == "Similarity":
docs = vector_store.similarity_search(
query=self.search_input,
k=self.number_of_results,
)
elif self.search_type == "MMR":
docs = vector_store.max_marginal_relevance_search(
query=self.search_input,
k=self.number_of_results,
)
else:
raise ValueError(f"Invalid search type: {self.search_type}")
search_type = self._map_search_type()
search_args = self._build_search_args()

docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args)
except Exception as e:
raise ValueError(f"Error performing search in AstraDBVectorStore: {str(e)}") from e

logger.debug(f"Retrieved documents: {len(docs)}")

data = [Data.from_document(doc) for doc in docs]
data = docs_to_data(docs)
logger.debug(f"Converted documents to data: {len(data)}")
self.status = data
return data
else:
logger.debug("No search input provided. Skipping search.")
return []

def _build_search_args(self):
args = {
"k": self.number_of_results,
"score_threshold": self.search_score_threshold,
}

if self.search_filter:
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
if len(clean_filter) > 0:
args["filter"] = clean_filter
return args

def get_retriever_kwargs(self):
search_args = self._build_search_args()
return {
"search_type": self._map_search_type(),
"search_kwargs": search_args,
}
Loading

0 comments on commit 4da8be3

Please sign in to comment.