From 8a1d13674933294f5b1ecefa021995adf89f2be7 Mon Sep 17 00:00:00 2001 From: sanaka <50254737+HorizonWind2004@users.noreply.github.com> Date: Tue, 24 Sep 2024 22:50:09 +0800 Subject: [PATCH 1/6] Update transformer_flux.py --- src/diffusers/models/transformers/transformer_flux.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index e38efe668c6c..84ab3f253e57 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -161,6 +161,7 @@ def forward( encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, + joint_attention_kwargs=None, ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) @@ -169,10 +170,13 @@ def forward( ) # Attention. + joint_attention_kwargs = {} if joint_attention_kwargs is None else joint_attention_kwargs attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) # Process attention outputs for the `hidden_states`. @@ -429,6 +433,7 @@ def forward( joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: + joint_attention_kwargs = {} lora_scale = 1.0 if USE_PEFT_BACKEND: @@ -497,6 +502,7 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual @@ -533,6 +539,7 @@ def custom_forward(*inputs): hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From b12ee928705a6f528cd7311acd482673ae80e072 Mon Sep 17 00:00:00 2001 From: sanaka <50254737+HorizonWind2004@users.noreply.github.com> Date: Tue, 24 Sep 2024 22:52:09 +0800 Subject: [PATCH 2/6] Update transformer_flux.py --- src/diffusers/models/transformers/transformer_flux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 84ab3f253e57..f4ae0ee31d09 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -433,7 +433,6 @@ def forward( joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: - joint_attention_kwargs = {} lora_scale = 1.0 if USE_PEFT_BACKEND: From 950976465addececeaaab819f33921c7babfa04f Mon Sep 17 00:00:00 2001 From: sanaka <50254737+HorizonWind2004@users.noreply.github.com> Date: Wed, 25 Sep 2024 01:17:52 +0800 Subject: [PATCH 3/6] Update transformer_flux.py --- src/diffusers/models/transformers/transformer_flux.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index f4ae0ee31d09..0c2fb30ea28a 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -83,11 +83,12 @@ def forward( hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, + joint_attention_kwargs=None, ): residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - + joint_attention_kwargs = joint_attention_kwargs if joint_attention_kwargs is not None else {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, @@ -161,7 +162,7 @@ def forward( encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, - joint_attention_kwargs=None, + joint_attention_kwargs={}, ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) @@ -170,7 +171,6 @@ def forward( ) # Attention. - joint_attention_kwargs = {} if joint_attention_kwargs is None else joint_attention_kwargs attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, From 51040989504870a8560866bcef7ff0e382954cc6 Mon Sep 17 00:00:00 2001 From: sanaka <50254737+HorizonWind2004@users.noreply.github.com> Date: Thu, 26 Sep 2024 20:28:49 +0800 Subject: [PATCH 4/6] Update transformer_flux.py fix a little bug --- src/diffusers/models/transformers/transformer_flux.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 0c2fb30ea28a..99eace66a83f 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -88,10 +88,11 @@ def forward( residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - joint_attention_kwargs = joint_attention_kwargs if joint_attention_kwargs is not None else {} + joint_attention_kwargs = joint_attention_kwargs or {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) @@ -162,14 +163,14 @@ def forward( encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, - joint_attention_kwargs={}, + joint_attention_kwargs=None, ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - + joint_attention_kwargs = joint_attention_kwargs or {} # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, From 9ed52843e62d92ab97301622a31fd717715410fb Mon Sep 17 00:00:00 2001 From: sanaka <50254737+HorizonWind2004@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:52:27 +0800 Subject: [PATCH 5/6] Update transformer_flux.py --- src/diffusers/models/transformers/transformer_flux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 99eace66a83f..2ee1e7c76b39 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -177,7 +177,6 @@ def forward( encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, - ) # Process attention outputs for the `hidden_states`. From b6128a453c07fd61c3e06c6864120fac7eef9337 Mon Sep 17 00:00:00 2001 From: sanaka Date: Tue, 8 Oct 2024 20:06:18 +0800 Subject: [PATCH 6/6] fix the quality --- src/diffusers/models/transformers/transformer_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 2ee1e7c76b39..6238ab8044bb 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -92,7 +92,7 @@ def forward( attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, - **joint_attention_kwargs + **joint_attention_kwargs, ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)