Skip to content

Commit

Permalink
feat: add index_update_method to MatchingEngineIndex create()
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 580589542
  • Loading branch information
lingyinw authored and Copybara-Service committed Nov 8, 2023
1 parent 21686ae commit dcb6205
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 17 deletions.
45 changes: 30 additions & 15 deletions google/cloud/aiplatform/matching_engine/matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _create(
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
index_update_method: Optional[str] = None,
) -> "MatchingEngineIndex":
"""Creates a MatchingEngineIndex resource.
Expand Down Expand Up @@ -153,27 +154,33 @@ def _create(
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Optional. Strings which should be sent along with the request as metadata.
encryption_spec (str):
Optional. Customer-managed encryption key
spec for data storage. If set, both of the
online and offline data storage will be secured
by this key.
sync (bool):
Optional. Whether to execute this creation synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
index_update_method (str):
Optional. The update method to use with this index. Choose
stream_update or batch_update. If not set, batch update will be
used by default.
Returns:
MatchingEngineIndex - Index resource object
"""
index_update_method_enum = None
if index_update_method in _INDEX_UPDATE_METHOD_TO_ENUM_VALUE:
index_update_method_enum = _INDEX_UPDATE_METHOD_TO_ENUM_VALUE[
index_update_method
]

gapic_index = gca_matching_engine_index.Index(
display_name=display_name,
description=description,
metadata={
"config": config.as_dict(),
"contentsDeltaUri": contents_delta_uri,
},
index_update_method=index_update_method_enum,
)

if labels:
Expand Down Expand Up @@ -386,6 +393,7 @@ def create_tree_ah_index(
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
index_update_method: Optional[str] = None,
) -> "MatchingEngineIndex":
"""Creates a MatchingEngineIndex resource that uses the tree-AH algorithm.
Expand Down Expand Up @@ -456,15 +464,14 @@ def create_tree_ah_index(
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Optional. Strings which should be sent along with the request as metadata.
encryption_spec (str):
Optional. Customer-managed encryption key
spec for data storage. If set, both of the
online and offline data storage will be secured
by this key.
sync (bool):
Optional. Whether to execute this creation synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
index_update_method (str):
Optional. The update method to use with this index. Choose
STREAM_UPDATE or BATCH_UPDATE. If not set, batch update will be
used by default.
Returns:
MatchingEngineIndex - Index resource object
Expand Down Expand Up @@ -494,6 +501,7 @@ def create_tree_ah_index(
credentials=credentials,
request_metadata=request_metadata,
sync=sync,
index_update_method=index_update_method,
)

@classmethod
Expand All @@ -512,6 +520,7 @@ def create_brute_force_index(
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
index_update_method: Optional[str] = None,
) -> "MatchingEngineIndex":
"""Creates a MatchingEngineIndex resource that uses the brute force algorithm.
Expand Down Expand Up @@ -571,15 +580,14 @@ def create_brute_force_index(
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Optional. Strings which should be sent along with the request as metadata.
encryption_spec (str):
Optional. Customer-managed encryption key
spec for data storage. If set, both of the
online and offline data storage will be secured
by this key.
sync (bool):
Optional. Whether to execute this creation synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
index_update_method (str):
Optional. The update method to use with this index. Choose
stream_update or batch_update. If not set, batch update will be
used by default.
Returns:
MatchingEngineIndex - Index resource object
Expand All @@ -605,4 +613,11 @@ def create_brute_force_index(
credentials=credentials,
request_metadata=request_metadata,
sync=sync,
index_update_method=index_update_method,
)


_INDEX_UPDATE_METHOD_TO_ENUM_VALUE = {
"STREAM_UPDATE": gca_matching_engine_index.Index.IndexUpdateMethod.STREAM_UPDATE,
"BATCH_UPDATE": gca_matching_engine_index.Index.IndexUpdateMethod.BATCH_UPDATE,
}
44 changes: 42 additions & 2 deletions tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@
),
]

# Index update method
_TEST_INDEX_BATCH_UPDATE_METHOD = "BATCH_UPDATE"
_TEST_INDEX_STREAM_UPDATE_METHOD = "STREAM_UPDATE"
_TEST_INDEX_EMPTY_UPDATE_METHOD = None
_TEST_INDEX_INVALID_UPDATE_METHOD = "INVALID_UPDATE_METHOD"
_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP = {
_TEST_INDEX_BATCH_UPDATE_METHOD: gca_index.Index.IndexUpdateMethod.BATCH_UPDATE,
_TEST_INDEX_STREAM_UPDATE_METHOD: gca_index.Index.IndexUpdateMethod.STREAM_UPDATE,
_TEST_INDEX_EMPTY_UPDATE_METHOD: None,
_TEST_INDEX_INVALID_UPDATE_METHOD: None,
}


def uuid_mock():
return uuid.UUID(int=1)
Expand Down Expand Up @@ -273,7 +285,16 @@ def test_delete_index(self, delete_index_mock, sync):

@pytest.mark.usefixtures("get_index_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_create_tree_ah_index(self, create_index_mock, sync):
@pytest.mark.parametrize(
"index_update_method",
[
_TEST_INDEX_STREAM_UPDATE_METHOD,
_TEST_INDEX_BATCH_UPDATE_METHOD,
_TEST_INDEX_EMPTY_UPDATE_METHOD,
_TEST_INDEX_INVALID_UPDATE_METHOD,
],
)
def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method):
aiplatform.init(project=_TEST_PROJECT)

my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
Expand All @@ -287,6 +308,7 @@ def test_create_tree_ah_index(self, create_index_mock, sync):
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
sync=sync,
index_update_method=index_update_method,
)

if not sync:
Expand All @@ -312,6 +334,9 @@ def test_create_tree_ah_index(self, create_index_mock, sync):
},
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[
index_update_method
],
)

create_index_mock.assert_called_once_with(
Expand All @@ -322,7 +347,18 @@ def test_create_tree_ah_index(self, create_index_mock, sync):

@pytest.mark.usefixtures("get_index_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_create_brute_force_index(self, create_index_mock, sync):
@pytest.mark.parametrize(
"index_update_method",
[
_TEST_INDEX_STREAM_UPDATE_METHOD,
_TEST_INDEX_BATCH_UPDATE_METHOD,
_TEST_INDEX_EMPTY_UPDATE_METHOD,
_TEST_INDEX_INVALID_UPDATE_METHOD,
],
)
def test_create_brute_force_index(
self, create_index_mock, sync, index_update_method
):
aiplatform.init(project=_TEST_PROJECT)

my_index = aiplatform.MatchingEngineIndex.create_brute_force_index(
Expand All @@ -333,6 +369,7 @@ def test_create_brute_force_index(self, create_index_mock, sync):
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
sync=sync,
index_update_method=index_update_method,
)

if not sync:
Expand All @@ -353,6 +390,9 @@ def test_create_brute_force_index(self, create_index_mock, sync):
},
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[
index_update_method
],
)

create_index_mock.assert_called_once_with(
Expand Down

0 comments on commit dcb6205

Please sign in to comment.