Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 179 additions & 22 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import math
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -38,23 +39,31 @@
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle

from mindiesd import attention_forward as mindie_sd_attn_forward

STREAM_VECTOR = torch.npu.Stream()

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

if torch.distributed.is_available():
import torch.distributed._functional_collectives as funcol

def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
from ..attention_dispatch import npu_fusion_attention
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True):
if cal_q:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)

return query, key, value, encoder_query, encoder_key, encoder_value

if cal_q:
return query, key, value, encoder_query, encoder_key, encoder_value
else:
return value, encoder_query, encoder_key, encoder_value

def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
Expand All @@ -66,11 +75,38 @@ def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_
return query, key, value, encoder_query, encoder_key, encoder_value


def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True):
if attn.fused_projections and cal_q:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)

return _get_projections(attn, hidden_states, encoder_hidden_states, cal_q)

def _wait_tensor(tensor):
if isinstance(tensor, funcol.AsyncCollectiveTensor):
tensor = tensor.wait()
return tensor

def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
shape = x.shape
x = x.flatten()
x = funcol.all_to_all_single(x, None, None, group)
x = x.reshape(shape)
x = _wait_tensor(x)
return x

def ulysses_preforward(
x: torch.Tensor,
group,
world_size,
B,
S_LOCAL,
H,
D,
H_LOCAL
):
x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
x = x.flatten()
x = funcol.all_to_all_single(x, None, None, group)
return x

class FluxAttnProcessor:
_attention_backend = None
Expand All @@ -87,10 +123,26 @@ def __call__(
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
pre_query: Optional[torch.Tensor] = None,
pre_key: Optional[torch.Tensor] = None,
cal_q=True
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
if hasattr(self._parallel_config, "context_parallel_config") and \
self._parallel_config.context_parallel_config is not None:

return self._context_parallel_forward_qkv(
attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, pre_query, pre_key, cal_q
)

qkv_proj_out = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states, cal_q
)
if cal_q:
query, key, value, encoder_query, encoder_key, encoder_value = qkv_proj_out
else:
value, encoder_query, encoder_key, encoder_value = qkv_proj_out
query = pre_query
key = pre_key

query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
Expand Down Expand Up @@ -138,6 +190,106 @@ def __call__(
else:
return hidden_states

def _context_parallel_forward_qkv(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
pre_query: Optional[torch.Tensor] = None,
pre_key: Optional[torch.Tensor] = None,
cal_q=True
) -> torch.Tensor:

ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh
world_size = self._parallel_config.context_parallel_config.ulysses_degree
group = ulysses_mesh.get_group()

ev_q = torch.npu.Event()
ev_k = torch.npu.Event()

query = attn.to_q(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
ev_q.record()
key = attn.to_k(hidden_states)
key = key.unflatten(-1, (attn.heads, -1))
ev_k.record()

value = attn.to_v(hidden_states)
value = value.unflatten(-1, (attn.heads, -1))
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
value = torch.cat([encoder_value, value], dim=1)

with torch.npu.stream(STREAM_VECTOR):
ev_q.wait()
query = attn.norm_q(query)
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
query = torch.cat([encoder_query, query], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)

B, S_Q_LOCAL, H, D = query.shape
H_LOCAL = H // world_size
query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL)

ev_k.wait()
key = attn.norm_k(key)
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_key = attn.norm_added_k(encoder_key)
key = torch.cat([encoder_key, key], dim=1)
if image_rotary_emb is not None:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)

_, S_KV_LOCAL, _, _ = key.shape
key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)

value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)

query_all = _wait_tensor(query_all)
query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()

key_all = _wait_tensor(key_all)
key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()

value_all = _wait_tensor(value_all)
value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()

out = mindie_sd_attn_forward(
query_all,
key_all,
value_all,
opt_mode="manual",
op_type="ascend_laser_attention",
layout="BNSD"
)

out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
out = _all_to_all_single(out, group)
hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()

hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)

if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states
else:
return hidden_states


class FluxIPAdapterAttnProcessor(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
Expand Down Expand Up @@ -633,6 +785,7 @@ def __init__(
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)

self.gradient_checkpointing = False
self.image_rotary_emb = None

def forward(
self,
Expand Down Expand Up @@ -717,11 +870,15 @@ def forward(
img_ids = img_ids[0]

ids = torch.cat((txt_ids, img_ids), dim=0)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
if self.image_rotary_emb is None:
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
self.image_rotary_emb = (
freqs_cos.npu().to(hidden_states.dtype),
freqs_sin.npu().to(hidden_states.dtype)
)
else:
self.image_rotary_emb = self.pos_embed(ids)

if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
Expand All @@ -735,7 +892,7 @@ def forward(
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
self.image_rotary_emb,
joint_attention_kwargs,
)

Expand All @@ -744,7 +901,7 @@ def forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
image_rotary_emb=self.image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

Expand All @@ -767,7 +924,7 @@ def forward(
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
self.image_rotary_emb,
joint_attention_kwargs,
)

Expand All @@ -776,7 +933,7 @@ def forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
image_rotary_emb=self.image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

Expand Down