Skip to content

Gate deep imports from torch.distributed#13673

Open
hlky wants to merge 1 commit intohuggingface:mainfrom
hlky:torch-distributed-gate
Open

Gate deep imports from torch.distributed#13673
hlky wants to merge 1 commit intohuggingface:mainfrom
hlky:torch-distributed-gate

Conversation

@hlky
Copy link
Copy Markdown
Contributor

@hlky hlky commented May 1, 2026

What does this PR do?

Some PyTorch builds, such as ROCm Windows, have non-functional torch.distributed. The module exists but any deeper imports such as from torch.distributed.fsdp import CPUOffload, ShardingStrategy will fail:

This affects various tests that indirectly import from diffusers.training_utils.

tests\pipelines\flux2\test_pipeline_flux2.py:17: in <module>
    from ..test_pipelines_common import (
.venv\Lib\site-packages\_pytest\assertion\rewrite.py:197: in exec_module
    exec(co, module.__dict__)
tests\pipelines\test_pipelines_common.py:61: in <module>
    from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict
.venv\Lib\site-packages\_pytest\assertion\rewrite.py:197: in exec_module
    exec(co, module.__dict__)
tests\models\transformers\test_models_transformer_flux.py:27: in <module>
    from ..testing_utils import (
tests\models\testing_utils\__init__.py:37: in <module>
    from .training import TrainingTesterMixin
tests\models\testing_utils\training.py:22: in <module>
    from diffusers.training_utils import EMAModel
src\diffusers\training_utils.py:18: in <module>
    from torch.distributed.fsdp import CPUOffload, ShardingStrategy
.venv\Lib\site-packages\torch\distributed\fsdp\__init__.py:1: in <module>
    from ._flat_param import FlatParameter as FlatParameter
.venv\Lib\site-packages\torch\distributed\fsdp\_flat_param.py:31: in <module>
    from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
.venv\Lib\site-packages\torch\testing\_internal\distributed\fake_pg.py:4: in <module>
    from torch._C._distributed_c10d import FakeProcessGroup
E   ModuleNotFoundError: No module named 'torch._C._distributed_c10d'; 'torch._C' is not a package

Currently, the imports are guarded with getattr(torch, "distributed", None) is not None. I have found no evidence that PyTorch builds will completely omit distributed, it appears that the module should always exist and torch.distributed.is_available() will return False if it is non-functional.

This PR replaces the guard with torch.distributed.is_available().

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions Bot added the size/S PR with diff < 50 LOC label May 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant