From 26c64d52672d1e8216977a40c236e880193950bb Mon Sep 17 00:00:00 2001 From: zhangtao0408 <365968531@qq.com> Date: Thu, 27 Nov 2025 09:59:10 +0800 Subject: [PATCH 1/3] Enhance FluxAttention with optional dual stream calculation Refactor FluxAttention to include optional dual stream calculation and integrate mindiesd attention forward. --- .../models/transformers/transformer_flux.py | 304 ++++++++++++++++-- 1 file changed, 281 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16c526f437f2..ff230561efee 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +import math import inspect from typing import Any, Dict, List, Optional, Tuple, Union @@ -38,13 +39,25 @@ from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from mindiesd import attention_forward as mindie_sd_attn_forward -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +STREAM_VECTOR = torch.npu.Stream() -def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if torch.distributed.is_available(): + import torch.distributed._functional_collectives as funcol +current_stream = torch.npu.current_stream() +stream2 = torch.npu.Stream() +current_event = torch.npu.Event() +event2 = torch.npu.Event() + +from ..attention_dispatch import npu_fusion_attention +def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True): + if cal_q: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) encoder_query = encoder_key = encoder_value = None @@ -52,9 +65,10 @@ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) - - return query, key, value, encoder_query, encoder_key, encoder_value - + if cal_q: + return query, key, value, encoder_query, encoder_key, encoder_value + else: + return value, encoder_query, encoder_key, encoder_value def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) @@ -66,11 +80,38 @@ def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_ return query, key, value, encoder_query, encoder_key, encoder_value -def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): - if attn.fused_projections: +def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True): + if attn.fused_projections and cal_q: return _get_fused_projections(attn, hidden_states, encoder_hidden_states) - return _get_projections(attn, hidden_states, encoder_hidden_states) - + return _get_projections(attn, hidden_states, encoder_hidden_states, cal_q) + +def _wait_tensor(tensor): + if isinstance(tensor, funcol.AsyncCollectiveTensor): + tensor = tensor.wait() + return tensor + +def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: + shape = x.shape + x = x.flatten() + x = funcol.all_to_all_single(x, None, None, group) + x = x.reshape(shape) + x = _wait_tensor(x) + return x + +def ulysses_preforward( + x: torch.Tensor, + group, + world_size, + B, + S_LOCAL, + H, + D, + H_LOCAL +): + x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + x = x.flatten() + x = funcol.all_to_all_single(x, None, None, group) + return x class FluxAttnProcessor: _attention_backend = None @@ -87,10 +128,26 @@ def __call__( encoder_hidden_states: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + pre_query: Optional[torch.Tensor] = None, + pre_key: Optional[torch.Tensor] = None, + cal_q=True ) -> torch.Tensor: - query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( - attn, hidden_states, encoder_hidden_states + if hasattr(self._parallel_config, "context_parallel_config") and \ + self._parallel_config.context_parallel_config is not None: + + return self._context_parallel_forward_qkv( + attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, pre_query, pre_key, cal_q + ) + + qkv_proj_out = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states, cal_q ) + if cal_q: + query, key, value, encoder_query, encoder_key, encoder_value = qkv_proj_out + else: + value, encoder_query, encoder_key, encoder_value = qkv_proj_out + query = pre_query + key = pre_key query = query.unflatten(-1, (attn.heads, -1)) key = key.unflatten(-1, (attn.heads, -1)) @@ -138,6 +195,202 @@ def __call__( else: return hidden_states + def _context_parallel_forward( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + pre_query: Optional[torch.Tensor] = None, + pre_key: Optional[torch.Tensor] = None, + cal_q=True + ) -> torch.Tensor: + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + value = attn.to_v(hidden_states) + value = value.unflatten(-1, (attn.heads, -1)) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + value = torch.cat([encoder_value, value], dim=1) + + B, S_KV_LOCAL, H, D = value.shape + H_LOCAL = H // world_size + value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + query = attn.to_q(hidden_states) + query = query.unflatten(-1, (attn.heads, -1)) + query = attn.norm_q(query) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_query = attn.norm_added_q(encoder_query) + query = torch.cat([encoder_query, query], dim=1) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + _, S_Q_LOCAL, _, _ = query.shape + query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + + key = attn.to_k(hidden_states) + key = key.unflatten(-1, (attn.heads, -1)) + key = attn.norm_k(key) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_key = attn.norm_added_k(encoder_key) + key = torch.cat([encoder_key, key], dim=1) + if image_rotary_emb is not None: + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + value_all = _wait_tensor(value_all) + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + query_all = _wait_tensor(query_all) + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + + key_all = _wait_tensor(key_all) + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + out = npu_fusion_attention( + query_all, + key_all, + value_all, + H_LOCAL, # num_heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(D), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + out = out.transpose(1, 2).contiguous() + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + def _context_parallel_forward_qkv( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + pre_query: Optional[torch.Tensor] = None, + pre_key: Optional[torch.Tensor] = None, + cal_q=True + ) -> torch.Tensor: + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + ev_q = torch.npu.Event() + ev_k = torch.npu.Event() + + query = attn.to_q(hidden_states) + query = query.unflatten(-1, (attn.heads, -1)) + ev_q.record() + key = attn.to_k(hidden_states) + key = key.unflatten(-1, (attn.heads, -1)) + ev_k.record() + + value = attn.to_v(hidden_states) + value = value.unflatten(-1, (attn.heads, -1)) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + value = torch.cat([encoder_value, value], dim=1) + + with torch.npu.stream(STREAM_VECTOR): + ev_q.wait() + query = attn.norm_q(query) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_query = attn.norm_added_q(encoder_query) + query = torch.cat([encoder_query, query], dim=1) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + + B, S_Q_LOCAL, H, D = query.shape + H_LOCAL = H // world_size + query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + + ev_k.wait() + key = attn.norm_k(key) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_key = attn.norm_added_k(encoder_key) + key = torch.cat([encoder_key, key], dim=1) + if image_rotary_emb is not None: + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + _, S_KV_LOCAL, _, _ = key.shape + key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + query_all = _wait_tensor(query_all) + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + key_all = _wait_tensor(key_all) + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + value_all = _wait_tensor(value_all) + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + out = mindie_sd_attn_forward( + query_all, + key_all, + value_all, + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + class FluxIPAdapterAttnProcessor(torch.nn.Module): """Flux Attention processor for IP-Adapter.""" @@ -633,6 +886,7 @@ def __init__( self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False + self.image_rotary_emb = None def forward( self, @@ -717,11 +971,15 @@ def forward( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + if self.image_rotary_emb is None: + if is_torch_npu_available(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + self.image_rotary_emb = ( + freqs_cos.npu().to(hidden_states.dtype), + freqs_sin.npu().to(hidden_states.dtype) + ) + else: + self.image_rotary_emb = self.pos_embed(ids) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") @@ -735,7 +993,7 @@ def forward( hidden_states, encoder_hidden_states, temb, - image_rotary_emb, + self.image_rotary_emb, joint_attention_kwargs, ) @@ -744,7 +1002,7 @@ def forward( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, - image_rotary_emb=image_rotary_emb, + image_rotary_emb=self.image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) @@ -767,7 +1025,7 @@ def forward( hidden_states, encoder_hidden_states, temb, - image_rotary_emb, + self.image_rotary_emb, joint_attention_kwargs, ) @@ -776,7 +1034,7 @@ def forward( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, - image_rotary_emb=image_rotary_emb, + image_rotary_emb=self.image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) From b5dfc0b6387dcbe36e341b9189d3c2ba36377061 Mon Sep 17 00:00:00 2001 From: zhangtao0408 <365968531@qq.com> Date: Thu, 27 Nov 2025 10:03:17 +0800 Subject: [PATCH 2/3] Remove _context_parallel_forward method Removed the _context_parallel_forward method, which handled context parallel forward operations for attention mechanisms. --- .../models/transformers/transformer_flux.py | 96 ------------------- 1 file changed, 96 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index ff230561efee..d90a088fba84 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -195,102 +195,6 @@ def __call__( else: return hidden_states - def _context_parallel_forward( - self, - attn: "FluxAttention", - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - pre_query: Optional[torch.Tensor] = None, - pre_key: Optional[torch.Tensor] = None, - cal_q=True - ) -> torch.Tensor: - - ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh - world_size = self._parallel_config.context_parallel_config.ulysses_degree - group = ulysses_mesh.get_group() - - value = attn.to_v(hidden_states) - value = value.unflatten(-1, (attn.heads, -1)) - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_value = attn.add_v_proj(encoder_hidden_states) - encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) - value = torch.cat([encoder_value, value], dim=1) - - B, S_KV_LOCAL, H, D = value.shape - H_LOCAL = H // world_size - value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) - - query = attn.to_q(hidden_states) - query = query.unflatten(-1, (attn.heads, -1)) - query = attn.norm_q(query) - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) - encoder_query = attn.norm_added_q(encoder_query) - query = torch.cat([encoder_query, query], dim=1) - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) - _, S_Q_LOCAL, _, _ = query.shape - query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) - - key = attn.to_k(hidden_states) - key = key.unflatten(-1, (attn.heads, -1)) - key = attn.norm_k(key) - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) - encoder_key = attn.norm_added_k(encoder_key) - key = torch.cat([encoder_key, key], dim=1) - if image_rotary_emb is not None: - key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) - - value_all = _wait_tensor(value_all) - value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() - - query_all = _wait_tensor(query_all) - query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() - - - key_all = _wait_tensor(key_all) - key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() - - out = npu_fusion_attention( - query_all, - key_all, - value_all, - H_LOCAL, # num_heads - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(D), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - out = out.transpose(1, 2).contiguous() - out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() - out = _all_to_all_single(out, group) - hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() - - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( - [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 - ) - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - def _context_parallel_forward_qkv( self, attn: "FluxAttention", From 200123686189f738cebc9cf7b4d085dd23c56054 Mon Sep 17 00:00:00 2001 From: zhangtao0408 <365968531@qq.com> Date: Thu, 27 Nov 2025 10:04:43 +0800 Subject: [PATCH 3/3] Clean up unused variables in transformer_flux.py Remove unused stream and event variables in transformer_flux.py --- src/diffusers/models/transformers/transformer_flux.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index d90a088fba84..9b4fbd53917d 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -43,15 +43,10 @@ STREAM_VECTOR = torch.npu.Stream() - logger = logging.get_logger(__name__) # pylint: disable=invalid-name if torch.distributed.is_available(): import torch.distributed._functional_collectives as funcol -current_stream = torch.npu.current_stream() -stream2 = torch.npu.Stream() -current_event = torch.npu.Event() -event2 = torch.npu.Event() from ..attention_dispatch import npu_fusion_attention def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True):