Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/hooks/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def _register_attention_processors_metadata():


def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_2d import BasicTransformerBlock
from ..models.transformers.transformer_bria import BriaTransformerBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
Expand Down
777 changes: 60 additions & 717 deletions src/diffusers/models/attention.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/diffusers/models/controlnets/controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from ..attention_processor import AttentionProcessor
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.modeling_common import Transformer2DModelOutput
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock


Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnets/controlnet_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from ..attention_processor import AttentionProcessor
from ..cache_utils import CacheMixin
from ..controlnets.controlnet import zero_module
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.modeling_common import Transformer2DModelOutput
from ..transformers.transformer_qwenimage import (
QwenEmbedRope,
QwenImageTransformerBlock,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnets/controlnet_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import AttentionProcessor
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from ..transformers.modeling_common import Transformer2DModelOutput
from ..transformers.sana_transformer import SanaTransformerBlock
from .controlnet import zero_module

Expand Down
7 changes: 3 additions & 4 deletions src/diffusers/models/controlnets/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import JointTransformerBlock
from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
from ..transformers.modeling_common import Transformer2DModelOutput
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock, SD3TransformerBlock
from .controlnet import BaseOutput, zero_module


Expand Down Expand Up @@ -132,7 +131,7 @@ def __init__(
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
SD3TransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,7 +1530,7 @@ def forward(self, image_embeds: torch.Tensor):
class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__()
from .attention import FeedForward
from .transformers.modeling_common import FeedForward

self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim)
Expand All @@ -1542,7 +1542,7 @@ def forward(self, image_embeds: torch.Tensor):
class IPAdapterFaceIDImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
super().__init__()
from .attention import FeedForward
from .transformers.modeling_common import FeedForward

self.num_tokens = num_tokens
self.cross_attention_dim = cross_attention_dim
Expand Down Expand Up @@ -2219,7 +2219,7 @@ def __init__(
ffn_ratio: float = 4,
) -> None:
super().__init__()
from .attention import FeedForward
from .transformers.modeling_common import FeedForward

self.ln0 = nn.LayerNorm(embed_dims)
self.ln1 = nn.LayerNorm(embed_dims)
Expand Down Expand Up @@ -2334,7 +2334,7 @@ def __init__(
ffproj_ratio: int = 2,
) -> None:
super().__init__()
from .attention import FeedForward
from .transformers.modeling_common import FeedForward

self.num_tokens = num_tokens
self.embed_dim = embed_dims
Expand Down Expand Up @@ -2404,7 +2404,7 @@ def __init__(
ffn_ratio: int = 4,
) -> None:
super().__init__()
from .attention import FeedForward
from .transformers.modeling_common import FeedForward

self.ln0 = nn.LayerNorm(hidden_dim)
self.ln1 = nn.LayerNorm(hidden_dim)
Expand Down
16 changes: 12 additions & 4 deletions src/diffusers/models/modeling_outputs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate


@dataclass
Expand All @@ -17,8 +17,7 @@ class AutoencoderKLOutput(BaseOutput):
latent_dist: "DiagonalGaussianDistribution" # noqa: F821


@dataclass
class Transformer2DModelOutput(BaseOutput):
class Transformer2DModelOutput:
"""
The output of [`Transformer2DModel`].

Expand All @@ -28,4 +27,13 @@ class Transformer2DModelOutput(BaseOutput):
distributions for the unnoised latent pixels.
"""

sample: "torch.Tensor" # noqa: F821
def __new__(cls, *args, **kwargs):
deprecate(
"Transformer2DModelOutput",
"1.0.0",
"Importing `Transformer2DModelOutput` from `diffusers.models.modeling_outputs` is deprecated. Please use `from diffusers.models.transformers.modeling_common import Transformer2DModelOutput` instead.",
standard_warn=False,
)
from .transformers.modeling_common import Transformer2DModelOutput

return Transformer2DModelOutput(*args, **kwargs)
7 changes: 4 additions & 3 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
FusedAuraFlowAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormZero, FP32LayerNorm
from .modeling_common import Transformer2DModelOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -194,7 +194,8 @@ def forward(


@maybe_allow_in_graph
class AuraFlowJointTransformerBlock(nn.Module):
# Copied from diffusers.models.transformers.transformer_sd3.SD3TransformerBlock with SD3->AuraFlow
class AuraFlowTransformerBlock(nn.Module):
r"""
Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive):

Expand Down Expand Up @@ -337,7 +338,7 @@ def __init__(

self.joint_transformer_blocks = nn.ModuleList(
[
AuraFlowJointTransformerBlock(
AuraFlowTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention import Attention
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
from .modeling_common import FeedForward, Transformer2DModelOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/consisid_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention import Attention
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
from .modeling_common import FeedForward, Transformer2DModelOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
Loading
Loading