diff --git a/vertexai/preview/_workflow/serialization_engine/serializers.py b/vertexai/preview/_workflow/serialization_engine/serializers.py index e4dca7e3a5..cf49f1f2b8 100644 --- a/vertexai/preview/_workflow/serialization_engine/serializers.py +++ b/vertexai/preview/_workflow/serialization_engine/serializers.py @@ -144,6 +144,7 @@ def _is_valid_gcs_path(path: str) -> bool: def _load_torch_model(path: str, map_location: "torch.device") -> "torch.nn.Module": import torch + try: return torch.load(path, map_location=map_location) except Exception: @@ -434,7 +435,9 @@ class TorchModelSerializer(serializers_base.Serializer): serializers_base.SerializationMetadata(serializer="TorchModelSerializer") ) - def serialize(self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs) -> str: + def serialize( + self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs + ) -> str: """Serializes a torch.nn.Module to a gcs path. Args: @@ -450,6 +453,7 @@ def serialize(self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs) -> ValueError: if `gcs_path` is not a valid GCS uri. """ import torch + del kwargs if not _is_valid_gcs_path(gcs_path): raise ValueError(f"Invalid gcs path: {gcs_path}") @@ -500,11 +504,18 @@ def deserialize(self, serialized_gcs_path: str, **kwargs) -> "torch.nn.Module": except ImportError as e: raise ImportError("torch is not installed.") from e - map_location = ( - torch._GLOBAL_DEVICE_CONTEXT.device - if torch._GLOBAL_DEVICE_CONTEXT - else None - ) + # Get the default device in the local torch environment. + # If `set_default_device` hasn't been called, _GLOBAL_DEVICE_CONTEXT + # should be None, then we set map_location to None as well. + map_location = None + # In torch 2.3.0, get_default_device is introduced + if hasattr(torch._GLOBAL_DEVICE_CONTEXT, "device_context") and hasattr( + torch, "get_default_device" + ): + map_location = torch.get_default_device() + # For older versions, we get default device from _GLOBAL_DEVICE_CONTEXT + elif hasattr(torch._GLOBAL_DEVICE_CONTEXT, "device"): + map_location = torch._GLOBAL_DEVICE_CONTEXT.device if serialized_gcs_path.startswith("gs://"): with tempfile.NamedTemporaryFile() as temp_file: @@ -731,7 +742,9 @@ class TorchDataLoaderSerializer(serializers_base.Serializer): serializers_base.SerializationMetadata(serializer="TorchDataLoaderSerializer") ) - def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path: str): + def _serialize_to_local( + self, to_serialize: "torch.utils.data.DataLoader", path: str + ): """Serializes a torch.utils.data.DataLoader to a local path. Args: @@ -778,6 +791,7 @@ def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path: # for default batch sampler we store batch_size, drop_last, and sampler object # but not batch sampler object. import torch + if isinstance(to_serialize.batch_sampler, torch.utils.data.BatchSampler): pass_through_args["batch_size"] = to_serialize.batch_size pass_through_args["drop_last"] = to_serialize.drop_last @@ -797,7 +811,9 @@ def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path: with open(f"{path}/pass_through_args.json", "w") as f: json.dump(pass_through_args, f) - def serialize(self, to_serialize: "torch.utils.data.DataLoader", gcs_path: str, **kwargs) -> str: + def serialize( + self, to_serialize: "torch.utils.data.DataLoader", gcs_path: str, **kwargs + ) -> str: """Serializes a torch.utils.data.DataLoader to a gcs path. Args: @@ -883,7 +899,9 @@ def _deserialize_from_local(self, path: str) -> "torch.utils.data.DataLoader": return torch.utils.data.DataLoader(**kwargs) - def deserialize(self, serialized_gcs_path: str, **kwargs) -> "torch.utils.data.DataLoader": + def deserialize( + self, serialized_gcs_path: str, **kwargs + ) -> "torch.utils.data.DataLoader": """Deserialize a torch.utils.data.DataLoader given the gcs path. Args: