Skip to content

Commit

Permalink
feat: add upsert_datapoints() to MatchingEngineIndex to support s…
Browse files Browse the repository at this point in the history
…treaming update index.

PiperOrigin-RevId: 583089201
  • Loading branch information
lingyinw authored and Copybara-Service committed Nov 16, 2023
1 parent ba2fb39 commit 7ca484d
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 14 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1beta1
types.index = types.index_v1beta1
types.index_endpoint = types.index_endpoint_v1beta1
types.index_service = types.index_service_v1beta1
types.io = types.io_v1beta1
types.job_service = types.job_service_v1beta1
types.job_state = types.job_state_v1beta1
Expand Down Expand Up @@ -189,6 +190,7 @@
types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1
types.index = types.index_v1
types.index_endpoint = types.index_endpoint_v1
types.index_service = types.index_service_v1
types.io = types.io_v1
types.job_service = types.job_service_v1
types.job_state = types.job_state_v1
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/compat/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
hyperparameter_tuning_job as hyperparameter_tuning_job_v1,
index as index_v1,
index_endpoint as index_endpoint_v1,
index_service as index_service_v1,
io as io_v1,
job_service as job_service_v1,
job_state as job_state_v1,
Expand Down Expand Up @@ -204,6 +205,7 @@
matching_engine_deployed_index_ref_v1,
index_v1,
index_endpoint_v1,
index_service_v1,
metadata_service_v1,
metadata_schema_v1,
metadata_store_v1,
Expand Down
49 changes: 40 additions & 9 deletions google/cloud/aiplatform/matching_engine/matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from google.protobuf import field_mask_pb2
from google.cloud.aiplatform import base
from google.cloud.aiplatform.compat.types import (
index_service_v1beta1 as gca_index_service_v1beta1,
index_service as gca_index_service,
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
matching_engine_index as gca_matching_engine_index,
encryption_spec as gca_encryption_spec,
Expand Down Expand Up @@ -665,6 +665,42 @@ def create_brute_force_index(
encryption_spec_key_name=encryption_spec_key_name,
)

def upsert_datapoints(
self,
datapoints: Sequence[gca_matching_engine_index.IndexDatapoint],
) -> "MatchingEngineIndex":
"""Upsert datapoints to this index.
Args:
datapoints (Sequence[gca_matching_engine_index.IndexDatapoint]):
Required. Datapoints to be upserted to this index.
Returns:
MatchingEngineIndex - Index resource object
"""

self.wait()

_LOGGER.log_action_start_against_resource(
"Upserting datapoints",
"index",
self,
)

self.api_client.upsert_datapoints(
gca_index_service.UpsertDatapointsRequest(
index=self.resource_name,
datapoints=datapoints,
)
)

_LOGGER.log_action_completed_against_resource(
"index", "Upserted datapoints", self
)

return self

def remove_datapoints(
self,
datapoint_ids: Sequence[str],
Expand All @@ -678,6 +714,7 @@ def remove_datapoints(
Returns:
MatchingEngineIndex - Index resource object
"""

self.wait()

_LOGGER.log_action_start_against_resource(
Expand All @@ -686,19 +723,13 @@ def remove_datapoints(
self,
)

remove_lro = self.api_client.remove_datapoints(
gca_index_service_v1beta1.RemoveDatapointsRequest(
self.api_client.remove_datapoints(
gca_index_service.RemoveDatapointsRequest(
index=self.resource_name,
datapoint_ids=datapoint_ids,
)
)

_LOGGER.log_action_started_against_resource_with_lro(
"Remove datapoints", "index", self.__class__, remove_lro
)

self._gca_resource = remove_lro.result(timeout=None)

_LOGGER.log_action_completed_against_resource(
"index", "Removed datapoints", self
)
Expand Down
66 changes: 61 additions & 5 deletions tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from google.cloud.aiplatform.compat.types import (
index as gca_index,
encryption_spec as gca_encryption_spec,
index_service_v1beta1 as gca_index_service_v1beta1,
index_service as gca_index_service,
)
import constants as test_constants

Expand Down Expand Up @@ -111,8 +111,42 @@
# Encryption spec
_TEST_ENCRYPTION_SPEC_KEY_NAME = "TEST_ENCRYPTION_SPEC"

# Streaming update
_TEST_DATAPOINT_IDS = ("1", "2")
_TEST_DATAPOINT_1 = gca_index.IndexDatapoint(
datapoint_id="0",
feature_vector=[0.00526886899, -0.0198396724],
restricts=[
gca_index.IndexDatapoint.Restriction(namespace="Color", allow_list=["red"])
],
numeric_restricts=[
gca_index.IndexDatapoint.NumericRestriction(
namespace="cost",
value_int=1,
)
],
)
_TEST_DATAPOINT_2 = gca_index.IndexDatapoint(
datapoint_id="1",
feature_vector=[0.00526886899, -0.0198396724],
numeric_restricts=[
gca_index.IndexDatapoint.NumericRestriction(
namespace="cost",
value_double=0.1,
)
],
crowding_tag=gca_index.IndexDatapoint.CrowdingTag(crowding_attribute="crowding"),
)
_TEST_DATAPOINT_3 = gca_index.IndexDatapoint(
datapoint_id="2",
feature_vector=[0.00526886899, -0.0198396724],
numeric_restricts=[
gca_index.IndexDatapoint.NumericRestriction(
namespace="cost",
value_float=1.1,
)
],
)
_TEST_DATAPOINTS = (_TEST_DATAPOINT_1, _TEST_DATAPOINT_2, _TEST_DATAPOINT_3)


def uuid_mock():
Expand Down Expand Up @@ -196,13 +230,19 @@ def create_index_mock():
yield create_index_mock


@pytest.fixture
def upsert_datapoints_mock():
with patch.object(
index_service_client.IndexServiceClient, "upsert_datapoints"
) as upsert_datapoints_mock:
yield upsert_datapoints_mock


@pytest.fixture
def remove_datapoints_mock():
with patch.object(
index_service_client.IndexServiceClient, "remove_datapoints"
) as remove_datapoints_mock:
remove_datapoints_lro_mock = mock.Mock(operation.Operation)
remove_datapoints_mock.return_value = remove_datapoints_lro_mock
yield remove_datapoints_mock


Expand Down Expand Up @@ -509,6 +549,22 @@ def test_create_brute_force_index_backward_compatibility(self, create_index_mock
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
def test_upsert_datapoints(self, upsert_datapoints_mock):
aiplatform.init(project=_TEST_PROJECT)

my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID)
my_index.upsert_datapoints(
datapoints=_TEST_DATAPOINTS,
)

upsert_datapoints_request = gca_index_service.UpsertDatapointsRequest(
index=_TEST_INDEX_NAME,
datapoints=_TEST_DATAPOINTS,
)

upsert_datapoints_mock.assert_called_once_with(upsert_datapoints_request)

@pytest.mark.usefixtures("get_index_mock")
def test_remove_datapoints(self, remove_datapoints_mock):
aiplatform.init(project=_TEST_PROJECT)
Expand All @@ -518,7 +574,7 @@ def test_remove_datapoints(self, remove_datapoints_mock):
datapoint_ids=_TEST_DATAPOINT_IDS,
)

remove_datapoints_request = gca_index_service_v1beta1.RemoveDatapointsRequest(
remove_datapoints_request = gca_index_service.RemoveDatapointsRequest(
index=_TEST_INDEX_NAME,
datapoint_ids=_TEST_DATAPOINT_IDS,
)
Expand Down

0 comments on commit 7ca484d

Please sign in to comment.