Skip to content

Commit

Permalink
feat: Change the Metadata SDK _Context class to an external class (#1519
Browse files Browse the repository at this point in the history
)

* feat: Change the Metadata SDK _Context class to an external class

* Add base schema class for context

* Add additional context schema types

* Add additional context schema types

* Add create method to Context.

* Fix unit test failure.

* add unit tests

* fix lint issue

* Add Context to root __init__.

* correct import path
  • Loading branch information
SinaChavoshi committed Jul 21, 2022
1 parent fd55daf commit 95b107c
Show file tree
Hide file tree
Showing 11 changed files with 575 additions and 46 deletions.
1 change: 1 addition & 0 deletions google/cloud/aiplatform/__init__.py
Expand Up @@ -95,6 +95,7 @@
ExperimentRun = metadata.experiment_run_resource.ExperimentRun
Artifact = metadata.artifact.Artifact
Execution = metadata.execution.Execution
Context = metadata.context.Context


__all__ = (
Expand Down
154 changes: 152 additions & 2 deletions google/cloud/aiplatform/metadata/context.py
Expand Up @@ -19,6 +19,8 @@

import proto

from google.auth import credentials as auth_credentials

from google.cloud.aiplatform import base
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.metadata import utils as metadata_utils
Expand All @@ -31,10 +33,11 @@
)
from google.cloud.aiplatform.metadata import artifact
from google.cloud.aiplatform.metadata import execution
from google.cloud.aiplatform.metadata import metadata_store
from google.cloud.aiplatform.metadata import resource


class _Context(resource._Resource):
class Context(resource._Resource):
"""Metadata Context resource for Vertex AI"""

_resource_noun = "contexts"
Expand Down Expand Up @@ -81,6 +84,153 @@ def get_artifacts(self) -> List[artifact.Artifact]:
credentials=self.credentials,
)

@classmethod
def create(
cls,
schema_title: str,
*,
resource_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Context":
"""Creates a new Metadata Context.
Args:
schema_title (str):
Required. schema_title identifies the schema title used by the Context.
Please reference https://cloud.google.com/vertex-ai/docs/ml-metadata/system-schemas.
resource_id (str):
Optional. The <resource_id> portion of the Context name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>.
display_name (str):
Optional. The user-defined name of the Context.
schema_version (str):
Optional. schema_version specifies the version used by the Context.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Context to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Context.
metadata_store_id (str):
Optional. The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Optional. Project used to create this Context. Overrides project set in
aiplatform.init.
location (str):
Optional. Location used to create this Context. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Context. Overrides
credentials set in aiplatform.init.
Returns:
Context: Instantiated representation of the managed Metadata Context.
"""
return cls._create(
resource_id=resource_id,
schema_title=schema_title,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=metadata,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)

# TODO() refactor code to move _create to _Resource class.
@classmethod
def _create(
cls,
resource_id: str,
schema_title: str,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Context":
"""Creates a new Metadata resource.
Args:
resource_id (str):
Required. The <resource_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
schema_title (str):
Required. schema_title identifies the schema title used by the resource.
display_name (str):
Optional. The user-defined name of the resource.
schema_version (str):
Optional. schema_version specifies the version used by the resource.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the resource to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the resource.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
Returns:
resource (_Resource):
Instantiated representation of the managed Metadata resource.
"""
api_client = cls._instantiate_client(location=location, credentials=credentials)

parent = utils.full_resource_name(
resource_name=metadata_store_id,
resource_noun=metadata_store._MetadataStore._resource_noun,
parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name,
format_resource_name_method=metadata_store._MetadataStore._format_resource_name,
project=project,
location=location,
)

resource = cls._create_resource(
client=api_client,
parent=parent,
resource_id=resource_id,
schema_title=schema_title,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=metadata,
)

self = cls._empty_constructor(
project=project, location=location, credentials=credentials
)
self._gca_resource = resource

return self

@classmethod
def _create_resource(
cls,
Expand Down Expand Up @@ -147,7 +297,7 @@ def _list_resources(
)
return client.list_contexts(request=list_request)

def add_context_children(self, contexts: List["_Context"]):
def add_context_children(self, contexts: List["Context"]):
"""Adds the provided contexts as children of this context.
Args:
Expand Down
32 changes: 16 additions & 16 deletions google/cloud/aiplatform/metadata/experiment_resources.py
Expand Up @@ -119,13 +119,13 @@ def __init__(
)

with _SetLoggerLevel(resource):
experiment_context = context._Context(**metadata_args)
experiment_context = context.Context(**metadata_args)
self._validate_experiment_context(experiment_context)

self._metadata_context = experiment_context

@staticmethod
def _validate_experiment_context(experiment_context: context._Context):
def _validate_experiment_context(experiment_context: context.Context):
"""Validates this context is an experiment context.
Args:
Expand All @@ -146,7 +146,7 @@ def _validate_experiment_context(experiment_context: context._Context):
)

@staticmethod
def _is_tensorboard_experiment(context: context._Context) -> bool:
def _is_tensorboard_experiment(context: context.Context) -> bool:
"""Returns True if Experiment is a Tensorboard Experiment created by CustomJob."""
return constants.TENSORBOARD_CUSTOM_JOB_EXPERIMENT_FIELD in context.metadata

Expand Down Expand Up @@ -192,7 +192,7 @@ def create(
)

with _SetLoggerLevel(resource):
experiment_context = context._Context._create(
experiment_context = context.Context._create(
resource_id=experiment_name,
display_name=experiment_name,
description=description,
Expand Down Expand Up @@ -248,7 +248,7 @@ def get_or_create(
)

with _SetLoggerLevel(resource):
experiment_context = context._Context.get_or_create(
experiment_context = context.Context.get_or_create(
resource_id=experiment_name,
display_name=experiment_name,
description=description,
Expand Down Expand Up @@ -303,7 +303,7 @@ def list(
)

with _SetLoggerLevel(resource):
experiment_contexts = context._Context.list(
experiment_contexts = context.Context.list(
filter=filter_str,
project=project,
location=location,
Expand Down Expand Up @@ -341,7 +341,7 @@ def delete(self, *, delete_backing_tensorboard_runs: bool = False):
runs under this experiment that we used to store time series metrics.
"""

experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context._Context][
experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context.Context][
constants.SYSTEM_EXPERIMENT_RUN
].list(experiment=self)
for experiment_run in experiment_runs:
Expand Down Expand Up @@ -380,11 +380,11 @@ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821

filter_str = metadata_utils._make_filter_string(
schema_title=sorted(
list(_SUPPORTED_LOGGABLE_RESOURCES[context._Context].keys())
list(_SUPPORTED_LOGGABLE_RESOURCES[context.Context].keys())
),
parent_contexts=[self._metadata_context.resource_name],
)
contexts = context._Context.list(filter_str, **service_request_args)
contexts = context.Context.list(filter_str, **service_request_args)

filter_str = metadata_utils._make_filter_string(
schema_title=list(
Expand All @@ -398,7 +398,7 @@ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821
rows = []
for metadata_context in contexts:
row_dict = (
_SUPPORTED_LOGGABLE_RESOURCES[context._Context][
_SUPPORTED_LOGGABLE_RESOURCES[context.Context][
metadata_context.schema_title
]
._query_experiment_row(metadata_context)
Expand Down Expand Up @@ -568,7 +568,7 @@ class _VertexResourceWithMetadata(NamedTuple):
"""Represents a resource coupled with it's metadata representation"""

resource: base.VertexAiResourceNoun
metadata: Union[artifact.Artifact, execution.Execution, context._Context]
metadata: Union[artifact.Artifact, execution.Execution, context.Context]


class _ExperimentLoggableSchema(NamedTuple):
Expand All @@ -581,7 +581,7 @@ class _ExperimentLoggableSchema(NamedTuple):
"""

title: str
type: Union[Type[context._Context], Type[execution.Execution]] = context._Context
type: Union[Type[context.Context], Type[execution.Execution]] = context.Context


class _ExperimentLoggable(abc.ABC):
Expand Down Expand Up @@ -618,7 +618,7 @@ class PipelineJob(..., experiment_loggable_schemas=
_SUPPORTED_LOGGABLE_RESOURCES[schema.type][schema.title] = cls

@abc.abstractmethod
def _get_context(self) -> context._Context:
def _get_context(self) -> context.Context:
"""Should return the metadata context that represents this resource.
The subclass should enforce this context exists.
Expand All @@ -631,7 +631,7 @@ def _get_context(self) -> context._Context:
@classmethod
@abc.abstractmethod
def _query_experiment_row(
cls, node: Union[context._Context, execution.Execution]
cls, node: Union[context.Context, execution.Execution]
) -> _ExperimentRow:
"""Should return parameters and metrics for this resource as a run row.
Expand Down Expand Up @@ -716,6 +716,6 @@ def _associate_to_experiment(self, experiment: Union[str, Experiment]):
# Context -> 'system.ExperimentRun' -> aiplatform.ExperimentRun
# Execution -> 'system.Run' -> aiplatform.ExperimentRun
_SUPPORTED_LOGGABLE_RESOURCES: Dict[
Union[Type[context._Context], Type[execution.Execution]],
Union[Type[context.Context], Type[execution.Execution]],
Dict[str, _ExperimentLoggable],
] = {execution.Execution: dict(), context._Context: dict()}
] = {execution.Execution: dict(), context.Context: dict()}

0 comments on commit 95b107c

Please sign in to comment.