From faf61a4877cc26c87349748b50d6b732a2d662d4 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 7 Oct 2025 14:42:35 +0530 Subject: [PATCH 01/10] update --- src/diffusers/models/attention_dispatch.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..025cd443f0af 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -238,10 +238,12 @@ def list_backends(cls): return list(cls._backends.keys()) @classmethod - def _is_context_parallel_enabled( + def _is_context_parallel_available( cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] ) -> bool: - supports_context_parallel = backend in cls._supports_context_parallel + supports_context_parallel = ( + backend in cls._supports_context_parallel and cls._supports_context_parallel[backend] + ) is_degree_greater_than_1 = parallel_config is not None and ( parallel_config.context_parallel_config.ring_degree > 1 or parallel_config.context_parallel_config.ulysses_degree > 1 @@ -293,7 +295,7 @@ def dispatch_attention_fn( backend_name = AttentionBackendName(backend) backend_fn = _AttentionBackendRegistry._backends.get(backend_name) - if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled( + if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_available( backend_name, parallel_config ): raise ValueError( From 428399b5906f03445c5522517c230f2921ac81c2 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 7 Oct 2025 15:54:04 +0530 Subject: [PATCH 02/10] update --- src/diffusers/models/_modeling_parallel.py | 60 ++++++++++++++-------- src/diffusers/models/attention_dispatch.py | 6 +-- src/diffusers/models/modeling_utils.py | 35 +++++-------- 3 files changed, 52 insertions(+), 49 deletions(-) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a1d2cc6ceea..f48b4c4969c4 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -79,29 +79,47 @@ def __post_init__(self): if self.ulysses_degree is None: self.ulysses_degree = 1 + if self.ring_degree == 1 and self.ulysses_degree == 1: + raise ValueError( + "Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference" + ) + if self.ring_degree < 1 or self.ulysses_degree < 1: + raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") + if self.ring_degree > 1 and self.ulysses_degree > 1: + raise ValueError( + "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." + ) + if self.rotate_method != "allgather": + raise NotImplementedError( + f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." + ) + + @property + def mesh_shape(self) -> Tuple[int, int]: + """Shape of the device mesh (ring_degree, ulysses_degree).""" + return (self.ring_degree, self.ulysses_degree) + + @property + def mesh_dim_names(self) -> Tuple[str, str]: + """Dimension names for the device mesh.""" + return ("ring", "ulysses") + def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh): self._rank = rank self._world_size = world_size self._device = device self._mesh = mesh - if self.ring_degree is None: - self.ring_degree = 1 - if self.ulysses_degree is None: - self.ulysses_degree = 1 - if self.rotate_method != "allgather": - raise NotImplementedError( - f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." + + if self.ulysses_degree * self.ring_degree > world_size: + raise ValueError( + f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})." ) - if self._flattened_mesh is None: - self._flattened_mesh = self._mesh._flatten() - if self._ring_mesh is None: - self._ring_mesh = self._mesh["ring"] - if self._ulysses_mesh is None: - self._ulysses_mesh = self._mesh["ulysses"] - if self._ring_local_rank is None: - self._ring_local_rank = self._ring_mesh.get_local_rank() - if self._ulysses_local_rank is None: - self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() + + self._flattened_mesh = self._mesh._flatten() + self._ring_mesh = self._mesh["ring"] + self._ulysses_mesh = self._mesh["ulysses"] + self._ring_local_rank = self._ring_mesh.get_local_rank() + self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() @dataclass @@ -119,7 +137,7 @@ class ParallelConfig: _rank: int = None _world_size: int = None _device: torch.device = None - _cp_mesh: torch.distributed.device_mesh.DeviceMesh = None + _mesh: torch.distributed.device_mesh.DeviceMesh = None def setup( self, @@ -127,14 +145,14 @@ def setup( world_size: int, device: torch.device, *, - cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, + mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, ): self._rank = rank self._world_size = world_size self._device = device - self._cp_mesh = cp_mesh + self._mesh = mesh if self.context_parallel_config is not None: - self.context_parallel_config.setup(rank, world_size, device, cp_mesh) + self.context_parallel_config.setup(rank, world_size, device, mesh) @dataclass(frozen=True) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 025cd443f0af..efbb3afc5d39 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -244,11 +244,7 @@ def _is_context_parallel_available( supports_context_parallel = ( backend in cls._supports_context_parallel and cls._supports_context_parallel[backend] ) - is_degree_greater_than_1 = parallel_config is not None and ( - parallel_config.context_parallel_config.ring_degree > 1 - or parallel_config.context_parallel_config.ulysses_degree > 1 - ) - return supports_context_parallel and is_degree_greater_than_1 + return supports_context_parallel and parallel_config.context_parallel_config is not None @contextlib.contextmanager diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1af7ba9ac511..57c3c8866fc8 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1483,46 +1483,35 @@ def enable_parallelism( config: Union[ParallelConfig, ContextParallelConfig], cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None, ): - from ..hooks.context_parallel import apply_context_parallel - from .attention import AttentionModuleMixin - from .attention_processor import Attention, MochiAttention - logger.warning( "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." ) - if isinstance(config, ContextParallelConfig): - config = ParallelConfig(context_parallel_config=config) - if not torch.distributed.is_initialized(): raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.") + from ..hooks.context_parallel import apply_context_parallel + from .attention import AttentionModuleMixin + from .attention_processor import Attention, MochiAttention + + if isinstance(config, ContextParallelConfig): + config = ParallelConfig(context_parallel_config=config) + rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() device_type = torch._C._get_accelerator().type device_module = torch.get_device_module(device_type) device = torch.device(device_type, rank % device_module.device_count()) - cp_mesh = None + mesh = None if config.context_parallel_config is not None: cp_config = config.context_parallel_config - if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1: - raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1: - raise ValueError( - "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - ) - if cp_config.ring_degree * cp_config.ulysses_degree > world_size: - raise ValueError( - f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})." - ) - cp_mesh = torch.distributed.device_mesh.init_device_mesh( + mesh = torch.distributed.device_mesh.init_device_mesh( device_type=device_type, - mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree), - mesh_dim_names=("ring", "ulysses"), + mesh_shape=cp_config.mesh_shape, + mesh_dim_names=cp_config.mesh_dim_names, ) - - config.setup(rank, world_size, device, cp_mesh=cp_mesh) + config.setup(rank, world_size, device, mesh=mesh) if cp_plan is None and self._cp_plan is None: raise ValueError( From 1d763226751b9dca6089c326bcb5c6f34a13c3d1 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 7 Oct 2025 16:54:14 +0530 Subject: [PATCH 03/10] update --- src/diffusers/models/attention_dispatch.py | 5 ++- src/diffusers/models/modeling_utils.py | 52 +++++++++++++++++----- 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index efbb3afc5d39..f105e28ae735 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -239,12 +239,13 @@ def list_backends(cls): @classmethod def _is_context_parallel_available( - cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] + cls, + backend: AttentionBackendName, ) -> bool: supports_context_parallel = ( backend in cls._supports_context_parallel and cls._supports_context_parallel[backend] ) - return supports_context_parallel and parallel_config.context_parallel_config is not None + return supports_context_parallel @contextlib.contextmanager diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 57c3c8866fc8..0ad53c1a9ee0 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1503,6 +1503,38 @@ def enable_parallelism( device_module = torch.get_device_module(device_type) device = torch.device(device_type, rank % device_module.device_count()) + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + + # Step 1: Validate attention backend supports context parallelism if enabled + if config.context_parallel_config is not None: + from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry + + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + + attention_backend = processor._attention_backend + if attention_backend is None: + attention_backend, _ = _AttentionBackendRegistry.get_active_backend() + else: + attention_backend = AttentionBackendName(attention_backend) + + if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend): + raise ValueError( + f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' " + f"is using backend '{attention_backend.value}' which does not support context parallelism. " + f"Please set a compatible attention backend using `model.set_attention_backend()` before " + f"calling `enable_parallelism()`." + ) + + # All modules use the same attention processor and backend. We don't need to + # iterate over all modules after checking the first processor + break + mesh = None if config.context_parallel_config is not None: cp_config = config.context_parallel_config @@ -1511,20 +1543,10 @@ def enable_parallelism( mesh_shape=cp_config.mesh_shape, mesh_dim_names=cp_config.mesh_dim_names, ) - config.setup(rank, world_size, device, mesh=mesh) - - if cp_plan is None and self._cp_plan is None: - raise ValueError( - "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute." - ) - cp_plan = cp_plan if cp_plan is not None else self._cp_plan - - if config.context_parallel_config is not None: - apply_context_parallel(self, config.context_parallel_config, cp_plan) + config.setup(rank, world_size, device, mesh=mesh) self._parallel_config = config - attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): if not isinstance(module, attention_classes): continue @@ -1533,6 +1555,14 @@ def enable_parallelism( continue processor._parallel_config = config + if config.context_parallel_config is not None: + if cp_plan is None and self._cp_plan is None: + raise ValueError( + "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute." + ) + cp_plan = cp_plan if cp_plan is not None else self._cp_plan + apply_context_parallel(self, config.context_parallel_config, cp_plan) + @classmethod def _load_pretrained_model( cls, From a66787b62b81ae937cb07f5602435cc788e20384 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 7 Oct 2025 17:00:10 +0530 Subject: [PATCH 04/10] update --- src/diffusers/models/attention_dispatch.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f105e28ae735..0901337a6f2b 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -292,14 +292,6 @@ def dispatch_attention_fn( backend_name = AttentionBackendName(backend) backend_fn = _AttentionBackendRegistry._backends.get(backend_name) - if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_available( - backend_name, parallel_config - ): - raise ValueError( - f"Backend {backend_name} either does not support context parallelism or context parallelism " - f"was enabled with a world size of 1." - ) - kwargs = { "query": query, "key": key, From 881e262c08cee1ce2b20a724228f3d49e3bb0033 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 7 Oct 2025 17:35:04 +0530 Subject: [PATCH 05/10] update --- src/diffusers/models/attention_dispatch.py | 10 +++++----- src/diffusers/models/modeling_utils.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 0901337a6f2b..e412bbfa6ed6 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -207,7 +207,7 @@ class _AttentionBackendRegistry: _backends = {} _constraints = {} _supported_arg_names = {} - _supports_context_parallel = {} + _supports_context_parallel = set() _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) _checks_enabled = DIFFUSERS_ATTN_CHECKS @@ -224,7 +224,9 @@ def decorator(func): cls._backends[backend] = func cls._constraints[backend] = constraints or [] cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) - cls._supports_context_parallel[backend] = supports_context_parallel + if supports_context_parallel: + cls._supports_context_parallel.add(backend) + return func return decorator @@ -242,9 +244,7 @@ def _is_context_parallel_available( cls, backend: AttentionBackendName, ) -> bool: - supports_context_parallel = ( - backend in cls._supports_context_parallel and cls._supports_context_parallel[backend] - ) + supports_context_parallel = backend in cls._supports_context_parallel return supports_context_parallel diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0ad53c1a9ee0..bbd8f9b2257e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1527,7 +1527,7 @@ def enable_parallelism( raise ValueError( f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' " f"is using backend '{attention_backend.value}' which does not support context parallelism. " - f"Please set a compatible attention backend using `model.set_attention_backend()` before " + f"Please set a compatible attention backend: {_AttentionBackendRegistry._supports_context_parallel} using `model.set_attention_backend()` before " f"calling `enable_parallelism()`." ) From 0845ca07d30a2f76327f8381213ff03c62559614 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 7 Oct 2025 17:37:50 +0530 Subject: [PATCH 06/10] update --- src/diffusers/models/attention_dispatch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e412bbfa6ed6..dff6f43934fc 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -225,7 +225,7 @@ def decorator(func): cls._constraints[backend] = constraints or [] cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) if supports_context_parallel: - cls._supports_context_parallel.add(backend) + cls._supports_context_parallel.add(backend.value) return func @@ -244,7 +244,7 @@ def _is_context_parallel_available( cls, backend: AttentionBackendName, ) -> bool: - supports_context_parallel = backend in cls._supports_context_parallel + supports_context_parallel = backend.value in cls._supports_context_parallel return supports_context_parallel From 8018a6a733fa2891cc637a5eeab1ba4ff8d1be66 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 7 Oct 2025 17:45:42 +0530 Subject: [PATCH 07/10] update --- src/diffusers/models/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index bbd8f9b2257e..9ba04bb556e9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1524,10 +1524,11 @@ def enable_parallelism( attention_backend = AttentionBackendName(attention_backend) if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend): + compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel) raise ValueError( f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' " f"is using backend '{attention_backend.value}' which does not support context parallelism. " - f"Please set a compatible attention backend: {_AttentionBackendRegistry._supports_context_parallel} using `model.set_attention_backend()` before " + f"Please set a compatible attention backend: {compatible_backends}) using `model.set_attention_backend()` before " f"calling `enable_parallelism()`." ) From f92578342ff460bc86e14fa6e9f98e9ac3cd0a1d Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 7 Oct 2025 17:47:02 +0530 Subject: [PATCH 08/10] update --- src/diffusers/models/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 9ba04bb556e9..5475858dc09b 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1528,7 +1528,7 @@ def enable_parallelism( raise ValueError( f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' " f"is using backend '{attention_backend.value}' which does not support context parallelism. " - f"Please set a compatible attention backend: {compatible_backends}) using `model.set_attention_backend()` before " + f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before " f"calling `enable_parallelism()`." ) From 5bfc7dd419747d824581ecaa527631d8a547f3bc Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 7 Oct 2025 18:40:24 +0530 Subject: [PATCH 09/10] update --- src/diffusers/models/modeling_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5475858dc09b..7c647b5c0adb 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1487,8 +1487,10 @@ def enable_parallelism( "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." ) - if not torch.distributed.is_initialized(): - raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.") + if not torch.distributed.is_available() and not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed must be available and initialized before calling `enable_parallelism`." + ) from ..hooks.context_parallel import apply_context_parallel from .attention import AttentionModuleMixin From fb15ff526f79bb0aecb02bc6e018e2c2b394bdc5 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 8 Oct 2025 14:35:36 +0530 Subject: [PATCH 10/10] update --- src/diffusers/models/modeling_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 7c647b5c0adb..d8ba27a2fcaa 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1494,6 +1494,7 @@ def enable_parallelism( from ..hooks.context_parallel import apply_context_parallel from .attention import AttentionModuleMixin + from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry from .attention_processor import Attention, MochiAttention if isinstance(config, ContextParallelConfig): @@ -1509,8 +1510,6 @@ def enable_parallelism( # Step 1: Validate attention backend supports context parallelism if enabled if config.context_parallel_config is not None: - from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry - for module in self.modules(): if not isinstance(module, attention_classes): continue