Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cannot import accelerate when torch>=2.0.1 and torch.distributed is disabled #1787

Closed
4 tasks
natsukium opened this issue Jul 28, 2023 · 1 comment · Fixed by #1800 or #2121
Closed
4 tasks

cannot import accelerate when torch>=2.0.1 and torch.distributed is disabled #1787

natsukium opened this issue Jul 28, 2023 · 1 comment · Fixed by #1800 or #2121

Comments

@natsukium
Copy link
Contributor

natsukium commented Jul 28, 2023

System Info

I can't run `accelerate env` because of an import error.

accelerate: 0.21.0
OS: macOS
python: 3.10.12
numpy: 1.24.2
torch: 2.0.1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

  1. build torch >= 2.0.1 with USE_DISTRIBUTED=0
  2. install accelerate == 0.21.0
  3. python -c "import accelerate"
  4. raise ModuleNotFoundError: No module named 'torch._C._distributed_c10d'; 'torch._C' is not a package
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/nix/store/v9h5iiawvw6y0j03840qxjpqc9nbk4c2-python3-3.10.12-env/lib/python3.10/site-packages/accelerate/__init__.py", line 3, in <module>
    from .accelerator import Accelerator
  File "/nix/store/v9h5iiawvw6y0j03840qxjpqc9nbk4c2-python3-3.10.12-env/lib/python3.10/site-packages/accelerate/accelerator.py", line 35, in <module>
    from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
  File "/nix/store/v9h5iiawvw6y0j03840qxjpqc9nbk4c2-python3-3.10.12-env/lib/python3.10/site-packages/accelerate/checkpointing.py", line 24, in <module>
    from .utils import (
  File "/nix/store/v9h5iiawvw6y0j03840qxjpqc9nbk4c2-python3-3.10.12-env/lib/python3.10/site-packages/accelerate/utils/__init__.py", line 132, in <module>
    from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer
  File "/nix/store/v9h5iiawvw6y0j03840qxjpqc9nbk4c2-python3-3.10.12-env/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 24, in <module>
    import torch.distributed.checkpoint as dist_cp
  File "/nix/store/v9h5iiawvw6y0j03840qxjpqc9nbk4c2-python3-3.10.12-env/lib/python3.10/site-packages/torch/distributed/checkpoint/__init__.py", line 1, in <module>
    from .metadata import (
  File "/nix/store/v9h5iiawvw6y0j03840qxjpqc9nbk4c2-python3-3.10.12-env/lib/python3.10/site-packages/torch/distributed/checkpoint/metadata.py", line 3, in <module>
    from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
  File "/nix/store/v9h5iiawvw6y0j03840qxjpqc9nbk4c2-python3-3.10.12-env/lib/python3.10/site-packages/torch/distributed/_shard/__init__.py", line 1, in <module>
    from .api import (
  File "/nix/store/v9h5iiawvw6y0j03840qxjpqc9nbk4c2-python3-3.10.12-env/lib/python3.10/site-packages/torch/distributed/_shard/api.py", line 5, in <module>
    from torch.distributed import distributed_c10d
  File "/nix/store/v9h5iiawvw6y0j03840qxjpqc9nbk4c2-python3-3.10.12-env/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 16, in <module>
    from torch._C._distributed_c10d import (

Expected behavior

This is the line in the issue.

if is_torch_version(">=", FSDP_PYTORCH_VERSION):
import torch.distributed.checkpoint as dist_cp

I think it would be better to decide whether to import torch.distributed by the result of torch.distributed.is_available() besides the torch version.

@sgugger
Copy link
Collaborator

sgugger commented Jul 28, 2023

Yes, we need a torch.distributed.is_available() in that test in case PyTorch was built without distributed support, cc @pacman100

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants