Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions examples/pg_vectorstore.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -258,22 +258,22 @@
"\n",
"docs = [\n",
" Document(\n",
" id=uuid.uuid4(),\n",
" id=str(uuid.uuid4()),\n",
" page_content=\"there are cats in the pond\",\n",
" metadata={\"likes\": 1, \"location\": \"pond\", \"topic\": \"animals\"},\n",
" ),\n",
" Document(\n",
" id=uuid.uuid4(),\n",
" id=str(uuid.uuid4()),\n",
" page_content=\"ducks are also found in the pond\",\n",
" metadata={\"likes\": 30, \"location\": \"pond\", \"topic\": \"animals\"},\n",
" ),\n",
" Document(\n",
" id=uuid.uuid4(),\n",
" id=str(uuid.uuid4()),\n",
" page_content=\"fresh apples are available at the market\",\n",
" metadata={\"likes\": 20, \"location\": \"market\", \"topic\": \"food\"},\n",
" ),\n",
" Document(\n",
" id=uuid.uuid4(),\n",
" id=str(uuid.uuid4()),\n",
" page_content=\"the market also sells fresh oranges\",\n",
" metadata={\"likes\": 5, \"location\": \"market\", \"topic\": \"food\"},\n",
" ),\n",
Expand All @@ -283,6 +283,28 @@
"await vectorstore.aadd_documents(documents=docs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get collection\n",
"\n",
"Get collection from the vectorstore using filters and parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"documents_with_apple = await vectorstore.aget(where_document={\"$ilike\": \"%apple%\"}, include=\"documents\")\n",
"paginated_ids = await vectorstore.aget(limit=3, offset=3)\n",
"\n",
"print(documents_with_apple[\"documents\"])\n",
"print(paginated_ids[\"ids\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
22 changes: 22 additions & 0 deletions examples/pg_vectorstore_how_to.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,28 @@
"await store.aadd_texts(all_texts, metadatas=metadatas, ids=ids)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get collection\n",
"\n",
"Get collection from the vectorstore using filters and parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"documents_with_apple = await store.aget(where_document={\"$ilike\": \"%apple%\"}, include=\"documents\")\n",
"paginated_ids = await store.aget(limit=3, offset=3)\n",
"\n",
"print(documents_with_apple[\"documents\"])\n",
"print(paginated_ids[\"ids\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
112 changes: 112 additions & 0 deletions langchain_postgres/v2/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,39 @@ async def __query_collection(
return combined_results
return dense_results

async def __query_collection_with_filter(
self,
*,
limit: Optional[int] = None,
offset: Optional[int] = None,
filter: Optional[dict] = None,
columns: Optional[list[str]] = None,
**kwargs: Any,
) -> Sequence[RowMapping]:
"""Asynchronously query the database collection using filters and parameters and return matching rows."""

column_names = ", ".join(f'"{col}"' for col in columns)

safe_filter = None
filter_dict = None
if filter and isinstance(filter, dict):
safe_filter, filter_dict = self._create_filter_clause(filter)

suffix_id = str(uuid.uuid4()).split("-")[0]
where_filters = f"WHERE {safe_filter}" if safe_filter else ""
dense_query_stmt = f"""SELECT {column_names}
FROM "{self.schema_name}"."{self.table_name}" {where_filters} LIMIT :limit_{suffix_id} OFFSET :offset_{suffix_id};
"""
param_dict = {f"limit_{suffix_id}": limit, f"offset_{suffix_id}": offset}
if filter_dict:
param_dict.update(filter_dict)
async with self.engine.connect() as conn:
result = await conn.execute(text(dense_query_stmt), param_dict)
result_map = result.mappings()
results = result_map.fetchall()

return results

async def asimilarity_search(
self,
query: str,
Expand Down Expand Up @@ -995,6 +1028,71 @@ async def is_valid_index(
results = result_map.fetchall()
return bool(len(results) == 1)

async def aget(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also support an ids input and just call the get_by_ids method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but then other parameters won't work. How about updating filters with ids filter?

self,
ids: Optional[Sequence[str]] = None,
where: Optional[dict] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[dict] = None,
include: Optional[list[str]] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Retrieve documents from the collection using filters and parameters."""
filter = {}
if ids:
filter.update({self.id_column: {"$in": ids}})
if where:
filter.update(where)
if where_document:
filter.update({self.content_column: where_document})

if include is None:
include = ["metadatas", "documents"]

fields_mapping = {
"embeddings": [self.embedding_column],
"metadatas": self.metadata_columns + [self.metadata_json_column]
if self.metadata_json_column
else self.metadata_columns,
"documents": [self.content_column],
}

included_fields = ["ids"]
columns = [self.id_column]

for field, cols in fields_mapping.items():
if field in include:
included_fields.append(field)
columns.extend(cols)

results = await self.__query_collection_with_filter(
limit=limit, offset=offset, filter=filter, columns=columns, **kwargs
)

final_results = {field: [] for field in included_fields}

for row in results:
final_results["ids"].append(str(row[self.id_column]))

if "metadatas" in final_results:
metadata = (
row.get(self.metadata_json_column) or {}
if self.metadata_json_column
else {}
)
for col in self.metadata_columns:
metadata[col] = row[col]
final_results["metadatas"].append(metadata)

if "documents" in final_results:
final_results["documents"].append(row[self.content_column])

if "embeddings" in final_results:
final_results["embeddings"].append(row[self.embedding_column])

return final_results

async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""

Expand Down Expand Up @@ -1249,6 +1347,20 @@ def _create_filter_clause(self, filters: Any) -> tuple[str, dict]:
else:
return "", {}

def get(
self,
ids: Optional[Sequence[str]] = None,
where: Optional[dict] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[dict] = None,
include: Optional[list[str]] = None,
**kwargs: Any,
) -> dict[str, Any]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
)

def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
Expand Down
46 changes: 46 additions & 0 deletions langchain_postgres/v2/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,5 +875,51 @@ def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""
return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids))

async def aget(
self,
ids: Optional[Sequence[str]] = None,
where: Optional[dict] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[dict] = None,
include: Optional[list[str]] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Retrieve documents from the collection using filters and parameters."""
return await self._engine._run_as_async(
self.__vs.aget(
ids=ids,
where=where,
limit=limit,
offset=offset,
where_document=where_document,
include=include,
**kwargs,
)
)

def get(
self,
ids: Optional[Sequence[str]] = None,
where: Optional[dict] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[dict] = None,
include: Optional[list[str]] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Retrieve documents from the collection using filters and parameters."""
return self._engine._run_as_sync(
self.__vs.aget(
ids=ids,
where=where,
limit=limit,
offset=offset,
where_document=where_document,
include=include,
**kwargs,
)
)

def get_table_name(self) -> str:
return self.__vs.table_name
46 changes: 46 additions & 0 deletions tests/unit_tests/v2/test_async_pg_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,52 @@ async def test_vectorstore_with_metadata_filters(
)
assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter

async def test_async_vectorstore_get_ids(
self,
vs_custom_filter: AsyncPGVectorStore
) -> None:
"""Test end to end construction and filter."""

res = await vs_custom_filter.aget(ids=ids[:2])
assert set(res["ids"]) == set(ids[:2])

async def test_async_vectorstore_get_docs(
self,
vs_custom_filter: AsyncPGVectorStore
) -> None:
"""Test end to end construction and filter."""

res = await vs_custom_filter.aget(where_document={"$in": texts[:2]})
assert set(res["documents"]) == set(texts[:2])

@pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES)
async def test_vectorstore_get(
self,
vs_custom_filter: AsyncPGVectorStore,
test_filter: dict,
expected_ids: list[str],
) -> None:
"""Test end to end construction and filter."""
res = await vs_custom_filter.aget(where=test_filter)
assert set([r["code"] for r in res["metadatas"]]) == set(expected_ids), (
test_filter
)

async def test_vectorstore_get_limit_offset(
self,
vs_custom_filter: AsyncPGVectorStore,
) -> None:
"""Test limit and offset parameters of get method"""

all_ids = (await vs_custom_filter.aget())["ids"]
ids_from_combining = (
(await vs_custom_filter.aget(limit=1))["ids"]
+ (await vs_custom_filter.aget(limit=1, offset=1))["ids"]
+ (await vs_custom_filter.aget(offset=2))["ids"]
)

assert all_ids == ids_from_combining

async def test_asimilarity_hybrid_search(self, vs: AsyncPGVectorStore) -> None:
results = await vs.asimilarity_search(
"foo", k=1, hybrid_search_config=HybridSearchConfig()
Expand Down
47 changes: 47 additions & 0 deletions tests/unit_tests/v2/test_pg_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,53 @@ def test_sync_vectorstore_with_metadata_filters(
docs = vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter

def test_sync_vectorstore_get_ids(
self,
vs_custom_filter_sync: PGVectorStore
) -> None:
"""Test end to end construction and filter."""

res = vs_custom_filter_sync.get(ids=ids[:2])
assert set(res["ids"]) == set(ids[:2])

def test_sync_vectorstore_get_docs(
self,
vs_custom_filter_sync: PGVectorStore
) -> None:
"""Test end to end construction and filter."""

res = vs_custom_filter_sync.get(where_document={"$in": texts[:2]})
assert set(res["documents"]) == set(texts[:2])

@pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES)
def test_sync_vectorstore_get(
self,
vs_custom_filter_sync: PGVectorStore,
test_filter: dict,
expected_ids: list[str],
) -> None:
"""Test end to end construction and filter."""

res = vs_custom_filter_sync.get(where=test_filter)
assert set([r["code"] for r in res["metadatas"]]) == set(expected_ids), (
test_filter
)

def test_sync_vectorstore_get_limit_offset(
self,
vs_custom_filter_sync: PGVectorStore,
) -> None:
"""Test limit and offset parameters of get method"""

all_ids = vs_custom_filter_sync.get()["ids"]
ids_from_combining = (
vs_custom_filter_sync.get(limit=1)["ids"]
+ vs_custom_filter_sync.get(limit=1, offset=1)["ids"]
+ vs_custom_filter_sync.get(offset=2)["ids"]
)

assert all_ids == ids_from_combining

@pytest.mark.parametrize("test_filter", NEGATIVE_TEST_CASES)
def test_metadata_filter_negative_tests(
self, vs_custom_filter_sync: PGVectorStore, test_filter: dict
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.