Skip to content

Commit

Permalink
fix: Update get_experiment_df to pass Experiment and allow empty metr…
Browse files Browse the repository at this point in the history
…ics.

PiperOrigin-RevId: 634867806
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed May 17, 2024
1 parent abacda6 commit de5d0f3
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 10 deletions.
6 changes: 4 additions & 2 deletions google/cloud/aiplatform/metadata/experiment_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,8 @@ def get_data_frame(
metadata_context.schema_title
]._query_experiment_row,
metadata_context,
include_time_series,
experiment=self,
include_time_series=include_time_series,
)
for metadata_context in contexts
]
Expand All @@ -494,7 +495,8 @@ def get_data_frame(
metadata_execution.schema_title
]._query_experiment_row,
metadata_execution,
include_time_series,
experiment=self,
include_time_series=include_time_series,
)
for metadata_execution in executions
)
Expand Down
16 changes: 12 additions & 4 deletions google/cloud/aiplatform/metadata/experiment_run_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ def _initialize_experiment_run(
self._metadata_metric_artifact = self._v1_get_metric_artifact()
if not self._is_legacy_experiment_run() and lookup_tensorboard_run:
self._backing_tensorboard_run = self._lookup_tensorboard_run_artifact()
if not self._backing_tensorboard_run:
self._assign_to_experiment_backing_tensorboard()

@classmethod
def list(
Expand Down Expand Up @@ -553,13 +555,16 @@ def _create_v1_experiment_run(
def _query_experiment_row(
cls,
node: Union[context.Context, execution.Execution],
include_time_series: Optional[bool] = True,
experiment: Optional[experiment_resources.Experiment] = None,
include_time_series: bool = True,
) -> experiment_resources._ExperimentRow:
"""Retrieves the runs metric and parameters into an experiment run row.
Args:
node (Union[context._Context, execution.Execution]):
Required. Metadata node instance that represents this run.
experiment:
Optional. Experiment associated with this run.
include_time_series (bool):
Optional. Whether or not to include time series metrics in df.
Default is True.
Expand All @@ -568,7 +573,7 @@ def _query_experiment_row(
"""
this_experiment_run = cls.__new__(cls)
this_experiment_run._initialize_experiment_run(
node, lookup_tensorboard_run=include_time_series
node, experiment=experiment, lookup_tensorboard_run=include_time_series
)

row = experiment_resources._ExperimentRow(
Expand Down Expand Up @@ -620,8 +625,11 @@ def _get_latest_time_series_metric_columns(self) -> Dict[str, Union[float, int]]
return {
display_name: data.values[-1].scalar.value
for display_name, data in time_series_metrics.items()
if data.value_type
== gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR
if (
data.values
and data.value_type
== gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR
)
}
return {}

Expand Down
13 changes: 10 additions & 3 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,17 +887,24 @@ def _get_context(self) -> context.Context:

@classmethod
def _query_experiment_row(
cls, node: context.Context, include_time_series: Optional[bool] = True
cls,
node: context.Context,
experiment: Optional[experiment_resources.Experiment] = None,
include_time_series: bool = True,
) -> experiment_resources._ExperimentRow:
"""Queries the PipelineJob metadata as an experiment run parameter and metric row.
Parameters are retrieved from the system.Run Execution.metadata of the PipelineJob.
Parameters are retrieved from the system.Run Execution.metadata of the
PipelineJob.
Metrics are retrieved from the system.Metric Artifacts.metadata produced by this PipelineJob.
Metrics are retrieved from the system.Metric Artifacts.metadata produced
by this PipelineJob.
Args:
node (context._Context):
Required. System.PipelineRun context that represents a PipelineJob Run.
experiment:
Optional. Experiment associated with this run.
include_time_series (bool):
Optional. Whether or not to include time series metrics in df.
Default is True.
Expand Down
28 changes: 27 additions & 1 deletion tests/unit/aiplatform/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import copy
from importlib import reload
from unittest import mock
from unittest import TestCase, mock
from unittest.mock import patch, call

import numpy as np
Expand Down Expand Up @@ -2048,6 +2048,32 @@ def test_log_pipeline_job(
],
)

@pytest.mark.usefixtures(
"get_experiment_mock",
)
def test_get_experiment_df_passes_experiment_variable(
self,
list_context_mock_for_experiment_dataframe_mock,
list_artifact_mock_for_experiment_dataframe,
list_executions_mock_for_experiment_dataframe,
get_tensorboard_run_artifact_mock,
get_tensorboard_run_mock,
):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

with patch.object(
experiment_run_resource.ExperimentRun, "_query_experiment_row"
) as query_experiment_row_mock:
row = experiment_resources._ExperimentRow(
experiment_run_type=constants.SYSTEM_EXPERIMENT_RUN,
name=_TEST_EXPERIMENT,
)
query_experiment_row_mock.return_value = row

aiplatform.get_experiment_df(_TEST_EXPERIMENT)
_, kwargs = query_experiment_row_mock.call_args_list[0]
TestCase.assertTrue(self, kwargs["experiment"].name == _TEST_EXPERIMENT)

@pytest.mark.usefixtures(
"get_experiment_mock",
"list_tensorboard_time_series_mock",
Expand Down

0 comments on commit de5d0f3

Please sign in to comment.