Skip to content

Commit

Permalink
Incorporate metadata filters into PGVector store querying (#6968)
Browse files Browse the repository at this point in the history
  • Loading branch information
sourabhdesai committed Jul 19, 2023
1 parent dd0bc90 commit d19b82b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New Features
- Added a `SentenceTransformerRerank` node post-processor for fast local re-ranking (#6934)
- Add numpy support for evaluating queries in pandas query engine (#6935)
- Add metadata filtering support for Postgres Vector Storage integration (#6968)

### Bug Fixes / Nits
- Added `model_name` to LLMMetadata (#6911)
Expand Down
39 changes: 28 additions & 11 deletions llama_index/vector_stores/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
NodeWithEmbedding,
VectorStoreQuery,
VectorStoreQueryResult,
MetadataFilters,
)
from llama_index.vector_stores.utils import node_to_metadata_dict, metadata_dict_to_node

Expand Down Expand Up @@ -122,23 +123,39 @@ def add(self, embedding_results: List[NodeWithEmbedding]) -> List[str]:
return ids

def _query_with_score(
self, embedding: Optional[List[float]], limit: int = 10
self,
embedding: Optional[List[float]],
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
) -> List[Any]:
import sqlalchemy
from sqlalchemy import and_

with self._session() as session:
with session.begin():
res = (
session.query(
self.table_class,
self.table_class.embedding.l2_distance(embedding),
# type: ignore
)
.order_by(self.table_class.embedding.l2_distance(embedding))
.limit(limit)
) # type: ignore
query = session.query(
self.table_class,
self.table_class.embedding.l2_distance(embedding),
).order_by(self.table_class.embedding.l2_distance(embedding))
if metadata_filters:
for filter_ in metadata_filters.filters:
bind_parameter = f"value_{filter_.key}"
query = query.filter(
and_(
sqlalchemy.text(
f"metadata_->>'{filter_.key}' = :{bind_parameter}"
)
)
)
query = query.params(**{bind_parameter: str(filter_.value)})
query = query.limit(limit)
res = query
return res.all()

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
results = self._query_with_score(query.query_embedding, query.similarity_top_k)
results = self._query_with_score(
query.query_embedding, query.similarity_top_k, query.filters
)
nodes = []
similarities = []
ids = []
Expand Down
34 changes: 32 additions & 2 deletions tests/vector_stores/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

from llama_index.schema import NodeRelationship, RelatedNodeInfo, TextNode
from llama_index.vector_stores import PGVectorStore
from llama_index.vector_stores.types import NodeWithEmbedding, VectorStoreQuery
from llama_index.vector_stores.types import (
NodeWithEmbedding,
VectorStoreQuery,
MetadataFilters,
ExactMatchFilter,
)

# from testing find install here https://github.com/pgvector/pgvector#installation-notes

Expand Down Expand Up @@ -38,11 +43,12 @@ def conn() -> Any:
return conn_


@pytest.fixture(scope="session")
@pytest.fixture()
def db(conn: Any) -> Generator:
conn.autocommit = True

with conn.cursor() as c:
c.execute(f"DROP DATABASE IF EXISTS {TEST_DB}")
c.execute(f"CREATE DATABASE {TEST_DB}")
conn.commit()
yield
Expand All @@ -68,6 +74,7 @@ def node_embeddings() -> List[NodeWithEmbedding]:
text="dolor sit amet",
id_="bbb",
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="bbb")},
extra_info={"test_key": "test_value"},
),
),
]
Expand Down Expand Up @@ -101,6 +108,29 @@ def test_add_to_db_and_query(
assert res.nodes[0].node_id == "aaa"


@pytest.mark.skipif(postgres_not_available, reason="postgres db is not available")
def test_add_to_db_and_query_with_metadata_filters(
db: None, node_embeddings: List[NodeWithEmbedding]
) -> None:
pg = PGVectorStore.from_params(
**PARAMS, # type: ignore
database=TEST_DB,
table_name=TEST_TABLE_NAME,
)
pg.add(node_embeddings)
assert isinstance(pg, PGVectorStore)
filters = MetadataFilters(
filters=[ExactMatchFilter(key="test_key", value="test_value")]
)
q = VectorStoreQuery(
query_embedding=[0.5] * 1536, similarity_top_k=10, filters=filters
)
res = pg.query(q)
assert res.nodes
assert len(res.nodes) == 1
assert res.nodes[0].node_id == "bbb"


@pytest.mark.skipif(postgres_not_available, reason="postgres db is not available")
def test_add_to_db_query_and_delete(
db: None, node_embeddings: List[NodeWithEmbedding]
Expand Down

0 comments on commit d19b82b

Please sign in to comment.