Skip to content

Commit

Permalink
fix: Improve import time by moving TensorFlow to lazy import
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613018264
  • Loading branch information
yinghsienwu authored and Copybara-Service committed Mar 6, 2024
1 parent 2690e72 commit f294ba8
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 31 deletions.
3 changes: 2 additions & 1 deletion setup.py
Expand Up @@ -142,7 +142,8 @@
"pytest-asyncio",
"pytest-xdist",
"scikit-learn",
"tensorflow >= 2.3.0, <= 2.12.0",
# Lazy import requires > 2.12.0
"tensorflow == 2.13.0",
# TODO(jayceeli) torch 2.1.0 has conflict with pyfakefs, will check if
# future versions fix this issue
"torch >= 2.0.0, < 2.1.0",
Expand Down
10 changes: 0 additions & 10 deletions vertexai/preview/_workflow/executor/training_script.py
Expand Up @@ -33,16 +33,6 @@
from vertexai.preview.developer import remote_specs


try:
# This line ensures a tensorflow model to be loaded by cloudpickle correctly
# We put it in a try clause since not all models are tensorflow and if it is
# a tensorflow model, the dependency should've been installed and therefore
# import should work.
import tensorflow as tf # noqa: F401
except ImportError:
pass


os.environ["_IS_VERTEX_REMOTE_TRAINING"] = "True"

print(constants._START_EXECUTION_MSG)
Expand Down
54 changes: 38 additions & 16 deletions vertexai/preview/_workflow/serialization_engine/serializers.py
Expand Up @@ -25,7 +25,7 @@
import pickle
import shutil
import tempfile
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union, TYPE_CHECKING
import uuid

from google.cloud.aiplatform.utils import gcs_utils
Expand All @@ -48,17 +48,18 @@

SERIALIZATION_METADATA_FRAMEWORK_KEY = "framework"

try:
from tensorflow import keras
import tensorflow as tf
if TYPE_CHECKING:
try:
from tensorflow import keras
import tensorflow as tf

KerasModel = keras.models.Model
TFDataset = tf.data.Dataset
except ImportError:
keras = None
tf = None
KerasModel = Any
TFDataset = Any
KerasModel = keras.models.Model
TFDataset = tf.data.Dataset
except ImportError:
keras = None
tf = None
KerasModel = Any
TFDataset = Any

try:
import torch
Expand Down Expand Up @@ -184,7 +185,7 @@ class KerasModelSerializer(serializers_base.Serializer):
)

def serialize(
self, to_serialize: KerasModel, gcs_path: str, **kwargs
self, to_serialize: "keras.models.Model", gcs_path: str, **kwargs # noqa: F821
) -> str: # pytype: disable=invalid-annotation
"""Serializes a tensorflow.keras.models.Model to a gcs path.
Expand Down Expand Up @@ -232,7 +233,9 @@ def serialize(
to_serialize.save(gcs_path, save_format=save_format)
return gcs_path

def deserialize(self, serialized_gcs_path: str, **kwargs) -> KerasModel:
def deserialize(
self, serialized_gcs_path: str, **kwargs
) -> "keras.models.Model": # noqa: F821
"""Deserialize a tensorflow.keras.models.Model given the gcs file name.
Args:
Expand Down Expand Up @@ -335,6 +338,7 @@ def deserialize(self, serialized_gcs_path: str, **kwargs):
Raises:
ValueError: if `serialized_gcs_path` is not a valid GCS uri.
"""
from tensorflow import keras

if not _is_valid_gcs_path(serialized_gcs_path):
raise ValueError(f"Invalid gcs path: {serialized_gcs_path}")
Expand Down Expand Up @@ -922,8 +926,12 @@ class TFDatasetSerializer(serializers_base.Serializer):
serializers_base.SerializationMetadata(serializer="TFDatasetSerializer")
)

def serialize(self, to_serialize: TFDataset, gcs_path: str, **kwargs) -> str:
def serialize(
self, to_serialize: "tf.data.Dataset", gcs_path: str, **kwargs # noqa: F821
) -> str: # noqa: F821
del kwargs
import tensorflow as tf

if not _is_valid_gcs_path(gcs_path):
raise ValueError(f"Invalid gcs path: {gcs_path}")
TFDatasetSerializer._metadata.dependencies = (
Expand All @@ -936,8 +944,12 @@ def serialize(self, to_serialize: TFDataset, gcs_path: str, **kwargs) -> str:
tf.data.experimental.save(to_serialize, gcs_path)
return gcs_path

def deserialize(self, serialized_gcs_path: str, **kwargs) -> TFDataset:
def deserialize(
self, serialized_gcs_path: str, **kwargs
) -> "tf.data.Dataset": # noqa: F821
del kwargs
import tensorflow as tf

try:
deserialized = tf.data.Dataset.load(serialized_gcs_path)
except AttributeError:
Expand Down Expand Up @@ -1180,6 +1192,11 @@ def serialize(
return gcs_path

def _get_tfio_verison(self):
import tensorflow as tf

if tf.__version__ < "2.13.0":
raise ValueError("TensorFlow version < 2.13.0 is not supported.")

major, minor, _ = version.Version(tf.__version__).release
tf_version = f"{major}.{minor}"

Expand Down Expand Up @@ -1277,7 +1294,7 @@ def _deserialize_tensorflow(
serialized_gcs_path: str,
batch_size: Optional[int] = None,
target_col: Optional[str] = None,
) -> TFDataset:
) -> "tf.data.Dataset": # noqa: F821
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
serialized_gcs_path is a folder containing one or more parquet files.
Expand All @@ -1287,6 +1304,11 @@ def _deserialize_tensorflow(
target_col = target_col.encode("ASCII") if target_col else b"target"

# Deserialization at remote environment
import tensorflow as tf

if tf.__version__ < "2.13.0":
raise ValueError("TensorFlow version < 2.13.0 is not supported.")

try:
import tensorflow_io as tfio
except ImportError as e:
Expand Down
14 changes: 10 additions & 4 deletions vertexai/preview/developer/remote_specs.py
Expand Up @@ -34,10 +34,6 @@
serializers,
)

try:
import tensorflow as tf
except ImportError:
pass
try:
import torch
except ImportError:
Expand Down Expand Up @@ -763,6 +759,11 @@ def _get_keras_distributed_strategy(enable_distributed: bool, accelerator_count:
Returns:
A tf.distribute.Strategy.
"""
import tensorflow as tf

if tf.__version__ < "2.13.0":
raise ValueError("TensorFlow version < 2.13.0 is not supported.")

if enable_distributed:
cluster_spec = _get_cluster_spec()
# Multiple workers, use tf.distribute.MultiWorkerMirroredStrategy().
Expand Down Expand Up @@ -793,6 +794,11 @@ def _set_keras_distributed_strategy(model: Any, strategy: Any):
A tf.distribute.Strategy.
"""
# Clone and compile model within scope of chosen strategy.
import tensorflow as tf

if tf.__version__ < "2.13.0":
raise ValueError("TensorFlow version < 2.13.0 is not supported.")

with strategy.scope():
cloned_model = tf.keras.models.clone_model(model)
cloned_model.compile_from_config(model.get_compile_config())
Expand Down

0 comments on commit f294ba8

Please sign in to comment.