Skip to content

Commit

Permalink
Add support for the Negation('$not') operator (#62)
Browse files Browse the repository at this point in the history
Add support for the Negation(`$not`) operator in PGVector

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
  • Loading branch information
Raj725 and eyurtsev committed Jun 10, 2024
1 parent f724ab3 commit fb82e21
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 12 deletions.
43 changes: 33 additions & 10 deletions langchain_postgres/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class DistanceStrategy(str, enum.Enum):
"$ilike",
}

LOGICAL_OPERATORS = {"$and", "$or"}
LOGICAL_OPERATORS = {"$and", "$or", "$not"}

SUPPORTED_OPERATORS = (
set(COMPARISONS_TO_NATIVE)
Expand Down Expand Up @@ -1248,26 +1248,25 @@ def _create_filter_clause(self, filters: Any) -> Any:
"""
if isinstance(filters, dict):
if len(filters) == 1:
# The only operators allowed at the top level are $AND and $OR
# The only operators allowed at the top level are $AND, $OR, and $NOT
# First check if an operator or a field
key, value = list(filters.items())[0]
if key.startswith("$"):
# Then it's an operator
if key.lower() not in ["$and", "$or"]:
if key.lower() not in ["$and", "$or", "$not"]:
raise ValueError(
f"Invalid filter condition. Expected $and or $or "
f"Invalid filter condition. Expected $and, $or or $not "
f"but got: {key}"
)
else:
# Then it's a field
return self._handle_field_filter(key, filters[key])

# Here we handle the $and and $or operators
if not isinstance(value, list):
raise ValueError(
f"Expected a list, but got {type(value)} for value: {value}"
)
if key.lower() == "$and":
if not isinstance(value, list):
raise ValueError(
f"Expected a list, but got {type(value)} for value: {value}"
)
and_ = [self._create_filter_clause(el) for el in value]
if len(and_) > 1:
return sqlalchemy.and_(*and_)
Expand All @@ -1279,6 +1278,10 @@ def _create_filter_clause(self, filters: Any) -> Any:
"but got an empty dictionary"
)
elif key.lower() == "$or":
if not isinstance(value, list):
raise ValueError(
f"Expected a list, but got {type(value)} for value: {value}"
)
or_ = [self._create_filter_clause(el) for el in value]
if len(or_) > 1:
return sqlalchemy.or_(*or_)
Expand All @@ -1289,9 +1292,29 @@ def _create_filter_clause(self, filters: Any) -> Any:
"Invalid filter condition. Expected a dictionary "
"but got an empty dictionary"
)
elif key.lower() == "$not":
if isinstance(value, list):
not_conditions = [
self._create_filter_clause(item) for item in value
]
not_ = sqlalchemy.and_(
*[
sqlalchemy.not_(condition)
for condition in not_conditions
]
)
return not_
elif isinstance(value, dict):
not_ = self._create_filter_clause(value)
return sqlalchemy.not_(not_)
else:
raise ValueError(
f"Invalid filter condition. Expected a dictionary "
f"or a list but got: {type(value)}"
)
else:
raise ValueError(
f"Invalid filter condition. Expected $and or $or "
f"Invalid filter condition. Expected $and, $or or $not "
f"but got: {key}"
)
elif len(filters) > 1:
Expand Down
37 changes: 35 additions & 2 deletions tests/unit_tests/fixtures/filtering_test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@

TYPE_2_FILTERING_TEST_CASES = [
# These involve equality checks and other operators
# like $ne, $gt, $gte, $lt, $lte, $not
# like $ne, $gt, $gte, $lt, $lte
(
{"id": 1},
[1],
Expand Down Expand Up @@ -168,7 +168,7 @@
]

TYPE_3_FILTERING_TEST_CASES = [
# These involve usage of AND and OR operators
# These involve usage of AND, OR and NOT operators
(
{"$or": [{"id": 1}, {"id": 2}]},
[1, 2],
Expand All @@ -185,6 +185,39 @@
{"$or": [{"id": 1}, {"id": 2}, {"id": 3}]},
[1, 2, 3],
),
# Test for $not operator
(
{"$not": {"id": 1}},
[2, 3],
),
(
{"$not": [{"id": 1}]},
[2, 3],
),
(
{"$not": {"name": "adam"}},
[2, 3],
),
(
{"$not": [{"name": "adam"}]},
[2, 3],
),
(
{"$not": {"is_active": True}},
[2],
),
(
{"$not": [{"is_active": True}]},
[2],
),
(
{"$not": {"height": {"$gt": 5.0}}},
[3],
),
(
{"$not": [{"height": {"$gt": 5.0}}]},
[3],
),
]

TYPE_4_FILTERING_TEST_CASES = [
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,7 @@ async def test_async_pgvector_with_with_metadata_filters_5(
{"$eq": {}},
{"$exists": {}},
{"$exists": 1},
{"$not": 2},
],
)
def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None:
Expand All @@ -1016,5 +1017,6 @@ def test_validate_operators() -> None:
"$lte",
"$ne",
"$nin",
"$not",
"$or",
]

0 comments on commit fb82e21

Please sign in to comment.