Skip to content

Commit

Permalink
Fix import error when torch>=2.0.1 and torch.distributed is disabled (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
natsukium committed Nov 8, 2023
1 parent 217e1a2 commit 76de60d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..state import PartialState
from .constants import FSDP_PYTORCH_VERSION
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_safetensors_available, is_tpu_available
from .imports import is_deepspeed_available, is_safetensors_available, is_torch_distributed_available, is_tpu_available
from .transformer_engine import convert_model
from .versions import is_torch_version

Expand Down Expand Up @@ -77,7 +77,7 @@ def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True):

options += (DeepSpeedEngine,)

if is_torch_version(">=", FSDP_PYTORCH_VERSION):
if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available():
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

options += (FSDP,)
Expand Down

0 comments on commit 76de60d

Please sign in to comment.