Skip to content

Commit

Permalink
feat: Improve get_experiment_df execution speed
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619596185
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Mar 27, 2024
1 parent 57bb955 commit 2e56acc
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 63 deletions.
56 changes: 35 additions & 21 deletions google/cloud/aiplatform/metadata/experiment_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import abc
import concurrent.futures
from dataclasses import dataclass
import logging
from typing import Dict, List, NamedTuple, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -448,28 +449,41 @@ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821
executions = execution.Execution.list(filter_str, **service_request_args)

rows = []
for metadata_context in contexts:
row_dict = (
_SUPPORTED_LOGGABLE_RESOURCES[context.Context][
metadata_context.schema_title
if contexts or executions:
with concurrent.futures.ThreadPoolExecutor(
max_workers=max([len(contexts), len(executions)])
) as executor:
futures = [
executor.submit(
_SUPPORTED_LOGGABLE_RESOURCES[context.Context][
metadata_context.schema_title
]._query_experiment_row,
metadata_context,
)
for metadata_context in contexts
]
._query_experiment_row(metadata_context)
.to_dict()
)
row_dict.update({"experiment_name": self.name})
rows.append(row_dict)

# backward compatibility
for metadata_execution in executions:
row_dict = (
_SUPPORTED_LOGGABLE_RESOURCES[execution.Execution][
metadata_execution.schema_title
]
._query_experiment_row(metadata_execution)
.to_dict()
)
row_dict.update({"experiment_name": self.name})
rows.append(row_dict)

# backward compatibility
futures.extend(
executor.submit(
_SUPPORTED_LOGGABLE_RESOURCES[execution.Execution][
metadata_execution.schema_title
]._query_experiment_row,
metadata_execution,
)
for metadata_execution in executions
)

for future in futures:
try:
row_dict = future.result().to_dict()
except Exception as exc:
raise ValueError(
f"Failed to get experiment row for {self.name}"
) from exc
else:
row_dict.update({"experiment_name": self.name})
rows.append(row_dict)

df = pd.DataFrame(rows)

Expand Down
80 changes: 39 additions & 41 deletions google/cloud/aiplatform/metadata/experiment_run_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,23 @@ def get(
except exceptions.NotFound:
return None

def _initialize_experiment_run(
self,
node: Union[context.Context, execution.Execution],
experiment: Optional[experiment_resources.Experiment] = None,
):
self._experiment = experiment
self._run_name = node.display_name
self._metadata_node = node
self._largest_step = None

if self._is_legacy_experiment_run():
self._metadata_metric_artifact = self._v1_get_metric_artifact()
self._backing_tensorboard_run = None
else:
self._metadata_metric_artifact = None
self._backing_tensorboard_run = self._lookup_tensorboard_run_artifact()

@classmethod
def list(
cls,
Expand Down Expand Up @@ -495,33 +512,17 @@ def list(

run_executions = execution.Execution.list(filter=filter_str, **metadata_args)

def _initialize_experiment_run(context: context.Context) -> ExperimentRun:
def _create_experiment_run(context: context.Context) -> ExperimentRun:
this_experiment_run = cls.__new__(cls)
this_experiment_run._experiment = experiment
this_experiment_run._run_name = context.display_name
this_experiment_run._metadata_node = context

with experiment_resources._SetLoggerLevel(resource):
tb_run = this_experiment_run._lookup_tensorboard_run_artifact()
if tb_run:
this_experiment_run._backing_tensorboard_run = tb_run
else:
this_experiment_run._backing_tensorboard_run = None

this_experiment_run._largest_step = None
this_experiment_run._initialize_experiment_run(context, experiment)

return this_experiment_run

def _initialize_v1_experiment_run(
def _create_v1_experiment_run(
execution: execution.Execution,
) -> ExperimentRun:
this_experiment_run = cls.__new__(cls)
this_experiment_run._experiment = experiment
this_experiment_run._run_name = execution.display_name
this_experiment_run._metadata_node = execution
this_experiment_run._metadata_metric_artifact = (
this_experiment_run._v1_get_metric_artifact()
)
this_experiment_run._initialize_experiment_run(execution, experiment)

return this_experiment_run

Expand All @@ -530,13 +531,13 @@ def _initialize_v1_experiment_run(
max_workers=max([len(run_contexts), len(run_executions)])
) as executor:
submissions = [
executor.submit(_initialize_experiment_run, context)
executor.submit(_create_experiment_run, context)
for context in run_contexts
]
experiment_runs = [submission.result() for submission in submissions]

submissions = [
executor.submit(_initialize_v1_experiment_run, execution)
executor.submit(_create_v1_experiment_run, execution)
for execution in run_executions
]

Expand All @@ -560,30 +561,20 @@ def _query_experiment_row(
Experiment run row that represents this run.
"""
this_experiment_run = cls.__new__(cls)
this_experiment_run._metadata_node = node
this_experiment_run._initialize_experiment_run(node)

row = experiment_resources._ExperimentRow(
experiment_run_type=node.schema_title,
name=node.display_name,
)

if isinstance(node, context.Context):
this_experiment_run._backing_tensorboard_run = (
this_experiment_run._lookup_tensorboard_run_artifact()
)
row.params = node.metadata[constants._PARAM_KEY]
row.metrics = node.metadata[constants._METRIC_KEY]
row.time_series_metrics = (
this_experiment_run._get_latest_time_series_metric_columns()
)
row.state = node.metadata[constants._STATE_KEY]
else:
this_experiment_run._metadata_metric_artifact = (
this_experiment_run._v1_get_metric_artifact()
)
row.params = node.metadata
row.metrics = this_experiment_run._metadata_metric_artifact.metadata
row.state = node.state.name
row.params = this_experiment_run.get_params()
row.metrics = this_experiment_run.get_metrics()
row.state = this_experiment_run.get_state()
row.time_series_metrics = (
this_experiment_run._get_latest_time_series_metric_columns()
)

return row

def _get_logged_pipeline_runs(self) -> List[context.Context]:
Expand Down Expand Up @@ -659,7 +650,7 @@ def log(

@staticmethod
def _validate_run_id(run_id: str):
"""Validates the run id
"""Validates the run id.
Args:
run_id(str): Required. The run id to validate.
Expand Down Expand Up @@ -1455,6 +1446,13 @@ def get_metrics(self) -> Dict[str, Union[float, int, str]]:
else:
return self._metadata_node.metadata[constants._METRIC_KEY]

def get_state(self) -> gca_execution.Execution.State:
"""The state of this run."""
if self._is_legacy_experiment_run():
return self._metadata_node.state.name
else:
return self._metadata_node.metadata[constants._STATE_KEY]

@_v1_not_supported
def get_classification_metrics(self) -> List[Dict[str, Union[str, List]]]:
"""Get all the classification metrics logged to this run.
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def get_experiment_df(
aiplatform.log_params({'learning_rate': 0.2})
aiplatform.log_metrics({'accuracy': 0.95})
aiplatform.get_experiments_df()
aiplatform.get_experiment_df()
```
Will result in the following DataFrame:
Expand Down

0 comments on commit 2e56acc

Please sign in to comment.