Skip to content

Commit

Permalink
feat: Enable Tensorboard profile plugin in all regions by default.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638377255
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed May 29, 2024
1 parent cb2f4aa commit 8a4a41a
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 27 deletions.
24 changes: 14 additions & 10 deletions google/cloud/aiplatform/tensorboard/uploader_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
from tensorboard.plugins.image import metadata as images_metadata
from tensorboard.plugins.scalar import metadata as scalar_metadata
from tensorboard.plugins.text import metadata as text_metadata

ALLOWED_PLUGINS = [
scalar_metadata.PLUGIN_NAME,
histogram_metadata.PLUGIN_NAME,
distribution_metadata.PLUGIN_NAME,
text_metadata.PLUGIN_NAME,
hparams_metadata.PLUGIN_NAME,
images_metadata.PLUGIN_NAME,
graphs_metadata.PLUGIN_NAME,
]
from tensorboard_plugin_profile import profile_plugin

ALLOWED_PLUGINS = frozenset(
[
scalar_metadata.PLUGIN_NAME,
histogram_metadata.PLUGIN_NAME,
distribution_metadata.PLUGIN_NAME,
text_metadata.PLUGIN_NAME,
hparams_metadata.PLUGIN_NAME,
images_metadata.PLUGIN_NAME,
graphs_metadata.PLUGIN_NAME,
profile_plugin.PLUGIN_NAME,
]
)

# Minimum length of a logdir polling cycle in seconds. Shorter cycles will
# sleep to avoid spinning over the logdir, which isn't great for disks and can
Expand Down
12 changes: 5 additions & 7 deletions google/cloud/aiplatform/tensorboard/uploader_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,11 @@ def main(argv):
experiment_name, FLAGS.experiment_display_name, project_id, region
)

plugins = uploader_constants.ALLOWED_PLUGINS
if FLAGS.allowed_plugins:
plugins += [
plugin
for plugin in FLAGS.allowed_plugins
if plugin not in uploader_constants.ALLOWED_PLUGINS
]
plugins = (
uploader_constants.ALLOWED_PLUGINS.union(FLAGS.allowed_plugins)
if FLAGS.allowed_plugins
else uploader_constants.ALLOWED_PLUGINS
)

tb_uploader = uploader.TensorBoardUploader(
experiment_name=experiment_name,
Expand Down
12 changes: 5 additions & 7 deletions google/cloud/aiplatform/tensorboard/uploader_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,11 @@ def _create_uploader(
api_client, tensorboard_resource_name, project
)

plugins = uploader_constants.ALLOWED_PLUGINS
if allowed_plugins:
plugins += [
plugin
for plugin in allowed_plugins
if plugin not in uploader_constants.ALLOWED_PLUGINS
]
plugins = (
uploader_constants.ALLOWED_PLUGINS.union(allowed_plugins)
if allowed_plugins
else uploader_constants.ALLOWED_PLUGINS
)

tensorboard_uploader = TensorBoardUploader(
experiment_name=tensorboard_experiment_name,
Expand Down
11 changes: 8 additions & 3 deletions tests/unit/aiplatform/test_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,18 @@ def _create_uploader(
max_blob_size=max_blob_size,
)

plugins = (
uploader_constants.ALLOWED_PLUGINS.union(allowed_plugins)
if allowed_plugins
else uploader_constants.ALLOWED_PLUGINS
)

return uploader_lib.TensorBoardUploader(
experiment_name=experiment_name,
tensorboard_resource_name=tensorboard_resource_name,
writer_client=writer_client,
logdir=logdir,
allowed_plugins=allowed_plugins,
allowed_plugins=plugins,
upload_limits=upload_limits,
blob_storage_bucket=blob_storage_bucket,
blob_storage_folder=blob_storage_folder,
Expand Down Expand Up @@ -1239,7 +1245,7 @@ def create_time_series(tensorboard_time_series, parent=None):
)
@patch.object(metadata, "_experiment_tracker", autospec=True)
@patch.object(experiment_resources, "Experiment", autospec=True)
def test_add_profile_plugin(
def test_profile_plugin_included_by_default(
self, experiment_resources_mock, experiment_tracker_mock, run_resource_mock
):
experiment_resources_mock.get.return_value = _TEST_EXPERIMENT_NAME
Expand All @@ -1259,7 +1265,6 @@ def test_add_profile_plugin(
_create_mock_client(),
logdir,
one_shot=True,
allowed_plugins=frozenset(("profile",)),
run_name_prefix=run_name,
)

Expand Down

0 comments on commit 8a4a41a

Please sign in to comment.