Skip to content

Commit

Permalink
feat: Create Vertex Experiment when uploading Tensorboard logs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626105964
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Apr 19, 2024
1 parent c0e7acc commit 2aa8d05
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 19 deletions.
7 changes: 7 additions & 0 deletions google/cloud/aiplatform/metadata/experiment_run_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,13 @@ def create(
cls._validate_run_id(run_id)

def _create_context():
print("Creating context")
print("run_id: " + run_id)
print("run_name: " + run_name)
print("state: " + state.name)
print("project: " + project)
print("location: " + location)
print("credentials: " + credentials)
with experiment_resources._SetLoggerLevel(resource):
return context.Context._create(
resource_id=run_id,
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/aiplatform/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _get_global_tensorboard(self) -> Optional[tensorboard_resource.Tensorboard]:
Returns:
tensorboard_resource.Tensorboard: the global TensorBoard instance.
"""
if self._global_tensorboard:
if self._global_tensorboard and hasattr(self._global_tensorboard, "resource_name"):
credentials, _ = google.auth.default()
if self.experiment and self.experiment._metadata_context.credentials:
credentials = self.experiment._metadata_context.credentials
Expand Down Expand Up @@ -468,6 +468,9 @@ def start_run(
)

else:
print('run_name: ', run)
print('experiment: ', self.experiment)
print('tensorboard: ', tensorboard)
self._experiment_run = experiment_run_resource.ExperimentRun.create(
run_name=run, experiment=self.experiment, tensorboard=tensorboard
)
Expand Down
66 changes: 48 additions & 18 deletions google/cloud/aiplatform/tensorboard/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
from google.cloud.aiplatform.compat.types import tensorboard_experiment
from google.cloud.aiplatform.compat.types import tensorboard_service
from google.cloud.aiplatform.compat.types import tensorboard_time_series
from google.cloud.aiplatform.metadata import constants
from google.cloud.aiplatform.metadata import experiment_run_resource
from google.cloud.aiplatform.metadata import metadata
from google.cloud.aiplatform.tensorboard import logdir_loader
from google.cloud.aiplatform.tensorboard import tensorboard_resource
from google.cloud.aiplatform.tensorboard import upload_tracker
from google.cloud.aiplatform.tensorboard import uploader_constants
from google.cloud.aiplatform.tensorboard import uploader_utils
Expand Down Expand Up @@ -249,36 +253,50 @@ def _create_or_get_experiment(self) -> tensorboard_experiment.TensorboardExperim
"""
logger.info("Creating experiment")

tb_experiment = tensorboard_experiment.TensorboardExperiment(
description=self._description, display_name=self._experiment_display_name
)
# tb_experiment = tensorboard_experiment.TensorboardExperiment(
# description=self._description,
# display_name=self._experiment_display_name,
# )

try:
experiment = self._api.create_tensorboard_experiment(
parent=self._tensorboard_resource_name,
tensorboard_experiment=tb_experiment,
experiment = tensorboard_resource.TensorboardExperiment(
self._experiment_resource_name,
)
except exceptions.NotFound:
experiment = tensorboard_resource.TensorboardExperiment.create(
tensorboard_experiment_id=self._experiment_name,
display_name=self._experiment_display_name,
tensorboard_name=self._tensorboard_resource_name,
labels=constants._VERTEX_EXPERIMENT_TB_EXPERIMENT_LABEL,
)
self._is_brand_new_experiment = True
except exceptions.AlreadyExists:
logger.info("Creating experiment failed. Retrieving experiment.")
experiment_name = os.path.join(
self._tensorboard_resource_name, "experiments", self._experiment_name
)
experiment = self._api.get_tensorboard_experiment(name=experiment_name)
# except exceptions.AlreadyExists:
# logger.info("Creating experiment failed. Retrieving experiment.")
# experiment_name = os.path.join(
# self._tensorboard_resource_name, "experiments", self._experiment_name
# )
# experiment = self._api.get_tensorboard_experiment(name=experiment_name)
return experiment

def create_experiment(self):
"""Creates an Experiment for this upload session and returns the ID."""

experiment = self._create_or_get_experiment()
self._experiment = experiment
metadata._experiment_tracker.set_tensorboard(
tensorboard=self._tensorboard_resource_name
)
metadata._experiment_tracker.set_experiment(
experiment=self._experiment_name,
description=self._description,
backing_tensorboard=self._tensorboard_resource_name,
)
self._experiment_resource_name = f"{self._tensorboard_resource_name}/experiments/{self._experiment_name}"
self._create_or_get_experiment()
self._one_platform_resource_manager = uploader_utils.OnePlatformResourceManager(
self._experiment.name, self._api
self._experiment_resource_name, self._api
)

self._request_sender = _BatchedRequestSender(
self._experiment.name,
self._experiment_resource_name,
self._api,
allowed_plugins=self._allowed_plugins,
upload_limits=self._upload_limits,
Expand All @@ -294,7 +312,7 @@ def create_experiment(self):
# Update partials with experiment name
for sender in self._additional_senders.keys():
self._additional_senders[sender] = self._additional_senders[sender](
experiment_resource_name=self._experiment.name,
experiment_resource_name=self._experiment_resource_name,
)

self._dispatcher = _Dispatcher(
Expand Down Expand Up @@ -333,7 +351,7 @@ def _create_additional_senders(self) -> Dict[str, uploader_utils.RequestSender]:
)

def get_experiment_resource_name(self):
return self._experiment.name
return self._experiment_resource_name

def start_uploading(self):
"""Blocks forever to continuously upload data from the logdir.
Expand Down Expand Up @@ -814,6 +832,18 @@ def flush(self):
run_name,
tag_to_time_series_data,
) in self._run_to_tag_to_time_series_data.items():
# print('experiment_resource_id: ', self._experiment_resource_id)
# tensorboard = self._experiment_resource_id.split("/experiments/")[0]
# experiment = self._experiment_resource_id.split("/experiments/")[1]
# print('experiment: ', experiment)
# print('tensorboard: ', tensorboard)
metadata._experiment_tracker.start_run(run_name)
# print('experiment: ', experiment)
# experiment_run_resource.ExperimentRun.create(
# run_name=run_name,
# experiment=experiment,
# #tensorboard=tensorboard,
# )
r = tensorboard_service.WriteTensorboardRunDataRequest(
tensorboard_run=self._one_platform_resource_manager.get_run_resource_name(
run_name
Expand Down

0 comments on commit 2aa8d05

Please sign in to comment.