Skip to content

Commit f630b04

Browse files
committed
PYCBC-1597: Support for base64 encoded vector types
Motivation ========== Add support for base64 encoded vector types. Modification ============ * Update VectorQuery to handle vector of either List[float] or str. * Update VectorQuery vector validation. * Update search request to add vector_base64 to request if that is the VectorQuery vector type. * Add unit test to confirm functionality. Change-Id: I47525fff85a390513faf1a929ace85ee82172342 Reviewed-on: https://review.couchbase.org/c/couchbase-python-client/+/210668 Reviewed-by: Dimitris Christodoulou <dimitris.christodoulou@couchbase.com> Tested-by: Build Bot <build@couchbase.com>
1 parent d7391bc commit f630b04

File tree

3 files changed

+78
-16
lines changed

3 files changed

+78
-16
lines changed

couchbase/logic/search.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,9 +961,12 @@ def encode_vector_search(self) -> Optional[List[Dict[str, Any]]]:
961961
for query in self._vector_search.queries:
962962
encoded_query = {
963963
'field': query.field_name,
964-
'vector': query.vector,
965964
'k': query.num_candidates if query.num_candidates is not None else 3
966965
}
966+
if query.vector is not None:
967+
encoded_query['vector'] = query.vector
968+
else:
969+
encoded_query['vector_base64'] = query.vector_base64
967970
if query.boost is not None:
968971
encoded_query['boost'] = query.boost
969972
encoded_queries.append(encoded_query)

couchbase/logic/vector_search.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import List, Optional
4+
from typing import (List,
5+
Optional,
6+
Union)
57

68
from couchbase.exceptions import InvalidArgumentException
79
from couchbase.options import VectorSearchOptions
@@ -37,30 +39,29 @@ class VectorQuery:
3739
3840
Args:
3941
field_name (str): The name of the field in the search index that stores the vector.
40-
vector (List[float]): The vector to use in the query.
42+
vector (Union[List[float], str]): The vector to use in the query.
4143
num_candidates (int, optional): Specifies the number of results returned. If provided, must be greater or equal to 1.
4244
boost (float, optional): Add boost to query.
4345
4446
Raises:
4547
:class:`~couchbase.exceptions.InvalidArgumentException`: If the vector is not provided.
46-
:class:`~couchbase.exceptions.InvalidArgumentException`: If all values of the provided vector are not instances of float.
48+
:class:`~couchbase.exceptions.InvalidArgumentException`: If the vector is not a list or str.
49+
:class:`~couchbase.exceptions.InvalidArgumentException`: If vector is a list and all values of the provided vector are not instances of float.
4750
4851
Returns:
4952
:class:`~couchbase.vector_search.VectorQuery`: The created vector query.
5053
""" # noqa: E501
5154

5255
def __init__(self,
5356
field_name, # type: str
54-
vector, # type: List[float]
57+
vector, # type: Union[List[float], str]
5558
num_candidates=None, # type: Optional[int]
5659
boost=None, # type: Optional[float]
5760
):
5861
self._field_name = field_name
59-
if vector is None or len(vector) == 0:
60-
raise InvalidArgumentException('Provided vector cannot be empty.')
61-
if not all(map(lambda q: isinstance(q, float), vector)):
62-
raise InvalidArgumentException('All vector values must be a float.')
63-
self._vector = vector
62+
self._vector = None
63+
self._vector_base64 = None
64+
self._validate_and_set_vector(vector)
6465
self._num_candidates = self._boost = None
6566
if num_candidates is not None:
6667
self.num_candidates = num_candidates
@@ -116,19 +117,46 @@ def num_candidates(self,
116117
self._num_candidates = value
117118

118119
@property
119-
def vector(self) -> List[float]:
120+
def vector(self) -> Optional[List[float]]:
120121
"""
121122
**UNCOMMITTED** This API is unlikely to change,
122123
but may still change as final consensus on its behavior has not yet been reached.
123124
124-
List[float]: Returns the vector query's vector.
125+
Optional[List[float]]: Returns the vector query's vector.
125126
"""
126127
return self._vector
127128

129+
@property
130+
def vector_base64(self) -> Optional[str]:
131+
"""
132+
**UNCOMMITTED** This API is unlikely to change,
133+
but may still change as final consensus on its behavior has not yet been reached.
134+
135+
Optional[str]: Returns the vector query's base64 vector str.
136+
"""
137+
return self._vector_base64
138+
139+
def _validate_and_set_vector(self,
140+
vector, # type: Union[List[float], str]
141+
) -> None:
142+
if vector is None:
143+
raise InvalidArgumentException('Provided vector cannot be empty.')
144+
if isinstance(vector, list):
145+
if len(vector) == 0:
146+
raise InvalidArgumentException('Provided vector cannot be empty.')
147+
if not all(map(lambda q: isinstance(q, float), vector)):
148+
raise InvalidArgumentException('All vector values must be a float.')
149+
self._vector = vector
150+
return
151+
elif not isinstance(vector, str):
152+
raise InvalidArgumentException('Provided vector must be either a List[float] or base64 encoded str.')
153+
154+
self._vector_base64 = vector
155+
128156
@classmethod
129157
def create(cls,
130158
field_name, # type: str
131-
vector, # type: List[float]
159+
vector, # type: Union[List[float], str]
132160
num_candidates=None, # type: Optional[int]
133161
boost=None, # type: Optional[float]
134162
) -> VectorQuery:
@@ -139,13 +167,14 @@ def create(cls,
139167
140168
Args:
141169
field_name (str): The name of the field in the search index that stores the vector.
142-
vector (List[float]): The vector to use in the query.
170+
vector (Union[List[float], str]): The vector to use in the query.
143171
num_candidates (int, optional): Specifies the number of results returned. If provided, must be greater or equal to 1.
144172
boost (float, optional): Add boost to query.
145173
146174
Raises:
147-
:class:`~couchbase.exceptions.InvalidArgumentException`: If the vector is not provided.
148-
:class:`~couchbase.exceptions.InvalidArgumentException`: If all values of the provided vector are not instances of float.
175+
:class:`~couchbase.exceptions.InvalidArgumentException`: If the vector is not provided.
176+
:class:`~couchbase.exceptions.InvalidArgumentException`: If the vector is not a list or str.
177+
:class:`~couchbase.exceptions.InvalidArgumentException`: If vector is a list and all values of the provided vector are not instances of float.
149178
150179
Returns:
151180
:class:`~couchbase.vector_search.VectorQuery`: The created vector query.

couchbase/tests/search_params_t.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,12 +868,15 @@ class VectorSearchParamTestSuite:
868868
-0.02059517428278923,
869869
0.019551364704966545]
870870

871+
TEST_VECTOR_BASE64 = 'SSdtIGp1c3QgYSB0ZXN0IHN0cmluZw=='
872+
871873
TEST_MANIFEST = [
872874
'test_search_request_invalid',
873875
'test_vector_query_invalid_boost',
874876
'test_vector_query_invalid_num_candidates',
875877
'test_vector_query_invalid_vector',
876878
'test_vector_search',
879+
'test_vector_search_base64',
877880
'test_vector_search_invalid',
878881
'test_vector_search_multiple_queries'
879882
]
@@ -933,6 +936,9 @@ def test_vector_query_invalid_vector(self):
933936
VectorQuery('vector_field', [1])
934937
with pytest.raises(InvalidArgumentException):
935938
VectorQuery('vector_field', [1.111, 2, 3.14159])
939+
# if not a list, should be a str
940+
with pytest.raises(InvalidArgumentException):
941+
VectorQuery('vector_field', {})
936942

937943
def test_vector_search(self, cb_env):
938944
exp_json = {
@@ -958,6 +964,30 @@ def test_vector_search(self, cb_env):
958964
encoded_q = cb_env.get_encoded_query(search_query)
959965
assert exp_json == encoded_q
960966

967+
def test_vector_search_base64(self, cb_env):
968+
exp_json = {
969+
'query': {'match_none': None},
970+
'index_name': cb_env.TEST_INDEX_NAME,
971+
'metrics': True,
972+
'show_request': False,
973+
'vector_search': [
974+
{
975+
'field': 'vector_field',
976+
'vector_base64': self.TEST_VECTOR_BASE64,
977+
'k': 3
978+
}
979+
]
980+
}
981+
982+
vector_search = VectorSearch.from_vector_query(VectorQuery('vector_field', self.TEST_VECTOR_BASE64))
983+
req = SearchRequest.create(vector_search)
984+
search_query = search.SearchQueryBuilder.create_search_query_from_request(
985+
cb_env.TEST_INDEX_NAME,
986+
req
987+
)
988+
encoded_q = cb_env.get_encoded_query(search_query)
989+
assert exp_json == encoded_q
990+
961991
def test_vector_search_invalid(self):
962992
with pytest.raises(InvalidArgumentException):
963993
VectorSearch([])

0 commit comments

Comments
 (0)