Skip to content

Commit

Permalink
feat: Support shard_size for MatchingEngineIndex create index.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613432969
  • Loading branch information
lingyinw authored and Copybara-Service committed Mar 7, 2024
1 parent f294ba8 commit 6dbf7d3
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
25 changes: 25 additions & 0 deletions google/cloud/aiplatform/matching_engine/matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def create_tree_ah_index(
index_update_method: Optional[str] = None,
encryption_spec_key_name: Optional[str] = None,
create_request_timeout: Optional[float] = None,
shard_size: Optional[str] = None,
) -> "MatchingEngineIndex":
"""Creates a MatchingEngineIndex resource that uses the tree-AH algorithm.
Expand Down Expand Up @@ -525,6 +526,16 @@ def create_tree_ah_index(
created.
create_request_timeout (float):
Optional. The timeout for the request in seconds.
shard_size (str):
Optional. The size of each shard. Index will get resharded
based on specified shard size. During serving, each shard will
be served on a separate node and will scale independently.
Choose one of the following:
SHARD_SIZE_SMALL
SHARD_SIZE_MEDIUM
SHARD_SIZE_LARGE
Returns:
MatchingEngineIndex - Index resource object
Expand All @@ -541,6 +552,7 @@ def create_tree_ah_index(
algorithm_config=algorithm_config,
approximate_neighbors_count=approximate_neighbors_count,
distance_measure_type=distance_measure_type,
shard_size=shard_size,
)

return cls._create(
Expand Down Expand Up @@ -578,6 +590,7 @@ def create_brute_force_index(
index_update_method: Optional[str] = None,
encryption_spec_key_name: Optional[str] = None,
create_request_timeout: Optional[float] = None,
shard_size: Optional[str] = None,
) -> "MatchingEngineIndex":
"""Creates a MatchingEngineIndex resource that uses the brute force algorithm.
Expand Down Expand Up @@ -659,6 +672,17 @@ def create_brute_force_index(
created.
create_request_timeout (float):
Optional. The timeout for the request in seconds.
shard_size (str):
Optional. The size of each shard. Index will get resharded
based on specified shard size. During serving, each shard will
be served on a separate node and will scale independently.
If not set, shard size is default to SHARD_SIZE_MEDIUM.
Choose one of the following:
SHARD_SIZE_SMALL
SHARD_SIZE_MEDIUM
SHARD_SIZE_LARGE
Returns:
MatchingEngineIndex - Index resource object
Expand All @@ -671,6 +695,7 @@ def create_brute_force_index(
dimensions=dimensions,
algorithm_config=algorithm_config,
distance_measure_type=distance_measure_type,
shard_size=shard_size,
)

return cls._create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ class MatchingEngineIndexConfig:
approximate search algorithm are reordered via a more expensive distance computation.
Required if tree-AH algorithm is used.
shard_size (str):
Optional. The size of each shard. Index will get resharded the
based on specified shard size. During serving,
each shard will be served on a separate node and will scale
independently.
distance_measure_type (DistanceMeasureType):
Optional. The distance measure used in nearest neighbor search.
"""
Expand All @@ -126,17 +131,19 @@ class MatchingEngineIndexConfig:
algorithm_config: AlgorithmConfig
approximate_neighbors_count: Optional[int] = None
distance_measure_type: Optional[DistanceMeasureType] = None
shard_size: Optional[str] = None

def as_dict(self) -> Dict[str, Any]:
"""Returns the configuration as a dictionary.
Returns:
Dict[str, Any]
"""

return {
res = {
"dimensions": self.dimensions,
"algorithmConfig": self.algorithm_config.as_dict(),
"approximateNeighborsCount": self.approximate_neighbors_count,
"distanceMeasureType": self.distance_measure_type,
"shardSize": self.shard_size,
}
return res
21 changes: 18 additions & 3 deletions tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT = 150
_TEST_LEAF_NODE_EMBEDDING_COUNT = 123
_TEST_LEAF_NODES_TO_SEARCH_PERCENT = 50
_TEST_SHARD_SIZES = ["SHARD_SIZE_SMALL", "SHARD_SIZE_LARGE", "SHARD_SIZE_MEDIUM"]

_TEST_INDEX_DESCRIPTION = test_constants.MatchingEngineConstants._TEST_INDEX_DESCRIPTION

Expand Down Expand Up @@ -361,7 +362,10 @@ def test_delete_index(self, delete_index_mock, sync):
_TEST_INDEX_INVALID_UPDATE_METHOD,
],
)
def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method):
@pytest.mark.parametrize("shard_size", _TEST_SHARD_SIZES)
def test_create_tree_ah_index(
self, create_index_mock, sync, index_update_method, shard_size
):
aiplatform.init(project=_TEST_PROJECT)

my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
Expand All @@ -378,6 +382,7 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
index_update_method=index_update_method,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
create_request_timeout=_TEST_TIMEOUT,
shard_size=shard_size,
)

if not sync:
Expand All @@ -398,6 +403,7 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"shardSize": shard_size,
},
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
},
Expand Down Expand Up @@ -429,8 +435,9 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
_TEST_INDEX_INVALID_UPDATE_METHOD,
],
)
@pytest.mark.parametrize("shard_size", _TEST_SHARD_SIZES)
def test_create_tree_ah_index_with_empty_index(
self, create_index_mock, sync, index_update_method
self, create_index_mock, sync, index_update_method, shard_size
):
aiplatform.init(project=_TEST_PROJECT)

Expand All @@ -448,6 +455,7 @@ def test_create_tree_ah_index_with_empty_index(
index_update_method=index_update_method,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
create_request_timeout=_TEST_TIMEOUT,
shard_size=shard_size,
)

if not sync:
Expand All @@ -468,6 +476,7 @@ def test_create_tree_ah_index_with_empty_index(
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"shardSize": shard_size,
},
},
description=_TEST_INDEX_DESCRIPTION,
Expand Down Expand Up @@ -518,6 +527,7 @@ def test_create_tree_ah_index_backward_compatibility(self, create_index_mock):
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"shardSize": None,
},
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
},
Expand All @@ -543,8 +553,9 @@ def test_create_tree_ah_index_backward_compatibility(self, create_index_mock):
_TEST_INDEX_INVALID_UPDATE_METHOD,
],
)
@pytest.mark.parametrize("shard_size", _TEST_SHARD_SIZES)
def test_create_brute_force_index(
self, create_index_mock, sync, index_update_method
self, create_index_mock, sync, index_update_method, shard_size
):
aiplatform.init(project=_TEST_PROJECT)

Expand All @@ -559,6 +570,7 @@ def test_create_brute_force_index(
index_update_method=index_update_method,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
create_request_timeout=_TEST_TIMEOUT,
shard_size=shard_size,
)

if not sync:
Expand All @@ -574,6 +586,7 @@ def test_create_brute_force_index(
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": None,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"shardSize": shard_size,
},
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
},
Expand Down Expand Up @@ -635,6 +648,7 @@ def test_create_brute_force_index_with_empty_index(
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": None,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"shardSize": None,
},
},
description=_TEST_INDEX_DESCRIPTION,
Expand Down Expand Up @@ -677,6 +691,7 @@ def test_create_brute_force_index_backward_compatibility(self, create_index_mock
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": None,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"shardSize": None,
},
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
},
Expand Down

0 comments on commit 6dbf7d3

Please sign in to comment.