diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4879a7bf045e..07408cc30e80 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -70,6 +70,8 @@ title: Reduce memory usage - local: optimization/speed-memory-optims title: Compiling and offloading quantized models + - local: api/parallel + title: Parallel inference - title: Community optimizations sections: - local: optimization/pruna diff --git a/docs/source/en/api/parallel.md b/docs/source/en/api/parallel.md new file mode 100644 index 000000000000..e38ffe571eac --- /dev/null +++ b/docs/source/en/api/parallel.md @@ -0,0 +1,24 @@ + + +# Parallelism + +Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times. + +## ParallelConfig + +[[autodoc]] ParallelConfig + +## ContextParallelConfig + +[[autodoc]] ContextParallelConfig + +[[autodoc]] hooks.apply_context_parallel diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 741fcd14f283..8867250deda8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -202,6 +202,7 @@ "CogView4Transformer2DModel", "ConsisIDTransformer3DModel", "ConsistencyDecoderVAE", + "ContextParallelConfig", "ControlNetModel", "ControlNetUnionModel", "ControlNetXSAdapter", @@ -229,6 +230,7 @@ "MultiAdapter", "MultiControlNetModel", "OmniGenTransformer2DModel", + "ParallelConfig", "PixArtTransformer2DModel", "PriorTransformer", "QwenImageControlNetModel", @@ -888,6 +890,7 @@ CogView4Transformer2DModel, ConsisIDTransformer3DModel, ConsistencyDecoderVAE, + ContextParallelConfig, ControlNetModel, ControlNetUnionModel, ControlNetXSAdapter, @@ -915,6 +918,7 @@ MultiAdapter, MultiControlNetModel, OmniGenTransformer2DModel, + ParallelConfig, PixArtTransformer2DModel, PriorTransformer, QwenImageControlNetModel, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 525a0747da8b..524a92ea9966 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -16,6 +16,7 @@ if is_torch_available(): + from .context_parallel import apply_context_parallel from .faster_cache import FasterCacheConfig, apply_faster_cache from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py new file mode 100644 index 000000000000..83406d4969b7 --- /dev/null +++ b/src/diffusers/hooks/context_parallel.py @@ -0,0 +1,297 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Dict, List, Type, Union + +import torch +import torch.distributed._functional_collectives as funcol + +from ..models._modeling_parallel import ( + ContextParallelConfig, + ContextParallelInput, + ContextParallelModelPlan, + ContextParallelOutput, +) +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}" +_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" + + +# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata +@dataclass +class ModuleForwardMetadata: + cached_parameter_indices: Dict[str, int] = None + _cls: Type = None + + def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): + kwargs = kwargs or {} + + if identifier in kwargs: + return kwargs[identifier], True, None + + if self.cached_parameter_indices is not None: + index = self.cached_parameter_indices.get(identifier, None) + if index is None: + raise ValueError(f"Parameter '{identifier}' not found in cached indices.") + return args[index], False, index + + if self._cls is None: + raise ValueError("Model class is not set for metadata.") + + parameters = list(inspect.signature(self._cls.forward).parameters.keys()) + parameters = parameters[1:] # skip `self` + self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)} + + if identifier not in self.cached_parameter_indices: + raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") + + index = self.cached_parameter_indices[identifier] + + if index >= len(args): + raise ValueError(f"Expected {index} arguments but got {len(args)}.") + + return args[index], False, index + + +def apply_context_parallel( + module: torch.nn.Module, + parallel_config: ContextParallelConfig, + plan: Dict[str, ContextParallelModelPlan], +) -> None: + """Apply context parallel on a model.""" + logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}") + + for module_id, cp_model_plan in plan.items(): + submodule = _get_submodule_by_name(module, module_id) + if not isinstance(submodule, list): + submodule = [submodule] + + logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules") + + for m in submodule: + if isinstance(cp_model_plan, dict): + hook = ContextParallelSplitHook(cp_model_plan, parallel_config) + hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) + elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): + if isinstance(cp_model_plan, ContextParallelOutput): + cp_model_plan = [cp_model_plan] + if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan): + raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}") + hook = ContextParallelGatherHook(cp_model_plan, parallel_config) + hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) + else: + raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") + registry = HookRegistry.check_if_exists_or_initialize(m) + registry.register_hook(hook, hook_name) + + +def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None: + for module_id, cp_model_plan in plan.items(): + submodule = _get_submodule_by_name(module, module_id) + if not isinstance(submodule, list): + submodule = [submodule] + + for m in submodule: + registry = HookRegistry.check_if_exists_or_initialize(m) + if isinstance(cp_model_plan, dict): + hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) + elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): + hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) + else: + raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") + registry.remove_hook(hook_name) + + +class ContextParallelSplitHook(ModelHook): + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: + super().__init__() + self.metadata = metadata + self.parallel_config = parallel_config + self.module_forward_metadata = None + + def initialize_hook(self, module): + cls = unwrap_module(module).__class__ + self.module_forward_metadata = ModuleForwardMetadata(_cls=cls) + return module + + def pre_forward(self, module, *args, **kwargs): + args_list = list(args) + + for name, cpm in self.metadata.items(): + if isinstance(cpm, ContextParallelInput) and cpm.split_output: + continue + + # Maybe the parameter was passed as a keyword argument + input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs( + name, args_list, kwargs + ) + + if input_val is None: + continue + + # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard + # the output instead of input for a particular layer by setting split_output=True + if isinstance(input_val, torch.Tensor): + input_val = self._prepare_cp_input(input_val, cpm) + elif isinstance(input_val, (list, tuple)): + if len(input_val) != len(cpm): + raise ValueError( + f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}." + ) + sharded_input_val = [] + for i, x in enumerate(input_val): + if torch.is_tensor(x) and not cpm[i].split_output: + x = self._prepare_cp_input(x, cpm[i]) + sharded_input_val.append(x) + input_val = sharded_input_val + else: + raise ValueError(f"Unsupported input type: {type(input_val)}") + + if is_kwarg: + kwargs[name] = input_val + elif index is not None and index < len(args_list): + args_list[index] = input_val + else: + raise ValueError( + f"An unexpected error occurred while processing the input '{name}'. Please open an " + f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible " + f"example along with the full stack trace." + ) + + return tuple(args_list), kwargs + + def post_forward(self, module, output): + is_tensor = isinstance(output, torch.Tensor) + is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output) + + if not is_tensor and not is_tensor_list: + raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") + + output = [output] if is_tensor else list(output) + for index, cpm in self.metadata.items(): + if not isinstance(cpm, ContextParallelInput) or not cpm.split_output: + continue + if index >= len(output): + raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.") + current_output = output[index] + current_output = self._prepare_cp_input(current_output, cpm) + output[index] = current_output + + return output[0] if is_tensor else tuple(output) + + def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: + if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: + raise ValueError( + f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." + ) + return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) + + +class ContextParallelGatherHook(ModelHook): + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: + super().__init__() + self.metadata = metadata + self.parallel_config = parallel_config + + def post_forward(self, module, output): + is_tensor = isinstance(output, torch.Tensor) + + if is_tensor: + output = [output] + elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)): + raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") + + output = list(output) + + if len(output) != len(self.metadata): + raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.") + + for i, cpm in enumerate(self.metadata): + if cpm is None: + continue + output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) + + return output[0] if is_tensor else tuple(output) + + +class AllGatherFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, dim, group): + ctx.dim = dim + ctx.group = group + ctx.world_size = torch.distributed.get_world_size(group) + ctx.rank = torch.distributed.get_rank(group) + return funcol.all_gather_tensor(tensor, dim, group=group) + + @staticmethod + def backward(ctx, grad_output): + grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim) + return grad_chunks[ctx.rank], None, None + + +class EquipartitionSharder: + @classmethod + def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + # NOTE: the following assertion does not have to be true in general. We simply enforce it for now + # because the alternate case has not yet been tested/required for any model. + assert tensor.size()[dim] % mesh.size() == 0, ( + "Tensor size along dimension to be sharded must be divisible by mesh size" + ) + + # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) + # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] + + return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())] + + @classmethod + def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + tensor = tensor.contiguous() + tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group()) + return tensor + + +def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: + if name.count("*") > 1: + raise ValueError("Wildcard '*' can only be used once in the name") + return _find_submodule_by_name(model, name) + + +def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: + if name == "": + return model + first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "") + if first_atom == "*": + if not isinstance(model, torch.nn.ModuleList): + raise ValueError("Wildcard '*' can only be used with ModuleList") + submodules = [] + for submodule in model: + subsubmodules = _find_submodule_by_name(submodule, remaining_name) + if not isinstance(subsubmodules, list): + subsubmodules = [subsubmodules] + submodules.extend(subsubmodules) + return submodules + else: + if hasattr(model, first_atom): + submodule = getattr(model, first_atom) + return _find_submodule_by_name(submodule, remaining_name) + else: + raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'") diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 49ac2a1c56fd..457f70448af3 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -25,6 +25,7 @@ _import_structure = {} if is_torch_available(): + _import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"] _import_structure["auto_model"] = ["AutoModel"] @@ -119,6 +120,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): + from ._modeling_parallel import ContextParallelConfig, ParallelConfig from .adapter import MultiAdapter, T2IAdapter from .attention_dispatch import AttentionBackendName, attention_backend from .auto_model import AutoModel diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py new file mode 100644 index 000000000000..2a1d2cc6ceea --- /dev/null +++ b/src/diffusers/models/_modeling_parallel.py @@ -0,0 +1,241 @@ +# 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨 +# Experimental changes are subject to change and APIs may break without warning. + +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union + +import torch + +from ..utils import get_logger + + +if TYPE_CHECKING: + pass + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(aryan): add support for the following: +# - Unified Attention +# - More dispatcher attention backends +# - CFG/Data Parallel +# - Tensor Parallel + + +@dataclass +class ContextParallelConfig: + """ + Configuration for context parallelism. + + Args: + ring_degree (`int`, *optional*, defaults to `1`): + Number of devices to use for ring attention within a context parallel region. Must be a divisor of the + total number of devices in the context parallel mesh. + ulysses_degree (`int`, *optional*, defaults to `1`): + Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the + total number of devices in the context parallel mesh. + convert_to_fp32 (`bool`, *optional*, defaults to `True`): + Whether to convert output and LSE to float32 for ring attention numerical stability. + rotate_method (`str`, *optional*, defaults to `"allgather"`): + Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"` + is supported. + + """ + + ring_degree: Optional[int] = None + ulysses_degree: Optional[int] = None + convert_to_fp32: bool = True + # TODO: support alltoall + rotate_method: Literal["allgather", "alltoall"] = "allgather" + + _rank: int = None + _world_size: int = None + _device: torch.device = None + _mesh: torch.distributed.device_mesh.DeviceMesh = None + _flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None + _ring_mesh: torch.distributed.device_mesh.DeviceMesh = None + _ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None + _ring_local_rank: int = None + _ulysses_local_rank: int = None + + def __post_init__(self): + if self.ring_degree is None: + self.ring_degree = 1 + if self.ulysses_degree is None: + self.ulysses_degree = 1 + + def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh): + self._rank = rank + self._world_size = world_size + self._device = device + self._mesh = mesh + if self.ring_degree is None: + self.ring_degree = 1 + if self.ulysses_degree is None: + self.ulysses_degree = 1 + if self.rotate_method != "allgather": + raise NotImplementedError( + f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." + ) + if self._flattened_mesh is None: + self._flattened_mesh = self._mesh._flatten() + if self._ring_mesh is None: + self._ring_mesh = self._mesh["ring"] + if self._ulysses_mesh is None: + self._ulysses_mesh = self._mesh["ulysses"] + if self._ring_local_rank is None: + self._ring_local_rank = self._ring_mesh.get_local_rank() + if self._ulysses_local_rank is None: + self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() + + +@dataclass +class ParallelConfig: + """ + Configuration for applying different parallelisms. + + Args: + context_parallel_config (`ContextParallelConfig`, *optional*): + Configuration for context parallelism. + """ + + context_parallel_config: Optional[ContextParallelConfig] = None + + _rank: int = None + _world_size: int = None + _device: torch.device = None + _cp_mesh: torch.distributed.device_mesh.DeviceMesh = None + + def setup( + self, + rank: int, + world_size: int, + device: torch.device, + *, + cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, + ): + self._rank = rank + self._world_size = world_size + self._device = device + self._cp_mesh = cp_mesh + if self.context_parallel_config is not None: + self.context_parallel_config.setup(rank, world_size, device, cp_mesh) + + +@dataclass(frozen=True) +class ContextParallelInput: + """ + Configuration for splitting an input tensor across context parallel region. + + Args: + split_dim (`int`): + The dimension along which to split the tensor. + expected_dims (`int`, *optional*): + The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the + tensor has the expected number of dimensions before splitting. + split_output (`bool`, *optional*, defaults to `False`): + Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor. + This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex: + RoPE). + """ + + split_dim: int + expected_dims: Optional[int] = None + split_output: bool = False + + def __repr__(self): + return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})" + + +@dataclass(frozen=True) +class ContextParallelOutput: + """ + Configuration for gathering an output tensor across context parallel region. + + Args: + gather_dim (`int`): + The dimension along which to gather the tensor. + expected_dims (`int`, *optional*): + The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the + tensor has the expected number of dimensions before gathering. + """ + + gather_dim: int + expected_dims: Optional[int] = None + + def __repr__(self): + return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})" + + +# A dictionary where keys denote the input to be split across context parallel region, and the +# value denotes the sharding configuration. +# If the key is a string, it denotes the name of the parameter in the forward function. +# If the key is an integer, split_output must be set to True, and it denotes the index of the output +# to be split across context parallel region. +ContextParallelInputType = Dict[ + Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]] +] + +# A dictionary where keys denote the output to be gathered across context parallel region, and the +# value denotes the gathering configuration. +ContextParallelOutputType = Union[ + ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...] +] + +# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of +# the module should be split/gathered across context parallel region. +ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]] + + +# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel): +# +# Each model should define a _cp_plan attribute that contains information on how to shard/gather +# tensors at different stages of the forward: +# +# ```python +# _cp_plan = { +# "": { +# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), +# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), +# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), +# }, +# "pos_embed": { +# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), +# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), +# }, +# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), +# } +# ``` +# +# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be +# split/gathered according to this at the respective module level. Here, the following happens: +# - "": +# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before +# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs) +# - "pos_embed": +# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs), +# we can individually specify how they should be split +# - "proj_out": +# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear +# layer forward has run). +# +# ContextParallelInput: +# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to +# +# ContextParallelOutput: +# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f71be7c8ecc0..0a2ad681237b 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -17,9 +17,10 @@ import inspect import math from enum import Enum -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch +import torch.distributed._functional_collectives as funcol from ..utils import ( get_logger, @@ -39,6 +40,9 @@ from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS +if TYPE_CHECKING: + from ._modeling_parallel import ParallelConfig + _REQUIRED_FLASH_VERSION = "2.6.3" _REQUIRED_SAGE_VERSION = "2.1.1" _REQUIRED_FLEX_VERSION = "2.5.0" @@ -56,9 +60,12 @@ if _CAN_USE_FLASH_ATTN: from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward else: flash_attn_func = None flash_attn_varlen_func = None + _wrapped_flash_attn_backward = None + _wrapped_flash_attn_forward = None if _CAN_USE_FLASH_ATTN_3: @@ -197,17 +204,24 @@ class _AttentionBackendRegistry: _backends = {} _constraints = {} _supported_arg_names = {} + _supports_context_parallel = {} _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) _checks_enabled = DIFFUSERS_ATTN_CHECKS @classmethod - def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None): + def register( + cls, + backend: AttentionBackendName, + constraints: Optional[List[Callable]] = None, + supports_context_parallel: bool = False, + ): logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") def decorator(func): cls._backends[backend] = func cls._constraints[backend] = constraints or [] cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) + cls._supports_context_parallel[backend] = supports_context_parallel return func return decorator @@ -220,6 +234,17 @@ def get_active_backend(cls): def list_backends(cls): return list(cls._backends.keys()) + @classmethod + def _is_context_parallel_enabled( + cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] + ) -> bool: + supports_context_parallel = backend in cls._supports_context_parallel + is_degree_greater_than_1 = parallel_config is not None and ( + parallel_config.context_parallel_config.ring_degree > 1 + or parallel_config.context_parallel_config.ulysses_degree > 1 + ) + return supports_context_parallel and is_degree_greater_than_1 + @contextlib.contextmanager def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): @@ -253,6 +278,7 @@ def dispatch_attention_fn( attention_kwargs: Optional[Dict[str, Any]] = None, *, backend: Optional[AttentionBackendName] = None, + parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: attention_kwargs = attention_kwargs or {} @@ -264,6 +290,14 @@ def dispatch_attention_fn( backend_name = AttentionBackendName(backend) backend_fn = _AttentionBackendRegistry._backends.get(backend_name) + if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled( + backend_name, parallel_config + ): + raise ValueError( + f"Backend {backend_name} either does not support context parallelism or context parallelism " + f"was enabled with a world size of 1." + ) + kwargs = { "query": query, "key": key, @@ -273,6 +307,7 @@ def dispatch_attention_fn( "is_causal": is_causal, "scale": scale, **attention_kwargs, + "_parallel_config": parallel_config, } if is_torch_version(">=", "2.5.0"): kwargs["enable_gqa"] = enable_gqa @@ -521,22 +556,621 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # Registrations are required for fullgraph tracing compatibility # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 - - -@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _wrapped_flash_attn_3_original( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor +@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") +def _wrapped_flash_attn_3( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + qv: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + sm_margin: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: - out, lse = flash_attn_3_func(query, key, value) + # Hardcoded for now because pytorch does not support tuple/int type hints + window_size = (-1, -1) + out, lse, *_ = flash_attn_3_func( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=causal, + qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + sm_margin=sm_margin, + ) lse = lse.permute(0, 2, 1) return out, lse -@_register_fake("flash_attn_3::_flash_attn_forward") -def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, seq_len, num_heads, head_dim = query.shape +@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward") +def _( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + qv: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + window_size = (-1, -1) # noqa: F841 + # A lot of the parameters here are not yet used in any way within diffusers. + # We can safely ignore for now and keep the fake op shape propagation simple. + batch_size, seq_len, num_heads, head_dim = q.shape lse_shape = (batch_size, seq_len, num_heads) - return torch.empty_like(query), query.new_empty(lse_shape) + return torch.empty_like(q), q.new_empty(lse_shape) + + +# ===== Helper functions to use attention backends with templated CP autograd functions ===== + + +# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 +# forward declaration: +# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +def _cudnn_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.") + + tensors_to_save = () + + # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results + # if the input tensors are not contiguous. + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + tensors_to_save += (query, key, value) + + out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( + torch.ops.aten._scaled_dot_product_cudnn_attention( + query=query, + key=key, + value=value, + attn_bias=attn_mask, + compute_log_sumexp=return_lse, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=False, + scale=scale, + ) + ) + + tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) + if _save_ctx: + ctx.save_for_backward(*tensors_to_save) + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.attn_mask = attn_mask + ctx.max_q = max_q + ctx.max_k = max_k + + out = out.transpose(1, 2).contiguous() + if lse is not None: + lse = lse.transpose(1, 2).contiguous() + return (out, lse) if return_lse else out + + +# backward declaration: +# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) +def _cudnn_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors + + grad_out = grad_out.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + # Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341 + grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( + grad_out, + query, + key, + value, + out, + logsumexp=lse, + philox_seed=philox_seed, + philox_offset=philox_offset, + attn_bias=ctx.attn_mask, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=ctx.max_q, + max_k=ctx.max_k, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + ) + grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value)) + + return grad_query, grad_key, grad_value + + +# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 +def _flash_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if attn_mask is not None: + raise ValueError("`attn_mask` is not yet supported for flash-attn 2.") + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.") + + # Hardcoded for now + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + grad_enabled = any(x.requires_grad for x in (query, key, value)) + + if scale is None: + scale = query.shape[-1] ** (-0.5) + + # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround. + if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): + dropout_p = dropout_p if dropout_p > 0 else 1e-30 + + with torch.set_grad_enabled(grad_enabled): + out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward( + query, + key, + value, + dropout_p, + scale, + is_causal, + window_size[0], + window_size[1], + softcap, + alibi_slopes, + return_lse, + ) + lse = lse.permute(0, 2, 1) + + if _save_ctx: + ctx.save_for_backward(query, key, value, out, lse, rng_state) + ctx.dropout_p = dropout_p + ctx.scale = scale + ctx.is_causal = is_causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + + return (out, lse) if return_lse else out + + +def _flash_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + query, key, value, out, lse, rng_state = ctx.saved_tensors + grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) + + lse_d = _wrapped_flash_attn_backward( # noqa: F841 + grad_out, + query, + key, + value, + out, + lse, + grad_query, + grad_key, + grad_value, + ctx.dropout_p, + ctx.scale, + ctx.is_causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state, + ) + + # Head dimension may have been padded + grad_query = grad_query[..., : grad_out.shape[-1]] + grad_key = grad_key[..., : grad_out.shape[-1]] + grad_value = grad_value[..., : grad_out.shape[-1]] + + return grad_query, grad_key, grad_value + + +def _sage_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if attn_mask is not None: + raise ValueError("`attn_mask` is not yet supported for Sage attention.") + if dropout_p > 0.0: + raise ValueError("`dropout_p` is not yet supported for Sage attention.") + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for Sage attention.") + + out = sageattn( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + lse = None + if return_lse: + out, lse, *_ = out + lse = lse.permute(0, 2, 1) + + return (out, lse) if return_lse else out + + +def _sage_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, +): + raise NotImplementedError("Backward pass is not implemented for Sage attention.") + + +# ===== Context parallel ===== + + +# Reference: +# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827 +# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246 +# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method): +def _wait_tensor(tensor): + if isinstance(tensor, funcol.AsyncCollectiveTensor): + tensor = tensor.wait() + return tensor + + +def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: + shape = x.shape + # HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization + # to benchmark triton codegen fails somewhere: + # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3') + # ValueError: Tensors must be contiguous + x = x.flatten() + x = funcol.all_to_all_single(x, None, None, group) + x = x.reshape(shape) + x = _wait_tensor(x) + return x + + +class TemplatedRingAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + ): + ring_mesh = _parallel_config.context_parallel_config._ring_mesh + rank = _parallel_config.context_parallel_config._ring_local_rank + world_size = _parallel_config.context_parallel_config.ring_degree + next_rank = (rank + 1) % world_size + prev_out = prev_lse = None + + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx.q_shape = query.shape + ctx.kv_shape = key.shape + ctx._parallel_config = _parallel_config + + kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() + kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) + kv_buffer = kv_buffer.chunk(world_size) + + for i in range(world_size): + if i > 0: + kv = kv_buffer[next_rank] + key_numel = key.numel() + key = kv[:key_numel].reshape_as(key) + value = kv[key_numel:].reshape_as(value) + next_rank = (next_rank + 1) % world_size + + out, lse = forward_op( + ctx, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + True, + _save_ctx=i == 0, + _parallel_config=_parallel_config, + ) + + if _parallel_config.context_parallel_config.convert_to_fp32: + out = out.to(torch.float32) + lse = lse.to(torch.float32) + + lse = lse.unsqueeze(-1) + if prev_out is not None: + out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) + lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) + prev_out = out + prev_lse = lse + + out = out.to(query.dtype) + lse = lse.squeeze(-1) + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + ): + ring_mesh = ctx._parallel_config.context_parallel_config._ring_mesh + rank = ctx._parallel_config.context_parallel_config._ring_local_rank + world_size = ctx._parallel_config.context_parallel_config.ring_degree + next_rank = (rank + 1) % world_size + next_ranks = list(range(1, world_size)) + [0] + + accum_dtype = torch.float32 if ctx._parallel_config.context_parallel_config.convert_to_fp32 else grad_out.dtype + grad_query = torch.zeros(ctx.q_shape, dtype=accum_dtype, device=grad_out.device) + grad_key = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device) + grad_value = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device) + next_grad_kv = None + + query, key, value, *_ = ctx.saved_tensors + kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() + kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) + kv_buffer = kv_buffer.chunk(world_size) + + for i in range(world_size): + if i > 0: + kv = kv_buffer[next_rank] + key_numel = key.numel() + key = kv[:key_numel].reshape_as(key) + value = kv[key_numel:].reshape_as(value) + next_rank = (next_rank + 1) % world_size + + grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out) + + if i > 0: + grad_kv_buffer = _wait_tensor(next_grad_kv) + grad_key_numel = grad_key.numel() + grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key) + grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value) + + grad_query += grad_query_op + grad_key += grad_key_op + grad_value += grad_value_op + + if i < world_size - 1: + grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous() + next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group()) + + grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + + +class TemplatedUlyssesAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + world_size = _parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx._parallel_config = _parallel_config + + B, S_Q_LOCAL, H, D = query.shape + _, S_KV_LOCAL, _, _ = key.shape + H_LOCAL = H // world_size + query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) + query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) + + out = forward_op( + ctx, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx=True, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse, *_ = out + + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + if return_lse: + lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() + lse = _all_to_all_single(lse, group) + lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() + else: + lse = None + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + ): + ulysses_mesh = ctx._parallel_config.context_parallel_config._ulysses_mesh + world_size = ctx._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + B, S_LOCAL, H, D = grad_out.shape + H_LOCAL = H // world_size + + grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + grad_out = _all_to_all_single(grad_out, group) + grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out) + + grad_query, grad_key, grad_value = ( + x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + for x in (grad_query_op, grad_key_op, grad_value_op) + ) + grad_query, grad_key, grad_value = (_all_to_all_single(x, group) for x in (grad_query, grad_key, grad_value)) + grad_query, grad_key, grad_value = ( + x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) + ) + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + + +def _templated_context_parallel_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + *, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, +): + if attn_mask is not None: + raise ValueError("Attention mask is not yet supported for templated attention.") + if is_causal: + raise ValueError("Causal attention is not yet supported for templated attention.") + if enable_gqa: + raise ValueError("GQA is not yet supported for templated attention.") + + # TODO: add support for unified attention with ring/ulysses degree both being > 1 + if _parallel_config.context_parallel_config.ring_degree > 1: + return TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + elif _parallel_config.context_parallel_config.ulysses_degree > 1: + return TemplatedUlyssesAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + else: + raise ValueError("Reaching this branch of code is unexpected. Please report a bug.") # ===== Attention backends ===== @@ -545,34 +1179,50 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc @_AttentionBackendRegistry.register( AttentionBackendName.FLASH, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dropout_p: float = 0.0, - scale: Optional[float] = None, is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False, - return_attn_probs: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out = flash_attn_func( - q=query, - k=key, - v=value, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=return_attn_probs, - ) - return out + lse = None + if _parallel_config is None: + out = flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + is_causal, + scale, + False, + return_lse, + forward_op=_flash_attention_forward_op, + backward_op=_flash_attention_backward_op, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse = out + + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( @@ -583,19 +1233,12 @@ def _flash_varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: Optional[float] = None, is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - attn_mask: Optional[torch.Tensor] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape @@ -603,16 +1246,11 @@ def _flash_varlen_attention( if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) - else: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) - cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + ) key_valid, value_valid = [], [] for b in range(batch_size): @@ -635,11 +1273,7 @@ def _flash_varlen_attention( dropout_p=dropout_p, softmax_scale=scale, causal=is_causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=return_attn_probs, + return_attn_probs=return_lse, ) out = out.unflatten(0, (batch_size, -1)) @@ -656,30 +1290,17 @@ def _flash_attention_3( value: torch.Tensor, scale: Optional[float] = None, is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - deterministic: bool = False, - return_attn_probs: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out, lse, *_ = flash_attn_3_func( + out, lse = _wrapped_flash_attn_3( q=query, k=key, v=value, softmax_scale=scale, causal=is_causal, - qv=None, - q_descale=None, - k_descale=None, - v_descale=None, - window_size=window_size, - attention_chunk=0, - softcap=softcap, - num_splits=1, - pack_gqa=None, - deterministic=deterministic, - sm_margin=0, ) - return (out, lse) if return_attn_probs else out + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( @@ -696,6 +1317,7 @@ def _flash_attention_3_hub( softcap: float = 0.0, deterministic: bool = False, return_attn_probs: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: out = flash_attn_3_func_hub( q=query, @@ -728,17 +1350,11 @@ def _flash_varlen_attention_3( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, + attn_mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - deterministic: bool = False, - return_attn_probs: bool = False, - attn_mask: Optional[torch.Tensor] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape @@ -746,16 +1362,11 @@ def _flash_varlen_attention_3( if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) - else: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) - cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + ) key_valid, value_valid = [], [] for b in range(batch_size): @@ -775,24 +1386,12 @@ def _flash_varlen_attention_3( cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, - seqused_q=None, - seqused_k=None, softmax_scale=scale, causal=is_causal, - qv=None, - q_descale=None, - k_descale=None, - v_descale=None, - window_size=window_size, - softcap=softcap, - num_splits=1, - pack_gqa=None, - deterministic=deterministic, - sm_margin=0, ) out = out.unflatten(0, (batch_size, -1)) - return (out, lse) if return_attn_probs else out + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( @@ -808,7 +1407,7 @@ def _native_flex_attention( scale: Optional[float] = None, enable_gqa: bool = False, return_lse: bool = False, - kernel_options: Optional[Dict[str, Any]] = None, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: # TODO: should we LRU cache the block mask creation? score_mod = None @@ -853,7 +1452,6 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): scale=scale, enable_gqa=enable_gqa, return_lse=return_lse, - kernel_options=kernel_options, ) out = out.permute(0, 2, 1, 3) return out @@ -872,7 +1470,11 @@ def _native_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if return_lse: + raise ValueError("Native attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( query=query, @@ -891,6 +1493,7 @@ def _native_attention( @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_CUDNN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _native_cudnn_attention( query: torch.Tensor, @@ -901,21 +1504,43 @@ def _native_cudnn_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, + lse = None + if _parallel_config is None and not return_lse: + query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + else: + out = _templated_context_parallel_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op=_cudnn_attention_forward_op, + backward_op=_cudnn_attention_backward_op, + _parallel_config=_parallel_config, ) - out = out.permute(0, 2, 1, 3) - return out + if return_lse: + out, lse = out + + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( @@ -931,7 +1556,11 @@ def _native_efficient_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if return_lse: + raise ValueError("Native efficient attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): out = torch.nn.functional.scaled_dot_product_attention( @@ -960,7 +1589,11 @@ def _native_flash_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if return_lse: + raise ValueError("Native flash attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): out = torch.nn.functional.scaled_dot_product_attention( @@ -990,7 +1623,11 @@ def _native_math_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if return_lse: + raise ValueError("Native math attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): out = torch.nn.functional.scaled_dot_product_attention( @@ -1017,7 +1654,11 @@ def _native_npu_attention( value: torch.Tensor, dropout_p: float = 0.0, scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if return_lse: + raise ValueError("NPU attention backend does not support setting `return_lse=True`.") query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) out = npu_fusion_attention( query, @@ -1047,7 +1688,11 @@ def _native_xla_attention( key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if return_lse: + raise ValueError("XLA attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) query = query / math.sqrt(query.shape[-1]) out = xla_flash_attention( @@ -1063,6 +1708,7 @@ def _native_xla_attention( @_AttentionBackendRegistry.register( AttentionBackendName.SAGE, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _sage_attention( query: torch.Tensor, @@ -1071,16 +1717,40 @@ def _sage_attention( is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - return sageattn( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) + lse = None + if _parallel_config is None: + out = sageattn( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + if return_lse: + out, lse, *_ = out + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + 0.0, + is_causal, + scale, + False, + return_lse, + forward_op=_sage_attention_forward_op, + backward_op=_sage_attention_backward_op, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse = out + + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( @@ -1091,31 +1761,26 @@ def _sage_varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, - smooth_k: bool = True, - attn_mask: Optional[torch.Tensor] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if return_lse: + raise ValueError("Sage varlen backend does not support setting `return_lse=True`.") + batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) - else: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) - cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + ) key_valid, value_valid = [], [] for b in range(batch_size): @@ -1137,7 +1802,6 @@ def _sage_varlen_attention( max_seqlen_k=max_seqlen_k, is_causal=is_causal, sm_scale=scale, - smooth_k=smooth_k, ) out = out.unflatten(0, (batch_size, -1)) @@ -1154,11 +1818,8 @@ def _sage_qk_int8_pv_fp8_cuda_attention( value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, - qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", - pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", - smooth_k: bool = True, - smooth_v: bool = False, return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp8_cuda( q=query, @@ -1166,11 +1827,7 @@ def _sage_qk_int8_pv_fp8_cuda_attention( v=value, tensor_layout="NHD", is_causal=is_causal, - qk_quant_gran=qk_quant_gran, sm_scale=scale, - pv_accum_dtype=pv_accum_dtype, - smooth_k=smooth_k, - smooth_v=smooth_v, return_lse=return_lse, ) @@ -1185,10 +1842,8 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, - qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", - pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", - smooth_k: bool = True, return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp8_cuda_sm90( q=query, @@ -1196,10 +1851,7 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( v=value, tensor_layout="NHD", is_causal=is_causal, - qk_quant_gran=qk_quant_gran, sm_scale=scale, - pv_accum_dtype=pv_accum_dtype, - smooth_k=smooth_k, return_lse=return_lse, ) @@ -1214,11 +1866,8 @@ def _sage_qk_int8_pv_fp16_cuda_attention( value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, - qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", - pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32", - smooth_k: bool = True, - smooth_v: bool = False, return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp16_cuda( q=query, @@ -1226,11 +1875,7 @@ def _sage_qk_int8_pv_fp16_cuda_attention( v=value, tensor_layout="NHD", is_causal=is_causal, - qk_quant_gran=qk_quant_gran, sm_scale=scale, - pv_accum_dtype=pv_accum_dtype, - smooth_k=smooth_k, - smooth_v=smooth_v, return_lse=return_lse, ) @@ -1245,19 +1890,16 @@ def _sage_qk_int8_pv_fp16_triton_attention( value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, - quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", - smooth_k: bool = True, return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp16_triton( q=query, k=key, v=value, tensor_layout="NHD", - quantization_backend=quantization_backend, is_causal=is_causal, sm_scale=scale, - smooth_k=smooth_k, return_lse=return_lse, ) @@ -1275,7 +1917,12 @@ def _xformers_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if return_lse: + raise ValueError("xformers attention backend does not support setting `return_lse=True`.") + batch_size, seq_len_q, num_heads_q, _ = query.shape _, seq_len_kv, num_heads_kv, _ = key.shape diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 2388989be215..b3d74954bd26 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -65,6 +65,7 @@ populate_model_card, ) from ..utils.torch_utils import empty_device_cache +from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig from .model_loading_utils import ( _caching_allocator_warmup, _determine_device_map, @@ -248,6 +249,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _skip_layerwise_casting_patterns = None _supports_group_offloading = True _repeated_blocks = [] + _parallel_config = None + _cp_plan = None def __init__(self): super().__init__() @@ -620,8 +623,8 @@ def set_attention_backend(self, backend: str) -> None: def reset_attention_backend(self) -> None: """ - Resets the attention backend for the model. Following calls to `forward` will use the environment default or - the torch native scaled dot product attention. + Resets the attention backend for the model. Following calls to `forward` will use the environment default, if + set, or the torch native scaled dot product attention. """ from .attention import AttentionModuleMixin from .attention_processor import Attention, MochiAttention @@ -960,6 +963,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P quantization_config = kwargs.pop("quantization_config", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) + parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None) is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING if is_parallel_loading_enabled and not low_cpu_mem_usage: @@ -1340,6 +1344,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.eval() + if parallel_config is not None: + model.enable_parallelism(config=parallel_config) + if output_loading_info: return model, loading_info @@ -1478,6 +1485,73 @@ def compile_repeated_blocks(self, *args, **kwargs): f"Regional compilation failed because {repeated_blocks} classes are not found in the model. " ) + def enable_parallelism( + self, + *, + config: Union[ParallelConfig, ContextParallelConfig], + cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None, + ): + from ..hooks.context_parallel import apply_context_parallel + from .attention import AttentionModuleMixin + from .attention_processor import Attention, MochiAttention + + logger.warning( + "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." + ) + + if isinstance(config, ContextParallelConfig): + config = ParallelConfig(context_parallel_config=config) + + if not torch.distributed.is_initialized(): + raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.") + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + device_type = torch._C._get_accelerator().type + device_module = torch.get_device_module(device_type) + device = torch.device(device_type, rank % device_module.device_count()) + + cp_mesh = None + if config.context_parallel_config is not None: + cp_config = config.context_parallel_config + if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1: + raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") + if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1: + raise ValueError( + "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." + ) + if cp_config.ring_degree * cp_config.ulysses_degree > world_size: + raise ValueError( + f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})." + ) + cp_mesh = torch.distributed.device_mesh.init_device_mesh( + device_type=device_type, + mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree), + mesh_dim_names=("ring", "ulysses"), + ) + + config.setup(rank, world_size, device, cp_mesh=cp_mesh) + + if cp_plan is None and self._cp_plan is None: + raise ValueError( + "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute." + ) + cp_plan = cp_plan if cp_plan is not None else self._cp_plan + + if config.context_parallel_config is not None: + apply_context_parallel(self, config.context_parallel_config, cp_plan) + + self._parallel_config = config + + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_parallel_config"): + continue + processor._parallel_config = config + @classmethod def _load_pretrained_model( cls, diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index 04a9c5645c81..d54679306e64 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -120,6 +120,7 @@ def get_1d_rotary_pos_embed( class BriaAttnProcessor: _attention_backend = None + _parallel_config = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -161,7 +162,12 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = dispatch_attention_fn( - query, key, value, attn_mask=attention_mask, backend=self._attention_backend + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 7ab371a1a18e..1a4464432425 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -24,6 +24,7 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -73,6 +74,7 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st class FluxAttnProcessor: _attention_backend = None + _parallel_config = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -114,7 +116,12 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = dispatch_attention_fn( - query, key, value, attn_mask=attention_mask, backend=self._attention_backend + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) @@ -136,6 +143,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module): """Flux Attention processor for IP-Adapter.""" _attention_backend = None + _parallel_config = None def __init__( self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None @@ -220,6 +228,7 @@ def __call__( dropout_p=0.0, is_causal=False, backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) @@ -252,6 +261,7 @@ def __call__( dropout_p=0.0, is_causal=False, backend=self._attention_backend, + parallel_config=self._parallel_config, ) current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) @@ -556,6 +566,15 @@ class FluxTransformer2DModel( _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 79149fb76067..9f3840690d81 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -24,6 +24,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -51,6 +52,7 @@ class LTXVideoAttnProcessor: """ _attention_backend = None + _parallel_config = None def __init__(self): if is_torch_version("<", "2.0"): @@ -100,6 +102,7 @@ def __call__( dropout_p=0.0, is_causal=False, backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) @@ -409,6 +412,18 @@ class LTXVideoTransformer3DModel( _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] _repeated_blocks = ["LTXVideoTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 846add8906ac..05379270c13b 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -25,6 +25,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention @@ -261,6 +262,7 @@ class QwenDoubleStreamAttnProcessor2_0: """ _attention_backend = None + _parallel_config = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -334,6 +336,7 @@ def __call__( dropout_p=0.0, is_causal=False, backend=self._attention_backend, + parallel_config=self._parallel_config, ) # Reshape back @@ -502,6 +505,18 @@ class QwenImageTransformer2DModel( _no_split_modules = ["QwenImageTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _repeated_blocks = ["QwenImageTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "pos_embed": { + 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 358759164b9e..6b600aa22487 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -73,6 +73,7 @@ def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states class SkyReelsV2AttnProcessor: _attention_backend = None + _parallel_config = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -139,6 +140,7 @@ def apply_rotary_emb( dropout_p=0.0, is_causal=False, backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -151,6 +153,7 @@ def apply_rotary_emb( dropout_p=0.0, is_causal=False, backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 968a0369c243..25c055fb563c 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -23,6 +23,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -66,6 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t class WanAttnProcessor: _attention_backend = None + _parallel_config = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -132,6 +134,7 @@ def apply_rotary_emb( dropout_p=0.0, is_causal=False, backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -144,6 +147,7 @@ def apply_rotary_emb( dropout_p=0.0, is_causal=False, backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -539,6 +543,19 @@ class WanTransformer3DModel( _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["WanTransformerBlock"] + _cp_plan = { + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True), + }, + "blocks.0": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*": { + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } @register_to_config def __init__( diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bbb971249604..6e7d22797902 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -648,6 +648,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ContextParallelConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ControlNetModel(metaclass=DummyObject): _backends = ["torch"] @@ -1053,6 +1068,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ParallelConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PixArtTransformer2DModel(metaclass=DummyObject): _backends = ["torch"]