Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 53 additions & 127 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,81 +1695,6 @@ def __call__(
return hidden_states


# YiYi to-do: refactor rope related functions/classes
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


class FluxSingleAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape

query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

return hidden_states


class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

Expand All @@ -1785,16 +1710,7 @@ def __call__(
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size = encoder_hidden_states.shape[0]
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape

# `sample` projections.
query = attn.to_q(hidden_states)
Expand All @@ -1813,59 +1729,58 @@ def __call__(
if attn.norm_k is not None:
key = attn.norm_k(key)

# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
from .embeddings import apply_rotary_emb

query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states
return hidden_states, encoder_hidden_states
else:
return hidden_states


class XFormersAttnAddedKVProcessor:
Expand Down Expand Up @@ -4105,6 +4020,17 @@ def __init__(self):
pass


class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""

def __init__(self):
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
super().__init__()


ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
Expand Down
22 changes: 17 additions & 5 deletions src/diffusers/models/controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, zero_module
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from .modeling_outputs import Transformer2DModelOutput
from .transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
Expand Down Expand Up @@ -272,8 +272,20 @@ def forward(
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
ids = torch.cat((txt_ids, img_ids), dim=1)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]

ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)

block_samples = ()
Expand Down
42 changes: 35 additions & 7 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def get_1d_rotary_pos_embed(
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
Expand All @@ -468,6 +469,8 @@ def get_1d_rotary_pos_embed(
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
Expand All @@ -476,19 +479,19 @@ def get_1d_rotary_pos_embed(
if isinstance(pos, int):
pos = np.arange(pos)
theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
return freqs_cis


Expand Down Expand Up @@ -540,6 +543,31 @@ def apply_rotary_emb(
return x_out.type_as(x)


class FluxPosEmbed(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a reference to the original BFL inference code?

# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim

def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.squeeze().float().cpu().numpy()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul the results for flux are identical with this refactor
the only other difference is here, where we downcast the dtype for mps see #9133 for more details

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aye, thanks!

for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin


class TimestepEmbedding(nn.Module):
def __init__(
self,
Expand Down
Loading