Skip to content

Commit

Permalink
fix: Fix bug that broke profiler with '0-rc2' tensorflow versions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 491683085
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Nov 29, 2022
1 parent 3e95e8d commit 8779df5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

"""A plugin to handle remote tensoflow profiler sessions for Vertex AI."""

from google.cloud.aiplatform.training_utils.cloud_profiler import cloud_profiler_utils
from google.cloud.aiplatform.training_utils.cloud_profiler import (
cloud_profiler_utils,
)

try:
import tensorflow as tf
from tensorboard_plugin_profile.profile_plugin import ProfilePlugin
from tensorboard_plugin_profile.profile_plugin import (
ProfilePlugin,
)
except ImportError as err:
raise ImportError(cloud_profiler_utils.import_error_msg) from err

Expand All @@ -36,10 +40,14 @@
import tensorboard.plugins.base_plugin as tensorboard_base_plugin
from werkzeug import Response

from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import (
profile_uploader,
)
from google.cloud.aiplatform.training_utils import environment_variables
from google.cloud.aiplatform.training_utils.cloud_profiler import wsgi_types
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import (
base_plugin,
)
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import (
tensorboard_api,
)
Expand Down Expand Up @@ -68,8 +76,7 @@ def _get_tf_versioning() -> Optional[Version]:
versioning = version.split(".")
if len(versioning) != 3:
return

return Version(int(versioning[0]), int(versioning[1]), int(versioning[2]))
return Version(int(versioning[0]), int(versioning[1]), versioning[2])


def _is_compatible_version(version: Version) -> bool:
Expand Down Expand Up @@ -228,7 +235,7 @@ def warn_tensorboard_env_var(var_name: str):
Required. The name of the missing environment variable.
"""
logging.warning(
f"Environment variable `{var_name}` must be set. " + _BASE_TB_ENV_WARNING
"Environment variable `%s` must be set. %s", var_name, _BASE_TB_ENV_WARNING
)


Expand Down
18 changes: 14 additions & 4 deletions tests/unit/aiplatform/test_cloud_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@
from google.api_core import exceptions
from google.cloud import aiplatform
from google.cloud.aiplatform import training_utils
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import (
profile_uploader,
)
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import (
base_plugin,
)
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import (
tf_profiler,
)
Expand Down Expand Up @@ -175,15 +179,21 @@ def tf_import_mock(name, *args, **kwargs):
def testCanInitializeTFVersion(self):
import tensorflow

with mock.patch.object(tensorflow, "__version__", return_value="1.2.3.4"):
with mock.patch.object(tensorflow, "__version__", "1.2.3.4"):
assert not TFProfiler.can_initialize()

def testCanInitializeOldTFVersion(self):
import tensorflow

with mock.patch.object(tensorflow, "__version__", return_value="2.3.0"):
with mock.patch.object(tensorflow, "__version__", "2.3.0"):
assert not TFProfiler.can_initialize()

def testCanInitializeRcTFVersion(self):
import tensorflow as tf

with mock.patch.object(tf, "__version__", "2.4.0-rc2"):
assert TFProfiler.can_initialize()

def testCanInitializeNoProfilePlugin(self):
orig_find_spec = importlib.util.find_spec

Expand Down

0 comments on commit 8779df5

Please sign in to comment.