Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support Vector Search #896

Merged
merged 59 commits into from Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
c361fe2
Support vector data type
Jan 17, 2024
5b954cb
Simplify the test structure
Jan 17, 2024
669e935
Copy the vector search proto from preview branch
Jan 24, 2024
b9869f7
initial vector search
Jan 30, 2024
6c46e71
Add firestore vector search example
Jan 31, 2024
607c883
Update examples with comments
Jan 31, 2024
e21ebab
Update
Mar 25, 2024
559a414
Merge branch 'main' of https://github.com/pl04351820/python-firestore
Mar 25, 2024
b6dd533
Merge branch 'main' into vector
Mar 25, 2024
e97ebc3
Revert type change
Mar 25, 2024
b5375be
Better unittest coverage on vector query
Mar 25, 2024
54ea8e8
Reorder vector type in SDK
Mar 25, 2024
0b05eda
Add tests for system test
Mar 25, 2024
5c418d7
Revert unnecessary change in test_system
Mar 25, 2024
da6cd34
Delete vector_search_example
Mar 25, 2024
453e06c
Fix type lint
Mar 26, 2024
fe8eb13
Lint for 3.7
Mar 26, 2024
7299b5f
More cleanup
Mar 28, 2024
7457976
more lint
Mar 28, 2024
3b155bc
type annotation
Mar 28, 2024
ebb3b22
More lint
Mar 28, 2024
c1846d1
lint
Mar 28, 2024
bd54464
lint
Mar 28, 2024
5601cfd
lint
Mar 28, 2024
8deb2ae
more lint
Mar 28, 2024
042f7bc
lint
Mar 28, 2024
8cecc4f
more lint
Mar 28, 2024
43bca61
better test coverage
Mar 28, 2024
3545594
Throw ValueError
Mar 28, 2024
aca325a
test coverage with txn
Mar 28, 2024
a88212b
More coverage
Mar 28, 2024
4d8a14b
lint
Mar 28, 2024
79aa0e0
lint
Mar 28, 2024
9f84caa
More test coverage
Mar 28, 2024
5cd9e17
Merge branch 'googleapis:main' into vector
pl04351820 Mar 28, 2024
9d126b8
Better docs
Mar 29, 2024
e38cd2b
More doc
Mar 29, 2024
69e65d6
lint
Mar 29, 2024
42dfbbc
index_bootstrap_scipt and system_test
Mar 29, 2024
d3cf322
lint
Mar 29, 2024
ab8be44
Skip emulator for vector index
Mar 29, 2024
b87484e
Resolve the comment
Apr 1, 2024
e28350b
lint
Apr 1, 2024
6c232e1
lint
Apr 1, 2024
051774a
more lint
Apr 1, 2024
3d61c3c
remove abc type
Apr 1, 2024
7eb35db
lint
Apr 2, 2024
a4b0044
Remove unused import
Apr 2, 2024
47fef4b
Implement sequence
Apr 2, 2024
107fd54
lint and better comment
Apr 2, 2024
5f00712
resolve comments
Apr 2, 2024
ecfb49a
typing_extensions for Self
Apr 2, 2024
0d734de
TypeVar
Apr 2, 2024
b0f786e
Resolve comments
Apr 2, 2024
3e5e536
Decouple order from TypeOrder to avoid braking change
Apr 2, 2024
f2e2e35
update order
Apr 2, 2024
3790a91
Resolve comment
Apr 2, 2024
001e951
Add coverage for elements
Apr 2, 2024
5b2e032
lint
Apr 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 15 additions & 4 deletions google/cloud/firestore_v1/_helpers.py
Expand Up @@ -26,6 +26,7 @@

from google.cloud import exceptions # type: ignore
from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore
from google.cloud.firestore_v1.vector import Vector
from google.cloud.firestore_v1.types.write import DocumentTransform
from google.cloud.firestore_v1 import transforms
from google.cloud.firestore_v1 import types
Expand Down Expand Up @@ -160,7 +161,8 @@ def encode_value(value) -> types.document.Value:

Args:
value (Union[NoneType, bool, int, float, datetime.datetime, \
str, bytes, dict, ~google.cloud.Firestore.GeoPoint]): A native
str, bytes, dict, ~google.cloud.Firestore.GeoPoint, \
~google.cloud.firestore_v1.vector.Vector]): A native
Python value to convert to a protobuf field.

Returns:
Expand Down Expand Up @@ -209,6 +211,9 @@ def encode_value(value) -> types.document.Value:
value_pb = document.ArrayValue(values=value_list)
return document.Value(array_value=value_pb)

if isinstance(value, Vector):
return encode_value(value.to_map_value())
pl04351820 marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(value, dict):
value_dict = encode_dict(value)
value_pb = document.MapValue(fields=value_dict)
Expand Down Expand Up @@ -331,7 +336,9 @@ def reference_value_to_document(reference_value, client) -> Any:

def decode_value(
value, client
) -> Union[None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint]:
) -> Union[
None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint, Vector
]:
"""Converts a Firestore protobuf ``Value`` to a native Python value.

Args:
Expand Down Expand Up @@ -382,7 +389,7 @@ def decode_value(
raise ValueError("Unknown ``value_type``", value_type)


def decode_dict(value_fields, client) -> dict:
def decode_dict(value_fields, client) -> Union[dict, Vector]:
pl04351820 marked this conversation as resolved.
Show resolved Hide resolved
"""Converts a protobuf map of Firestore ``Value``-s.

Args:
Expand All @@ -397,8 +404,12 @@ def decode_dict(value_fields, client) -> dict:
of native Python values converted from the ``value_fields``.
"""
value_fields_pb = getattr(value_fields, "_pb", value_fields)
res = {key: decode_value(value, client) for key, value in value_fields_pb.items()}

if res.get("__type__", None) == "__vector__":
return Vector(res["value"])
pl04351820 marked this conversation as resolved.
Show resolved Hide resolved

return {key: decode_value(value, client) for key, value in value_fields_pb.items()}
return res


def get_doc_id(document_pb, expected_prefix) -> str:
Expand Down
32 changes: 32 additions & 0 deletions google/cloud/firestore_v1/base_collection.py
Expand Up @@ -19,9 +19,12 @@
from google.api_core import retry as retries

from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
from google.cloud.firestore_v1.document import DocumentReference
from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery
from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery
from google.cloud.firestore_v1.base_query import QueryType
from google.cloud.firestore_v1.vector import Vector


from typing import (
Expand All @@ -46,6 +49,7 @@
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.transaction import Transaction
from google.cloud.firestore_v1.field_path import FieldPath
from firestore_v1.vector_query import VectorQuery

_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"

Expand Down Expand Up @@ -120,6 +124,9 @@ def _query(self) -> QueryType:
def _aggregation_query(self) -> BaseAggregationQuery:
raise NotImplementedError

def _vector_query(self) -> BaseVectorQuery:
raise NotImplementedError

def document(self, document_id: Optional[str] = None) -> DocumentReference:
"""Create a sub-document underneath the current collection.

Expand Down Expand Up @@ -539,6 +546,31 @@ def avg(self, field_ref: str | FieldPath, alias=None):
"""
return self._aggregation_query().avg(field_ref, alias=alias)

def find_nearest(
self,
vector_field: str,
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
) -> VectorQuery:
"""
Finds the closest vector embeddings to the given query vector.

Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.

Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
"""
return self._vector_query().find_nearest(
vector_field, query_vector, limit, distance_measure
)


def _auto_id() -> str:
"""Generate a "random" automatically generated ID.
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/firestore_v1/base_query.py
Expand Up @@ -33,6 +33,7 @@
from google.cloud.firestore_v1 import document
from google.cloud.firestore_v1 import field_path as field_path_module
from google.cloud.firestore_v1 import transforms
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
from google.cloud.firestore_v1.types import StructuredQuery
from google.cloud.firestore_v1.types import query
from google.cloud.firestore_v1.types import Cursor
Expand All @@ -51,11 +52,13 @@
Union,
TYPE_CHECKING,
)
from google.cloud.firestore_v1.vector import Vector

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery
from google.cloud.firestore_v1.field_path import FieldPath

_BAD_DIR_STRING: str
Expand Down Expand Up @@ -972,6 +975,15 @@ def _to_protobuf(self) -> StructuredQuery:
query_kwargs["limit"] = wrappers_pb2.Int32Value(value=self._limit)
return query.StructuredQuery(**query_kwargs)

def find_nearest(
self,
vector_field: str,
queryVector: Vector,
pl04351820 marked this conversation as resolved.
Show resolved Hide resolved
limit: int,
distance_measure: DistanceMeasure,
) -> BaseVectorQuery:
raise NotImplementedError

def count(
self, alias: str | None = None
) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]:
Expand Down
119 changes: 119 additions & 0 deletions google/cloud/firestore_v1/base_vector_query.py
@@ -0,0 +1,119 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes for representing vector queries for the Google Cloud Firestore API.
"""
pl04351820 marked this conversation as resolved.
Show resolved Hide resolved

import abc

from abc import ABC
from enum import Enum
from typing import Iterable, Optional, Tuple, Union
from google.api_core import gapic_v1
from google.api_core import retry as retries
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.types import query
from google.cloud.firestore_v1.vector import Vector
from google.cloud.firestore_v1 import _helpers


class DistanceMeasure(Enum):
EUCLIDEAN = 1
COSINE = 2
DOT_PRODUCT = 3


class BaseVectorQuery(ABC):
"""Represents a vector query to the Firestore API."""

def __init__(self, nested_query) -> None:
self._nested_query = nested_query
pl04351820 marked this conversation as resolved.
Show resolved Hide resolved
self._collection_ref = nested_query._parent
self._vector_field: Optional[str] = None
self._query_vector: Optional[Vector] = None
self._limit: Optional[int] = None
self._distance_measure: Optional[DistanceMeasure] = None

@property
def _client(self):
return self._collection_ref._client

def _to_protobuf(self) -> query.StructuredQuery:
pb = query.StructuredQuery()

distance_measure_proto = None
pl04351820 marked this conversation as resolved.
Show resolved Hide resolved
if self._distance_measure == DistanceMeasure.EUCLIDEAN:
distance_measure_proto = (
query.StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN
)
elif self._distance_measure == DistanceMeasure.COSINE:
distance_measure_proto = (
query.StructuredQuery.FindNearest.DistanceMeasure.COSINE
)
elif self._distance_measure == DistanceMeasure.DOT_PRODUCT:
distance_measure_proto = (
query.StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT
)
else:
raise ValueError("Invalid distance_measure")

pb = self._nested_query._to_protobuf()
pb.find_nearest = query.StructuredQuery.FindNearest(
vector_field=query.StructuredQuery.FieldReference(
field_path=self._vector_field
),
query_vector=_helpers.encode_value(self._query_vector),
distance_measure=distance_measure_proto,
limit=self._limit,
)
return pb

def _prep_stream(
self,
transaction=None,
retry: Union[retries.Retry, None, gapic_v1.method._MethodDefault] = None,
timeout: Optional[float] = None,
) -> Tuple[dict, str, dict]:
parent_path, expected_prefix = self._collection_ref._parent_info()
request = {
"parent": parent_path,
"structured_query": self._to_protobuf(),
"transaction": _helpers.get_transaction_id(transaction),
}
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

return request, expected_prefix, kwargs

@abc.abstractmethod
def get(
self,
transaction=None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> Iterable[DocumentSnapshot]:
"""Runs the vector query."""

def find_nearest(
self,
vector_field: str,
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
):
"""Finds the closest vector embeddings to the given query vector."""
self._vector_field = vector_field
self._query_vector = query_vector
self._limit = limit
self._distance_measure = distance_measure
return self
9 changes: 9 additions & 0 deletions google/cloud/firestore_v1/collection.py
Expand Up @@ -23,6 +23,7 @@
)
from google.cloud.firestore_v1 import query as query_mod
from google.cloud.firestore_v1 import aggregation
from google.cloud.firestore_v1 import vector_query
from google.cloud.firestore_v1.watch import Watch
from google.cloud.firestore_v1 import document
from typing import Any, Callable, Generator, Tuple, Union
Expand Down Expand Up @@ -76,6 +77,14 @@ def _aggregation_query(self) -> aggregation.AggregationQuery:
"""
return aggregation.AggregationQuery(self._query())

def _vector_query(self) -> vector_query.VectorQuery:
"""VectorQuery factory.

Returns:
:class:`~google.cloud.firestore_v1.vector_query.VectorQuery`
"""
return vector_query.VectorQuery(self._query())

def add(
self,
document_data: dict,
Expand Down