Skip to content

Commit

Permalink
feat: Support empty index for MatchingEngineIndex create index.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599760961
  • Loading branch information
lingyinw authored and Copybara-Service committed Jan 19, 2024
1 parent b0b604e commit a00db07
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 14 deletions.
32 changes: 18 additions & 14 deletions google/cloud/aiplatform/matching_engine/matching_engine_index.py
Expand Up @@ -101,8 +101,8 @@ def description(self) -> str:
def _create(
cls,
display_name: str,
contents_delta_uri: str,
config: matching_engine_index_config.MatchingEngineIndexConfig,
contents_delta_uri: Optional[str] = None,
config: matching_engine_index_config.MatchingEngineIndexConfig = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
project: Optional[str] = None,
Expand All @@ -121,7 +121,7 @@ def _create(
The name can be up to 128 characters long and
can be consist of any UTF-8 characters.
contents_delta_uri (str):
Required. Allows inserting, updating or deleting the contents of the Matching Engine Index.
Optional. Allows inserting, updating or deleting the contents of the Matching Engine Index.
The string must be a valid Google Cloud Storage directory path. If this
field is set when calling IndexService.UpdateIndex, then no other
Index field can be also updated as part of the same call.
Expand Down Expand Up @@ -188,13 +188,17 @@ def _create(
index_update_method
]

metadata = {"config": config.as_dict()}
if contents_delta_uri:
metadata = {
"config": config.as_dict(),
"contentsDeltaUri": contents_delta_uri,
}

gapic_index = gca_matching_engine_index.Index(
display_name=display_name,
description=description,
metadata={
"config": config.as_dict(),
"contentsDeltaUri": contents_delta_uri,
},
metadata=metadata,
index_update_method=index_update_method_enum,
)

Expand Down Expand Up @@ -399,9 +403,9 @@ def deployed_indexes(
def create_tree_ah_index(
cls,
display_name: str,
contents_delta_uri: str,
dimensions: int,
approximate_neighbors_count: int,
contents_delta_uri: Optional[str] = None,
dimensions: int = None,
approximate_neighbors_count: int = None,
leaf_node_embedding_count: Optional[int] = None,
leaf_nodes_to_search_percent: Optional[float] = None,
distance_measure_type: Optional[
Expand Down Expand Up @@ -439,7 +443,7 @@ def create_tree_ah_index(
The name can be up to 128 characters long and
can be consist of any UTF-8 characters.
contents_delta_uri (str):
Required. Allows inserting, updating or deleting the contents of the Matching Engine Index.
Optional. Allows inserting, updating or deleting the contents of the Matching Engine Index.
The string must be a valid Google Cloud Storage directory path. If this
field is set when calling IndexService.UpdateIndex, then no other
Index field can be also updated as part of the same call.
Expand Down Expand Up @@ -543,8 +547,8 @@ def create_tree_ah_index(
def create_brute_force_index(
cls,
display_name: str,
contents_delta_uri: str,
dimensions: int,
contents_delta_uri: Optional[str] = None,
dimensions: int = None,
distance_measure_type: Optional[
matching_engine_index_config.DistanceMeasureType
] = None,
Expand Down Expand Up @@ -578,7 +582,7 @@ def create_brute_force_index(
The name can be up to 128 characters long and
can be consist of any UTF-8 characters.
contents_delta_uri (str):
Required. Allows inserting, updating or deleting the contents of the Matching Engine Index.
Optional. Allows inserting, updating or deleting the contents of the Matching Engine Index.
The string must be a valid Google Cloud Storage directory path. If this
field is set when calling IndexService.UpdateIndex, then no other
Index field can be also updated as part of the same call.
Expand Down
125 changes: 125 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index.py
Expand Up @@ -409,6 +409,73 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"index_update_method",
[
_TEST_INDEX_STREAM_UPDATE_METHOD,
_TEST_INDEX_BATCH_UPDATE_METHOD,
_TEST_INDEX_EMPTY_UPDATE_METHOD,
_TEST_INDEX_INVALID_UPDATE_METHOD,
],
)
def test_create_tree_ah_index_with_empty_index(
self, create_index_mock, sync, index_update_method
):
aiplatform.init(project=_TEST_PROJECT)

my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
display_name=_TEST_INDEX_DISPLAY_NAME,
contents_delta_uri=None,
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT,
leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT,
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
sync=sync,
index_update_method=index_update_method,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
)

if not sync:
my_index.wait()

config = {
"treeAhConfig": {
"leafNodeEmbeddingCount": _TEST_LEAF_NODE_EMBEDDING_COUNT,
"leafNodesToSearchPercent": _TEST_LEAF_NODES_TO_SEARCH_PERCENT,
}
}

expected = gca_index.Index(
display_name=_TEST_INDEX_DISPLAY_NAME,
metadata={
"config": {
"algorithmConfig": config,
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
},
},
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[
index_update_method
],
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
),
)

create_index_mock.assert_called_once_with(
parent=_TEST_PARENT,
index=expected,
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
def test_create_tree_ah_index_backward_compatibility(self, create_index_mock):
aiplatform.init(project=_TEST_PROJECT)
Expand Down Expand Up @@ -513,6 +580,64 @@ def test_create_brute_force_index(
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"index_update_method",
[
_TEST_INDEX_STREAM_UPDATE_METHOD,
_TEST_INDEX_BATCH_UPDATE_METHOD,
_TEST_INDEX_EMPTY_UPDATE_METHOD,
_TEST_INDEX_INVALID_UPDATE_METHOD,
],
)
def test_create_brute_force_index_with_empty_index(
self, create_index_mock, sync, index_update_method
):
aiplatform.init(project=_TEST_PROJECT)

my_index = aiplatform.MatchingEngineIndex.create_brute_force_index(
display_name=_TEST_INDEX_DISPLAY_NAME,
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
sync=sync,
index_update_method=index_update_method,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
)

if not sync:
my_index.wait()

config = {"bruteForceConfig": {}}

expected = gca_index.Index(
display_name=_TEST_INDEX_DISPLAY_NAME,
metadata={
"config": {
"algorithmConfig": config,
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": None,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
},
},
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[
index_update_method
],
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
),
)

create_index_mock.assert_called_once_with(
parent=_TEST_PARENT,
index=expected,
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
def test_create_brute_force_index_backward_compatibility(self, create_index_mock):
aiplatform.init(project=_TEST_PROJECT)
Expand Down

0 comments on commit a00db07

Please sign in to comment.