diff --git a/vertexai/preview/developer/remote_specs.py b/vertexai/preview/developer/remote_specs.py index fd68a276be..c8a418b880 100644 --- a/vertexai/preview/developer/remote_specs.py +++ b/vertexai/preview/developer/remote_specs.py @@ -34,10 +34,6 @@ serializers, ) -try: - import torch -except ImportError: - pass _LOGGER = base.Logger(__name__) @@ -842,6 +838,8 @@ def my_train_method(self, ...): Returns: A custom model built on top of `torch.nn.Module` wrapped in DistributedDataParallel. """ + import torch + if not model.cluster_spec: # cluster_spec is populated for multi-worker training return model