diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a1d2cc6ceea..f48b4c4969c4 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -79,29 +79,47 @@ def __post_init__(self): if self.ulysses_degree is None: self.ulysses_degree = 1 + if self.ring_degree == 1 and self.ulysses_degree == 1: + raise ValueError( + "Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference" + ) + if self.ring_degree < 1 or self.ulysses_degree < 1: + raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") + if self.ring_degree > 1 and self.ulysses_degree > 1: + raise ValueError( + "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." + ) + if self.rotate_method != "allgather": + raise NotImplementedError( + f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." + ) + + @property + def mesh_shape(self) -> Tuple[int, int]: + """Shape of the device mesh (ring_degree, ulysses_degree).""" + return (self.ring_degree, self.ulysses_degree) + + @property + def mesh_dim_names(self) -> Tuple[str, str]: + """Dimension names for the device mesh.""" + return ("ring", "ulysses") + def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh): self._rank = rank self._world_size = world_size self._device = device self._mesh = mesh - if self.ring_degree is None: - self.ring_degree = 1 - if self.ulysses_degree is None: - self.ulysses_degree = 1 - if self.rotate_method != "allgather": - raise NotImplementedError( - f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." + + if self.ulysses_degree * self.ring_degree > world_size: + raise ValueError( + f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})." ) - if self._flattened_mesh is None: - self._flattened_mesh = self._mesh._flatten() - if self._ring_mesh is None: - self._ring_mesh = self._mesh["ring"] - if self._ulysses_mesh is None: - self._ulysses_mesh = self._mesh["ulysses"] - if self._ring_local_rank is None: - self._ring_local_rank = self._ring_mesh.get_local_rank() - if self._ulysses_local_rank is None: - self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() + + self._flattened_mesh = self._mesh._flatten() + self._ring_mesh = self._mesh["ring"] + self._ulysses_mesh = self._mesh["ulysses"] + self._ring_local_rank = self._ring_mesh.get_local_rank() + self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() @dataclass @@ -119,7 +137,7 @@ class ParallelConfig: _rank: int = None _world_size: int = None _device: torch.device = None - _cp_mesh: torch.distributed.device_mesh.DeviceMesh = None + _mesh: torch.distributed.device_mesh.DeviceMesh = None def setup( self, @@ -127,14 +145,14 @@ def setup( world_size: int, device: torch.device, *, - cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, + mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, ): self._rank = rank self._world_size = world_size self._device = device - self._cp_mesh = cp_mesh + self._mesh = mesh if self.context_parallel_config is not None: - self.context_parallel_config.setup(rank, world_size, device, cp_mesh) + self.context_parallel_config.setup(rank, world_size, device, mesh) @dataclass(frozen=True) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..dff6f43934fc 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -207,7 +207,7 @@ class _AttentionBackendRegistry: _backends = {} _constraints = {} _supported_arg_names = {} - _supports_context_parallel = {} + _supports_context_parallel = set() _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) _checks_enabled = DIFFUSERS_ATTN_CHECKS @@ -224,7 +224,9 @@ def decorator(func): cls._backends[backend] = func cls._constraints[backend] = constraints or [] cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) - cls._supports_context_parallel[backend] = supports_context_parallel + if supports_context_parallel: + cls._supports_context_parallel.add(backend.value) + return func return decorator @@ -238,15 +240,12 @@ def list_backends(cls): return list(cls._backends.keys()) @classmethod - def _is_context_parallel_enabled( - cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] + def _is_context_parallel_available( + cls, + backend: AttentionBackendName, ) -> bool: - supports_context_parallel = backend in cls._supports_context_parallel - is_degree_greater_than_1 = parallel_config is not None and ( - parallel_config.context_parallel_config.ring_degree > 1 - or parallel_config.context_parallel_config.ulysses_degree > 1 - ) - return supports_context_parallel and is_degree_greater_than_1 + supports_context_parallel = backend.value in cls._supports_context_parallel + return supports_context_parallel @contextlib.contextmanager @@ -293,14 +292,6 @@ def dispatch_attention_fn( backend_name = AttentionBackendName(backend) backend_fn = _AttentionBackendRegistry._backends.get(backend_name) - if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled( - backend_name, parallel_config - ): - raise ValueError( - f"Backend {backend_name} either does not support context parallelism or context parallelism " - f"was enabled with a world size of 1." - ) - kwargs = { "query": query, "key": key, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1af7ba9ac511..d8ba27a2fcaa 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1483,59 +1483,72 @@ def enable_parallelism( config: Union[ParallelConfig, ContextParallelConfig], cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None, ): - from ..hooks.context_parallel import apply_context_parallel - from .attention import AttentionModuleMixin - from .attention_processor import Attention, MochiAttention - logger.warning( "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." ) + if not torch.distributed.is_available() and not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed must be available and initialized before calling `enable_parallelism`." + ) + + from ..hooks.context_parallel import apply_context_parallel + from .attention import AttentionModuleMixin + from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry + from .attention_processor import Attention, MochiAttention + if isinstance(config, ContextParallelConfig): config = ParallelConfig(context_parallel_config=config) - if not torch.distributed.is_initialized(): - raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.") - rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() device_type = torch._C._get_accelerator().type device_module = torch.get_device_module(device_type) device = torch.device(device_type, rank % device_module.device_count()) - cp_mesh = None + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + + # Step 1: Validate attention backend supports context parallelism if enabled if config.context_parallel_config is not None: - cp_config = config.context_parallel_config - if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1: - raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1: - raise ValueError( - "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - ) - if cp_config.ring_degree * cp_config.ulysses_degree > world_size: - raise ValueError( - f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})." - ) - cp_mesh = torch.distributed.device_mesh.init_device_mesh( - device_type=device_type, - mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree), - mesh_dim_names=("ring", "ulysses"), - ) + for module in self.modules(): + if not isinstance(module, attention_classes): + continue - config.setup(rank, world_size, device, cp_mesh=cp_mesh) + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue - if cp_plan is None and self._cp_plan is None: - raise ValueError( - "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute." - ) - cp_plan = cp_plan if cp_plan is not None else self._cp_plan + attention_backend = processor._attention_backend + if attention_backend is None: + attention_backend, _ = _AttentionBackendRegistry.get_active_backend() + else: + attention_backend = AttentionBackendName(attention_backend) + + if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend): + compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel) + raise ValueError( + f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' " + f"is using backend '{attention_backend.value}' which does not support context parallelism. " + f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before " + f"calling `enable_parallelism()`." + ) + # All modules use the same attention processor and backend. We don't need to + # iterate over all modules after checking the first processor + break + + mesh = None if config.context_parallel_config is not None: - apply_context_parallel(self, config.context_parallel_config, cp_plan) + cp_config = config.context_parallel_config + mesh = torch.distributed.device_mesh.init_device_mesh( + device_type=device_type, + mesh_shape=cp_config.mesh_shape, + mesh_dim_names=cp_config.mesh_dim_names, + ) + config.setup(rank, world_size, device, mesh=mesh) self._parallel_config = config - attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): if not isinstance(module, attention_classes): continue @@ -1544,6 +1557,14 @@ def enable_parallelism( continue processor._parallel_config = config + if config.context_parallel_config is not None: + if cp_plan is None and self._cp_plan is None: + raise ValueError( + "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute." + ) + cp_plan = cp_plan if cp_plan is not None else self._cp_plan + apply_context_parallel(self, config.context_parallel_config, cp_plan) + @classmethod def _load_pretrained_model( cls,