From 09d1946711fb022bf584137299ed187bf885cb23 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Tue, 20 Feb 2024 18:50:51 -0800 Subject: [PATCH] feat: Support custom `timeout` for `MatchingEngineIndex` and `MatchingEngineIndexEndpoint` APIs. PiperOrigin-RevId: 608815992 --- .../matching_engine/matching_engine_index.py | 20 +++++++++++++ .../matching_engine_index_endpoint.py | 29 +++++++++++++++++++ .../aiplatform/test_matching_engine_index.py | 18 +++++++++++- .../test_matching_engine_index_endpoint.py | 13 +++++++++ 4 files changed, 79 insertions(+), 1 deletion(-) diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index.py b/google/cloud/aiplatform/matching_engine/matching_engine_index.py index 734d559c1d..012eaa7b2d 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index.py @@ -112,6 +112,7 @@ def _create( sync: bool = True, index_update_method: Optional[str] = None, encryption_spec_key_name: Optional[str] = None, + create_request_timeout: Optional[float] = None, ) -> "MatchingEngineIndex": """Creates a MatchingEngineIndex resource. @@ -177,6 +178,8 @@ def _create( secured by this key. The key needs to be in the same region as where the index is created. + create_request_timeout (float): + Optional. The timeout for the request in seconds. Returns: MatchingEngineIndex - Index resource object @@ -220,6 +223,7 @@ def _create( ), index=gapic_index, metadata=request_metadata, + timeout=create_request_timeout, ) _LOGGER.log_create_with_lro(cls, create_lro) @@ -243,6 +247,7 @@ def update_metadata( description: Optional[str] = None, labels: Optional[Dict[str, str]] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + update_request_timeout: Optional[float] = None, ) -> "MatchingEngineIndex": """Updates the metadata for this index. @@ -269,6 +274,8 @@ def update_metadata( "aiplatform.googleapis.com/" and are immutable. request_metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. + update_request_timeout (float): + Optional. The timeout for the request in seconds. Returns: MatchingEngineIndex - The updated index resource object. @@ -307,6 +314,7 @@ def update_metadata( index=gapic_index, update_mask=update_mask, metadata=request_metadata, + timeout=update_request_timeout, ) _LOGGER.log_action_started_against_resource_with_lro( @@ -324,6 +332,7 @@ def update_embeddings( contents_delta_uri: str, is_complete_overwrite: Optional[bool] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + update_request_timeout: Optional[float] = None, ) -> "MatchingEngineIndex": """Updates the embeddings for this index. @@ -341,6 +350,8 @@ def update_embeddings( then existing content of the Index will be replaced by the data from the contentsDeltaUri. request_metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. + update_request_timeout (float): + Optional. The timeout for the request in seconds. Returns: MatchingEngineIndex - The updated index resource object. @@ -373,6 +384,7 @@ def update_embeddings( index=gapic_index, update_mask=update_mask, metadata=request_metadata, + timeout=update_request_timeout, ) _LOGGER.log_action_started_against_resource_with_lro( @@ -420,6 +432,7 @@ def create_tree_ah_index( sync: bool = True, index_update_method: Optional[str] = None, encryption_spec_key_name: Optional[str] = None, + create_request_timeout: Optional[float] = None, ) -> "MatchingEngineIndex": """Creates a MatchingEngineIndex resource that uses the tree-AH algorithm. @@ -510,6 +523,8 @@ def create_tree_ah_index( secured by this key. The key needs to be in the same region as where the index is created. + create_request_timeout (float): + Optional. The timeout for the request in seconds. Returns: MatchingEngineIndex - Index resource object @@ -541,6 +556,7 @@ def create_tree_ah_index( sync=sync, index_update_method=index_update_method, encryption_spec_key_name=encryption_spec_key_name, + create_request_timeout=create_request_timeout, ) @classmethod @@ -561,6 +577,7 @@ def create_brute_force_index( sync: bool = True, index_update_method: Optional[str] = None, encryption_spec_key_name: Optional[str] = None, + create_request_timeout: Optional[float] = None, ) -> "MatchingEngineIndex": """Creates a MatchingEngineIndex resource that uses the brute force algorithm. @@ -640,6 +657,8 @@ def create_brute_force_index( secured by this key. The key needs to be in the same region as where the index is created. + create_request_timeout (float): + Optional. The timeout for the request in seconds. Returns: MatchingEngineIndex - Index resource object @@ -667,6 +686,7 @@ def create_brute_force_index( sync=sync, index_update_method=index_update_method, encryption_spec_key_name=encryption_spec_key_name, + create_request_timeout=create_request_timeout, ) def upsert_datapoints( diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index 95a839c341..7aba68b133 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -346,6 +346,7 @@ def create( enable_private_service_connect: bool = False, project_allowlist: Optional[Sequence[str]] = None, encryption_spec_key_name: Optional[str] = None, + create_request_timeout: Optional[float] = None, ) -> "MatchingEngineIndexEndpoint": """Creates a MatchingEngineIndexEndpoint resource. @@ -423,6 +424,8 @@ def create( index endpoint will be secured by this key. The key needs to be in the same region as where the index endpoint is created. + create_request_timeout (float): + Optional. The timeout for the request in seconds. Returns: MatchingEngineIndexEndpoint - IndexEndpoint resource object @@ -469,6 +472,7 @@ def create( enable_private_service_connect=enable_private_service_connect, project_allowlist=project_allowlist, encryption_spec_key_name=encryption_spec_key_name, + create_request_timeout=create_request_timeout, ) @classmethod @@ -488,6 +492,7 @@ def _create( enable_private_service_connect: bool = False, project_allowlist: Optional[Sequence[str]] = None, encryption_spec_key_name: Optional[str] = None, + create_request_timeout: Optional[float] = None, ) -> "MatchingEngineIndexEndpoint": """Helper method to ensure network synchronization and to create a MatchingEngineIndexEndpoint resource. @@ -552,6 +557,8 @@ def _create( project_allowlist (MutableSequence[str]): A list of Projects from which the forwarding rule will target the service attachment. + create_request_timeout (float): + Optional. The timeout for the request in seconds. Returns: MatchingEngineIndexEndpoint - IndexEndpoint resource object @@ -596,6 +603,7 @@ def _create( ), index_endpoint=gapic_index_endpoint, metadata=request_metadata, + timeout=create_request_timeout, ) _LOGGER.log_create_with_lro(cls, create_lro) @@ -716,6 +724,7 @@ def update( description: Optional[str] = None, labels: Optional[Dict[str, str]] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + update_request_timeout: Optional[float] = None, ) -> "MatchingEngineIndexEndpoint": """Updates an existing index endpoint resource. @@ -742,6 +751,8 @@ def update( "aiplatform.googleapis.com/" and are immutable. request_metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. + update_request_timeout (float): + Optional. The timeout for the request in seconds. Returns: MatchingEngineIndexEndpoint - The updated index endpoint resource object. @@ -774,6 +785,7 @@ def update( index_endpoint=gapic_index_endpoint, update_mask=update_mask, metadata=request_metadata, + timeout=update_request_timeout, ) return self @@ -937,6 +949,7 @@ def deploy_index( auth_config_audiences: Optional[Sequence[str]] = None, auth_config_allowed_issuers: Optional[Sequence[str]] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + deploy_request_timeout: Optional[float] = None, ) -> "MatchingEngineIndexEndpoint": """Deploys an existing index resource to this endpoint resource. @@ -1030,6 +1043,9 @@ def deploy_index( auth_config_audiences and auth_config_allowed_issuers must be passed together. request_metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. + + deploy_request_timeout (float): + Optional. The timeout for the request in seconds. Returns: MatchingEngineIndexEndpoint - IndexEndpoint resource object """ @@ -1060,6 +1076,7 @@ def deploy_index( index_endpoint=self.resource_name, deployed_index=deployed_index, metadata=request_metadata, + timeout=deploy_request_timeout, ) _LOGGER.log_action_started_against_resource_with_lro( @@ -1081,6 +1098,7 @@ def undeploy_index( self, deployed_index_id: str, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + undeploy_request_timeout: Optional[float] = None, ) -> "MatchingEngineIndexEndpoint": """Undeploy a deployed index endpoint resource. @@ -1090,6 +1108,8 @@ def undeploy_index( to be undeployed from the IndexEndpoint. request_metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. + undeploy_request_timeout (float): + Optional. The timeout for the request in seconds. Returns: MatchingEngineIndexEndpoint - IndexEndpoint resource object """ @@ -1106,6 +1126,7 @@ def undeploy_index( index_endpoint=self.resource_name, deployed_index_id=deployed_index_id, metadata=request_metadata, + timeout=undeploy_request_timeout, ) _LOGGER.log_action_started_against_resource_with_lro( @@ -1126,6 +1147,7 @@ def mutate_deployed_index( min_replica_count: int = 1, max_replica_count: int = 1, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + mutate_request_timeout: Optional[float] = None, ): """Updates an existing deployed index under this endpoint resource. @@ -1157,6 +1179,8 @@ def mutate_deployed_index( will automatically be increased to be min_replica_count. request_metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. + timeout (float): + Optional. The timeout for the request in seconds. """ self.wait() @@ -1178,6 +1202,7 @@ def mutate_deployed_index( index_endpoint=self.resource_name, deployed_index=deployed_index, metadata=request_metadata, + timeout=mutate_request_timeout, ) _LOGGER.log_action_started_against_resource_with_lro( @@ -1211,6 +1236,7 @@ def _undeploy( deployed_index_id: str, metadata: Optional[Sequence[Tuple[str, str]]] = (), sync=True, + undeploy_request_timeout: Optional[float] = None, ) -> None: """Undeploys a deployed index. @@ -1221,6 +1247,8 @@ def _undeploy( metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. + timeout (float): + Optional. The timeout for the request in seconds. """ self._sync_gca_resource() @@ -1230,6 +1258,7 @@ def _undeploy( index_endpoint=self.resource_name, deployed_index_id=deployed_index_id, metadata=metadata, + timeout=undeploy_request_timeout, ) _LOGGER.log_action_started_against_resource_with_lro( diff --git a/tests/unit/aiplatform/test_matching_engine_index.py b/tests/unit/aiplatform/test_matching_engine_index.py index 1a32d74dd1..e221e2a4b9 100644 --- a/tests/unit/aiplatform/test_matching_engine_index.py +++ b/tests/unit/aiplatform/test_matching_engine_index.py @@ -147,6 +147,7 @@ ], ) _TEST_DATAPOINTS = (_TEST_DATAPOINT_1, _TEST_DATAPOINT_2, _TEST_DATAPOINT_3) +_TEST_TIMEOUT = 1800.0 def uuid_mock(): @@ -262,7 +263,8 @@ def test_init_index(self, index_name, get_index_mock): my_index = aiplatform.MatchingEngineIndex(index_name=index_name) get_index_mock.assert_called_once_with( - name=my_index.resource_name, retry=base._DEFAULT_RETRY + name=my_index.resource_name, + retry=base._DEFAULT_RETRY, ) @pytest.mark.usefixtures("get_index_mock") @@ -274,6 +276,7 @@ def test_update_index_metadata(self, update_index_metadata_mock): display_name=_TEST_DISPLAY_NAME_UPDATE, description=_TEST_DESCRIPTION_UPDATE, labels=_TEST_LABELS_UPDATE, + update_request_timeout=_TEST_TIMEOUT, ) expected = gca_index.Index( @@ -289,6 +292,7 @@ def test_update_index_metadata(self, update_index_metadata_mock): paths=["labels", "display_name", "description"] ), metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, ) assert updated_index.gca_resource == expected @@ -301,6 +305,7 @@ def test_update_index_embeddings(self, update_index_embeddings_mock): updated_index = my_index.update_embeddings( contents_delta_uri=_TEST_CONTENTS_DELTA_URI_UPDATE, is_complete_overwrite=_TEST_IS_COMPLETE_OVERWRITE_UPDATE, + update_request_timeout=_TEST_TIMEOUT, ) expected = gca_index.Index( @@ -315,6 +320,7 @@ def test_update_index_embeddings(self, update_index_embeddings_mock): index=expected, update_mask=field_mask_pb2.FieldMask(paths=["metadata"]), metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, ) # The service only returns the name of the Index @@ -370,6 +376,7 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method sync=sync, index_update_method=index_update_method, encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, + create_request_timeout=_TEST_TIMEOUT, ) if not sync: @@ -407,6 +414,7 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method parent=_TEST_PARENT, index=expected, metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, ) @pytest.mark.usefixtures("get_index_mock") @@ -438,6 +446,7 @@ def test_create_tree_ah_index_with_empty_index( sync=sync, index_update_method=index_update_method, encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, + create_request_timeout=_TEST_TIMEOUT, ) if not sync: @@ -474,6 +483,7 @@ def test_create_tree_ah_index_with_empty_index( parent=_TEST_PARENT, index=expected, metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, ) @pytest.mark.usefixtures("get_index_mock") @@ -518,6 +528,7 @@ def test_create_tree_ah_index_backward_compatibility(self, create_index_mock): parent=_TEST_PARENT, index=expected, metadata=_TEST_REQUEST_METADATA, + timeout=None, ) @pytest.mark.usefixtures("get_index_mock") @@ -546,6 +557,7 @@ def test_create_brute_force_index( sync=sync, index_update_method=index_update_method, encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, + create_request_timeout=_TEST_TIMEOUT, ) if not sync: @@ -578,6 +590,7 @@ def test_create_brute_force_index( parent=_TEST_PARENT, index=expected, metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, ) @pytest.mark.usefixtures("get_index_mock") @@ -605,6 +618,7 @@ def test_create_brute_force_index_with_empty_index( sync=sync, index_update_method=index_update_method, encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, + create_request_timeout=_TEST_TIMEOUT, ) if not sync: @@ -636,6 +650,7 @@ def test_create_brute_force_index_with_empty_index( parent=_TEST_PARENT, index=expected, metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, ) @pytest.mark.usefixtures("get_index_mock") @@ -672,6 +687,7 @@ def test_create_brute_force_index_backward_compatibility(self, create_index_mock parent=_TEST_PARENT, index=expected, metadata=_TEST_REQUEST_METADATA, + timeout=None, ) @pytest.mark.usefixtures("get_index_mock") diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 02aaa298e7..6e0d052bb7 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -305,6 +305,7 @@ ), ] ] +_TEST_TIMEOUT = 1800.0 def uuid_mock(): @@ -689,6 +690,7 @@ def test_update_index_endpoint(self, update_index_endpoint_mock): description=_TEST_DESCRIPTION_UPDATE, labels=_TEST_LABELS_UPDATE, request_metadata=_TEST_REQUEST_METADATA, + update_request_timeout=_TEST_TIMEOUT, ) expected = gca_index_endpoint.IndexEndpoint( @@ -704,6 +706,7 @@ def test_update_index_endpoint(self, update_index_endpoint_mock): paths=["labels", "display_name", "description"] ), metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, ) assert updated_endpoint.gca_resource == expected @@ -747,6 +750,7 @@ def test_create_index_endpoint(self, create_index_endpoint_mock, sync): network=_TEST_INDEX_ENDPOINT_VPC_NETWORK, description=_TEST_INDEX_ENDPOINT_DESCRIPTION, labels=_TEST_LABELS, + create_request_timeout=_TEST_TIMEOUT, ) if not sync: @@ -762,6 +766,7 @@ def test_create_index_endpoint(self, create_index_endpoint_mock, sync): parent=_TEST_PARENT, index_endpoint=expected, metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, ) @pytest.mark.usefixtures("get_index_endpoint_mock") @@ -795,6 +800,7 @@ def test_create_index_endpoint_with_private_service_connect( parent=_TEST_PARENT, index_endpoint=expected, metadata=_TEST_REQUEST_METADATA, + timeout=None, ) @pytest.mark.usefixtures("get_index_endpoint_mock") @@ -823,6 +829,7 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc parent=_TEST_PARENT, index_endpoint=expected, metadata=_TEST_REQUEST_METADATA, + timeout=None, ) @pytest.mark.usefixtures("get_index_public_endpoint_mock") @@ -857,6 +864,7 @@ def test_create_index_endpoint_with_public_endpoint_enabled( parent=_TEST_PARENT, index_endpoint=expected, metadata=_TEST_REQUEST_METADATA, + timeout=None, ) assert ( @@ -962,6 +970,7 @@ def test_deploy_index(self, deploy_index_mock, undeploy_index_mock): auth_config_audiences=_TEST_AUTH_CONFIG_AUDIENCES, auth_config_allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, request_metadata=_TEST_REQUEST_METADATA, + deploy_request_timeout=_TEST_TIMEOUT, ) deploy_index_mock.assert_called_once_with( @@ -985,6 +994,7 @@ def test_deploy_index(self, deploy_index_mock, undeploy_index_mock): ), ), metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, ) my_index_endpoint = my_index_endpoint.undeploy_index( @@ -995,6 +1005,7 @@ def test_deploy_index(self, deploy_index_mock, undeploy_index_mock): index_endpoint=my_index_endpoint.resource_name, deployed_index_id=_TEST_DEPLOYED_INDEX_ID, metadata=_TEST_REQUEST_METADATA, + timeout=None, ) @pytest.mark.usefixtures("get_index_endpoint_mock", "get_index_mock") @@ -1010,6 +1021,7 @@ def test_mutate_deployed_index(self, mutate_deployed_index_mock): min_replica_count=_TEST_MIN_REPLICA_COUNT_UPDATED, max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED, request_metadata=_TEST_REQUEST_METADATA, + mutate_request_timeout=_TEST_TIMEOUT, ) mutate_deployed_index_mock.assert_called_once_with( @@ -1022,6 +1034,7 @@ def test_mutate_deployed_index(self, mutate_deployed_index_mock): }, ), metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, ) @pytest.mark.usefixtures("get_index_endpoint_mock")