Skip to content

Commit

Permalink
fix: project/location parsing for nested resources (#1700)
Browse files Browse the repository at this point in the history
* testing parsing

* adding util function

* removing print statements

* adding changes

* using regex and dict

* lint check

* adding fs test for passing in location and project

* comment fix

* adding docstring changes

* fixing featurestore unit tests

* lint
  • Loading branch information
nayaknishant committed Sep 30, 2022
1 parent ed0492e commit 9e1d796
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 2 deletions.
7 changes: 7 additions & 0 deletions google/cloud/aiplatform/base.py
Expand Up @@ -1077,6 +1077,13 @@ def _list(
Returns:
List[VertexAiResourceNoun] - A list of SDK resource objects
"""
if parent:
parent_resources = utils.extract_project_and_location_from_parent(parent)
if parent_resources:
project, location = (
parent_resources["project"],
parent_resources["location"],
)

resource = cls._empty_constructor(
project=project, location=location, credentials=credentials
Expand Down
28 changes: 28 additions & 0 deletions google/cloud/aiplatform/utils/__init__.py
Expand Up @@ -325,6 +325,34 @@ def extract_bucket_and_prefix_from_gcs_path(gcs_path: str) -> Tuple[str, Optiona
return (gcs_bucket, gcs_blob_prefix)


def extract_project_and_location_from_parent(
parent: str,
) -> Dict[str, str]:
"""Given a complete parent resource name, return the project and location as a dict.
Example Usage:
parent_resources = extract_project_and_location_from_parent(
"projects/123/locations/us-central1/datasets/456"
)
parent_resources["project"] = "123"
parent_resources["location"] = "us-central1"
Args:
parent (str):
Required. A complete parent resource name.
Returns:
Dict[str, str]
A project, location dict from provided parent resource name.
"""
parent_resources = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)(/|$)", parent
)
return parent_resources.groupdict() if parent_resources else {}


class ClientWithOverride:
class WrappedClient:
"""Wrapper class for client that creates client at API invocation
Expand Down
37 changes: 35 additions & 2 deletions tests/unit/aiplatform/test_featurestores.py
Expand Up @@ -1023,7 +1023,23 @@ def test_list_entity_types(self, list_entity_types_mock):
aiplatform.init(project=_TEST_PROJECT)

my_featurestore = aiplatform.Featurestore(
featurestore_name=_TEST_FEATURESTORE_ID
featurestore_name=_TEST_FEATURESTORE_ID,
)
my_entity_type_list = my_featurestore.list_entity_types()

list_entity_types_mock.assert_called_once_with(
request={"parent": _TEST_FEATURESTORE_NAME}
)
assert len(my_entity_type_list) == len(_TEST_ENTITY_TYPE_LIST)
for my_entity_type in my_entity_type_list:
assert type(my_entity_type) == aiplatform.EntityType

@pytest.mark.usefixtures("get_featurestore_mock")
def test_list_entity_types_with_no_init(self, list_entity_types_mock):
my_featurestore = aiplatform.Featurestore(
featurestore_name=_TEST_FEATURESTORE_ID,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
my_entity_type_list = my_featurestore.list_entity_types()

Expand Down Expand Up @@ -1762,7 +1778,7 @@ def test_update_entity_type(self, update_entity_type_mock):
@pytest.mark.parametrize(
"featurestore_name", [_TEST_FEATURESTORE_NAME, _TEST_FEATURESTORE_ID]
)
def test_list_entity_types(self, featurestore_name, list_entity_types_mock):
def test_list_entity_type(self, featurestore_name, list_entity_types_mock):
aiplatform.init(project=_TEST_PROJECT)

my_entity_type_list = aiplatform.EntityType.list(
Expand Down Expand Up @@ -1790,6 +1806,23 @@ def test_list_features(self, list_features_mock):
for my_feature in my_feature_list:
assert type(my_feature) == aiplatform.Feature

@pytest.mark.usefixtures("get_entity_type_mock")
def test_list_features_with_no_init(self, list_features_mock):
my_entity_type = aiplatform.EntityType(
entity_type_name=_TEST_ENTITY_TYPE_ID,
featurestore_id=_TEST_FEATURESTORE_ID,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
my_feature_list = my_entity_type.list_features()

list_features_mock.assert_called_once_with(
request={"parent": _TEST_ENTITY_TYPE_NAME}
)
assert len(my_feature_list) == len(_TEST_FEATURE_LIST)
for my_feature in my_feature_list:
assert type(my_feature) == aiplatform.Feature

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_entity_type_mock", "get_feature_mock")
def test_delete_features(self, delete_feature_mock, sync):
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/aiplatform/test_utils.py
Expand Up @@ -320,6 +320,30 @@ def test_extract_bucket_and_prefix_from_gcs_path(gcs_path: str, expected: tuple)
assert expected == utils.extract_bucket_and_prefix_from_gcs_path(gcs_path)


@pytest.mark.parametrize(
"parent, expected",
[
(
"projects/123/locations/us-central1/datasets/456",
{"project": "123", "location": "us-central1"},
),
(
"projects/123/locations/us-central1/",
{"project": "123", "location": "us-central1"},
),
(
"projects/123/locations/us-central1",
{"project": "123", "location": "us-central1"},
),
("projects/123/locations/", {}),
("projects/123", {}),
],
)
def test_extract_project_and_location_from_parent(parent: str, expected: tuple):
# Given a parent resource name, ensure correct project and location are extracted
assert expected == utils.extract_project_and_location_from_parent(parent)


@pytest.mark.usefixtures("google_auth_mock")
def test_wrapped_client():
test_client_info = gapic_v1.client_info.ClientInfo()
Expand Down

0 comments on commit 9e1d796

Please sign in to comment.