diff --git a/google/cloud/aiplatform/tensorboard/uploader_constants.py b/google/cloud/aiplatform/tensorboard/uploader_constants.py index 642f8cbcd5..6c82210f82 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_constants.py +++ b/google/cloud/aiplatform/tensorboard/uploader_constants.py @@ -13,16 +13,20 @@ from tensorboard.plugins.image import metadata as images_metadata from tensorboard.plugins.scalar import metadata as scalar_metadata from tensorboard.plugins.text import metadata as text_metadata - -ALLOWED_PLUGINS = [ - scalar_metadata.PLUGIN_NAME, - histogram_metadata.PLUGIN_NAME, - distribution_metadata.PLUGIN_NAME, - text_metadata.PLUGIN_NAME, - hparams_metadata.PLUGIN_NAME, - images_metadata.PLUGIN_NAME, - graphs_metadata.PLUGIN_NAME, -] +from tensorboard_plugin_profile import profile_plugin + +ALLOWED_PLUGINS = frozenset( + [ + scalar_metadata.PLUGIN_NAME, + histogram_metadata.PLUGIN_NAME, + distribution_metadata.PLUGIN_NAME, + text_metadata.PLUGIN_NAME, + hparams_metadata.PLUGIN_NAME, + images_metadata.PLUGIN_NAME, + graphs_metadata.PLUGIN_NAME, + profile_plugin.PLUGIN_NAME, + ] +) # Minimum length of a logdir polling cycle in seconds. Shorter cycles will # sleep to avoid spinning over the logdir, which isn't great for disks and can diff --git a/google/cloud/aiplatform/tensorboard/uploader_main.py b/google/cloud/aiplatform/tensorboard/uploader_main.py index 31fc51c9af..e1b131ba83 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_main.py +++ b/google/cloud/aiplatform/tensorboard/uploader_main.py @@ -103,13 +103,11 @@ def main(argv): experiment_name, FLAGS.experiment_display_name, project_id, region ) - plugins = uploader_constants.ALLOWED_PLUGINS - if FLAGS.allowed_plugins: - plugins += [ - plugin - for plugin in FLAGS.allowed_plugins - if plugin not in uploader_constants.ALLOWED_PLUGINS - ] + plugins = ( + uploader_constants.ALLOWED_PLUGINS.union(FLAGS.allowed_plugins) + if FLAGS.allowed_plugins + else uploader_constants.ALLOWED_PLUGINS + ) tb_uploader = uploader.TensorBoardUploader( experiment_name=experiment_name, diff --git a/google/cloud/aiplatform/tensorboard/uploader_tracker.py b/google/cloud/aiplatform/tensorboard/uploader_tracker.py index 5cb46256f8..023c2b99b0 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_tracker.py +++ b/google/cloud/aiplatform/tensorboard/uploader_tracker.py @@ -263,13 +263,11 @@ def _create_uploader( api_client, tensorboard_resource_name, project ) - plugins = uploader_constants.ALLOWED_PLUGINS - if allowed_plugins: - plugins += [ - plugin - for plugin in allowed_plugins - if plugin not in uploader_constants.ALLOWED_PLUGINS - ] + plugins = ( + uploader_constants.ALLOWED_PLUGINS.union(allowed_plugins) + if allowed_plugins + else uploader_constants.ALLOWED_PLUGINS + ) tensorboard_uploader = TensorBoardUploader( experiment_name=tensorboard_experiment_name, diff --git a/tests/unit/aiplatform/test_uploader.py b/tests/unit/aiplatform/test_uploader.py index f0e409fe63..31fefdbe5d 100644 --- a/tests/unit/aiplatform/test_uploader.py +++ b/tests/unit/aiplatform/test_uploader.py @@ -255,12 +255,18 @@ def _create_uploader( max_blob_size=max_blob_size, ) + plugins = ( + uploader_constants.ALLOWED_PLUGINS.union(allowed_plugins) + if allowed_plugins + else uploader_constants.ALLOWED_PLUGINS + ) + return uploader_lib.TensorBoardUploader( experiment_name=experiment_name, tensorboard_resource_name=tensorboard_resource_name, writer_client=writer_client, logdir=logdir, - allowed_plugins=allowed_plugins, + allowed_plugins=plugins, upload_limits=upload_limits, blob_storage_bucket=blob_storage_bucket, blob_storage_folder=blob_storage_folder, @@ -1239,7 +1245,7 @@ def create_time_series(tensorboard_time_series, parent=None): ) @patch.object(metadata, "_experiment_tracker", autospec=True) @patch.object(experiment_resources, "Experiment", autospec=True) - def test_add_profile_plugin( + def test_profile_plugin_included_by_default( self, experiment_resources_mock, experiment_tracker_mock, run_resource_mock ): experiment_resources_mock.get.return_value = _TEST_EXPERIMENT_NAME @@ -1259,7 +1265,6 @@ def test_add_profile_plugin( _create_mock_client(), logdir, one_shot=True, - allowed_plugins=frozenset(("profile",)), run_name_prefix=run_name, )