Skip to content

Commit

Permalink
feat: add encryption_spec_key_name, `enable_private_service_connect…
Browse files Browse the repository at this point in the history
…`,`project_allowlist` to MatchingEngineIndexEndpoint `create`.

PiperOrigin-RevId: 581328160
  • Loading branch information
lingyinw authored and Copybara-Service committed Nov 10, 2023
1 parent fcf05cb commit 750e17b
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 9 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/compat/__init__.py
Expand Up @@ -111,6 +111,7 @@
types.model_garden_service = types.model_garden_service_v1beta1
types.model_monitoring = types.model_monitoring_v1beta1
types.model_service = types.model_service_v1beta1
types.service_networking = types.service_networking_v1beta1
types.operation = types.operation_v1beta1
types.pipeline_failure_policy = types.pipeline_failure_policy_v1beta1
types.pipeline_job = types.pipeline_job_v1beta1
Expand Down Expand Up @@ -208,6 +209,7 @@
types.model_deployment_monitoring_job = types.model_deployment_monitoring_job_v1
types.model_monitoring = types.model_monitoring_v1
types.model_service = types.model_service_v1
types.service_networking = types.service_networking_v1
types.operation = types.operation_v1
types.pipeline_failure_policy = types.pipeline_failure_policy_v1
types.pipeline_job = types.pipeline_job_v1
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/compat/types/__init__.py
Expand Up @@ -75,6 +75,7 @@
pipeline_state as pipeline_state_v1beta1,
prediction_service as prediction_service_v1beta1,
publisher_model as publisher_model_v1beta1,
service_networking as service_networking_v1beta1,
schedule as schedule_v1beta1,
schedule_service as schedule_service_v1beta1,
specialist_pool as specialist_pool_v1beta1,
Expand Down Expand Up @@ -147,6 +148,7 @@
publisher_model as publisher_model_v1,
schedule as schedule_v1,
schedule_service as schedule_service_v1,
service_networking as service_networking_v1,
specialist_pool as specialist_pool_v1,
specialist_pool_service as specialist_pool_service_v1,
study as study_v1,
Expand Down
Expand Up @@ -28,6 +28,8 @@
matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
match_service_v1beta1 as gca_match_service_v1beta1,
index_v1beta1 as gca_index_v1beta1,
service_networking as gca_service_networking,
encryption_spec as gca_encryption_spec,
)
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
from google.cloud.aiplatform.matching_engine._protos import (
Expand Down Expand Up @@ -145,6 +147,9 @@ def create(
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
enable_private_service_connect: Optional[bool] = False,
project_allowlist: Optional[Sequence[str]] = None,
encryption_spec_key_name: Optional[str] = None,
) -> "MatchingEngineIndexEndpoint":
"""Creates a MatchingEngineIndexEndpoint resource.
Expand Down Expand Up @@ -205,6 +210,23 @@ def create(
Optional. Whether to execute this creation synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
enable_private_service_connect (bool):
If true, expose the index endpoint via private service connect.
project_allowlist (Sequence[str]):
Optional. List of projects from which the forwarding rule will
target the service attachment.
encryption_spec_key_name (str):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the index endpoint.
Has the form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this index endpoint and all sub-resources of this
index endpoint will be secured by this key.
The key needs to be in the same region as where the index
endpoint is created.
Returns:
MatchingEngineIndexEndpoint - IndexEndpoint resource object
Expand All @@ -214,14 +236,27 @@ def create(
"""
network = network or initializer.global_config.network

if not network and not public_endpoint_enabled:
if not (network or public_endpoint_enabled or enable_private_service_connect):
raise ValueError(
"Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint"
"Please provide `network` argument for Private Service Access endpoint,"
"or provide `enable_private_service_connect` for Private Service"
"Connect endpoint, or provide `public_endpoint_enabled` to"
"deploy to a public endpoint"
)

if network and public_endpoint_enabled:
if (
sum(
bool(network_setting)
for network_setting in [
network,
public_endpoint_enabled,
enable_private_service_connect,
]
)
> 1
):
raise ValueError(
"`network` and `public_endpoint_enabled` argument should not be set at the same time"
"One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set."
)

return cls._create(
Expand All @@ -235,6 +270,9 @@ def create(
credentials=credentials,
request_metadata=request_metadata,
sync=sync,
enable_private_service_connect=enable_private_service_connect,
project_allowlist=project_allowlist,
encryption_spec_key_name=encryption_spec_key_name,
)

@classmethod
Expand All @@ -251,6 +289,9 @@ def _create(
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
enable_private_service_connect: Optional[bool] = False,
project_allowlist: Optional[Sequence[str]] = None,
encryption_spec_key_name: Optional[str] = None,
) -> "MatchingEngineIndexEndpoint":
"""Helper method to ensure network synchronization and to
create a MatchingEngineIndexEndpoint resource.
Expand Down Expand Up @@ -304,20 +345,53 @@ def _create(
Optional. Whether to execute this creation synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
encryption_spec_key_name (str):
Immutable. Customer-managed encryption key
spec for an IndexEndpoint. If set, this
IndexEndpoint and all sub-resources of this
IndexEndpoint will be secured by this key.
enable_private_service_connect (bool):
Required. If true, expose the IndexEndpoint
via private service connect.
project_allowlist (MutableSequence[str]):
A list of Projects from which the forwarding
rule will target the service attachment.
Returns:
MatchingEngineIndexEndpoint - IndexEndpoint resource object
"""

# Public
if public_endpoint_enabled:
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
display_name=display_name,
description=description,
public_endpoint_enabled=public_endpoint_enabled,
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=encryption_spec_key_name
),
)
# PSA
elif network:
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
display_name=display_name,
description=description,
network=network,
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=encryption_spec_key_name
),
)
# PSC
else:
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
display_name=display_name, description=description, network=network
display_name=display_name,
description=description,
private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig(
project_allowlist=project_allowlist,
enable_private_service_connect=enable_private_service_connect,
),
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=encryption_spec_key_name
),
)

if labels:
Expand Down
96 changes: 93 additions & 3 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Expand Up @@ -34,6 +34,8 @@
index as gca_index,
match_service_v1beta1 as gca_match_service_v1beta1,
index_v1beta1 as gca_index_v1beta1,
service_networking as gca_service_networking,
encryption_spec as gca_encryption_spec,
)
from google.cloud.aiplatform.compat.services import (
index_endpoint_service_client,
Expand Down Expand Up @@ -236,6 +238,8 @@
_TEST_APPROX_NUM_NEIGHBORS = 2
_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE = 0.8
_TEST_RETURN_FULL_DATAPOINT = True
_TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name"
_TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"]


def uuid_mock():
Expand Down Expand Up @@ -619,6 +623,7 @@ def test_create_index_endpoint(self, create_index_endpoint_mock, sync):
network=_TEST_INDEX_ENDPOINT_VPC_NETWORK,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
labels=_TEST_LABELS,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
)

if not sync:
Expand All @@ -629,6 +634,42 @@ def test_create_index_endpoint(self, create_index_endpoint_mock, sync):
network=_TEST_INDEX_ENDPOINT_VPC_NETWORK,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
labels=_TEST_LABELS,
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
),
)
create_index_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
index_endpoint=expected,
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_create_index_endpoint_with_private_service_connect(
self, create_index_endpoint_mock
):
aiplatform.init(project=_TEST_PROJECT)

aiplatform.MatchingEngineIndexEndpoint.create(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
labels=_TEST_LABELS,
enable_private_service_connect=True,
project_allowlist=_TEST_PROJECT_ALLOWLIST,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
)

expected = gca_index_endpoint.IndexEndpoint(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
labels=_TEST_LABELS,
private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig(
project_allowlist=_TEST_PROJECT_ALLOWLIST,
enable_private_service_connect=True,
),
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
),
)
create_index_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
Expand All @@ -644,6 +685,7 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
labels=_TEST_LABELS,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
)

expected = gca_index_endpoint.IndexEndpoint(
Expand All @@ -652,6 +694,9 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
labels=_TEST_LABELS,
public_endpoint_enabled=False,
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
),
)

create_index_endpoint_mock.assert_called_once_with(
Expand All @@ -671,6 +716,7 @@ def test_create_index_endpoint_with_public_endpoint_enabled(
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
public_endpoint_enabled=True,
labels=_TEST_LABELS,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
)

my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
Expand All @@ -682,6 +728,9 @@ def test_create_index_endpoint_with_public_endpoint_enabled(
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
public_endpoint_enabled=True,
labels=_TEST_LABELS,
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
),
)

create_index_endpoint_mock.assert_called_once_with(
Expand All @@ -700,7 +749,12 @@ def test_create_index_endpoint_missing_argument_throw_error(
):
aiplatform.init(project=_TEST_PROJECT)

expected_message = "Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint"
expected_message = (
"Please provide `network` argument for Private Service Access endpoint,"
"or provide `enable_private_service_connect` for Private Service"
"Connect endpoint, or provide `public_endpoint_enabled` to"
"deploy to a public endpoint"
)

with pytest.raises(ValueError) as exception:
_ = aiplatform.MatchingEngineIndexEndpoint.create(
Expand All @@ -711,12 +765,12 @@ def test_create_index_endpoint_missing_argument_throw_error(

assert str(exception.value) == expected_message

def test_create_index_endpoint_set_both_throw_error(
def test_create_index_endpoint_set_both_psa_and_public_throw_error(
self, create_index_endpoint_mock
):
aiplatform.init(project=_TEST_PROJECT)

expected_message = "`network` and `public_endpoint_enabled` argument should not be set at the same time"
expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set."

with pytest.raises(ValueError) as exception:
_ = aiplatform.MatchingEngineIndexEndpoint.create(
Expand All @@ -729,6 +783,42 @@ def test_create_index_endpoint_set_both_throw_error(

assert str(exception.value) == expected_message

def test_create_index_endpoint_set_both_psa_and_psc_throw_error(
self, create_index_endpoint_mock
):
aiplatform.init(project=_TEST_PROJECT)

expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set."

with pytest.raises(ValueError) as exception:
_ = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
network=_TEST_INDEX_ENDPOINT_VPC_NETWORK,
labels=_TEST_LABELS,
enable_private_service_connect=True,
)

assert str(exception.value) == expected_message

def test_create_index_endpoint_set_both_psc_and_public_throw_error(
self, create_index_endpoint_mock
):
aiplatform.init(project=_TEST_PROJECT)

expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set."

with pytest.raises(ValueError) as exception:
_ = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
public_endpoint_enabled=True,
labels=_TEST_LABELS,
enable_private_service_connect=True,
)

assert str(exception.value) == expected_message

@pytest.mark.usefixtures("get_index_endpoint_mock", "get_index_mock")
def test_deploy_index(self, deploy_index_mock, undeploy_index_mock):
aiplatform.init(project=_TEST_PROJECT)
Expand Down

0 comments on commit 750e17b

Please sign in to comment.