diff --git a/examples/pg_vectorstore.ipynb b/examples/pg_vectorstore.ipynb index 2c20e90..b5d26f3 100644 --- a/examples/pg_vectorstore.ipynb +++ b/examples/pg_vectorstore.ipynb @@ -258,22 +258,22 @@ "\n", "docs = [\n", " Document(\n", - " id=uuid.uuid4(),\n", + " id=str(uuid.uuid4()),\n", " page_content=\"there are cats in the pond\",\n", " metadata={\"likes\": 1, \"location\": \"pond\", \"topic\": \"animals\"},\n", " ),\n", " Document(\n", - " id=uuid.uuid4(),\n", + " id=str(uuid.uuid4()),\n", " page_content=\"ducks are also found in the pond\",\n", " metadata={\"likes\": 30, \"location\": \"pond\", \"topic\": \"animals\"},\n", " ),\n", " Document(\n", - " id=uuid.uuid4(),\n", + " id=str(uuid.uuid4()),\n", " page_content=\"fresh apples are available at the market\",\n", " metadata={\"likes\": 20, \"location\": \"market\", \"topic\": \"food\"},\n", " ),\n", " Document(\n", - " id=uuid.uuid4(),\n", + " id=str(uuid.uuid4()),\n", " page_content=\"the market also sells fresh oranges\",\n", " metadata={\"likes\": 5, \"location\": \"market\", \"topic\": \"food\"},\n", " ),\n", @@ -283,6 +283,28 @@ "await vectorstore.aadd_documents(documents=docs)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get collection\n", + "\n", + "Get collection from the vectorstore using filters and parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "documents_with_apple = await vectorstore.aget(where_document={\"$ilike\": \"%apple%\"}, include=\"documents\")\n", + "paginated_ids = await vectorstore.aget(limit=3, offset=3)\n", + "\n", + "print(documents_with_apple[\"documents\"])\n", + "print(paginated_ids[\"ids\"])" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/examples/pg_vectorstore_how_to.ipynb b/examples/pg_vectorstore_how_to.ipynb index 2c5e75a..e238fe9 100644 --- a/examples/pg_vectorstore_how_to.ipynb +++ b/examples/pg_vectorstore_how_to.ipynb @@ -327,6 +327,28 @@ "await store.aadd_texts(all_texts, metadatas=metadatas, ids=ids)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get collection\n", + "\n", + "Get collection from the vectorstore using filters and parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "documents_with_apple = await store.aget(where_document={\"$ilike\": \"%apple%\"}, include=\"documents\")\n", + "paginated_ids = await store.aget(limit=3, offset=3)\n", + "\n", + "print(documents_with_apple[\"documents\"])\n", + "print(paginated_ids[\"ids\"])" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index c83930e..73f78f8 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -672,6 +672,39 @@ async def __query_collection( return combined_results return dense_results + async def __query_collection_with_filter( + self, + *, + limit: Optional[int] = None, + offset: Optional[int] = None, + filter: Optional[dict] = None, + columns: Optional[list[str]] = None, + **kwargs: Any, + ) -> Sequence[RowMapping]: + """Asynchronously query the database collection using filters and parameters and return matching rows.""" + + column_names = ", ".join(f'"{col}"' for col in columns) + + safe_filter = None + filter_dict = None + if filter and isinstance(filter, dict): + safe_filter, filter_dict = self._create_filter_clause(filter) + + suffix_id = str(uuid.uuid4()).split("-")[0] + where_filters = f"WHERE {safe_filter}" if safe_filter else "" + dense_query_stmt = f"""SELECT {column_names} + FROM "{self.schema_name}"."{self.table_name}" {where_filters} LIMIT :limit_{suffix_id} OFFSET :offset_{suffix_id}; + """ + param_dict = {f"limit_{suffix_id}": limit, f"offset_{suffix_id}": offset} + if filter_dict: + param_dict.update(filter_dict) + async with self.engine.connect() as conn: + result = await conn.execute(text(dense_query_stmt), param_dict) + result_map = result.mappings() + results = result_map.fetchall() + + return results + async def asimilarity_search( self, query: str, @@ -995,6 +1028,71 @@ async def is_valid_index( results = result_map.fetchall() return bool(len(results) == 1) + async def aget( + self, + ids: Optional[Sequence[str]] = None, + where: Optional[dict] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + where_document: Optional[dict] = None, + include: Optional[list[str]] = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Retrieve documents from the collection using filters and parameters.""" + filter = {} + if ids: + filter.update({self.id_column: {"$in": ids}}) + if where: + filter.update(where) + if where_document: + filter.update({self.content_column: where_document}) + + if include is None: + include = ["metadatas", "documents"] + + fields_mapping = { + "embeddings": [self.embedding_column], + "metadatas": self.metadata_columns + [self.metadata_json_column] + if self.metadata_json_column + else self.metadata_columns, + "documents": [self.content_column], + } + + included_fields = ["ids"] + columns = [self.id_column] + + for field, cols in fields_mapping.items(): + if field in include: + included_fields.append(field) + columns.extend(cols) + + results = await self.__query_collection_with_filter( + limit=limit, offset=offset, filter=filter, columns=columns, **kwargs + ) + + final_results = {field: [] for field in included_fields} + + for row in results: + final_results["ids"].append(str(row[self.id_column])) + + if "metadatas" in final_results: + metadata = ( + row.get(self.metadata_json_column) or {} + if self.metadata_json_column + else {} + ) + for col in self.metadata_columns: + metadata[col] = row[col] + final_results["metadatas"].append(metadata) + + if "documents" in final_results: + final_results["documents"].append(row[self.content_column]) + + if "embeddings" in final_results: + final_results["embeddings"].append(row[self.embedding_column]) + + return final_results + async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]: """Get documents by ids.""" @@ -1249,6 +1347,20 @@ def _create_filter_clause(self, filters: Any) -> tuple[str, dict]: else: return "", {} + def get( + self, + ids: Optional[Sequence[str]] = None, + where: Optional[dict] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + where_document: Optional[dict] = None, + include: Optional[list[str]] = None, + **kwargs: Any, + ) -> dict[str, Any]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead." + ) + def get_by_ids(self, ids: Sequence[str]) -> list[Document]: raise NotImplementedError( "Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead." diff --git a/langchain_postgres/v2/vectorstores.py b/langchain_postgres/v2/vectorstores.py index edbfb57..dd0c0c6 100644 --- a/langchain_postgres/v2/vectorstores.py +++ b/langchain_postgres/v2/vectorstores.py @@ -875,5 +875,51 @@ def get_by_ids(self, ids: Sequence[str]) -> list[Document]: """Get documents by ids.""" return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids)) + async def aget( + self, + ids: Optional[Sequence[str]] = None, + where: Optional[dict] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + where_document: Optional[dict] = None, + include: Optional[list[str]] = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Retrieve documents from the collection using filters and parameters.""" + return await self._engine._run_as_async( + self.__vs.aget( + ids=ids, + where=where, + limit=limit, + offset=offset, + where_document=where_document, + include=include, + **kwargs, + ) + ) + + def get( + self, + ids: Optional[Sequence[str]] = None, + where: Optional[dict] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + where_document: Optional[dict] = None, + include: Optional[list[str]] = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Retrieve documents from the collection using filters and parameters.""" + return self._engine._run_as_sync( + self.__vs.aget( + ids=ids, + where=where, + limit=limit, + offset=offset, + where_document=where_document, + include=include, + **kwargs, + ) + ) + def get_table_name(self) -> str: return self.__vs.table_name diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py index 84f9d72..deefc77 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py @@ -370,6 +370,52 @@ async def test_vectorstore_with_metadata_filters( ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + async def test_async_vectorstore_get_ids( + self, + vs_custom_filter: AsyncPGVectorStore + ) -> None: + """Test end to end construction and filter.""" + + res = await vs_custom_filter.aget(ids=ids[:2]) + assert set(res["ids"]) == set(ids[:2]) + + async def test_async_vectorstore_get_docs( + self, + vs_custom_filter: AsyncPGVectorStore + ) -> None: + """Test end to end construction and filter.""" + + res = await vs_custom_filter.aget(where_document={"$in": texts[:2]}) + assert set(res["documents"]) == set(texts[:2]) + + @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) + async def test_vectorstore_get( + self, + vs_custom_filter: AsyncPGVectorStore, + test_filter: dict, + expected_ids: list[str], + ) -> None: + """Test end to end construction and filter.""" + res = await vs_custom_filter.aget(where=test_filter) + assert set([r["code"] for r in res["metadatas"]]) == set(expected_ids), ( + test_filter + ) + + async def test_vectorstore_get_limit_offset( + self, + vs_custom_filter: AsyncPGVectorStore, + ) -> None: + """Test limit and offset parameters of get method""" + + all_ids = (await vs_custom_filter.aget())["ids"] + ids_from_combining = ( + (await vs_custom_filter.aget(limit=1))["ids"] + + (await vs_custom_filter.aget(limit=1, offset=1))["ids"] + + (await vs_custom_filter.aget(offset=2))["ids"] + ) + + assert all_ids == ids_from_combining + async def test_asimilarity_hybrid_search(self, vs: AsyncPGVectorStore) -> None: results = await vs.asimilarity_search( "foo", k=1, hybrid_search_config=HybridSearchConfig() diff --git a/tests/unit_tests/v2/test_pg_vectorstore_search.py b/tests/unit_tests/v2/test_pg_vectorstore_search.py index 7815a25..a2dba46 100644 --- a/tests/unit_tests/v2/test_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_pg_vectorstore_search.py @@ -429,6 +429,53 @@ def test_sync_vectorstore_with_metadata_filters( docs = vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + def test_sync_vectorstore_get_ids( + self, + vs_custom_filter_sync: PGVectorStore + ) -> None: + """Test end to end construction and filter.""" + + res = vs_custom_filter_sync.get(ids=ids[:2]) + assert set(res["ids"]) == set(ids[:2]) + + def test_sync_vectorstore_get_docs( + self, + vs_custom_filter_sync: PGVectorStore + ) -> None: + """Test end to end construction and filter.""" + + res = vs_custom_filter_sync.get(where_document={"$in": texts[:2]}) + assert set(res["documents"]) == set(texts[:2]) + + @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) + def test_sync_vectorstore_get( + self, + vs_custom_filter_sync: PGVectorStore, + test_filter: dict, + expected_ids: list[str], + ) -> None: + """Test end to end construction and filter.""" + + res = vs_custom_filter_sync.get(where=test_filter) + assert set([r["code"] for r in res["metadatas"]]) == set(expected_ids), ( + test_filter + ) + + def test_sync_vectorstore_get_limit_offset( + self, + vs_custom_filter_sync: PGVectorStore, + ) -> None: + """Test limit and offset parameters of get method""" + + all_ids = vs_custom_filter_sync.get()["ids"] + ids_from_combining = ( + vs_custom_filter_sync.get(limit=1)["ids"] + + vs_custom_filter_sync.get(limit=1, offset=1)["ids"] + + vs_custom_filter_sync.get(offset=2)["ids"] + ) + + assert all_ids == ids_from_combining + @pytest.mark.parametrize("test_filter", NEGATIVE_TEST_CASES) def test_metadata_filter_negative_tests( self, vs_custom_filter_sync: PGVectorStore, test_filter: dict diff --git a/uv.lock b/uv.lock index 933849d..28773e3 100644 --- a/uv.lock +++ b/uv.lock @@ -621,7 +621,7 @@ wheels = [ [[package]] name = "langchain-postgres" -version = "0.0.15" +version = "0.0.16" source = { editable = "." } dependencies = [ { name = "asyncpg" },