From 750e17b4c25c9030018521545b3c21e1fb1404c2 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Fri, 10 Nov 2023 11:58:43 -0800 Subject: [PATCH] feat: add `encryption_spec_key_name`, `enable_private_service_connect`,`project_allowlist` to MatchingEngineIndexEndpoint `create`. PiperOrigin-RevId: 581328160 --- google/cloud/aiplatform/compat/__init__.py | 2 + .../cloud/aiplatform/compat/types/__init__.py | 2 + .../matching_engine_index_endpoint.py | 86 +++++++++++++++-- .../test_matching_engine_index_endpoint.py | 96 ++++++++++++++++++- 4 files changed, 177 insertions(+), 9 deletions(-) diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py index 965b8928cf..f4a5cdde26 100644 --- a/google/cloud/aiplatform/compat/__init__.py +++ b/google/cloud/aiplatform/compat/__init__.py @@ -111,6 +111,7 @@ types.model_garden_service = types.model_garden_service_v1beta1 types.model_monitoring = types.model_monitoring_v1beta1 types.model_service = types.model_service_v1beta1 + types.service_networking = types.service_networking_v1beta1 types.operation = types.operation_v1beta1 types.pipeline_failure_policy = types.pipeline_failure_policy_v1beta1 types.pipeline_job = types.pipeline_job_v1beta1 @@ -208,6 +209,7 @@ types.model_deployment_monitoring_job = types.model_deployment_monitoring_job_v1 types.model_monitoring = types.model_monitoring_v1 types.model_service = types.model_service_v1 + types.service_networking = types.service_networking_v1 types.operation = types.operation_v1 types.pipeline_failure_policy = types.pipeline_failure_policy_v1 types.pipeline_job = types.pipeline_job_v1 diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index fb72dc7103..83387b064f 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -75,6 +75,7 @@ pipeline_state as pipeline_state_v1beta1, prediction_service as prediction_service_v1beta1, publisher_model as publisher_model_v1beta1, + service_networking as service_networking_v1beta1, schedule as schedule_v1beta1, schedule_service as schedule_service_v1beta1, specialist_pool as specialist_pool_v1beta1, @@ -147,6 +148,7 @@ publisher_model as publisher_model_v1, schedule as schedule_v1, schedule_service as schedule_service_v1, + service_networking as service_networking_v1, specialist_pool as specialist_pool_v1, specialist_pool_service as specialist_pool_service_v1, study as study_v1, diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index 3ccfc06fb6..04f211c201 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -28,6 +28,8 @@ matching_engine_index_endpoint as gca_matching_engine_index_endpoint, match_service_v1beta1 as gca_match_service_v1beta1, index_v1beta1 as gca_index_v1beta1, + service_networking as gca_service_networking, + encryption_spec as gca_encryption_spec, ) from google.cloud.aiplatform.matching_engine._protos import match_service_pb2 from google.cloud.aiplatform.matching_engine._protos import ( @@ -145,6 +147,9 @@ def create( credentials: Optional[auth_credentials.Credentials] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), sync: bool = True, + enable_private_service_connect: Optional[bool] = False, + project_allowlist: Optional[Sequence[str]] = None, + encryption_spec_key_name: Optional[str] = None, ) -> "MatchingEngineIndexEndpoint": """Creates a MatchingEngineIndexEndpoint resource. @@ -205,6 +210,23 @@ def create( Optional. Whether to execute this creation synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + enable_private_service_connect (bool): + If true, expose the index endpoint via private service connect. + project_allowlist (Sequence[str]): + Optional. List of projects from which the forwarding rule will + target the service attachment. + encryption_spec_key_name (str): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the index endpoint. + Has the form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this index endpoint and all sub-resources of this + index endpoint will be secured by this key. + The key needs to be in the same region as where the index + endpoint is created. Returns: MatchingEngineIndexEndpoint - IndexEndpoint resource object @@ -214,14 +236,27 @@ def create( """ network = network or initializer.global_config.network - if not network and not public_endpoint_enabled: + if not (network or public_endpoint_enabled or enable_private_service_connect): raise ValueError( - "Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint" + "Please provide `network` argument for Private Service Access endpoint," + "or provide `enable_private_service_connect` for Private Service" + "Connect endpoint, or provide `public_endpoint_enabled` to" + "deploy to a public endpoint" ) - if network and public_endpoint_enabled: + if ( + sum( + bool(network_setting) + for network_setting in [ + network, + public_endpoint_enabled, + enable_private_service_connect, + ] + ) + > 1 + ): raise ValueError( - "`network` and `public_endpoint_enabled` argument should not be set at the same time" + "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set." ) return cls._create( @@ -235,6 +270,9 @@ def create( credentials=credentials, request_metadata=request_metadata, sync=sync, + enable_private_service_connect=enable_private_service_connect, + project_allowlist=project_allowlist, + encryption_spec_key_name=encryption_spec_key_name, ) @classmethod @@ -251,6 +289,9 @@ def _create( credentials: Optional[auth_credentials.Credentials] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), sync: bool = True, + enable_private_service_connect: Optional[bool] = False, + project_allowlist: Optional[Sequence[str]] = None, + encryption_spec_key_name: Optional[str] = None, ) -> "MatchingEngineIndexEndpoint": """Helper method to ensure network synchronization and to create a MatchingEngineIndexEndpoint resource. @@ -304,20 +345,53 @@ def _create( Optional. Whether to execute this creation synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + encryption_spec_key_name (str): + Immutable. Customer-managed encryption key + spec for an IndexEndpoint. If set, this + IndexEndpoint and all sub-resources of this + IndexEndpoint will be secured by this key. + enable_private_service_connect (bool): + Required. If true, expose the IndexEndpoint + via private service connect. + project_allowlist (MutableSequence[str]): + A list of Projects from which the forwarding + rule will target the service attachment. Returns: MatchingEngineIndexEndpoint - IndexEndpoint resource object """ - + # Public if public_endpoint_enabled: gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint( display_name=display_name, description=description, public_endpoint_enabled=public_endpoint_enabled, + encryption_spec=gca_encryption_spec.EncryptionSpec( + kms_key_name=encryption_spec_key_name + ), + ) + # PSA + elif network: + gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint( + display_name=display_name, + description=description, + network=network, + encryption_spec=gca_encryption_spec.EncryptionSpec( + kms_key_name=encryption_spec_key_name + ), ) + # PSC else: gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint( - display_name=display_name, description=description, network=network + display_name=display_name, + description=description, + private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig( + project_allowlist=project_allowlist, + enable_private_service_connect=enable_private_service_connect, + ), + encryption_spec=gca_encryption_spec.EncryptionSpec( + kms_key_name=encryption_spec_key_name + ), ) if labels: diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 48e7c3c506..c578d786df 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -34,6 +34,8 @@ index as gca_index, match_service_v1beta1 as gca_match_service_v1beta1, index_v1beta1 as gca_index_v1beta1, + service_networking as gca_service_networking, + encryption_spec as gca_encryption_spec, ) from google.cloud.aiplatform.compat.services import ( index_endpoint_service_client, @@ -236,6 +238,8 @@ _TEST_APPROX_NUM_NEIGHBORS = 2 _TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE = 0.8 _TEST_RETURN_FULL_DATAPOINT = True +_TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name" +_TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"] def uuid_mock(): @@ -619,6 +623,7 @@ def test_create_index_endpoint(self, create_index_endpoint_mock, sync): network=_TEST_INDEX_ENDPOINT_VPC_NETWORK, description=_TEST_INDEX_ENDPOINT_DESCRIPTION, labels=_TEST_LABELS, + encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, ) if not sync: @@ -629,6 +634,42 @@ def test_create_index_endpoint(self, create_index_endpoint_mock, sync): network=_TEST_INDEX_ENDPOINT_VPC_NETWORK, description=_TEST_INDEX_ENDPOINT_DESCRIPTION, labels=_TEST_LABELS, + encryption_spec=gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME + ), + ) + create_index_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, + index_endpoint=expected, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_index_endpoint_mock") + def test_create_index_endpoint_with_private_service_connect( + self, create_index_endpoint_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + aiplatform.MatchingEngineIndexEndpoint.create( + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + labels=_TEST_LABELS, + enable_private_service_connect=True, + project_allowlist=_TEST_PROJECT_ALLOWLIST, + encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, + ) + + expected = gca_index_endpoint.IndexEndpoint( + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + labels=_TEST_LABELS, + private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig( + project_allowlist=_TEST_PROJECT_ALLOWLIST, + enable_private_service_connect=True, + ), + encryption_spec=gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME + ), ) create_index_endpoint_mock.assert_called_once_with( parent=_TEST_PARENT, @@ -644,6 +685,7 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, description=_TEST_INDEX_ENDPOINT_DESCRIPTION, labels=_TEST_LABELS, + encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, ) expected = gca_index_endpoint.IndexEndpoint( @@ -652,6 +694,9 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc description=_TEST_INDEX_ENDPOINT_DESCRIPTION, labels=_TEST_LABELS, public_endpoint_enabled=False, + encryption_spec=gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME + ), ) create_index_endpoint_mock.assert_called_once_with( @@ -671,6 +716,7 @@ def test_create_index_endpoint_with_public_endpoint_enabled( description=_TEST_INDEX_ENDPOINT_DESCRIPTION, public_endpoint_enabled=True, labels=_TEST_LABELS, + encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, ) my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( @@ -682,6 +728,9 @@ def test_create_index_endpoint_with_public_endpoint_enabled( description=_TEST_INDEX_ENDPOINT_DESCRIPTION, public_endpoint_enabled=True, labels=_TEST_LABELS, + encryption_spec=gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME + ), ) create_index_endpoint_mock.assert_called_once_with( @@ -700,7 +749,12 @@ def test_create_index_endpoint_missing_argument_throw_error( ): aiplatform.init(project=_TEST_PROJECT) - expected_message = "Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint" + expected_message = ( + "Please provide `network` argument for Private Service Access endpoint," + "or provide `enable_private_service_connect` for Private Service" + "Connect endpoint, or provide `public_endpoint_enabled` to" + "deploy to a public endpoint" + ) with pytest.raises(ValueError) as exception: _ = aiplatform.MatchingEngineIndexEndpoint.create( @@ -711,12 +765,12 @@ def test_create_index_endpoint_missing_argument_throw_error( assert str(exception.value) == expected_message - def test_create_index_endpoint_set_both_throw_error( + def test_create_index_endpoint_set_both_psa_and_public_throw_error( self, create_index_endpoint_mock ): aiplatform.init(project=_TEST_PROJECT) - expected_message = "`network` and `public_endpoint_enabled` argument should not be set at the same time" + expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set." with pytest.raises(ValueError) as exception: _ = aiplatform.MatchingEngineIndexEndpoint.create( @@ -729,6 +783,42 @@ def test_create_index_endpoint_set_both_throw_error( assert str(exception.value) == expected_message + def test_create_index_endpoint_set_both_psa_and_psc_throw_error( + self, create_index_endpoint_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set." + + with pytest.raises(ValueError) as exception: + _ = aiplatform.MatchingEngineIndexEndpoint.create( + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + network=_TEST_INDEX_ENDPOINT_VPC_NETWORK, + labels=_TEST_LABELS, + enable_private_service_connect=True, + ) + + assert str(exception.value) == expected_message + + def test_create_index_endpoint_set_both_psc_and_public_throw_error( + self, create_index_endpoint_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set." + + with pytest.raises(ValueError) as exception: + _ = aiplatform.MatchingEngineIndexEndpoint.create( + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + public_endpoint_enabled=True, + labels=_TEST_LABELS, + enable_private_service_connect=True, + ) + + assert str(exception.value) == expected_message + @pytest.mark.usefixtures("get_index_endpoint_mock", "get_index_mock") def test_deploy_index(self, deploy_index_mock, undeploy_index_mock): aiplatform.init(project=_TEST_PROJECT)