diff --git a/libs/oracledb/langchain_oracledb/vectorstores/oraclevs.py b/libs/oracledb/langchain_oracledb/vectorstores/oraclevs.py index 919d4e9..42390de 100644 --- a/libs/oracledb/langchain_oracledb/vectorstores/oraclevs.py +++ b/libs/oracledb/langchain_oracledb/vectorstores/oraclevs.py @@ -12,6 +12,7 @@ import functools import hashlib import inspect +import json import logging import os import re @@ -27,7 +28,6 @@ Optional, Tuple, Type, - TypedDict, TypeVar, Union, cast, @@ -48,6 +48,8 @@ from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore +from ..embeddings import OracleEmbeddings + logger = logging.getLogger(__name__) log_level = os.getenv("LOG_LEVEL", "ERROR").upper() logging.basicConfig( @@ -59,58 +61,294 @@ # define a type variable that can be any kind of function T = TypeVar("T", bound=Callable[..., Any]) +LOGICAL_MAP = { + "$and": (" AND ", "({0})"), + "$or": (" OR ", "({0})"), + "$nor": (" OR ", "( NOT ({0}) )"), +} + +COMPARISON_MAP = { + "$exists": "", + "$eq": "@ == {0}", + "$ne": "@ != {0}", + "$gt": "@ > {0}", + "$lt": "@ < {0}", + "$gte": "@ >= {0}", + "$lte": "@ <= {0}", + "$between": "", + "$startsWith": "@ starts with {0}", + "$hasSubstring": "@ has substring {0}", + "$instr": "@ has substring {0}", + "$regex": "@ like_regex {0}", + "$like": "@ like {0}", + "$in": "", + "$nin": "", + "$all": "", + "$not": "", +} + +# operations that may need negation +NOT_OPERS = ["$nin", "$not", "$exists"] + + +def _get_comparison_string( + oper: str, value: Any, bind_variables: List[str] +) -> tuple[str, str]: + if oper not in COMPARISON_MAP: + raise ValueError(f"Invalid operator: {oper}") + + # usual two sided operator case + if COMPARISON_MAP[oper] != "": + bind_l = len(bind_variables) + bind_variables.append(value) -class FilterCondition(TypedDict): - key: str - oper: str - value: str + return ( + COMPARISON_MAP[oper].format(f"$val{bind_l}"), + f':value{bind_l} as "val{bind_l}"', + ) + # between - needs two bindings + elif oper == "$between": + if not isinstance(value, List) or len(value) != 2: + raise ValueError( + f"Invalid value for $between: {value}. " + "It must be a list containing exactly 2 elements." + ) -class FilterGroup(TypedDict, total=False): - _and: Optional[List[Union["FilterCondition", "FilterGroup"]]] - _or: Optional[List[Union["FilterCondition", "FilterGroup"]]] + min_val, max_val = value + if min_val is None and max_val is None: + raise ValueError("At least one bound in $between must be non-null.") + conditions = [] + passings = [] + if min_val is not None: + bind_l = len(bind_variables) + bind_variables.append(min_val) -def _convert_oper_to_sql(oper: str) -> str: - oper_map = {"EQ": "==", "GT": ">", "LT": "<", "GTE": ">=", "LTE": "<="} - if oper not in oper_map: - raise ValueError("Filter operation {} not supported".format(oper)) - return oper_map.get(oper, "==") + conditions.append(f"@ >= $val{bind_l}") + passings.append(f':value{bind_l} as "val{bind_l}"') + if max_val is not None: + bind_l = len(bind_variables) + bind_variables.append(max_val) -def _generate_condition(condition: FilterCondition) -> str: - key = condition["key"] - oper = _convert_oper_to_sql(condition["oper"]) - value = condition["value"] - if isinstance(value, str): - value = f'"{value}"' - return f"JSON_EXISTS(metadata, '$.{key}?(@ {oper} {value})')" + conditions.append(f"@ <= $val{bind_l}") + passings.append(f':value{bind_l} as "val{bind_l}"') + passing_bind = ",".join(passings) -def _generate_where_clause(db_filter: Union[FilterCondition, FilterGroup]) -> str: - if "key" in db_filter: # identify as FilterCondition - return _generate_condition(cast(FilterCondition, db_filter)) + return " && ".join(conditions), passing_bind - if "_and" in db_filter and db_filter["_and"] is not None: - and_conditions = [ - _generate_where_clause(cond) - for cond in db_filter["_and"] - if isinstance(cond, dict) - ] - return "(" + " AND ".join(and_conditions) + ")" + # in/nin/all needs N bindings + elif oper in ["$in", "$nin", "$all"]: + if not isinstance(value, List): + raise ValueError( + f"Invalid value for $in: {value}. It must be a non-empty list." + ) - if "_or" in db_filter and db_filter["_or"] is not None: - or_conditions = [ - _generate_where_clause(cond) - for cond in db_filter["_or"] - if isinstance(cond, dict) - ] - return "(" + " OR ".join(or_conditions) + ")" + value_binds = [] + passings = [] + for val in value: + bind_l = len(bind_variables) + bind_variables.append(val) + + value_binds.append(f"$val{bind_l}") + passings.append(f':value{bind_l} as "val{bind_l}"') + + passing_bind = ",".join(passings) + condition = "" + + if oper == "$all": + condition = "@ == " + " && @ == ".join(value_binds) + + else: + value_bind = ",".join(value_binds) + condition = f"@ in ({value_bind})" + + return condition, passing_bind + + else: + raise ValueError(f"Invalid operator: {oper}. ") + + +def _validate_metadata_key(metadata_key: str) -> None: + # Allow letters, digits, underscore, dot, brackets, comma, *, space (for 'to') + pattern = re.compile(r"[a-zA-Z0-9_\.\[\],\s\*]*") + + if not pattern.fullmatch(metadata_key): + raise ValueError( + f"Invalid metadata key '{metadata_key}'. " + "Only letters, numbers, underscores, nesting via '.', " + "and array wildcards '[*]' are allowed." + ) + + +def _generate_condition( + metadata_key: str, value: Any, bind_variables: List[str] +) -> str: + # single check inside a JSON_EXISTS + SINGLE_MASK = ( + "JSON_EXISTS(metadata, '$.{key}?(@ {oper} $val)' " + 'PASSING {value_bind} AS "val")' + ) + # combined checks with multiple operators and passing values + MULTIPLE_MASK = "JSON_EXISTS(metadata, '$.{key}?({filters})' PASSING {passes})" - raise ValueError(f"Invalid filter structure: {db_filter}") + _validate_metadata_key(metadata_key) + if not isinstance(value, (dict, list, tuple)): + # scalar-equality Clause + bind = f":value{len(bind_variables)}" + bind_variables.append(value) -def _get_connection(client: Any) -> Connection | None: + return SINGLE_MASK.format(key=metadata_key, oper="==", value_bind=bind) + + elif isinstance(value, dict): + # all values are filters + result: str + passings: str + + # comparison operator keys + if all(value_key.startswith("$") for value_key in value.keys()): + not_dict = {} + + passing_values = [] + comparison_values = [] + + for k, v in value.items(): + # if need to negate, cannot combine in single JSON_EXISTS + if ( + k in NOT_OPERS + or (k == "$eq" and isinstance(v, (list, dict))) + or (k == "$ne" and isinstance(v, (list, dict))) + ): + not_dict[k] = v + continue + + result, passings = _get_comparison_string(k, v, bind_variables) + + comparison_values.append(result) + passing_values.append(passings) + + # combine all operators in a single JSON_EXISTS + all_conditions = [] + if len(comparison_values) != 0: + all_conditions.append( + MULTIPLE_MASK.format( + key=metadata_key, + filters=" && ".join(comparison_values), + passes=" , ".join(passing_values), + ) + ) + + # handle negated filters one by one, one JSON_EXISTS for each + for k, v in not_dict.items(): + if k == "$not": + condition = _generate_condition(metadata_key, v, bind_variables) + all_conditions.append(f"NOT ({condition})") + + elif k == "$exists": + if not isinstance(v, bool): + raise ValueError( + f"Invalid value for $exists: {value}. " + "It must be a boolean (true or false)." + ) + + if v: + all_conditions.append( + f"JSON_EXISTS(metadata, '$.{metadata_key}')" + ) + else: + all_conditions.append( + f"NOT (JSON_EXISTS(metadata, '$.{metadata_key}'))" + ) + + elif k == "$nin": # for now only $nin + result, passings = _get_comparison_string(k, v, bind_variables) + + all_conditions.append( + " NOT " + + MULTIPLE_MASK.format( + key=metadata_key, filters=result, passes=passings + ) + ) + + elif k == "$eq": + bind_l = len(bind_variables) + bind_variables.append(json.dumps(v)) + + all_conditions.append( + "JSON_EQUAL(" + f" JSON_QUERY(metadata, '$.{metadata_key}' )," + f" JSON(:value{bind_l})" + ")" + ) + + elif k == "$ne": + bind_l = len(bind_variables) + bind_variables.append(json.dumps(v)) + + all_conditions.append( + "NOT (JSON_EQUAL(" + f" JSON_QUERY(metadata, '$.{metadata_key}' )," + f" JSON(:value{bind_l})" + "))" + ) + + res = " AND ".join(all_conditions) + + if len(all_conditions) > 1: + return "(" + res + ")" + + return res + + else: + raise ValueError("Nested filters are not supported.") + + else: + raise ValueError("Filter format is invalid.") + + +def _generate_where_clause(db_filter: dict, bind_variables: List[str]) -> str: + if not isinstance(db_filter, dict): + raise ValueError("Filter syntax is incorrect. Must be a dictionary.") + + all_conditions = [] + + for key, value in db_filter.items(): + # must be a logical if on a high level + if key.startswith("$"): + if key not in LOGICAL_MAP.keys(): + raise ValueError(f"'{key}' is not a recognized logical operator.") + + filter_format = LOGICAL_MAP[key] + + if not isinstance(value, list): + raise ValueError("Logical operators require an array of values.") + + combine_conditions = [ + _generate_where_clause(v, bind_variables) for v in value + ] + + res = filter_format[1].format(filter_format[0].join(combine_conditions)) + + all_conditions.append(res) + + else: + # this is a metadata key - not an operator + res = _generate_condition(key, value, bind_variables) + all_conditions.append(res) + + # combine everything with AND + res = " AND ".join(all_conditions) + + if len(all_conditions) > 1: + res = "(" + res + ")" + + return res + + +def _get_connection(client: Any) -> Optional[Connection]: # check if ConnectionPool exists connection_pool_class = getattr(oracledb, "ConnectionPool", None) @@ -127,7 +365,7 @@ def _get_connection(client: Any) -> Connection | None: ) -async def _aget_connection(client: Any) -> AsyncConnection | None: +async def _aget_connection(client: Any) -> Optional[AsyncConnection]: # check if ConnectionPool exists connection_pool_class = getattr(oracledb, "AsyncConnectionPool", None) @@ -852,12 +1090,13 @@ def _get_similarity_search_query( table_name: str, distance_strategy: DistanceStrategy, k: int, - db_filter: Optional[FilterGroup], + db_filter: Optional[dict] = None, return_embeddings: bool = False, -) -> str: +) -> Tuple[str, list[str]]: where_clause = "" + bind_variables: list[str] = [] if db_filter: - where_clause = _generate_where_clause(db_filter) + where_clause = _generate_where_clause(db_filter, bind_variables) query = f""" SELECT id, @@ -872,7 +1111,7 @@ def _get_similarity_search_query( FETCH APPROX FIRST {k} ROWS ONLY """ - return query + return query, bind_variables async def _handle_context( @@ -1116,10 +1355,11 @@ async def _aembed_query(self, text: str) -> List[float]: def add_texts( self, texts: Iterable[str], - metadatas: Optional[List[Dict[Any, Any]]] = None, - ids: Optional[List[str]] = None, + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Add more texts to the vectorstore index. Args: texts: Iterable of strings to add to the vectorstore. @@ -1132,33 +1372,61 @@ def add_texts( texts = list(texts) processed_ids = get_processed_ids(texts, metadatas, ids) - embeddings = self._embed_documents(texts) if not metadatas: metadatas = [{} for _ in texts] - docs: List[Tuple[Any, Any, Any, Any]] = [ - ( - id_, - array.array("f", embedding), - metadata, - text, - ) - for id_, embedding, metadata, text in zip( - processed_ids, embeddings, metadatas, texts - ) - ] + docs: Any + if not isinstance(self.embeddings, OracleEmbeddings): + embeddings = self._embed_documents(texts) + + docs = [ + ( + id_, + array.array("f", embedding), + metadata, + text, + ) + for id_, embedding, metadata, text in zip( + processed_ids, embeddings, metadatas, texts + ) + ] + else: + docs = list(zip(processed_ids, metadatas, texts)) connection = _get_connection(self.client) if connection is None: raise ValueError("Failed to acquire a connection.") with connection.cursor() as cursor: - cursor.setinputsizes(None, None, oracledb.DB_TYPE_JSON, None) - cursor.executemany( - f"INSERT INTO {self.table_name} (id, embedding, metadata, " - f"text) VALUES (:1, :2, :3, :4)", - docs, - ) - connection.commit() + if not isinstance(self.embeddings, OracleEmbeddings): + cursor.setinputsizes(None, None, oracledb.DB_TYPE_JSON, None) + cursor.executemany( + f"INSERT INTO {self.table_name} (id, embedding, metadata, " + f"text) VALUES (:1, :2, :3, :4)", + docs, + ) + connection.commit() + else: + if self.embeddings.proxy: + cursor.execute( + "begin utl_http.set_proxy(:proxy); end;", + proxy=self.embeddings.proxy, + ) + + cursor.setinputsizes(None, oracledb.DB_TYPE_JSON, None) + cursor.executemany( + f"INSERT INTO {self.table_name} (id, metadata, " + f"text) VALUES (:1, :2, :3)", + docs, + ) + + cursor.setinputsizes(oracledb.DB_TYPE_JSON) + update_sql = ( + f"UPDATE {self.table_name} " + "SET embedding = dbms_vector_chain.utl_to_embedding(text, json(:1))" + ) + cursor.execute(update_sql, [self.embeddings.params]) + connection.commit() + return processed_ids @_ahandle_exceptions @@ -1182,33 +1450,60 @@ async def aadd_texts( texts = list(texts) processed_ids = get_processed_ids(texts, metadatas, ids) - embeddings = await self._aembed_documents(texts) if not metadatas: metadatas = [{} for _ in texts] - docs: List[Tuple[Any, Any, Any, Any]] = [ - ( - id_, - array.array("f", embedding), - metadata, - text, - ) - for id_, embedding, metadata, text in zip( - processed_ids, embeddings, metadatas, texts - ) - ] + docs: Any + if not isinstance(self.embeddings, OracleEmbeddings): + embeddings = await self._aembed_documents(texts) + + docs = [ + ( + id_, + array.array("f", embedding), + metadata, + text, + ) + for id_, embedding, metadata, text in zip( + processed_ids, embeddings, metadatas, texts + ) + ] + else: + docs = list(zip(processed_ids, metadatas, texts)) async def context(connection: Any) -> None: if connection is None: raise ValueError("Failed to acquire a connection.") with connection.cursor() as cursor: - cursor.setinputsizes(None, None, oracledb.DB_TYPE_JSON, None) - await cursor.executemany( - f"INSERT INTO {self.table_name} (id, embedding, metadata, " - f"text) VALUES (:1, :2, :3, :4)", - docs, - ) - await connection.commit() + if not isinstance(self.embeddings, OracleEmbeddings): + cursor.setinputsizes(None, None, oracledb.DB_TYPE_JSON, None) + await cursor.executemany( + f"INSERT INTO {self.table_name} (id, embedding, metadata, " + f"text) VALUES (:1, :2, :3, :4)", + docs, + ) + await connection.commit() + else: + if self.embeddings.proxy: + await cursor.execute( + "begin utl_http.set_proxy(:proxy); end;", + proxy=self.embeddings.proxy, + ) + + cursor.setinputsizes(None, oracledb.DB_TYPE_JSON, None) + await cursor.executemany( + f"INSERT INTO {self.table_name} (id, metadata, " + f"text) VALUES (:1, :2, :3)", + docs, + ) + + cursor.setinputsizes(oracledb.DB_TYPE_JSON) + update_sql = ( + f"UPDATE {self.table_name} " + "SET embedding = dbms_vector_chain.utl_to_embedding(text, json(:1))" # noqa: E501 + ) + await cursor.execute(update_sql, [self.embeddings.params]) + await connection.commit() await _handle_context(self.client, context) @@ -1218,15 +1513,15 @@ def similarity_search( self, query: str, k: int = 4, - filter: Optional[Dict[str, Any]] = None, + *, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to query.""" - embedding: List[float] = [] - if isinstance(self.embedding_function, Embeddings): - embedding = self.embedding_function.embed_query(query) + embedding: List[float] = self._embed_query(query) + documents = self.similarity_search_by_vector( - embedding=embedding, k=k, filter=filter, **kwargs + embedding=embedding, k=k, db_filter=db_filter, **kwargs ) return documents @@ -1234,15 +1529,15 @@ async def asimilarity_search( self, query: str, k: int = 4, - filter: Optional[Dict[str, Any]] = None, + *, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to query.""" - embedding: List[float] = [] - if isinstance(self.embedding_function, Embeddings): - embedding = await self.embedding_function.aembed_query(query) + embedding: List[float] = await self._aembed_query(query) + documents = await self.asimilarity_search_by_vector( - embedding=embedding, k=k, filter=filter, **kwargs + embedding=embedding, k=k, db_filter=db_filter, **kwargs ) return documents @@ -1250,11 +1545,12 @@ def similarity_search_by_vector( self, embedding: List[float], k: int = 4, - filter: Optional[dict[str, Any]] = None, + *, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: docs_and_scores = self.similarity_search_by_vector_with_relevance_scores( - embedding=embedding, k=k, filter=filter, **kwargs + embedding=embedding, k=k, db_filter=db_filter, **kwargs ) return [doc for doc, _ in docs_and_scores] @@ -1262,11 +1558,12 @@ async def asimilarity_search_by_vector( self, embedding: List[float], k: int = 4, - filter: Optional[dict[str, Any]] = None, + *, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: docs_and_scores = await self.asimilarity_search_by_vector_with_relevance_scores( - embedding=embedding, k=k, filter=filter, **kwargs + embedding=embedding, k=k, db_filter=db_filter, **kwargs ) return [doc for doc, _ in docs_and_scores] @@ -1274,15 +1571,14 @@ def similarity_search_with_score( self, query: str, k: int = 4, - filter: Optional[dict[str, Any]] = None, + *, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to query.""" - embedding: List[float] = [] - if isinstance(self.embedding_function, Embeddings): - embedding = self.embedding_function.embed_query(query) + embedding: List[float] = self._embed_query(query) docs_and_scores = self.similarity_search_by_vector_with_relevance_scores( - embedding=embedding, k=k, filter=filter, **kwargs + embedding=embedding, k=k, db_filter=db_filter, **kwargs ) return docs_and_scores @@ -1290,15 +1586,14 @@ async def asimilarity_search_with_score( self, query: str, k: int = 4, - filter: Optional[dict[str, Any]] = None, + *, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to query.""" - embedding: List[float] = [] - if isinstance(self.embedding_function, Embeddings): - embedding = await self.embedding_function.aembed_query(query) + embedding: List[float] = await self._aembed_query(query) docs_and_scores = await self.asimilarity_search_by_vector_with_relevance_scores( - embedding=embedding, k=k, filter=filter, **kwargs + embedding=embedding, k=k, db_filter=db_filter, **kwargs ) return docs_and_scores @@ -1307,15 +1602,15 @@ def similarity_search_by_vector_with_relevance_scores( self, embedding: List[float], k: int = 4, - filter: Optional[dict[str, Any]] = None, + *, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: docs_and_scores = [] embedding_arr: Any = array.array("f", embedding) - db_filter: Optional[FilterGroup] = kwargs.get("db_filter", None) - query = _get_similarity_search_query( + query, bind_variables = _get_similarity_search_query( self.table_name, self.distance_strategy, k, @@ -1329,10 +1624,13 @@ def similarity_search_by_vector_with_relevance_scores( raise ValueError("Failed to acquire a connection.") with connection.cursor() as cursor: cursor.outputtypehandler = output_type_string_handler - cursor.execute(query, embedding=embedding_arr) + params = {"embedding": embedding_arr} + for i, value in enumerate(bind_variables): + params[f"value{i}"] = value + + cursor.execute(query, **params) results = cursor.fetchall() - # filter results if filter is provided for result in results: metadata = result[2] or {} page_content_str = result[1] if result[1] is not None else "" @@ -1346,11 +1644,7 @@ def similarity_search_by_vector_with_relevance_scores( ) distance = result[3] - # apply filtering based on the 'filter' dictionary - if not filter or all( - metadata.get(key) in value for key, value in filter.items() - ): - docs_and_scores.append((doc, distance)) + docs_and_scores.append((doc, distance)) return docs_and_scores @@ -1359,15 +1653,15 @@ async def asimilarity_search_by_vector_with_relevance_scores( self, embedding: List[float], k: int = 4, - filter: Optional[dict[str, Any]] = None, + *, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: docs_and_scores = [] embedding_arr: Any = array.array("f", embedding) - db_filter: Optional[FilterGroup] = kwargs.get("db_filter", None) - query = _get_similarity_search_query( + query, bind_variables = _get_similarity_search_query( self.table_name, self.distance_strategy, k, @@ -1379,10 +1673,13 @@ async def context(connection: Any) -> List: # execute the query with connection.cursor() as cursor: cursor.outputtypehandler = output_type_string_handler - await cursor.execute(query, embedding=embedding_arr) + params = {"embedding": embedding_arr} + for i, value in enumerate(bind_variables): + params[f"value{i}"] = value + + await cursor.execute(query, **params) results = await cursor.fetchall() - # filter results if filter is provided for result in results: metadata = result[2] or {} page_content_str = result[1] if result[1] is not None else "" @@ -1395,11 +1692,7 @@ async def context(connection: Any) -> List: ) distance = result[3] - # apply filtering based on the 'filter' dictionary - if not filter or all( - metadata.get(key) in value for key, value in filter.items() - ): - docs_and_scores.append((doc, distance)) + docs_and_scores.append((doc, distance)) return docs_and_scores @@ -1409,16 +1702,16 @@ async def context(connection: Any) -> List: def similarity_search_by_vector_returning_embeddings( self, embedding: List[float], - k: int, - filter: Optional[Dict[str, Any]] = None, + k: int = 4, + *, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Tuple[Document, float, NDArray[np.float32]]]: embedding_arr: Any = array.array("f", embedding) documents = [] - db_filter: Optional[FilterGroup] = kwargs.get("db_filter", None) - query = _get_similarity_search_query( + query, bind_variables = _get_similarity_search_query( self.table_name, self.distance_strategy, k, @@ -1432,7 +1725,11 @@ def similarity_search_by_vector_returning_embeddings( raise ValueError("Failed to acquire a connection.") with connection.cursor() as cursor: cursor.outputtypehandler = output_type_string_handler - cursor.execute(query, embedding=embedding_arr) + params = {"embedding": embedding_arr} + for i, value in enumerate(bind_variables): + params[f"value{i}"] = value + + cursor.execute(query, **params) results = cursor.fetchall() for result in results: @@ -1441,25 +1738,18 @@ def similarity_search_by_vector_returning_embeddings( raise Exception("Unexpected type:", type(page_content_str)) metadata = result[2] or {} - # apply filter if provided and matches; otherwise, add all - # documents - if not filter or all( - metadata.get(key) in value for key, value in filter.items() - ): - document = Document( - page_content=page_content_str, metadata=metadata - ) - distance = result[3] + document = Document(page_content=page_content_str, metadata=metadata) + distance = result[3] - # assuming result[4] is already in the correct format; - # adjust if necessary - current_embedding = ( - np.array(result[4], dtype=np.float32) - if result[4] - else np.empty(0, dtype=np.float32) - ) + # assuming result[4] is already in the correct format; + # adjust if necessary + current_embedding = ( + np.array(result[4], dtype=np.float32) + if result[4] + else np.empty(0, dtype=np.float32) + ) - documents.append((document, distance, current_embedding)) + documents.append((document, distance, current_embedding)) return documents @@ -1467,16 +1757,16 @@ def similarity_search_by_vector_returning_embeddings( async def asimilarity_search_by_vector_returning_embeddings( self, embedding: List[float], - k: int, - filter: Optional[Dict[str, Any]] = None, + k: int = 4, + *, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Tuple[Document, float, NDArray[np.float32]]]: embedding_arr: Any = array.array("f", embedding) documents = [] - db_filter: Optional[FilterGroup] = kwargs.get("db_filter", None) - query = _get_similarity_search_query( + query, bind_variables = _get_similarity_search_query( self.table_name, self.distance_strategy, k, @@ -1488,7 +1778,11 @@ async def context(connection: Any) -> List: # execute the query with connection.cursor() as cursor: cursor.outputtypehandler = output_type_string_handler - await cursor.execute(query, embedding=embedding_arr) + params = {"embedding": embedding_arr} + for i, value in enumerate(bind_variables): + params[f"value{i}"] = value + + await cursor.execute(query, **params) results = await cursor.fetchall() for result in results: @@ -1497,25 +1791,20 @@ async def context(connection: Any) -> List: raise Exception("Unexpected type:", type(page_content_str)) metadata = result[2] or {} - # apply filter if provided and matches; otherwise, add all - # documents - if not filter or all( - metadata.get(key) in value for key, value in filter.items() - ): - document = Document( - page_content=page_content_str, metadata=metadata - ) - distance = result[3] - - # assuming result[4] is already in the correct format; - # adjust if necessary - current_embedding = ( - np.array(result[4], dtype=np.float32) - if result[4] - else np.empty(0, dtype=np.float32) - ) + document = Document( + page_content=page_content_str, metadata=metadata + ) + distance = result[3] + + # assuming result[4] is already in the correct format; + # adjust if necessary + current_embedding = ( + np.array(result[4], dtype=np.float32) + if result[4] + else np.empty(0, dtype=np.float32) + ) - documents.append((document, distance, current_embedding)) + documents.append((document, distance, current_embedding)) return documents @@ -1528,7 +1817,7 @@ def max_marginal_relevance_search_with_score_by_vector( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + db_filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: """Return docs and their similarity scores selected using the maximal marginal @@ -1544,8 +1833,8 @@ def max_marginal_relevance_search_with_score_by_vector( k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch before filtering to pass to MMR algorithm. - filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults - to None. + db_filter: (Optional[dict]): Filter by metadata. + Defaults to None. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. @@ -1558,7 +1847,7 @@ def max_marginal_relevance_search_with_score_by_vector( # fetch documents and their scores docs_scores_embeddings = self.similarity_search_by_vector_returning_embeddings( - embedding, fetch_k, filter=filter + embedding, fetch_k, db_filter=db_filter ) # assuming documents_with_scores is a list of tuples (Document, score) mmr_selected_documents_with_scores = mmr_from_docs_embeddings( @@ -1574,7 +1863,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + db_filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: """Return docs and their similarity scores selected using the maximal marginal @@ -1590,8 +1879,8 @@ async def amax_marginal_relevance_search_with_score_by_vector( k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch before filtering to pass to MMR algorithm. - filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults - to None. + db_filter: (Optional[dict]): Filter by metadata. + Defaults to None. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. @@ -1605,7 +1894,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( # fetch documents and their scores docs_scores_embeddings = ( await self.asimilarity_search_by_vector_returning_embeddings( - embedding, fetch_k, filter=filter + embedding, fetch_k, db_filter=db_filter ) ) # assuming documents_with_scores is a list of tuples (Document, score) @@ -1621,7 +1910,7 @@ def max_marginal_relevance_search_by_vector( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1639,13 +1928,18 @@ def max_marginal_relevance_search_by_vector( of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - filter: Optional[Dict[str, Any]] + db_filter: (Optional[dict]): Filter by metadata. + Defaults to None. **kwargs: Any Returns: List of Documents selected by maximal marginal relevance. """ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( - embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + db_filter=db_filter, ) return [doc for doc, _ in docs_and_scores] @@ -1655,7 +1949,7 @@ async def amax_marginal_relevance_search_by_vector( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1673,14 +1967,19 @@ async def amax_marginal_relevance_search_by_vector( of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - filter: Optional[Dict[str, Any]] + db_filter: (Optional[dict]): Filter by metadata. + Defaults to None. **kwargs: Any Returns: List of Documents selected by maximal marginal relevance. """ docs_and_scores = ( await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + db_filter=db_filter, ) ) return [doc for doc, _ in docs_and_scores] @@ -1692,7 +1991,7 @@ def max_marginal_relevance_search( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1710,7 +2009,8 @@ def max_marginal_relevance_search( of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - filter: Optional[Dict[str, Any]] + db_filter: (Optional[dict]): Filter by metadata. + Defaults to None. **kwargs Returns: List of Documents selected by maximal marginal relevance. @@ -1724,7 +2024,7 @@ def max_marginal_relevance_search( k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, - filter=filter, + db_filter=db_filter, **kwargs, ) return documents @@ -1736,7 +2036,7 @@ async def amax_marginal_relevance_search( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, + db_filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -1754,7 +2054,8 @@ async def amax_marginal_relevance_search( of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - filter: Optional[Dict[str, Any]] + db_filter: (Optional[dict]): Filter by metadata. + Defaults to None. **kwargs Returns: List of Documents selected by maximal marginal relevance. @@ -1768,7 +2069,7 @@ async def amax_marginal_relevance_search( k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, - filter=filter, + db_filter=db_filter, **kwargs, ) return documents @@ -1812,9 +2113,6 @@ async def context(connection: Any) -> None: @classmethod def _from_texts_helper( cls: Type[OracleVS], - texts: Iterable[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> Tuple[Any, str, DistanceStrategy, str, Dict]: client: Any = kwargs.get("client", None) @@ -1852,7 +2150,7 @@ def from_texts( distance_strategy, query, params, - ) = OracleVS._from_texts_helper(texts, embedding, metadatas, **kwargs) + ) = OracleVS._from_texts_helper(**kwargs) vss = cls( client=client, @@ -1881,7 +2179,7 @@ async def afrom_texts( distance_strategy, query, params, - ) = OracleVS._from_texts_helper(texts, embedding, metadatas, **kwargs) + ) = OracleVS._from_texts_helper(**kwargs) vss = await OracleVS.acreate( client=client, diff --git a/libs/oracledb/tests/integration_tests/vectorstores/test_oraclevs.py b/libs/oracledb/tests/integration_tests/vectorstores/test_oraclevs.py index e194f5f..6f46583 100644 --- a/libs/oracledb/tests/integration_tests/vectorstores/test_oraclevs.py +++ b/libs/oracledb/tests/integration_tests/vectorstores/test_oraclevs.py @@ -18,6 +18,7 @@ from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores.utils import DistanceStrategy +from langchain_oracledb.embeddings import OracleEmbeddings from langchain_oracledb.vectorstores.oraclevs import ( OracleVS, _acreate_table, @@ -51,6 +52,8 @@ ############################ ####### table_exists ####### ############################ + + def test_table_exists_test() -> None: try: connection = oracledb.connect(user=username, password=password, dsn=dsn) @@ -1676,26 +1679,32 @@ def test_perform_search_test() -> None: # perform search query = "YashB" - filter = {"id": ["106", "108", "yash"]} + db_filter: dict = { + "$or": [ # dict + {"id": "106"}, + {"id": "108"}, + {"id": "yash"}, + ] + } # similarity_searh without filter vs.similarity_search(query, 2) # similarity_searh with filter - vs.similarity_search(query, 2, filter=filter) + vs.similarity_search(query, 2, db_filter=db_filter) # Similarity search with relevance score vs.similarity_search_with_score(query, 2) # Similarity search with relevance score with filter - vs.similarity_search_with_score(query, 2, filter=filter) + vs.similarity_search_with_score(query, 2, db_filter=db_filter) # Max marginal relevance search vs.max_marginal_relevance_search(query, 2, fetch_k=20, lambda_mult=0.5) # Max marginal relevance search with filter vs.max_marginal_relevance_search( - query, 2, fetch_k=20, lambda_mult=0.5, filter=filter + query, 2, fetch_k=20, lambda_mult=0.5, db_filter=db_filter ) drop_table_purge(connection, "TB10") @@ -1762,26 +1771,32 @@ async def test_perform_search_test_async() -> None: # perform search query = "YashB" - filter = {"id": ["106", "108", "yash"]} + db_filter: dict = { + "$or": [ # dict + {"id": "106"}, + {"id": "108"}, + {"id": "yash"}, + ] + } # similarity_searh without filter await vs.asimilarity_search(query, 2) # similarity_searh with filter - await vs.asimilarity_search(query, 2, filter=filter) + await vs.asimilarity_search(query, 2, db_filter=db_filter) # Similarity search with relevance score await vs.asimilarity_search_with_score(query, 2) # Similarity search with relevance score with filter - await vs.asimilarity_search_with_score(query, 2, filter=filter) + await vs.asimilarity_search_with_score(query, 2, db_filter=db_filter) # Max marginal relevance search await vs.amax_marginal_relevance_search(query, 2, fetch_k=20, lambda_mult=0.5) # Max marginal relevance search with filter await vs.amax_marginal_relevance_search( - query, 2, fetch_k=20, lambda_mult=0.5, filter=filter + query, 2, fetch_k=20, lambda_mult=0.5, db_filter=db_filter ) await adrop_table_purge(connection, "TB10") @@ -1795,6 +1810,19 @@ async def test_perform_search_test_async() -> None: ################################## ##### perform_filter_search ###### ################################## + +FILTERED_FUNCTIONS = [ + "similarity_search", + "similarity_search_by_vector", + "similarity_search_with_score", + "similarity_search_by_vector_with_relevance_scores", + "similarity_search_by_vector_returning_embeddings", + "max_marginal_relevance_search_with_score_by_vector", + "max_marginal_relevance_search_by_vector", + "max_marginal_relevance_search", +] + + def test_db_filter_test() -> None: try: connection = oracledb.connect(user=username, password=password, dsn=dsn) @@ -1803,6 +1831,13 @@ def test_db_filter_test() -> None: model1 = HuggingFaceEmbeddings( model_name="sentence-transformers/paraphrase-mpnet-base-v2" ) + drop_table_purge(connection, "TB10") + drop_table_purge(connection, "TB11") + drop_table_purge(connection, "TB12") + drop_table_purge(connection, "TB13") + drop_table_purge(connection, "TB14") + drop_table_purge(connection, "TB15") + vs_1 = OracleVS(connection, model1, "TB10", DistanceStrategy.EUCLIDEAN_DISTANCE) vs_2 = OracleVS(connection, model1, "TB11", DistanceStrategy.DOT_PRODUCT) vs_3 = OracleVS(connection, model1, "TB12", DistanceStrategy.COSINE) @@ -1835,48 +1870,51 @@ def test_db_filter_test() -> None: # perform search query = "Strawberry" - filter = {"id": ["bl"]} - db_filter = {"key": "id", "oper": "EQ", "value": "bl"} # FilterCondition - - # similarity_search without filter - result = vs.similarity_search(query, 1) - assert result[0].metadata["id"] == "st" - - # similarity_search with filter - result = vs.similarity_search(query, 1, filter=filter) - assert len(result) == 0 - - # similarity_search with db_filter - result = vs.similarity_search(query, 1, db_filter=db_filter) - assert result[0].metadata["id"] == "bl" - - # similarity_search with filter and db_filter - result = vs.similarity_search(query, 1, filter=filter, db_filter=db_filter) - assert result[0].metadata["id"] == "bl" + db_filter: dict = {"id": {"$eq": "bl"}} # dict # nested db filter - db_filter_nested = { - "_or": [ - {"key": "id", "oper": "EQ", "value": "ba"}, # FilterCondition + db_filter_nested: dict = { + "$or": [ + {"id": "ba"}, # dict { - "_and": [ # FilterGroup - {"key": "order", "oper": "LTE", "value": 4}, - {"key": "id", "oper": "EQ", "value": "st"}, + "$and": [ # dict + {"order": {"$lte": 4}}, + {"id": "st"}, ] }, ] } - # similarity_search with db_filter - result = vs.similarity_search(query, 1, db_filter=db_filter_nested) - assert result[0].metadata["id"] == "st" + for filtered_function in FILTERED_FUNCTIONS: + method = getattr(vs, filtered_function) + + query_emb: list[float] | str = query + if "_by_vector" in filtered_function: + query_emb = vs.embedding_function.embed_query(query) # type: ignore[union-attr] + + # search without filter + result = method(query_emb, k=1) + result = result[0] if not isinstance(result[0], tuple) else result[0][0] + assert result.metadata["id"] == "st" + + # search with filter + result = method(query_emb, k=5, db_filter=db_filter) + assert len(result) == 1 + result = result[0] if not isinstance(result[0], tuple) else result[0][0] + assert result.metadata["id"] == "bl" + + # search with nested filter + result = method(query_emb, k=5, db_filter=db_filter_nested) + assert len(result) == 2 + result = result[0] if not isinstance(result[0], tuple) else result[0][0] + assert result.metadata["id"] == "st" exception_occurred = False try: - db_filter_exc = { + db_filter_exc: dict = { # type: ignore[typeddict-unknown-key] "_xor": [ # Incorrect operation _xor - {"key": "id", "oper": "EQ", "value": "ba"}, - {"key": "order", "oper": "LTE", "value": 4}, + {"order": {"$lte": 4}}, + {"id": "st"}, ] } result = vs.similarity_search(query, 1, db_filter=db_filter_exc) @@ -1888,13 +1926,9 @@ def test_db_filter_test() -> None: exception_occurred = False try: db_filter_exc = { - "_or": [ - { - "key": "id", - "oper": "XEQ", - "value": "ba", - }, # Incorrect operation XEQ - {"key": "order", "oper": "LTE", "value": 4}, + "$or": [ + {"order": {"$xeq": 4}}, # Incorrect operation XEQ + {"id": "st"}, ] } result = vs.similarity_search(query, 1, db_filter=db_filter_exc) @@ -1967,50 +2001,51 @@ async def test_db_filter_test_async() -> None: # perform search query = "Strawberry" - filter = {"id": ["bl"]} - db_filter = {"key": "id", "oper": "EQ", "value": "bl"} # FilterCondition - - # similarity_search without filter - result = await vs.asimilarity_search(query, 1) - assert result[0].metadata["id"] == "st" - - # similarity_search with filter - result = await vs.asimilarity_search(query, 1, filter=filter) - assert len(result) == 0 - - # similarity_search with db_filter - result = await vs.asimilarity_search(query, 1, db_filter=db_filter) - assert result[0].metadata["id"] == "bl" - - # similarity_search with filter and db_filter - result = await vs.asimilarity_search( - query, 1, filter=filter, db_filter=db_filter - ) - assert result[0].metadata["id"] == "bl" + db_filter: dict = {"id": {"$eq": "bl"}} # dict # nested db filter - db_filter_nested = { - "_or": [ - {"key": "id", "oper": "EQ", "value": "ba"}, # FilterCondition + db_filter_nested: dict = { + "$or": [ + {"id": "ba"}, # dict { - "_and": [ # FilterGroup - {"key": "order", "oper": "LTE", "value": 4}, - {"key": "id", "oper": "EQ", "value": "st"}, + "$and": [ # dict + {"order": {"$lte": 4}}, + {"id": "st"}, ] }, ] } - # similarity_search with db_filter - result = await vs.asimilarity_search(query, 1, db_filter=db_filter_nested) - assert result[0].metadata["id"] == "st" + for filtered_function in FILTERED_FUNCTIONS: + method = getattr(vs, "a" + filtered_function) + + query_emb: list[float] | str = query + if "_by_vector" in filtered_function: + query_emb = vs.embedding_function.embed_query(query) # type: ignore[union-attr] + + # search without filter + result = await method(query_emb, k=1) + result = result[0] if not isinstance(result[0], tuple) else result[0][0] + assert result.metadata["id"] == "st" + + # search with filter + result = await method(query_emb, k=5, db_filter=db_filter) + assert len(result) == 1 + result = result[0] if not isinstance(result[0], tuple) else result[0][0] + assert result.metadata["id"] == "bl" + + # search with nested filter + result = await method(query_emb, k=5, db_filter=db_filter_nested) + assert len(result) == 2 + result = result[0] if not isinstance(result[0], tuple) else result[0][0] + assert result.metadata["id"] == "st" exception_occurred = False try: - db_filter_exc = { + db_filter_exc: dict = { # type: ignore[typeddict-unknown-key] "_xor": [ # Incorrect operation _xor - {"key": "id", "oper": "EQ", "value": "ba"}, - {"key": "order", "oper": "LTE", "value": 4}, + {"order": {"$lte": 4}}, + {"id": "st"}, ] } result = await vs.asimilarity_search(query, 1, db_filter=db_filter_exc) @@ -2022,15 +2057,12 @@ async def test_db_filter_test_async() -> None: exception_occurred = False try: db_filter_exc = { - "_or": [ - { - "key": "id", - "oper": "XEQ", - "value": "ba", - }, # Incorrect operation XEQ - {"key": "order", "oper": "LTE", "value": 4}, + "$or": [ + {"order": {"$xeq": 4}}, # Incorrect operation XEQ + {"id": "st"}, ] } + result = await vs.asimilarity_search(query, 1, db_filter=db_filter_exc) except ValueError: exception_occurred = True @@ -2443,12 +2475,92 @@ async def test_index_table_case_async(caplog: pytest.LogCaptureFixture) -> None: await adrop_table_purge(connection, "TB2") +################################## +##### test_oracle_embeddings #### +################################## + + +def test_oracle_embeddings() -> None: + try: + connection = oracledb.connect(user=username, password=password, dsn=dsn) + except Exception: + sys.exit(1) + + drop_table_purge(connection, "TB1") + + texts = ["Database Document", "Code Document"] + metadata = [ + {"id": "100", "link": "Document Example Test 1"}, + {"id": "101", "link": "Document Example Test 2"}, + ] + embedder_params = {"provider": "database", "model": "allminilm"} + proxy = "" + + # instance + model = OracleEmbeddings(conn=connection, params=embedder_params, proxy=proxy) + + vs_obj = OracleVS(connection, model, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE) + + vs_obj.add_texts(texts, metadata) + res = vs_obj.similarity_search("database", 1) + + assert "Database" in res[0].page_content + + drop_table_purge(connection, "TB1") + + connection.close() + + +@pytest.mark.asyncio +async def test_oracle_embeddings_async(caplog: pytest.LogCaptureFixture) -> None: + try: + connection = await oracledb.connect_async( + user=username, password=password, dsn=dsn + ) + + connection_sync = oracledb.connect(user=username, password=password, dsn=dsn) + except Exception: + sys.exit(1) + + await adrop_table_purge(connection, "TB1") + + texts = ["Database Document", "Code Document"] + metadata = [ + {"id": "100", "link": "Document Example Test 1"}, + {"id": "101", "link": "Document Example Test 2"}, + ] + embedder_params = {"provider": "database", "model": "allminilm"} + proxy = "" + + # instance + model = OracleEmbeddings(conn=connection_sync, params=embedder_params, proxy=proxy) + + vs_obj = await OracleVS.acreate( + connection, model, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE + ) + + await vs_obj.aadd_texts(texts, metadata) + res = await vs_obj.asimilarity_search("database", 1) + + assert "Database" in res[0].page_content + + await adrop_table_purge(connection, "TB1") + + await connection.close() + + +################################## +##### test_quote_identifier ##### +################################## + + def test_quote_identifier() -> None: # unquoted assert _quote_indentifier("hello") == '"hello"' assert _quote_indentifier("--") == '"--"' assert _quote_indentifier("U1.table") == '"U1"."table"' assert _quote_indentifier("hnsw_idx2") == '"hnsw_idx2"' + assert _quote_indentifier("'") == '"\'"' with pytest.raises(ValueError): _quote_indentifier('hnsw_"idx2') @@ -2475,3 +2587,390 @@ def test_quote_identifier() -> None: # mixed assert _quote_indentifier('"U1".table') == '"U1"."table"' + + +################################## +########## test_filters ######### +################################## + + +def test_filters() -> None: + try: + connection = oracledb.connect(user=username, password=password, dsn=dsn) + except Exception: + sys.exit(1) + + def model1(_) -> list[float]: # type: ignore[no-untyped-def] + return [0.1, 0.2, 0.3] + + # model1 = lambda x: [0.1, 0.2, 0.3] + + drop_table_purge(connection, "TB10") + + vs = OracleVS(connection, model1, "TB10", DistanceStrategy.EUCLIDEAN_DISTANCE) + + texts = ["Strawberry", "Banana", "Blueberry"] + metadatas = [ + { + "id": "1", + "name": "Jason", + "age": 45, + "address": [ + { + "street": "25 A street", + "city": "Mono Vista", + "zip": 94088, + "state": "CA", + } + ], + "drinks": "tea", + }, + { + "id": "2", + "name": "Mary", + "age": 50, + "address": [ + { + "street": "15 C street", + "city": "Mono Vista", + "zip": 97090, + "state": "OR", + }, + { + "street": "30 ABC avenue", + "city": "Markstown", + "zip": 90001, + "state": "CA", + }, + ], + }, + {"id": "3", "name": "Mark", "age": 65, "drinks": ["soda", "tea"]}, + ] + + vs.add_texts(texts, metadatas) + + filter_res: list[tuple[dict, list[str]]] = [ + ({"drinks": {"$exists": True}}, ["1", "3"]), + ({"address.zip": 94088}, ["1"]), + ({"name": {"$eq": "Jason"}}, ["1"]), + ({"drinks": {"$ne": "tea"}}, ["3"]), # exits and not equal + ({"drinks": {"$eq": ["soda", "tea"]}}, ["3"]), + ({"drinks": {"$ne": ["soda", "tea"]}}, ["1"]), + ( + { + "address[0]": { + "$eq": { + "street": "25 A street", + "city": "Mono Vista", + "zip": 94088, + "state": "CA", + } + } + }, + ["1"], + ), + ( + { + "address[0]": { + "$ne": { + "street": "25 A street", + "city": "Mono Vista", + "zip": 94088, + "state": "CA", + } + } + }, + ["2"], + ), + ( + {"$or": [{"drinks": {"$exists": False}}, {"drinks": {"$ne": "tea"}}]}, + ["2", "3"], + ), + ( + { + "$or": [ + {"drinks": {"$exists": False}}, + {"drinks": {"$ne": ["soda", "tea"]}}, + ] + }, + ["1", "2"], + ), + ({"age": {"$gt": 45, "$lt": 55}}, ["2"]), + ({"age": {"$gt": 45}}, ["2", "3"]), + ({"age": {"$lt": 55}}, ["1", "2"]), + ({"age": {"$gte": 65}}, ["3"]), + ({"age": {"$lte": 50}}, ["1", "2"]), + ({"age": {"$between": [49, 51]}}, ["2"]), + ({"name": {"$startsWith": "Mar"}}, ["2", "3"]), + ({"name": {"$hasSubstring": "ar"}}, ["2", "3"]), + ({"name": {"$instr": "ar"}}, ["2", "3"]), + ({"name": {"$regex": ".*ar.*"}}, ["2", "3"]), + ({"name": {"$like": "%ar%"}}, ["2", "3"]), + ({"name": {"$in": ["Mark", "Mary"]}}, ["2", "3"]), + ({"name": {"$nin": ["Mark", "Mary"]}}, ["1"]), + ({"drinks": {"$all": ["tea", "soda"]}}, ["3"]), + ({"drinks": {"$all": ["tea"]}}, ["1", "3"]), + ({"drinks": {"$not": {"$all": ["tea", "soda"]}}}, ["1", "2"]), + ({"address[*].zip": {"$in": [94088, 1]}}, ["1"]), + ({"address[1].zip": 90001}, ["2"]), + ({"drinks[0,1]": "soda"}, ["3"]), + ({"drinks[1 to 2]": "soda"}, []), + ({"drinks": "tea"}, ["1", "3"]), + ({"drinks[*]": "tea"}, ["1", "3"]), + ({"name": "Jason"}, ["1"]), + ({"address.zip": {"$not": {"$eq": "90001"}}}, ["1", "3"]), + ({"age": {"$not": {"$gt": 46, "$lt": 65}}}, ["1", "3"]), + ({"$and": [{"name": {"$startsWith": "Ja"}}, {"drinks": "tea"}]}, ["1"]), + ({"name": {"$startsWith": "Ja"}, "drinks": "tea"}, ["1"]), + ({"$or": [{"drinks": "soda"}, {"address.zip": {"$lt": 94000}}]}, ["2", "3"]), + ({"$nor": [{"drinks": "soda"}, {"address.zip": {"$lt": 94000}}]}, ["1"]), + ( + { + "$and": [ + {"age": {"$gte": 60}}, + {"$or": [{"name": "Jason"}, {"drinks": {"$in": ["tea", "soda"]}}]}, + ] + }, + ["3"], + ), + ( + { + "$or": [ + {"$and": [{"name": "Jason"}, {"drinks": {"$in": ["tea", "soda"]}}]}, + {"$nor": [{"age": {"$lt": 65}}, {"name": "Jason"}]}, + ] + }, + ["1", "3"], + ), + ( + { + "age": 65, + "$or": [ + {"$and": [{"name": "Jason"}, {"drinks": {"$in": ["tea", "soda"]}}]}, + {"$nor": [{"age": {"$lt": 65}}, {"name": "Jason"}]}, + ], + }, + ["3"], + ), + ( + { + "age": 65, + "name": {"$regex": "*rk"}, + "$or": [ + {"$and": [{"name": "Jason"}, {"drinks": {"$in": ["tea", "soda"]}}]}, + {"$nor": [{"age": {"$lt": 65}}, {"name": "Jason"}]}, + ], + }, + ["3"], + ), + ] + + for _f, _r in filter_res: + # search with filter + result = vs.similarity_search("Hello", k=3, db_filter=_f) + ids = [res.metadata["id"] for res in result] + + assert set(ids) == set(_r) + + with pytest.raises(ValueError, match="Invalid metadata key"): + _f = {"ss')--": "HELLOE"} + result = vs.similarity_search("Hello", k=3, db_filter=_f) + + with pytest.raises(ValueError, match="Invalid operator"): + _f = {"drinks": {"$neq": ["soda", "tea"]}} + result = vs.similarity_search("Hello", k=3, db_filter=_f) + + drop_table_purge(connection, "TB10") + + +async def test_filters_async() -> None: + try: + connection = await oracledb.connect_async( + user=username, password=password, dsn=dsn + ) + except Exception: + sys.exit(1) + + def model1(_) -> list[float]: # type: ignore[no-untyped-def] + return [0.1, 0.2, 0.3] + + # model1 = lambda x: [0.1, 0.2, 0.3] + + await adrop_table_purge(connection, "TB10") + + vs = await OracleVS.acreate( + connection, model1, "TB10", DistanceStrategy.EUCLIDEAN_DISTANCE + ) + + texts = ["Strawberry", "Banana", "Blueberry"] + metadatas = [ + { + "id": "1", + "name": "Jason", + "age": 45, + "address": [ + { + "street": "25 A street", + "city": "Mono Vista", + "zip": 94088, + "state": "CA", + } + ], + "drinks": "tea", + }, + { + "id": "2", + "name": "Mary", + "age": 50, + "address": [ + { + "street": "15 C street", + "city": "Mono Vista", + "zip": 97090, + "state": "OR", + }, + { + "street": "30 ABC avenue", + "city": "Markstown", + "zip": 90001, + "state": "CA", + }, + ], + }, + {"id": "3", "name": "Mark", "age": 65, "drinks": ["soda", "tea"]}, + ] + + await vs.aadd_texts(texts, metadatas) + + filter_res: list[tuple[dict, list[str]]] = [ + ({"drinks": {"$exists": True}}, ["1", "3"]), + ({"address.zip": 94088}, ["1"]), + ({"name": {"$eq": "Jason"}}, ["1"]), + ({"drinks": {"$ne": "tea"}}, ["3"]), # exits and not equal + ({"drinks": {"$eq": ["soda", "tea"]}}, ["3"]), + ({"drinks": {"$ne": ["soda", "tea"]}}, ["1"]), + ( + { + "address[0]": { + "$eq": { + "street": "25 A street", + "city": "Mono Vista", + "zip": 94088, + "state": "CA", + } + } + }, + ["1"], + ), + ( + { + "address[0]": { + "$ne": { + "street": "25 A street", + "city": "Mono Vista", + "zip": 94088, + "state": "CA", + } + } + }, + ["2"], + ), + ( + {"$or": [{"drinks": {"$exists": False}}, {"drinks": {"$ne": "tea"}}]}, + ["2", "3"], + ), + ( + { + "$or": [ + {"drinks": {"$exists": False}}, + {"drinks": {"$ne": ["soda", "tea"]}}, + ] + }, + ["1", "2"], + ), + ({"age": {"$gt": 45, "$lt": 55}}, ["2"]), + ({"age": {"$gt": 45}}, ["2", "3"]), + ({"age": {"$lt": 55}}, ["1", "2"]), + ({"age": {"$gte": 65}}, ["3"]), + ({"age": {"$lte": 50}}, ["1", "2"]), + ({"age": {"$between": [49, 51]}}, ["2"]), + ({"name": {"$startsWith": "Mar"}}, ["2", "3"]), + ({"name": {"$hasSubstring": "ar"}}, ["2", "3"]), + ({"name": {"$instr": "ar"}}, ["2", "3"]), + ({"name": {"$regex": ".*ar.*"}}, ["2", "3"]), + ({"name": {"$like": "%ar%"}}, ["2", "3"]), + ({"name": {"$in": ["Mark", "Mary"]}}, ["2", "3"]), + ({"name": {"$nin": ["Mark", "Mary"]}}, ["1"]), + ({"drinks": {"$all": ["tea", "soda"]}}, ["3"]), + ({"drinks": {"$all": ["tea"]}}, ["1", "3"]), + ({"drinks": {"$not": {"$all": ["tea", "soda"]}}}, ["1", "2"]), + ({"address[*].zip": {"$in": [94088, 1]}}, ["1"]), + ({"address[1].zip": 90001}, ["2"]), + ({"drinks[0,1]": "soda"}, ["3"]), + ({"drinks[1 to 2]": "soda"}, []), + ({"drinks": "tea"}, ["1", "3"]), + ({"drinks[*]": "tea"}, ["1", "3"]), + ({"name": "Jason"}, ["1"]), + ({"address.zip": {"$not": {"$eq": "90001"}}}, ["1", "3"]), + ({"age": {"$not": {"$gt": 46, "$lt": 65}}}, ["1", "3"]), + ({"$and": [{"name": {"$startsWith": "Ja"}}, {"drinks": "tea"}]}, ["1"]), + ({"name": {"$startsWith": "Ja"}, "drinks": "tea"}, ["1"]), + ({"$or": [{"drinks": "soda"}, {"address.zip": {"$lt": 94000}}]}, ["2", "3"]), + ({"$nor": [{"drinks": "soda"}, {"address.zip": {"$lt": 94000}}]}, ["1"]), + ( + { + "$and": [ + {"age": {"$gte": 60}}, + {"$or": [{"name": "Jason"}, {"drinks": {"$in": ["tea", "soda"]}}]}, + ] + }, + ["3"], + ), + ( + { + "$or": [ + {"$and": [{"name": "Jason"}, {"drinks": {"$in": ["tea", "soda"]}}]}, + {"$nor": [{"age": {"$lt": 65}}, {"name": "Jason"}]}, + ] + }, + ["1", "3"], + ), + ( + { + "age": 65, + "$or": [ + {"$and": [{"name": "Jason"}, {"drinks": {"$in": ["tea", "soda"]}}]}, + {"$nor": [{"age": {"$lt": 65}}, {"name": "Jason"}]}, + ], + }, + ["3"], + ), + ( + { + "age": 65, + "name": {"$regex": "*rk"}, + "$or": [ + {"$and": [{"name": "Jason"}, {"drinks": {"$in": ["tea", "soda"]}}]}, + {"$nor": [{"age": {"$lt": 65}}, {"name": "Jason"}]}, + ], + }, + ["3"], + ), + ] + + for _f, _r in filter_res: + # search with filter + result = await vs.asimilarity_search("Hello", k=3, db_filter=_f) + ids = [res.metadata["id"] for res in result] + + assert set(ids) == set(_r) + + with pytest.raises(ValueError, match="Invalid metadata key"): + _f = {"ss')--": "HELLOE"} + result = await vs.asimilarity_search("Hello", k=3, db_filter=_f) + + with pytest.raises(ValueError, match="Invalid operator"): + _f = {"drinks": {"$neq": ["soda", "tea"]}} + result = await vs.asimilarity_search("Hello", k=3, db_filter=_f) + + await adrop_table_purge(connection, "TB10")