Skip to content

Commit

Permalink
fix: AttributeError for TorchModelSerializer.deserialize in torch >=2…
Browse files Browse the repository at this point in the history
….3.0

PiperOrigin-RevId: 631215839
  • Loading branch information
jaycee-li authored and Copybara-Service committed May 6, 2024
1 parent 195c77e commit 20b1866
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions vertexai/preview/_workflow/serialization_engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def _is_valid_gcs_path(path: str) -> bool:

def _load_torch_model(path: str, map_location: "torch.device") -> "torch.nn.Module":
import torch

try:
return torch.load(path, map_location=map_location)
except Exception:
Expand Down Expand Up @@ -434,7 +435,9 @@ class TorchModelSerializer(serializers_base.Serializer):
serializers_base.SerializationMetadata(serializer="TorchModelSerializer")
)

def serialize(self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs) -> str:
def serialize(
self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs
) -> str:
"""Serializes a torch.nn.Module to a gcs path.
Args:
Expand All @@ -450,6 +453,7 @@ def serialize(self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs) ->
ValueError: if `gcs_path` is not a valid GCS uri.
"""
import torch

del kwargs
if not _is_valid_gcs_path(gcs_path):
raise ValueError(f"Invalid gcs path: {gcs_path}")
Expand Down Expand Up @@ -500,11 +504,18 @@ def deserialize(self, serialized_gcs_path: str, **kwargs) -> "torch.nn.Module":
except ImportError as e:
raise ImportError("torch is not installed.") from e

map_location = (
torch._GLOBAL_DEVICE_CONTEXT.device
if torch._GLOBAL_DEVICE_CONTEXT
else None
)
# Get the default device in the local torch environment.
# If `set_default_device` hasn't been called, _GLOBAL_DEVICE_CONTEXT
# should be None, then we set map_location to None as well.
map_location = None
# In torch 2.3.0, get_default_device is introduced
if hasattr(torch._GLOBAL_DEVICE_CONTEXT, "device_context") and hasattr(
torch, "get_default_device"
):
map_location = torch.get_default_device()
# For older versions, we get default device from _GLOBAL_DEVICE_CONTEXT
elif hasattr(torch._GLOBAL_DEVICE_CONTEXT, "device"):
map_location = torch._GLOBAL_DEVICE_CONTEXT.device

if serialized_gcs_path.startswith("gs://"):
with tempfile.NamedTemporaryFile() as temp_file:
Expand Down Expand Up @@ -731,7 +742,9 @@ class TorchDataLoaderSerializer(serializers_base.Serializer):
serializers_base.SerializationMetadata(serializer="TorchDataLoaderSerializer")
)

def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path: str):
def _serialize_to_local(
self, to_serialize: "torch.utils.data.DataLoader", path: str
):
"""Serializes a torch.utils.data.DataLoader to a local path.
Args:
Expand Down Expand Up @@ -778,6 +791,7 @@ def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path:
# for default batch sampler we store batch_size, drop_last, and sampler object
# but not batch sampler object.
import torch

if isinstance(to_serialize.batch_sampler, torch.utils.data.BatchSampler):
pass_through_args["batch_size"] = to_serialize.batch_size
pass_through_args["drop_last"] = to_serialize.drop_last
Expand All @@ -797,7 +811,9 @@ def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path:
with open(f"{path}/pass_through_args.json", "w") as f:
json.dump(pass_through_args, f)

def serialize(self, to_serialize: "torch.utils.data.DataLoader", gcs_path: str, **kwargs) -> str:
def serialize(
self, to_serialize: "torch.utils.data.DataLoader", gcs_path: str, **kwargs
) -> str:
"""Serializes a torch.utils.data.DataLoader to a gcs path.
Args:
Expand Down Expand Up @@ -883,7 +899,9 @@ def _deserialize_from_local(self, path: str) -> "torch.utils.data.DataLoader":

return torch.utils.data.DataLoader(**kwargs)

def deserialize(self, serialized_gcs_path: str, **kwargs) -> "torch.utils.data.DataLoader":
def deserialize(
self, serialized_gcs_path: str, **kwargs
) -> "torch.utils.data.DataLoader":
"""Deserialize a torch.utils.data.DataLoader given the gcs path.
Args:
Expand Down

0 comments on commit 20b1866

Please sign in to comment.