From efc7ed990a43b79ba9e99f64ae56768057a2ccbb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 3 Aug 2024 01:26:55 +0200 Subject: [PATCH 01/14] edit --- src/diffusers/models/attention_processor.py | 2 +- src/diffusers/models/transformers/transformer_flux.py | 4 ++-- src/diffusers/pipelines/flux/pipeline_flux.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 855085c0d933..cbe2a3599d68 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1290,7 +1290,7 @@ class FluxSingleAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError("FluxSingleAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 73ccc03b38c4..ecc72a492c77 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -1,4 +1,4 @@ -# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -356,7 +356,7 @@ def forward( ) hidden_states = self.x_embedder(hidden_states) - timestep = timestep.to(hidden_states.dtype) * 1000 + timestep = timestep.to(hidden_states.dtype) if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 else: diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 4378f97ffd68..73c6d4ce7f05 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -705,8 +705,7 @@ def __call__( noise_pred = self.transformer( hidden_states=latents, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timestep / 1000, + timestep=timestep, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, From 9b8f8c79f87942e39a055afe3f038d9cdd137be3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 4 Aug 2024 00:50:12 +0200 Subject: [PATCH 02/14] refactor rotary embeds --- src/diffusers/models/attention_processor.py | 31 +++++--------- src/diffusers/models/embeddings.py | 20 +++++++++ .../models/transformers/transformer_flux.py | 41 ++----------------- 3 files changed, 34 insertions(+), 58 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cbe2a3599d68..f74f3e888829 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1274,15 +1274,6 @@ def __call__( return hidden_states -# YiYi to-do: refactor rope related functions/classes -def apply_rope(xq, xk, freqs_cis): - xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - - class FluxSingleAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -1290,7 +1281,9 @@ class FluxSingleAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("FluxSingleAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "FluxSingleAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) def __call__( self, @@ -1330,11 +1323,10 @@ def __call__( # Apply RoPE if needed if image_rotary_emb is not None: - # YiYi to-do: update uising apply_rotary_emb - # from ..embeddings import apply_rotary_emb - # query = apply_rotary_emb(query, image_rotary_emb) - # key = apply_rotary_emb(key, image_rotary_emb) - query, key = apply_rope(query, key, image_rotary_emb) + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 @@ -1418,11 +1410,10 @@ def __call__( value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: - # YiYi to-do: update uising apply_rotary_emb - # from ..embeddings import apply_rotary_emb - # query = apply_rotary_emb(query, image_rotary_emb) - # key = apply_rotary_emb(key, image_rotary_emb) - query, key = apply_rope(query, key, image_rotary_emb) + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 2821ce0330fc..b4979deebdb4 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -453,6 +453,26 @@ def apply_rotary_emb( return x_out.type_as(x) +class FluxPosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.squeeze().float().cpu().numpy() + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed(self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + class TimestepEmbedding(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index ecc72a492c77..2c27d9e078b2 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union import torch import torch.nn as nn @@ -27,48 +27,13 @@ from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings +from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..modeling_outputs import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# YiYi to-do: refactor rope related functions/classes -def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: - assert dim % 2 == 0, "The dimension must be even." - - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim - omega = 1.0 / (theta**scale) - - batch_size, seq_length = pos.shape - out = torch.einsum("...n,d->...nd", pos, omega) - cos_out = torch.cos(out) - sin_out = torch.sin(out) - - stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) - out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) - return out.float() - - -# YiYi to-do: refactor rope related functions/classes -class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: List[int]): - super().__init__() - self.dim = dim - self.theta = theta - self.axes_dim = axes_dim - - def forward(self, ids: torch.Tensor) -> torch.Tensor: - n_axes = ids.shape[-1] - emb = torch.cat( - [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], - dim=-3, - ) - - return emb.unsqueeze(1) - - @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): r""" @@ -264,7 +229,7 @@ def __init__( self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56]) + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=[16, 56, 56]) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) From 1887bda004fae0dbd199ef768e5f634904bfb101 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sat, 3 Aug 2024 12:54:12 -1000 Subject: [PATCH 03/14] Update src/diffusers/models/transformers/transformer_flux.py --- src/diffusers/models/transformers/transformer_flux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 3523f8afb430..099d3ad16568 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -1,5 +1,4 @@ # Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved. - # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From a9cdfccc692d0594ae2ac54334d612c3bf456139 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 5 Aug 2024 21:55:36 +0200 Subject: [PATCH 04/14] fix --- src/diffusers/models/transformers/transformer_flux.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 099d3ad16568..0a0be0999a82 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -225,13 +225,13 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, - axes_dims_rope: List[int] = [16, 56, 56], + axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=[16, 56, 56]) + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings From de66c587b122d9cda9de5d0509305f2e457b8c4a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 17 Aug 2024 04:54:05 +0200 Subject: [PATCH 05/14] remove the batch dimension in ids --- src/diffusers/models/transformers/transformer_flux.py | 2 +- src/diffusers/pipelines/flux/pipeline_flux.py | 10 ++-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 0a0be0999a82..b3b0992735c8 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -340,7 +340,7 @@ def forward( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - ids = torch.cat((txt_ids, img_ids), dim=1) + ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index c1da15898d00..b121b3e2f6a9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -343,10 +343,6 @@ def encode_prompt( scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt @@ -376,8 +372,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -437,9 +432,8 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) From abad85436cc97c2eacacd76986c9b48a98222a63 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 17 Aug 2024 10:12:33 +0200 Subject: [PATCH 06/14] keep transformer timesteps input same --- src/diffusers/models/transformers/transformer_flux.py | 2 +- src/diffusers/pipelines/flux/pipeline_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index b3b0992735c8..31f35956d7db 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -328,7 +328,7 @@ def forward( ) hidden_states = self.x_embedder(hidden_states) - timestep = timestep.to(hidden_states.dtype) + timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 else: diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index b121b3e2f6a9..31cbf063f29f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -701,7 +701,7 @@ def __call__( noise_pred = self.transformer( hidden_states=latents, - timestep=timestep, + timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, From 463b9101a5f2e9a3b890a9369f8b95822ab06d79 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 Aug 2024 02:57:03 +0200 Subject: [PATCH 07/14] add freqs_dtype, allow torch.float64 and make adjustment for mps device --- src/diffusers/models/embeddings.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0533acae9f75..9bcc0fcd4074 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -359,6 +359,7 @@ def get_1d_rotary_pos_embed( linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux) ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -381,6 +382,8 @@ def get_1d_rotary_pos_embed( repeat_interleave_real (`bool`, *optional*, defaults to `True`): If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ @@ -389,19 +392,19 @@ def get_1d_rotary_pos_embed( if isinstance(pos, int): pos = np.arange(pos) theta = theta * ntk_factor - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] - freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: - freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: - freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] - freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2] return freqs_cis @@ -464,8 +467,12 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.squeeze().float().cpu().numpy() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): - cos, sin = get_1d_rotary_pos_embed(self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True) + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype + ) cos_out.append(cos) sin_out.append(sin) freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) From f23cb1b7942b79863af85663b99efe72249fe9b2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 Aug 2024 03:58:25 +0200 Subject: [PATCH 08/14] deprecate flux single attn processor --- src/diffusers/models/attention_processor.py | 163 ++++++------------ .../models/transformers/transformer_flux.py | 4 +- 2 files changed, 51 insertions(+), 116 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f74f3e888829..0eee95907a39 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1274,73 +1274,6 @@ def __call__( return hidden_states -class FluxSingleAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxSingleAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - return hidden_states - - class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" @@ -1356,16 +1289,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. query = attn.to_q(hidden_states) @@ -1384,30 +1308,32 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: from .embeddings import apply_rotary_emb @@ -1419,23 +1345,21 @@ def __call__( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - return hidden_states, encoder_hidden_states + return hidden_states, encoder_hidden_states + else: + return hidden_states class XFormersAttnAddedKVProcessor: @@ -3418,6 +3342,17 @@ def __init__(self): pass +class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead." + deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message) + super().__init__() + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 31f35956d7db..1286c45308f8 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...models.attention import FeedForward -from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0 +from ...models.attention_processor import Attention, FluxAttnProcessor2_0 from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers @@ -58,7 +58,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - processor = FluxSingleAttnProcessor2_0() + processor = FluxAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, From ab3a550ef8f8405ccca28ccf6d5675c95e7e68e5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 Aug 2024 04:30:10 +0200 Subject: [PATCH 09/14] deprecate 2d ids inputs to flux transformer --- .../models/transformers/transformer_flux.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 1286c45308f8..10fff043343f 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -340,6 +340,18 @@ def forward( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if txt_ids.ndim == 2: + logger.warning( + "Passing `txt_ids` 2d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[:, 0] + if img_ids.ndim == 2: + logger.warning( + "Passing `img_ids` 2d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[:, 0] ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) From 4161d93a41aa96ae9cca202f5b8fd239bba59087 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 Aug 2024 04:46:57 +0200 Subject: [PATCH 10/14] use FluxPosEmbed in flux controlnet too --- src/diffusers/models/controlnet_flux.py | 6 +++--- .../models/transformers/transformer_flux.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index ba4933dcad67..aecff4cb104f 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -24,9 +24,9 @@ from ..models.modeling_utils import ModelMixin from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from .controlnet import BaseOutput, zero_module -from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings +from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from .modeling_outputs import Transformer2DModelOutput -from .transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock +from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -59,7 +59,7 @@ def __init__( self.out_channels = in_channels self.inner_dim = num_attention_heads * attention_head_dim - self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope) + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 223a4cd49a1f..3f28f7d134ec 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -348,18 +348,18 @@ def forward( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - if txt_ids.ndim == 2: + if txt_ids.ndim == 3: logger.warning( - "Passing `txt_ids` 2d torch.Tensor is deprecated." + "Passing `txt_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) - txt_ids = txt_ids[:, 0] - if img_ids.ndim == 2: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: logger.warning( - "Passing `img_ids` 2d torch.Tensor is deprecated." + "Passing `img_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) - img_ids = img_ids[:, 0] + img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) From 89e0cccaa52f79aa411d5c9997a1736176fdc85e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 Aug 2024 05:12:56 +0200 Subject: [PATCH 11/14] apply same change to controlnet --- src/diffusers/models/controlnet_flux.py | 3 +-- .../pipelines/flux/pipeline_flux_controlnet.py | 11 ++--------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index aecff4cb104f..e51f528caf31 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -272,8 +272,7 @@ def forward( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - txt_ids = txt_ids.expand(img_ids.size(0), -1, -1) - ids = torch.cat((txt_ids, img_ids), dim=1) + ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) block_samples = () diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 84450374cb30..b9e93e720baf 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -354,10 +354,6 @@ def encode_prompt( scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt @@ -387,8 +383,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -449,9 +444,8 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) @@ -804,7 +798,6 @@ def __call__( noise_pred = self.transformer( hidden_states=latents, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, From 0ff226647062102715f52e869f73b9c04ab27b35 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 Aug 2024 07:59:54 +0200 Subject: [PATCH 12/14] add a test for deprecated flux tranformers inputs: txt and img ids as 3d tensors --- tests/models/test_modeling_common.py | 1 - .../test_models_transformer_flux.py | 32 +++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 64722e2d9797..0ce01fb93f40 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -976,7 +976,6 @@ def test_sharded_checkpoints_device_map(self): self.assertTrue(actual_num_shards == expected_num_shards) new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto") - new_model = new_model.to(torch_device) torch.manual_seed(0) if "generator" in inputs_dict: diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index bda37621c27d..538d158cbcb9 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -44,8 +44,8 @@ def dummy_input(self): hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) - text_ids = torch.randn((batch_size, sequence_length, num_image_channels)).to(torch_device) - image_ids = torch.randn((batch_size, height * width, num_image_channels)).to(torch_device) + text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) + image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) return { @@ -80,3 +80,31 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + + def test_deprecated_inputs_img_txt_ids_3d(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output_1 = model(**inputs_dict).to_tuple()[0] + + # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) + text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) + image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) + + assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" + assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" + + inputs_dict["txt_ids"] = text_ids_3d + inputs_dict["img_ids"] = image_ids_3d + + with torch.no_grad(): + output_2 = model(**inputs_dict).to_tuple()[0] + + self.assertEqual(output_1.shape, output_2.shape) + self.assertTrue( + torch.allclose(output_1, output_2, atol=1e-5), + msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", + ) From 40e94e0f2f69e811ceaad0ea7164af501418306e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 Aug 2024 10:19:02 +0200 Subject: [PATCH 13/14] up --- src/diffusers/models/controlnet_flux.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index e51f528caf31..b29930f81ea2 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -272,6 +272,19 @@ def forward( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) From 72d1cf0b120412b961ed8ef7edb1cce398b7d40a Mon Sep 17 00:00:00 2001 From: Joseph Smidt Date: Mon, 19 Aug 2024 19:27:35 +0200 Subject: [PATCH 14/14] adding jsmidt as co-author of this PR for https://github.com/huggingface/diffusers/pull/9133 --- src/diffusers/models/attention_processor.py | 2 +- src/diffusers/models/embeddings.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e9e33c15e48d..fc225567ddc1 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -4027,7 +4027,7 @@ class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0): def __init__(self): deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead." - deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message) + deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message) super().__init__() diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 873430aceb19..b2f496833176 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -544,6 +544,7 @@ def apply_rotary_emb( class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() self.theta = theta