Skip to content

Commit

Permalink
fix: Register TensorFlow models from Ray checkpoints for more recent …
Browse files Browse the repository at this point in the history
…TensorFlow version, addressing the deprecation of SavedModel format in keras 3

PiperOrigin-RevId: 628562509
  • Loading branch information
yinghsienwu authored and Copybara-Service committed Apr 27, 2024
1 parent 9809a3a commit 1341e2c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.preview.vertex_ray.predict.util import constants
from google.cloud.aiplatform.preview.vertex_ray.predict.util import (
predict_utils,
)
Expand All @@ -44,6 +45,7 @@ def register_tensorflow(
artifact_uri: Optional[str] = None,
_model: Optional[Union["tf.keras.Model", Callable[[], "tf.keras.Model"]]] = None,
display_name: Optional[str] = None,
tensorflow_version: Optional[str] = None,
**kwargs,
) -> aiplatform.Model:
"""Uploads a Ray Tensorflow Checkpoint as Tensorflow Model to Model Registry.
Expand Down Expand Up @@ -79,6 +81,11 @@ def create_model():
display_name (str):
Optional. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
tensorflow_version (str):
Optional. The version of the Tensorflow serving container.
Supported versions:
https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
If the version is not specified, the latest version is used.
**kwargs:
Any kwargs will be passed to aiplatform.Model registration.
Expand All @@ -89,6 +96,9 @@ def create_model():
Raises:
ValueError: Invalid Argument.
"""

if tensorflow_version is None:
tensorflow_version = constants._TENSORFLOW_VERSION
artifact_uri = artifact_uri or initializer.global_config.staging_bucket
predict_utils.validate_artifact_uri(artifact_uri)
prefix = "ray-on-vertex-registered-tensorflow-model"
Expand All @@ -99,10 +109,16 @@ def create_model():
)
tf_model = _get_tensorflow_model_from(checkpoint, model=_model)
model_dir = os.path.join(artifact_uri, prefix)
tf_model.save(model_dir)
try:
import tensorflow as tf

tf.saved_model.save(tf_model, model_dir)
except ImportError:
logging.warning("TensorFlow must be installed to save the trained model.")
return aiplatform.Model.upload_tensorflow_saved_model(
saved_model_dir=model_dir,
display_name=display_model_name,
tensorflow_version=tensorflow_version,
**kwargs,
)

Expand Down Expand Up @@ -139,13 +155,13 @@ def _get_tensorflow_model_from(

return checkpoint.get_model(model)

# get_model() signature changed in future versions
try:
from tensorflow import keras
import tensorflow as tf

try:
return keras.models.load_model(checkpoint.path)
return tf.saved_model.load(checkpoint.path)
except OSError:
return keras.models.load_model("gs://" + checkpoint.path)
return tf.saved_model.load("gs://" + checkpoint.path)

except ImportError:
logging.warning("TensorFlow must be installed to load the trained model.")
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,8 @@
_PICKLE_EXTENTION = ".pkl"

_XGBOOST_VERSION = "1.6"
# TensorFlow 2.13 requires typing_extensions<4.6 and will cause errors in Ray.
# https://github.com/tensorflow/tensorflow/blob/v2.13.0/tensorflow/tools/pip_package/setup.py#L100
# 2.13 is the latest supported version of Vertex prebuilt prediction container.
# Set 2.12 as default here since 2.13 cause errors.
_TENSORFLOW_VERSION = "2.12"
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def register_xgboost(
Optional. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
xgboost_version (str): Optional. The version of the XGBoost serving container.
Supported versions: ["0.82", "0.90", "1.1", "1.2", "1.3", "1.4", "1.6", "1.7"].
If the version is not specified, the latest version is used.
Supported versions:
https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
If the version is not specified, the version 1.6 is used.
**kwargs:
Any kwargs will be passed to aiplatform.Model registration.
Expand Down

0 comments on commit 1341e2c

Please sign in to comment.