Skip to content

Commit

Permalink
feat: Enable continuous upload for profile logs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 624258810
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Apr 12, 2024
1 parent 894c73f commit f05924d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def _profile_dir(self, run_name: str) -> str:
Returns:
Full path for run name.
"""
if run_name is None:
return os.path.join(self._logdir, self.PROFILE_PATH)
return os.path.join(self._logdir, run_name, self.PROFILE_PATH)

def send_request(self, run_name: str):
Expand All @@ -171,7 +173,7 @@ def send_request(self, run_name: str):
"""

if not self._is_valid_event(run_name):
logger.warning("No such profile run for %s", run_name)
logger.debug("No such profile run for %s", run_name)
return

# Create a profiler loader if one is not created.
Expand Down
4 changes: 0 additions & 4 deletions google/cloud/aiplatform/tensorboard/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,6 @@ def create_experiment(self):
def _should_profile(self) -> bool:
"""Indicate if profile plugin should be enabled."""
if "profile" in self._allowed_plugins:
if not self._one_shot:
raise ValueError(
"Profile plugin currently only supported for one shot."
)
logger.info("Profile plugin is enabled.")
return True
return False
Expand Down
63 changes: 56 additions & 7 deletions tests/unit/aiplatform/test_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@
)
)

_SCALARS_HISTOGRAMS_AND_PROFILE = frozenset(
(
scalars_metadata.PLUGIN_NAME,
"profile",
)
)


# Sentinel for `_create_*` helpers, for arguments for which we want to
# supply a default other than the `None` used by the code under test.
_USE_DEFAULT = object()
Expand Down Expand Up @@ -1095,7 +1103,23 @@ def test_thread_continuously_uploads(self):

logdir = self.get_temp_dir()
mock_client = _create_mock_client()
uploader = _create_uploader(mock_client, logdir)
builder = _create_dispatcher(
experiment_resource_name=_TEST_ONE_PLATFORM_EXPERIMENT_NAME,
api=mock_client,
allowed_plugins=_SCALARS_HISTOGRAMS_AND_PROFILE,
logdir=logdir,
)
mock_rate_limiter = mock.create_autospec(util.RateLimiter)
mock_bucket = _create_mock_blob_storage()

uploader = _create_uploader(
mock_client,
logdir,
allowed_plugins=_SCALARS_HISTOGRAMS_AND_PROFILE,
rpc_rate_limiter=mock_rate_limiter,
blob_storage_bucket=mock_bucket,
)
uploader._dispatcher = builder
uploader.create_experiment()

# Convenience helpers for constructing expected requests.
Expand All @@ -1104,7 +1128,7 @@ def test_thread_continuously_uploads(self):
scalar = tensorboard_data.Scalar

# Directory with scalar data
writer = FileWriter(logdir)
writer = FileWriter(os.path.join(logdir, "a"))
metadata = summary_pb2.SummaryMetadata(
plugin_data=summary_pb2.SummaryMetadata.PluginData(
plugin_name="scalars", content=b"12345"
Expand All @@ -1121,18 +1145,43 @@ def test_thread_continuously_uploads(self):
value_metadata=metadata,
)
writer.flush()
writer_a = FileWriter(os.path.join(logdir, "a"))
writer_a = FileWriter(os.path.join(logdir, "b"))
writer_a.add_test_summary("qux", simple_value=9.0, step=2)
writer_a.flush()

# Directory with profile data
prof_run_name = "2024_04_04_04_24_24"
prof_path = os.path.join(
logdir, profile_uploader.ProfileRequestSender.PROFILE_PATH
)
os.makedirs(prof_path)
run_path = os.path.join(prof_path, prof_run_name)
os.makedirs(run_path)
tempfile.NamedTemporaryFile(
prefix="c", suffix=".xplane.pb", dir=run_path, delete=False
)
self.assertNotEmpty(os.listdir(run_path))

uploader_thread = threading.Thread(target=uploader.start_uploading)
uploader_thread.start()
time.sleep(5)
self.assertEqual(3, mock_client.create_tensorboard_time_series.call_count)

# Check create_time_series calls
self.assertEqual(4, mock_client.create_tensorboard_time_series.call_count)
call_args_list = mock_client.create_tensorboard_time_series.call_args_list
request = call_args_list[1][1]["tensorboard_time_series"]
self.assertEqual("scalars", request.plugin_name)
self.assertEqual(b"12345", request.plugin_data)
request1, request2, request3, request4 = (
call_args_list[0][1]["tensorboard_time_series"],
call_args_list[1][1]["tensorboard_time_series"],
call_args_list[2][1]["tensorboard_time_series"],
call_args_list[3][1]["tensorboard_time_series"],
)
self.assertEqual("scalars", request1.plugin_name)
self.assertEqual("scalars", request2.plugin_name)
self.assertEqual(b"12345", request2.plugin_data)
self.assertEqual("scalars", request3.plugin_name)
self.assertEqual("profile", request4.plugin_name)

# Check write_tensorboard_experiment_data calls
self.assertEqual(1, mock_client.write_tensorboard_experiment_data.call_count)
call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list
request1, request2 = (
Expand Down

0 comments on commit f05924d

Please sign in to comment.