Skip to content

Commit

Permalink
fix: list method for MLMD schema classes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 506383196
  • Loading branch information
jaycee-li authored and Copybara-Service committed Feb 1, 2023
1 parent 076308f commit 2401a1d
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 27 deletions.
26 changes: 3 additions & 23 deletions google/cloud/aiplatform/metadata/experiment_run_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,34 +1418,14 @@ def get_experiment_models(self) -> List[google_artifact_schema.ExperimentModel]:
Returns:
List of ExperimentModel instances associated this run.
"""
# TODO(b/264194064) Replace this by ExperimentModel.list
artifact_list = artifact.Artifact.list(
filter=metadata_utils._make_filter_string(
in_context=[self.resource_name],
schema_title=google_artifact_schema.ExperimentModel.schema_title,
),
experiment_model_list = google_artifact_schema.ExperimentModel.list(
filter=metadata_utils._make_filter_string(in_context=[self.resource_name]),
project=self.project,
location=self.location,
credentials=self.credentials,
)

res = []
for model_artifact in artifact_list:
experiment_model = google_artifact_schema.ExperimentModel(
framework_name="",
framework_version="",
model_file="",
uri="",
)
experiment_model._gca_resource = model_artifact._gca_resource
experiment_model.project = model_artifact.project
experiment_model.location = model_artifact.location
experiment_model.credentials = model_artifact.credentials
experiment_model.api_client = model_artifact.api_client

res.append(experiment_model)

return res
return experiment_model_list

@_v1_not_supported
def associate_execution(self, execution: execution.Execution):
Expand Down
61 changes: 59 additions & 2 deletions google/cloud/aiplatform/metadata/schema/base_artifact.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@

import abc

from typing import Any, Optional, Dict
from typing import Any, Optional, Dict, List

from google.auth import credentials as auth_credentials
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
Expand Down Expand Up @@ -202,6 +202,63 @@ def create(
self._init_with_resource_name(artifact_name=new_artifact_instance.resource_name)
return self

@classmethod
def list(
cls,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
order_by: Optional[str] = None,
) -> List["BaseArtifactSchema"]:
"""List all the Artifact resources with a particular schema.
Args:
filter (str):
Optional. A query to filter available resources for
matching results.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
order_by (str):
Optional. How the list of messages is ordered.
Specify the values to order by and an ordering operation. The
default sorting order is ascending. To specify descending order
for a field, users append a " desc" suffix; for example: "foo
desc, bar". Subfields are specified with a ``.`` character, such
as foo.bar. see https://google.aip.dev/132#ordering for more
details.
Returns:
A list of artifact resources with a particular schema.
"""
schema_filter = f'schema_title="{cls.schema_title}"'
if filter:
filter = f"{filter} AND {schema_filter}"
else:
filter = schema_filter

return super().list(
filter=filter,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)

def sync_resource(self):
"""Syncs local resource with the resource in metadata store.
Expand Down
59 changes: 58 additions & 1 deletion google/cloud/aiplatform/metadata/schema/base_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -160,6 +160,63 @@ def create(
self._init_with_resource_name(context_name=new_context.resource_name)
return self

@classmethod
def list(
cls,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
order_by: Optional[str] = None,
) -> List["BaseContextSchema"]:
"""List all the Context resources with a particular schema.
Args:
filter (str):
Optional. A query to filter available resources for
matching results.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
order_by (str):
Optional. How the list of messages is ordered.
Specify the values to order by and an ordering operation. The
default sorting order is ascending. To specify descending order
for a field, users append a " desc" suffix; for example: "foo
desc, bar". Subfields are specified with a ``.`` character, such
as foo.bar. see https://google.aip.dev/132#ordering for more
details.
Returns:
A list of context resources with a particular schema.
"""
schema_filter = f'schema_title="{cls.schema_title}"'
if filter:
filter = f"{filter} AND {schema_filter}"
else:
filter = schema_filter

return super().list(
filter=filter,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)

def add_artifacts_and_executions(
self,
artifact_resource_names: Optional[Sequence[str]] = None,
Expand Down
59 changes: 58 additions & 1 deletion google/cloud/aiplatform/metadata/schema/base_execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -170,6 +170,63 @@ def create(
)
return self

@classmethod
def list(
cls,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
order_by: Optional[str] = None,
) -> List["BaseExecutionSchema"]:
"""List all the Execution resources with a particular schema.
Args:
filter (str):
Optional. A query to filter available resources for
matching results.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
order_by (str):
Optional. How the list of messages is ordered.
Specify the values to order by and an ordering operation. The
default sorting order is ascending. To specify descending order
for a field, users append a " desc" suffix; for example: "foo
desc, bar". Subfields are specified with a ``.`` character, such
as foo.bar. see https://google.aip.dev/132#ordering for more
details.
Returns:
A list of execution resources with a particular schema.
"""
schema_filter = f'schema_title="{cls.schema_title}"'
if filter:
filter = f"{filter} AND {schema_filter}"
else:
filter = schema_filter

return super().list(
filter=filter,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)

def start_execution(
self,
*,
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/aiplatform/test_metadata_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,27 @@ def create_context_mock():
yield create_context_mock


@pytest.fixture
def list_artifacts_mock():
with patch.object(MetadataServiceClient, "list_artifacts") as list_artifacts_mock:
list_artifacts_mock.return_value = []
yield list_artifacts_mock


@pytest.fixture
def list_executions_mock():
with patch.object(MetadataServiceClient, "list_executions") as list_executions_mock:
list_executions_mock.return_value = []
yield list_executions_mock


@pytest.fixture
def list_contexts_mock():
with patch.object(MetadataServiceClient, "list_contexts") as list_contexts_mock:
list_contexts_mock.return_value = []
yield list_contexts_mock


@pytest.mark.usefixtures("google_auth_mock")
class TestMetadataBaseArtifactSchema:
def setup_method(self):
Expand Down Expand Up @@ -369,6 +390,20 @@ class TestArtifact(base_artifact.BaseArtifactSchema):
"sdk_command/aiplatform.metadata.schema.base_artifact.BaseArtifactSchema._init_with_resource_name"
]

def test_list_artifacts(self, list_artifacts_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

class TestArtifact(base_artifact.BaseArtifactSchema):
schema_title = _TEST_SCHEMA_TITLE

TestArtifact.list()
list_artifacts_mock.assert_called_once_with(
request={
"parent": f"{_TEST_PARENT}/metadataStores/default",
"filter": f'schema_title="{_TEST_SCHEMA_TITLE}"',
}
)


@pytest.mark.usefixtures("google_auth_mock")
class TestMetadataBaseExecutionSchema:
Expand Down Expand Up @@ -563,6 +598,20 @@ class TestExecution(base_execution.BaseExecutionSchema):
"sdk_command/aiplatform.metadata.schema.base_execution.BaseExecutionSchema._init_with_resource_name"
]

def test_list_executions(self, list_executions_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

class TestExecution(base_execution.BaseExecutionSchema):
schema_title = _TEST_SCHEMA_TITLE

TestExecution.list()
list_executions_mock.assert_called_once_with(
request={
"parent": f"{_TEST_PARENT}/metadataStores/default",
"filter": f'schema_title="{_TEST_SCHEMA_TITLE}"',
}
)


@pytest.mark.usefixtures("google_auth_mock")
class TestMetadataBaseContextSchema:
Expand Down Expand Up @@ -730,6 +779,20 @@ class TestContext(base_context.BaseContextSchema):
"sdk_command/aiplatform.metadata.schema.base_context.BaseContextSchema._init_with_resource_name"
]

def test_list_contexts(self, list_contexts_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

class TestContext(base_context.BaseContextSchema):
schema_title = _TEST_SCHEMA_TITLE

TestContext.list()
list_contexts_mock.assert_called_once_with(
request={
"parent": f"{_TEST_PARENT}/metadataStores/default",
"filter": f'schema_title="{_TEST_SCHEMA_TITLE}"',
}
)


@pytest.mark.usefixtures("google_auth_mock")
class TestMetadataGoogleArtifactSchema:
Expand Down

0 comments on commit 2401a1d

Please sign in to comment.