diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index f6e5bdd52d1f..9042917379f0 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -151,8 +151,8 @@ def _register_attention_processors_metadata(): def _register_transformer_blocks_metadata(): - from ..models.attention import BasicTransformerBlock from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock + from ..models.transformers.transformer_2d import BasicTransformerBlock from ..models.transformers.transformer_bria import BriaTransformerBlock from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5164cf311d3c..056e7db26c5d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -21,10 +21,8 @@ from ..utils import deprecate, logging from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU -from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0 +from .attention_processor import Attention, AttentionProcessor from .embeddings import SinusoidalPositionalEmbedding -from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX if is_xformers_available(): @@ -505,19 +503,16 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor return encoder_hidden_states -def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): - # "feed_forward_chunk_size" can be used to save memory - if hidden_states.shape[chunk_dim] % chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = hidden_states.shape[chunk_dim] // chunk_size - ff_output = torch.cat( - [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], - dim=chunk_dim, +def _chunked_feed_forward(*args, **kwargs): + deprecate( + "_chunked_feed_forward", + "1.0.0", + "Importing `_chunked_feed_forward` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.modeling_common import _chunked_feed_forward` instead.", + standard_warn=False, ) - return ff_output + from .transformers.modeling_common import _chunked_feed_forward + + return _chunked_feed_forward(*args, **kwargs) @maybe_allow_in_graph @@ -577,161 +572,16 @@ class JointTransformerBlock(nn.Module): processing of `context` conditions. """ - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - context_pre_only: bool = False, - qk_norm: Optional[str] = None, - use_dual_attention: bool = False, - ): - super().__init__() - - self.use_dual_attention = use_dual_attention - self.context_pre_only = context_pre_only - context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" - - if use_dual_attention: - self.norm1 = SD35AdaLayerNormZeroX(dim) - else: - self.norm1 = AdaLayerNormZero(dim) - - if context_norm_type == "ada_norm_continous": - self.norm1_context = AdaLayerNormContinuous( - dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" - ) - elif context_norm_type == "ada_norm_zero": - self.norm1_context = AdaLayerNormZero(dim) - else: - raise ValueError( - f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" - ) - - if hasattr(F, "scaled_dot_product_attention"): - processor = JointAttnProcessor2_0() - else: - raise ValueError( - "The current PyTorch version does not support the `scaled_dot_product_attention` function." - ) - - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - added_kv_proj_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - context_pre_only=context_pre_only, - bias=True, - processor=processor, - qk_norm=qk_norm, - eps=1e-6, - ) - - if use_dual_attention: - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - bias=True, - processor=processor, - qk_norm=qk_norm, - eps=1e-6, - ) - else: - self.attn2 = None - - self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - if not context_pre_only: - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - else: - self.norm2_context = None - self.ff_context = None - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - joint_attention_kwargs = joint_attention_kwargs or {} - if self.use_dual_attention: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( - hidden_states, emb=temb - ) - else: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - - if self.context_pre_only: - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) - else: - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb - ) - - # Attention. - attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - **joint_attention_kwargs, + def __new__(cls, *args, **kwargs): + deprecate( + "JointTransformerBlock", + "1.0.0", + "Importing `JointTransformerBlock` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_sd3 import SD3TransformerBlock` instead.", + standard_warn=False, ) + from .transformers.transformer_sd3 import SD3TransformerBlock - # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - if self.use_dual_attention: - attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) - attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 - hidden_states = hidden_states + attn_output2 - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = hidden_states + ff_output - - # Process attention outputs for the `encoder_hidden_states`. - if self.context_pre_only: - encoder_hidden_states = None - else: - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - context_ff_output = _chunked_feed_forward( - self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size - ) - else: - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - - return encoder_hidden_states, hidden_states + return SD3TransformerBlock(*args, **kwargs) @maybe_allow_in_graph @@ -770,300 +620,16 @@ class BasicTransformerBlock(nn.Module): The maximum number of positional embeddings to apply. """ - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' - norm_eps: float = 1e-5, - final_dropout: bool = False, - attention_type: str = "default", - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, - ada_norm_bias: Optional[int] = None, - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - attention_out_bias: bool = True, - ): - super().__init__() - self.dim = dim - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - self.dropout = dropout - self.cross_attention_dim = cross_attention_dim - self.activation_fn = activation_fn - self.attention_bias = attention_bias - self.double_self_attention = double_self_attention - self.norm_elementwise_affine = norm_elementwise_affine - self.positional_embeddings = positional_embeddings - self.num_positional_embeddings = num_positional_embeddings - self.only_cross_attention = only_cross_attention - - # We keep these boolean flags for backward-compatibility. - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - self.use_ada_layer_norm_single = norm_type == "ada_norm_single" - self.use_layer_norm = norm_type == "layer_norm" - self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - self.norm_type = norm_type - self.num_embeds_ada_norm = num_embeds_ada_norm - - if positional_embeddings and (num_positional_embeddings is None): - raise ValueError( - "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." - ) - - if positional_embeddings == "sinusoidal": - self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) - else: - self.pos_embed = None - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if norm_type == "ada_norm": - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif norm_type == "ada_norm_zero": - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - elif norm_type == "ada_norm_continuous": - self.norm1 = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "rms_norm", - ) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - if norm_type == "ada_norm": - self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif norm_type == "ada_norm_continuous": - self.norm2 = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "rms_norm", - ) - else: - self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - ) # is self-attn if encoder_hidden_states is none - else: - if norm_type == "ada_norm_single": # For Latte - self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - if norm_type == "ada_norm_continuous": - self.norm3 = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "layer_norm", - ) - - elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: - self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - elif norm_type == "layer_norm_i2vgen": - self.norm3 = None - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, + def __new__(cls, *args, **kwargs): + deprecate( + "BasicTransformerBlock", + "1.0.0", + "Importing `BasicTransformerBlock` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_2d import BasicTransformerBlock` instead.", + standard_warn=False, ) + from .transformers.transformer_2d import BasicTransformerBlock - # 4. Fuser - if attention_type == "gated" or attention_type == "gated-text-image": - self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) - - # 5. Scale-shift for PixArt-Alpha. - if norm_type == "ada_norm_single": - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Self-Attention - batch_size = hidden_states.shape[0] - - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.norm_type == "ada_norm_zero": - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm1(hidden_states) - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif self.norm_type == "ada_norm_single": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - else: - raise ValueError("Incorrect norm used") - - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) - - # 1. Prepare GLIGEN inputs - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - gligen_kwargs = cross_attention_kwargs.pop("gligen", None) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - if self.norm_type == "ada_norm_zero": - attn_output = gate_msa.unsqueeze(1) * attn_output - elif self.norm_type == "ada_norm_single": - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - # 1.2 GLIGEN Control - if gligen_kwargs is not None: - hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) - - # 3. Cross-Attention - if self.attn2 is not None: - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm2(hidden_states, timestep) - elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm2(hidden_states) - elif self.norm_type == "ada_norm_single": - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) - else: - raise ValueError("Incorrect norm") - - if self.pos_embed is not None and self.norm_type != "ada_norm_single": - norm_hidden_states = self.pos_embed(norm_hidden_states) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-forward - # i2vgen doesn't have this norm 🤷‍♂️ - if self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif not self.norm_type == "ada_norm_single": - norm_hidden_states = self.norm3(hidden_states) - - if self.norm_type == "ada_norm_zero": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self.norm_type == "ada_norm_single": - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - - if self.norm_type == "ada_norm_zero": - ff_output = gate_mlp.unsqueeze(1) * ff_output - elif self.norm_type == "ada_norm_single": - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - return hidden_states + return BasicTransformerBlock(*args, **kwargs) class LuminaFeedForward(nn.Module): @@ -1081,38 +647,16 @@ class LuminaFeedForward(nn.Module): dimension. Defaults to None. """ - def __init__( - self, - dim: int, - inner_dim: int, - multiple_of: Optional[int] = 256, - ffn_dim_multiplier: Optional[float] = None, - ): - super().__init__() - # custom hidden_size factor multiplier - if ffn_dim_multiplier is not None: - inner_dim = int(ffn_dim_multiplier * inner_dim) - inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) - - self.linear_1 = nn.Linear( - dim, - inner_dim, - bias=False, + def __new__(cls, *args, **kwargs): + deprecate( + "LuminaFeedForward", + "1.0.0", + "Importing `LuminaFeedForward` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_lumina2 import LuminaFeedForward` instead.", + standard_warn=False, ) - self.linear_2 = nn.Linear( - inner_dim, - dim, - bias=False, - ) - self.linear_3 = nn.Linear( - dim, - inner_dim, - bias=False, - ) - self.silu = FP32SiLU() + from .transformers.transformer_lumina2 import LuminaFeedForward - def forward(self, x): - return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) + return LuminaFeedForward(*args, **kwargs) @maybe_allow_in_graph @@ -1128,193 +672,29 @@ class TemporalBasicTransformerBlock(nn.Module): cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. """ - def __init__( - self, - dim: int, - time_mix_inner_dim: int, - num_attention_heads: int, - attention_head_dim: int, - cross_attention_dim: Optional[int] = None, - ): - super().__init__() - self.is_res = dim == time_mix_inner_dim - - self.norm_in = nn.LayerNorm(dim) - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - self.ff_in = FeedForward( - dim, - dim_out=time_mix_inner_dim, - activation_fn="geglu", + def __new__(cls, *args, **kwargs): + deprecate( + "TemporalBasicTransformerBlock", + "1.0.0", + "Importing `TemporalBasicTransformerBlock` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_temporal import TemporalBasicTransformerBlock` instead.", + standard_warn=False, ) + from .transformers.transformer_temporal import TemporalBasicTransformerBlock - self.norm1 = nn.LayerNorm(time_mix_inner_dim) - self.attn1 = Attention( - query_dim=time_mix_inner_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - cross_attention_dim=None, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = nn.LayerNorm(time_mix_inner_dim) - self.attn2 = Attention( - query_dim=time_mix_inner_dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(time_mix_inner_dim) - self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = None - - def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): - # Sets chunk feed-forward - self._chunk_size = chunk_size - # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off - self._chunk_dim = 1 - - def forward( - self, - hidden_states: torch.Tensor, - num_frames: int, - encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Self-Attention - batch_size = hidden_states.shape[0] - - batch_frames, seq_length, channels = hidden_states.shape - batch_size = batch_frames // num_frames - - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) - hidden_states = hidden_states.permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) - - residual = hidden_states - hidden_states = self.norm_in(hidden_states) - - if self._chunk_size is not None: - hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) - else: - hidden_states = self.ff_in(hidden_states) - - if self.is_res: - hidden_states = hidden_states + residual - - norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) - hidden_states = attn_output + hidden_states - - # 3. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = self.norm2(hidden_states) - attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) - hidden_states = attn_output + hidden_states - - # 4. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self._chunk_size is not None: - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - - if self.is_res: - hidden_states = ff_output + hidden_states - else: - hidden_states = ff_output - - hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) - hidden_states = hidden_states.permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) - - return hidden_states + return TemporalBasicTransformerBlock(*args, **kwargs) class SkipFFTransformerBlock(nn.Module): - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - kv_input_dim: int, - kv_input_dim_proj_use_bias: bool, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - attention_out_bias: bool = True, - ): - super().__init__() - if kv_input_dim != dim: - self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) - else: - self.kv_mapper = None - - self.norm1 = RMSNorm(dim, 1e-06) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim, - out_bias=attention_out_bias, + def __new__(cls, *args, **kwargs): + deprecate( + "SkipFFTransformerBlock", + "1.0.0", + "Importing `SkipFFTransformerBlock` from `diffusers.models.attention` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.uvit_2d import SkipFFTransformerBlock` instead.", + standard_warn=False, ) + from .unets.uvit_2d import SkipFFTransformerBlock - self.norm2 = RMSNorm(dim, 1e-06) - - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - out_bias=attention_out_bias, - ) - - def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - - if self.kv_mapper is not None: - encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) - - norm_hidden_states = self.norm1(hidden_states) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - **cross_attention_kwargs, - ) - - hidden_states = attn_output + hidden_states - - norm_hidden_states = self.norm2(hidden_states) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - **cross_attention_kwargs, - ) - - hidden_states = attn_output + hidden_states - - return hidden_states + return SkipFFTransformerBlock(*args, **kwargs) @maybe_allow_in_graph @@ -1679,50 +1059,13 @@ class FeedForward(nn.Module): bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - inner_dim=None, - bias: bool = True, - ): - super().__init__() - if inner_dim is None: - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim, bias=bias) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim, bias=bias) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, bias=bias) - elif activation_fn == "swiglu": - act_fn = SwiGLU(dim, inner_dim, bias=bias) - elif activation_fn == "linear-silu": - act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states + def __new__(cls, *args, **kwargs): + deprecate( + "FeedForward", + "1.0.0", + "Importing `FeedForward` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.modeling_common import FeedForward` instead.", + standard_warn=False, + ) + from .transformers.modeling_common import FeedForward + + return FeedForward(*args, **kwargs) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 063ff5bd8e2d..91a783c445dd 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -24,8 +24,8 @@ from ..attention_processor import AttentionProcessor from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin +from ..transformers.modeling_common import Transformer2DModelOutput from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 7c4955eb5828..2b6c22d65c97 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -24,8 +24,8 @@ from ..attention_processor import AttentionProcessor from ..cache_utils import CacheMixin from ..controlnets.controlnet import zero_module -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin +from ..transformers.modeling_common import Transformer2DModelOutput from ..transformers.transformer_qwenimage import ( QwenEmbedRope, QwenImageTransformerBlock, diff --git a/src/diffusers/models/controlnets/controlnet_sana.py b/src/diffusers/models/controlnets/controlnet_sana.py index ed521adbedda..2c2ce0b4f82a 100644 --- a/src/diffusers/models/controlnets/controlnet_sana.py +++ b/src/diffusers/models/controlnets/controlnet_sana.py @@ -23,9 +23,9 @@ from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import AttentionProcessor from ..embeddings import PatchEmbed, PixArtAlphaTextProjection -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm +from ..transformers.modeling_common import Transformer2DModelOutput from ..transformers.sana_transformer import SanaTransformerBlock from .controlnet import zero_module diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 0641c8bc0114..098c6069b392 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -22,12 +22,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import JointTransformerBlock from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..transformers.transformer_sd3 import SD3SingleTransformerBlock +from ..transformers.modeling_common import Transformer2DModelOutput +from ..transformers.transformer_sd3 import SD3SingleTransformerBlock, SD3TransformerBlock from .controlnet import BaseOutput, zero_module @@ -132,7 +131,7 @@ def __init__( # It needs to crafted when we get the actual checkpoints. self.transformer_blocks = nn.ModuleList( [ - JointTransformerBlock( + SD3TransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b51f5d7aec25..86c7e1ae4169 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1530,7 +1530,7 @@ def forward(self, image_embeds: torch.Tensor): class IPAdapterFullImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): super().__init__() - from .attention import FeedForward + from .transformers.modeling_common import FeedForward self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") self.norm = nn.LayerNorm(cross_attention_dim) @@ -1542,7 +1542,7 @@ def forward(self, image_embeds: torch.Tensor): class IPAdapterFaceIDImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): super().__init__() - from .attention import FeedForward + from .transformers.modeling_common import FeedForward self.num_tokens = num_tokens self.cross_attention_dim = cross_attention_dim @@ -2219,7 +2219,7 @@ def __init__( ffn_ratio: float = 4, ) -> None: super().__init__() - from .attention import FeedForward + from .transformers.modeling_common import FeedForward self.ln0 = nn.LayerNorm(embed_dims) self.ln1 = nn.LayerNorm(embed_dims) @@ -2334,7 +2334,7 @@ def __init__( ffproj_ratio: int = 2, ) -> None: super().__init__() - from .attention import FeedForward + from .transformers.modeling_common import FeedForward self.num_tokens = num_tokens self.embed_dim = embed_dims @@ -2404,7 +2404,7 @@ def __init__( ffn_ratio: int = 4, ) -> None: super().__init__() - from .attention import FeedForward + from .transformers.modeling_common import FeedForward self.ln0 = nn.LayerNorm(hidden_dim) self.ln1 = nn.LayerNorm(hidden_dim) diff --git a/src/diffusers/models/modeling_outputs.py b/src/diffusers/models/modeling_outputs.py index 0120a34d9052..541e3ccc4f5d 100644 --- a/src/diffusers/models/modeling_outputs.py +++ b/src/diffusers/models/modeling_outputs.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate @dataclass @@ -17,8 +17,7 @@ class AutoencoderKLOutput(BaseOutput): latent_dist: "DiagonalGaussianDistribution" # noqa: F821 -@dataclass -class Transformer2DModelOutput(BaseOutput): +class Transformer2DModelOutput: """ The output of [`Transformer2DModel`]. @@ -28,4 +27,13 @@ class Transformer2DModelOutput(BaseOutput): distributions for the unnoised latent pixels. """ - sample: "torch.Tensor" # noqa: F821 + def __new__(cls, *args, **kwargs): + deprecate( + "Transformer2DModelOutput", + "1.0.0", + "Importing `Transformer2DModelOutput` from `diffusers.models.modeling_outputs` is deprecated. Please use `from diffusers.models.transformers.modeling_common import Transformer2DModelOutput` instead.", + standard_warn=False, + ) + from .transformers.modeling_common import Transformer2DModelOutput + + return Transformer2DModelOutput(*args, **kwargs) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index bf6d9e1b3803..a693b8e10931 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -30,9 +30,9 @@ FusedAuraFlowAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormZero, FP32LayerNorm +from .modeling_common import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -194,7 +194,8 @@ def forward( @maybe_allow_in_graph -class AuraFlowJointTransformerBlock(nn.Module): +# Copied from diffusers.models.transformers.transformer_sd3.SD3TransformerBlock with SD3->AuraFlow +class AuraFlowTransformerBlock(nn.Module): r""" Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive): @@ -337,7 +338,7 @@ def __init__( self.joint_transformer_blocks = nn.ModuleList( [ - AuraFlowJointTransformerBlock( + AuraFlowTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 9e0afdee6615..0de1f6ed1c7b 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -22,13 +22,13 @@ from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, FeedForward +from ..attention import Attention from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py index 91fe811f0013..5b607623242b 100644 --- a/src/diffusers/models/transformers/consisid_transformer_3d.py +++ b/src/diffusers/models/transformers/consisid_transformer_3d.py @@ -22,12 +22,12 @@ from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, FeedForward +from ..attention import Attention from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 68f6f769436e..950ce58a9e6c 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -19,15 +19,348 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from ..attention import BasicTransformerBlock -from ..embeddings import PatchEmbed -from ..modeling_outputs import Transformer2DModelOutput +from ..attention import Attention, GatedSelfAttentionDense +from ..embeddings import PatchEmbed, SinusoidalPositionalEmbedding from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero +from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock +class DiTTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + class DiTTransformer2DModel(ModelMixin, ConfigMixin): r""" A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748). @@ -121,7 +454,7 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + DiTTransformerBlock( self.inner_dim, self.config.num_attention_heads, self.config.attention_head_dim, diff --git a/src/diffusers/models/transformers/dual_transformer_2d.py b/src/diffusers/models/transformers/dual_transformer_2d.py index 24eed2168229..eedbb4402a5f 100644 --- a/src/diffusers/models/transformers/dual_transformer_2d.py +++ b/src/diffusers/models/transformers/dual_transformer_2d.py @@ -15,7 +15,7 @@ from torch import nn -from ..modeling_outputs import Transformer2DModelOutput +from .modeling_common import Transformer2DModelOutput from .transformer_2d import Transformer2DModel diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index fbe9fe8df91c..3dce2d50efba 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -19,16 +19,15 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 from ..embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, PixArtAlphaTextProjection, ) -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, FP32LayerNorm +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 990c90512e39..d646d623d558 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -12,18 +12,359 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ..attention import BasicTransformerBlock +from ...utils import logging +from ..attention import Attention, GatedSelfAttentionDense from ..cache_utils import CacheMixin -from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid -from ..modeling_outputs import Transformer2DModelOutput +from ..embeddings import ( + PatchEmbed, + PixArtAlphaTextProjection, + SinusoidalPositionalEmbedding, + get_1d_sincos_pos_embed_from_grid, +) from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero +from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock +class LatteTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): @@ -110,7 +451,7 @@ def __init__( # 2. Define spatial transformers blocks self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + LatteTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, @@ -130,7 +471,7 @@ def __init__( # 3. Define temporal transformers blocks self.temporal_transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + LatteTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py index bed5e69c2d36..23c13ad74b50 100644 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -19,15 +19,15 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from ..attention import LuminaFeedForward from ..attention_processor import Attention, LuminaAttnProcessor2_0 from ..embeddings import ( LuminaCombinedTimestepCaptionEmbedding, LuminaPatchEmbed, ) -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm +from .modeling_common import Transformer2DModelOutput +from .transformer_lumina2 import LuminaFeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/modeling_common.py b/src/diffusers/models/transformers/modeling_common.py new file mode 100644 index 000000000000..86b3dbb92736 --- /dev/null +++ b/src/diffusers/models/transformers/modeling_common.py @@ -0,0 +1,114 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from ...utils import BaseOutput, deprecate +from ..activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU + + +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: "torch.Tensor" # noqa: F821 + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "swiglu": + act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 5a22144228ae..e792eee53fb6 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -18,17 +18,349 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from ..attention import BasicTransformerBlock +from ..attention import GatedSelfAttentionDense from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 -from ..embeddings import PatchEmbed, PixArtAlphaTextProjection -from ..modeling_outputs import Transformer2DModelOutput +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, SinusoidalPositionalEmbedding from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero +from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock +class PixArtTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + class PixArtTransformer2DModel(ModelMixin, ConfigMixin): r""" A 2D Transformer model as introduced in PixArt family of models (https://huggingface.co/papers/2310.00426, @@ -151,7 +483,7 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + PixArtTransformerBlock( self.inner_dim, self.config.num_attention_heads, self.config.attention_head_dim, diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py index 565da0da8b6e..f13913e4913a 100644 --- a/src/diffusers/models/transformers/prior_transformer.py +++ b/src/diffusers/models/transformers/prior_transformer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union import torch import torch.nn.functional as F @@ -7,8 +7,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...utils import BaseOutput -from ..attention import BasicTransformerBlock +from ...utils import BaseOutput, logging +from ..attention import Attention, GatedSelfAttentionDense from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -16,8 +16,345 @@ AttnAddedKVProcessor, AttnProcessor, ) -from ..embeddings import TimestepEmbedding, Timesteps +from ..embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero +from .modeling_common import FeedForward, _chunked_feed_forward + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock +class PriorTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states @dataclass @@ -133,7 +470,7 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock( + PriorTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 1e02ac32e86c..50dcde2b78ae 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -27,9 +27,9 @@ SanaLinearAttnProcessor2_0, ) from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm +from .modeling_common import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py index 969e6db122d9..68caff835ce4 100644 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -23,10 +23,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0 from ..modeling_utils import ModelMixin from ..transformers.transformer_2d import Transformer2DModelOutput +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 67fe9a33109b..7f04cc61a703 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -19,16 +19,352 @@ from ...configuration_utils import LegacyConfigMixin, register_to_config from ...utils import deprecate, logging -from ..attention import BasicTransformerBlock -from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection -from ..modeling_outputs import Transformer2DModelOutput +from ..attention import Attention, GatedSelfAttentionDense +from ..embeddings import ( + ImagePositionalEmbeddings, + PatchEmbed, + PixArtAlphaTextProjection, + SinusoidalPositionalEmbedding, +) from ..modeling_utils import LegacyModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero +from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + class Transformer2DModelOutput(Transformer2DModelOutput): def __init__(self, *args, **kwargs): deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead." diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 5fa59a71d977..de871629a175 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -22,13 +22,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward from ..attention_processor import AllegroAttnProcessor2_0, Attention from ..cache_utils import CacheMixin from ..embeddings import PatchEmbed, PixArtAlphaTextProjection -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index d54679306e64..6f64f2562cff 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -10,13 +10,13 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionModuleMixin, FeedForward +from ..attention import AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, apply_rotary_emb, get_timestep_embedding -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 5823ae9d3da6..48499346ce24 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -24,12 +24,12 @@ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, FeedForward +from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm +from .modeling_common import FeedForward, Transformer2DModelOutput from .transformer_flux import FluxAttention, FluxAttnProcessor diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 7356f4a606bb..b65c0be98df6 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -20,12 +20,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0 from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 64e9a538a7c2..8a51ca03fd48 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -22,13 +22,12 @@ from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import CogView3CombinedTimestepSizeEmbeddings -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LayerNorm, RMSNorm +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 373b470ae37b..5dab3d495028 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -22,12 +22,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin from ...utils import is_torchvision_available -from ..attention import FeedForward from ..attention_processor import Attention from ..embeddings import Timesteps -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import RMSNorm +from .modeling_common import FeedForward, Transformer2DModelOutput if is_torchvision_available(): diff --git a/src/diffusers/models/transformers/transformer_easyanimate.py b/src/diffusers/models/transformers/transformer_easyanimate.py index 545fa29730db..aad94c843939 100755 --- a/src/diffusers/models/transformers/transformer_easyanimate.py +++ b/src/diffusers/models/transformers/transformer_easyanimate.py @@ -22,11 +22,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, FeedForward +from ..attention import Attention from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 1a4464432425..0bb076926590 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -25,7 +25,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import ( @@ -34,9 +34,9 @@ apply_rotary_emb, get_1d_rotary_pos_embed, ) -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 4a5aee29abc4..91f9c9302f61 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -6,8 +6,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_utils import ModelMixin +from ...models.transformers.modeling_common import Transformer2DModelOutput from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index bc857ccab463..8c06026ccff8 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -23,7 +23,6 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor from ..cache_utils import CacheMixin from ..embeddings import ( @@ -33,9 +32,9 @@ Timesteps, get_1d_rotary_pos_embed, ) -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index 60b40fff3cb8..87ef80480d5c 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -23,9 +23,9 @@ from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers from ..cache_utils import CacheMixin from ..embeddings import get_1d_rotary_pos_embed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous +from .modeling_common import Transformer2DModelOutput from .transformer_hunyuan_video import ( HunyuanVideoConditionEmbedding, HunyuanVideoPatchEmbed, diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 685c73c07c75..45a2108725ee 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -25,13 +25,13 @@ from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 77121edb9fc9..c4be82fbec61 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -23,17 +23,66 @@ from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import LuminaFeedForward +from ..activations import FP32SiLU from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm +from .modeling_common import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.silu = FP32SiLU() + + def forward(self, x): + return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) + + class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 63911fe7c10d..af2f51062b1f 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -23,13 +23,12 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, RMSNorm +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 6939cac0a3a7..ca4411ad3cd9 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -23,9 +23,9 @@ from ...utils import logging from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, RMSNorm +from .modeling_common import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 05379270c13b..6cdbd63a9550 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -26,14 +26,14 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput -from ..attention import AttentionMixin, FeedForward +from ..attention import AttentionMixin from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, Timesteps -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, RMSNorm +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 762d89c303d7..0890582c75b2 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -20,7 +20,6 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward, JointTransformerBlock from ..attention_processor import ( Attention, AttentionProcessor, @@ -28,14 +27,184 @@ JointAttnProcessor2_0, ) from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX +from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class SD3TransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://huggingface.co/papers/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + context_pre_only: bool = False, + qk_norm: Optional[str] = None, + use_dual_attention: bool = False, + ): + super().__init__() + + self.use_dual_attention = use_dual_attention + self.context_pre_only = context_pre_only + context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" + + if use_dual_attention: + self.norm1 = SD35AdaLayerNormZeroX(dim) + else: + self.norm1 = AdaLayerNormZero(dim) + + if context_norm_type == "ada_norm_continous": + self.norm1_context = AdaLayerNormContinuous( + dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" + ) + elif context_norm_type == "ada_norm_zero": + self.norm1_context = AdaLayerNormZero(dim) + else: + raise ValueError( + f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" + ) + + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + processor = JointAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=context_pre_only, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + + if use_dual_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + else: + self.attn2 = None + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + if not context_pre_only: + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + else: + self.norm2_context = None + self.ff_context = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + if self.use_dual_attention: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + **joint_attention_kwargs, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + if self.use_dual_attention: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states + + @maybe_allow_in_graph class SD3SingleTransformerBlock(nn.Module): def __init__( @@ -155,7 +324,7 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - JointTransformerBlock( + SD3TransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 6b600aa22487..53ffda6a6065 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -23,7 +23,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import ( @@ -32,9 +32,9 @@ get_1d_rotary_pos_embed, get_1d_sincos_pos_embed_from_grid, ) -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin, get_parameter_dtype from ..normalization import FP32LayerNorm +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index ffaf31d04570..ef3f13215a2f 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -19,10 +19,13 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput -from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..resnet import AlphaBlender +from .modeling_common import FeedForward, _chunked_feed_forward +from .transformer_2d import BasicTransformerBlock @dataclass @@ -38,6 +41,136 @@ class TransformerTemporalModelOutput(BaseOutput): sample: torch.Tensor +@maybe_allow_in_graph +class TemporalBasicTransformerBlock(nn.Module): + r""" + A basic Transformer block for video like data. + + Parameters: + dim (`int`): The number of channels in the input and output. + time_mix_inner_dim (`int`): The number of channels for temporal attention. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + """ + + def __init__( + self, + dim: int, + time_mix_inner_dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.is_res = dim == time_mix_inner_dim + + self.norm_in = nn.LayerNorm(dim) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.ff_in = FeedForward( + dim, + dim_out=time_mix_inner_dim, + activation_fn="geglu", + ) + + self.norm1 = nn.LayerNorm(time_mix_inner_dim) + self.attn1 = Attention( + query_dim=time_mix_inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + cross_attention_dim=None, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(time_mix_inner_dim) + self.attn2 = Attention( + query_dim=time_mix_inner_dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(time_mix_inner_dim) + self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): + # Sets chunk feed-forward + self._chunk_size = chunk_size + # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off + self._chunk_dim = 1 + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: int, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states + + class TransformerTemporalModel(ModelMixin, ConfigMixin): """ A Transformer model for video-like data. diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index dd75fb124f1a..aa121cf025d6 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -24,13 +24,13 @@ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm +from .modeling_common import FeedForward, Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 30c38c244ad8..560027e4a193 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -21,11 +21,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import AttentionMixin, FeedForward +from ..attention import AttentionMixin from ..cache_utils import CacheMixin -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm +from .modeling_common import FeedForward, Transformer2DModelOutput from .transformer_wan import ( WanAttention, WanAttnProcessor, diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 7148723a84d8..5fb79073f06a 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -22,7 +22,7 @@ from ...loaders import UNet2DConditionLoadersMixin from ...utils import logging from ..activations import get_activation -from ..attention import Attention, FeedForward +from ..attention import Attention from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -33,6 +33,7 @@ ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin +from ..transformers.modeling_common import FeedForward from ..transformers.transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( UNetMidBlock3DCrossAttn, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 26616e53bdfd..884fd782b7a3 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -24,7 +24,6 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import BaseOutput, deprecate, logging from ...utils.torch_utils import apply_freeu -from ..attention import BasicTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -41,7 +40,7 @@ from ..modeling_utils import ModelMixin from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D from ..transformers.dual_transformer_2d import DualTransformer2DModel -from ..transformers.transformer_2d import Transformer2DModel +from ..transformers.transformer_2d import BasicTransformerBlock, Transformer2DModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel diff --git a/src/diffusers/models/unets/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py index 94b39c84f055..1f42f607d482 100644 --- a/src/diffusers/models/unets/uvit_2d.py +++ b/src/diffusers/models/unets/uvit_2d.py @@ -22,10 +22,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ..attention import BasicTransformerBlock, SkipFFTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, + Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, @@ -34,6 +34,79 @@ from ..modeling_utils import ModelMixin from ..normalization import GlobalResponseNorm, RMSNorm from ..resnet import Downsample2D, Upsample2D +from ..transformers.transformer_2d import BasicTransformerBlock + + +class SkipFFTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + kv_input_dim: int, + kv_input_dim_proj_use_bias: bool, + dropout=0.0, + cross_attention_dim: int = None, + attention_bias: bool = False, + attention_out_bias: bool = True, + ): + super().__init__() + if kv_input_dim != dim: + self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) + else: + self.kv_mapper = None + + self.norm1 = RMSNorm(dim, 1e-06) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + ) + + self.norm2 = RMSNorm(dim, 1e-06) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + if self.kv_mapper is not None: + encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + return hidden_states class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 6ab51a5e513f..e69de29bb2d1 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -1,147 +0,0 @@ -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import unittest - -import torch -from transformers import AutoTokenizer, T5EncoderModel - -from diffusers import ( - AutoencoderKLLTXVideo, - FlowMatchEulerDiscreteScheduler, - LTXPipeline, - LTXVideoTransformer3DModel, -) - -from ..testing_utils import floats_tensor, require_peft_backend - - -sys.path.append(".") - -from .utils import PeftLoraLoaderMixinTests # noqa: E402 - - -@require_peft_backend -class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): - pipeline_class = LTXPipeline - scheduler_cls = FlowMatchEulerDiscreteScheduler - scheduler_kwargs = {} - - transformer_kwargs = { - "in_channels": 8, - "out_channels": 8, - "patch_size": 1, - "patch_size_t": 1, - "num_attention_heads": 4, - "attention_head_dim": 8, - "cross_attention_dim": 32, - "num_layers": 1, - "caption_channels": 32, - } - transformer_cls = LTXVideoTransformer3DModel - vae_kwargs = { - "in_channels": 3, - "out_channels": 3, - "latent_channels": 8, - "block_out_channels": (8, 8, 8, 8), - "decoder_block_out_channels": (8, 8, 8, 8), - "layers_per_block": (1, 1, 1, 1, 1), - "decoder_layers_per_block": (1, 1, 1, 1, 1), - "spatio_temporal_scaling": (True, True, False, False), - "decoder_spatio_temporal_scaling": (True, True, False, False), - "decoder_inject_noise": (False, False, False, False, False), - "upsample_residual": (False, False, False, False), - "upsample_factor": (1, 1, 1, 1), - "timestep_conditioning": False, - "patch_size": 1, - "patch_size_t": 1, - "encoder_causal": True, - "decoder_causal": False, - } - vae_cls = AutoencoderKLLTXVideo - tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" - text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" - - text_encoder_target_modules = ["q", "k", "v", "o"] - - @property - def output_shape(self): - return (1, 9, 32, 32, 3) - - def get_dummy_inputs(self, with_generator=True): - batch_size = 1 - sequence_length = 16 - num_channels = 8 - num_frames = 9 - num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1 - latent_height = 8 - latent_width = 8 - - generator = torch.manual_seed(0) - noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width)) - input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - - pipeline_inputs = { - "prompt": "dance monkey", - "num_frames": num_frames, - "num_inference_steps": 4, - "guidance_scale": 6.0, - "height": 32, - "width": 32, - "max_sequence_length": sequence_length, - "output_type": "np", - } - if with_generator: - pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs - - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) - - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - - @unittest.skip("Not supported in LTXVideo.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in LTXVideo.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - - @unittest.skip("Not supported in LTXVideo.") - def test_modify_padding_mode(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass