diff --git a/vertexai/preview/_workflow/serialization_engine/serializers.py b/vertexai/preview/_workflow/serialization_engine/serializers.py index 33361d07af..a1291b5591 100644 --- a/vertexai/preview/_workflow/serialization_engine/serializers.py +++ b/vertexai/preview/_workflow/serialization_engine/serializers.py @@ -38,6 +38,8 @@ serializers_base, ) +from packaging import version + try: import cloudpickle except ImportError: @@ -125,6 +127,21 @@ _LIGHTNING_ROOT_DIR = "/vertex_lightning_root_dir/" SERIALIZATION_METADATA_FILENAME = "serialization_metadata" +# Map tf major.minor version to tfio version from https://pypi.org/project/tensorflow-io/ +_TFIO_VERSION_DICT = { + "2.3": "0.16.0", # Align with testing_extra_require: tensorflow >= 2.3.0 + "2.4": "0.17.1", + "2.5": "0.19.1", + "2.6": "0.21.0", + "2.7": "0.23.1", + "2.8": "0.25.0", + "2.9": "0.26.0", + "2.10": "0.27.0", + "2.11": "0.31.0", + "2.12": "0.32.0", + "2.13": "0.34.0", # TODO(b/295580335): Support TF 2.13 +} + def get_uri_prefix(gcs_uri: str) -> str: """Gets the directory of the gcs_uri. @@ -1117,13 +1134,7 @@ def serialize( gcs_path: str, **kwargs, ) -> str: - # All bigframe serializers will be identical (bigframes.dataframe.DataFrame --> parquet) - # Record the framework in metadata for deserialization - detected_framework = kwargs.get("framework") - BigframeSerializer._metadata.framework = detected_framework - if detected_framework == "torch": - self.register_custom_command("pip install torchdata") - self.register_custom_command("pip install torcharrow") + # All bigframe serializers will convert bigframes.dataframe.DataFrame --> parquet if not _is_valid_gcs_path(gcs_path): raise ValueError(f"Invalid gcs path: {gcs_path}") @@ -1131,6 +1142,16 @@ def serialize( supported_frameworks._get_bigframe_deps() ) + # Record the framework in metadata for deserialization + detected_framework = kwargs.get("framework") + BigframeSerializer._metadata.framework = detected_framework + if detected_framework == "torch": + self.register_custom_command("pip install torchdata") + self.register_custom_command("pip install torcharrow") + elif detected_framework == "tensorflow": + tensorflow_io_dep = "tensorflow-io==" + self._get_tfio_verison() + BigframeSerializer._metadata.dependencies.append(tensorflow_io_dep) + # Check if index.name is default and set index.name if not if to_serialize.index.name and to_serialize.index.name != "index": raise ValueError("Index name must be 'index'") @@ -1141,6 +1162,17 @@ def serialize( parquet_gcs_path = gcs_path + "/*" # path is required to contain '*' to_serialize.to_parquet(parquet_gcs_path, index=True) + def _get_tfio_verison(self): + major, minor, _ = version.Version(tf.__version__).release + tf_version = f"{major}.{minor}" + + if tf_version not in _TFIO_VERSION_DICT: + raise ValueError( + f"Tensorflow version {tf_version} is not supported for Bigframes." + + " Supported versions: tensorflow >= 2.3.0, <= 2.12.0." + ) + return _TFIO_VERSION_DICT[tf_version] + def deserialize( self, serialized_gcs_path: str, **kwargs ) -> Union[PandasData, BigframesData]: