From a00db077a3ca77ee86117beb0b15d70d02e85e87 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Fri, 19 Jan 2024 01:30:06 -0800 Subject: [PATCH] feat: Support empty index for `MatchingEngineIndex` create index. PiperOrigin-RevId: 599760961 --- .../matching_engine/matching_engine_index.py | 32 +++-- .../aiplatform/test_matching_engine_index.py | 125 ++++++++++++++++++ 2 files changed, 143 insertions(+), 14 deletions(-) diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index.py b/google/cloud/aiplatform/matching_engine/matching_engine_index.py index d7a1c742ca..734d559c1d 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index.py @@ -101,8 +101,8 @@ def description(self) -> str: def _create( cls, display_name: str, - contents_delta_uri: str, - config: matching_engine_index_config.MatchingEngineIndexConfig, + contents_delta_uri: Optional[str] = None, + config: matching_engine_index_config.MatchingEngineIndexConfig = None, description: Optional[str] = None, labels: Optional[Dict[str, str]] = None, project: Optional[str] = None, @@ -121,7 +121,7 @@ def _create( The name can be up to 128 characters long and can be consist of any UTF-8 characters. contents_delta_uri (str): - Required. Allows inserting, updating or deleting the contents of the Matching Engine Index. + Optional. Allows inserting, updating or deleting the contents of the Matching Engine Index. The string must be a valid Google Cloud Storage directory path. If this field is set when calling IndexService.UpdateIndex, then no other Index field can be also updated as part of the same call. @@ -188,13 +188,17 @@ def _create( index_update_method ] + metadata = {"config": config.as_dict()} + if contents_delta_uri: + metadata = { + "config": config.as_dict(), + "contentsDeltaUri": contents_delta_uri, + } + gapic_index = gca_matching_engine_index.Index( display_name=display_name, description=description, - metadata={ - "config": config.as_dict(), - "contentsDeltaUri": contents_delta_uri, - }, + metadata=metadata, index_update_method=index_update_method_enum, ) @@ -399,9 +403,9 @@ def deployed_indexes( def create_tree_ah_index( cls, display_name: str, - contents_delta_uri: str, - dimensions: int, - approximate_neighbors_count: int, + contents_delta_uri: Optional[str] = None, + dimensions: int = None, + approximate_neighbors_count: int = None, leaf_node_embedding_count: Optional[int] = None, leaf_nodes_to_search_percent: Optional[float] = None, distance_measure_type: Optional[ @@ -439,7 +443,7 @@ def create_tree_ah_index( The name can be up to 128 characters long and can be consist of any UTF-8 characters. contents_delta_uri (str): - Required. Allows inserting, updating or deleting the contents of the Matching Engine Index. + Optional. Allows inserting, updating or deleting the contents of the Matching Engine Index. The string must be a valid Google Cloud Storage directory path. If this field is set when calling IndexService.UpdateIndex, then no other Index field can be also updated as part of the same call. @@ -543,8 +547,8 @@ def create_tree_ah_index( def create_brute_force_index( cls, display_name: str, - contents_delta_uri: str, - dimensions: int, + contents_delta_uri: Optional[str] = None, + dimensions: int = None, distance_measure_type: Optional[ matching_engine_index_config.DistanceMeasureType ] = None, @@ -578,7 +582,7 @@ def create_brute_force_index( The name can be up to 128 characters long and can be consist of any UTF-8 characters. contents_delta_uri (str): - Required. Allows inserting, updating or deleting the contents of the Matching Engine Index. + Optional. Allows inserting, updating or deleting the contents of the Matching Engine Index. The string must be a valid Google Cloud Storage directory path. If this field is set when calling IndexService.UpdateIndex, then no other Index field can be also updated as part of the same call. diff --git a/tests/unit/aiplatform/test_matching_engine_index.py b/tests/unit/aiplatform/test_matching_engine_index.py index e2b6e51d71..1a32d74dd1 100644 --- a/tests/unit/aiplatform/test_matching_engine_index.py +++ b/tests/unit/aiplatform/test_matching_engine_index.py @@ -409,6 +409,73 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method metadata=_TEST_REQUEST_METADATA, ) + @pytest.mark.usefixtures("get_index_mock") + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.parametrize( + "index_update_method", + [ + _TEST_INDEX_STREAM_UPDATE_METHOD, + _TEST_INDEX_BATCH_UPDATE_METHOD, + _TEST_INDEX_EMPTY_UPDATE_METHOD, + _TEST_INDEX_INVALID_UPDATE_METHOD, + ], + ) + def test_create_tree_ah_index_with_empty_index( + self, create_index_mock, sync, index_update_method + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index( + display_name=_TEST_INDEX_DISPLAY_NAME, + contents_delta_uri=None, + dimensions=_TEST_INDEX_CONFIG_DIMENSIONS, + approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT, + distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE, + leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT, + leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT, + description=_TEST_INDEX_DESCRIPTION, + labels=_TEST_LABELS, + sync=sync, + index_update_method=index_update_method, + encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, + ) + + if not sync: + my_index.wait() + + config = { + "treeAhConfig": { + "leafNodeEmbeddingCount": _TEST_LEAF_NODE_EMBEDDING_COUNT, + "leafNodesToSearchPercent": _TEST_LEAF_NODES_TO_SEARCH_PERCENT, + } + } + + expected = gca_index.Index( + display_name=_TEST_INDEX_DISPLAY_NAME, + metadata={ + "config": { + "algorithmConfig": config, + "dimensions": _TEST_INDEX_CONFIG_DIMENSIONS, + "approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT, + "distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE, + }, + }, + description=_TEST_INDEX_DESCRIPTION, + labels=_TEST_LABELS, + index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[ + index_update_method + ], + encryption_spec=gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME + ), + ) + + create_index_mock.assert_called_once_with( + parent=_TEST_PARENT, + index=expected, + metadata=_TEST_REQUEST_METADATA, + ) + @pytest.mark.usefixtures("get_index_mock") def test_create_tree_ah_index_backward_compatibility(self, create_index_mock): aiplatform.init(project=_TEST_PROJECT) @@ -513,6 +580,64 @@ def test_create_brute_force_index( metadata=_TEST_REQUEST_METADATA, ) + @pytest.mark.usefixtures("get_index_mock") + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.parametrize( + "index_update_method", + [ + _TEST_INDEX_STREAM_UPDATE_METHOD, + _TEST_INDEX_BATCH_UPDATE_METHOD, + _TEST_INDEX_EMPTY_UPDATE_METHOD, + _TEST_INDEX_INVALID_UPDATE_METHOD, + ], + ) + def test_create_brute_force_index_with_empty_index( + self, create_index_mock, sync, index_update_method + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index = aiplatform.MatchingEngineIndex.create_brute_force_index( + display_name=_TEST_INDEX_DISPLAY_NAME, + dimensions=_TEST_INDEX_CONFIG_DIMENSIONS, + distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE, + description=_TEST_INDEX_DESCRIPTION, + labels=_TEST_LABELS, + sync=sync, + index_update_method=index_update_method, + encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, + ) + + if not sync: + my_index.wait() + + config = {"bruteForceConfig": {}} + + expected = gca_index.Index( + display_name=_TEST_INDEX_DISPLAY_NAME, + metadata={ + "config": { + "algorithmConfig": config, + "dimensions": _TEST_INDEX_CONFIG_DIMENSIONS, + "approximateNeighborsCount": None, + "distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE, + }, + }, + description=_TEST_INDEX_DESCRIPTION, + labels=_TEST_LABELS, + index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[ + index_update_method + ], + encryption_spec=gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, + ), + ) + + create_index_mock.assert_called_once_with( + parent=_TEST_PARENT, + index=expected, + metadata=_TEST_REQUEST_METADATA, + ) + @pytest.mark.usefixtures("get_index_mock") def test_create_brute_force_index_backward_compatibility(self, create_index_mock): aiplatform.init(project=_TEST_PROJECT)