Skip to content

Commit 65a3bf5

Browse files
Apply style fixes
1 parent ed33194 commit 65a3bf5

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
160160
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
161161
self.ff = AuraFlowFeedForward(dim, dim * 4)
162162

163-
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_kwargs: Optional[Dict[str, Any]] = None):
163+
def forward(
164+
self,
165+
hidden_states: torch.FloatTensor,
166+
temb: torch.FloatTensor,
167+
attention_kwargs: Optional[Dict[str, Any]] = None,
168+
):
164169
residual = hidden_states
165170
attention_kwargs = attention_kwargs or {}
166171

@@ -224,7 +229,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
224229
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
225230

226231
def forward(
227-
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_kwargs: Optional[Dict[str, Any]] = None,
232+
self,
233+
hidden_states: torch.FloatTensor,
234+
encoder_hidden_states: torch.FloatTensor,
235+
temb: torch.FloatTensor,
236+
attention_kwargs: Optional[Dict[str, Any]] = None,
228237
):
229238
residual = hidden_states
230239
residual_context = encoder_hidden_states
@@ -238,7 +247,9 @@ def forward(
238247

239248
# Attention.
240249
attn_output, context_attn_output = self.attn(
241-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **attention_kwargs,
250+
hidden_states=norm_hidden_states,
251+
encoder_hidden_states=norm_encoder_hidden_states,
252+
**attention_kwargs,
242253
)
243254

244255
# Process attention outputs for the `hidden_states`.
@@ -492,7 +503,10 @@ def forward(
492503

493504
else:
494505
encoder_hidden_states, hidden_states = block(
495-
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, attention_kwargs=attention_kwargs,
506+
hidden_states=hidden_states,
507+
encoder_hidden_states=encoder_hidden_states,
508+
temb=temb,
509+
attention_kwargs=attention_kwargs,
496510
)
497511

498512
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
@@ -509,7 +523,9 @@ def forward(
509523
)
510524

511525
else:
512-
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs)
526+
combined_hidden_states = block(
527+
hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
528+
)
513529

514530
hidden_states = combined_hidden_states[:, encoder_seq_len:]
515531

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,7 @@ def __call__(
564564
batch_size = prompt_embeds.shape[0]
565565

566566
device = self._execution_device
567-
lora_scale = (
568-
self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
569-
)
567+
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
570568

571569
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
572570
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`

tests/lora/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,7 +2190,14 @@ def test_correct_lora_configs_with_different_ranks(self):
21902190

21912191
@property
21922192
def supports_text_encoder_lora(self):
2193-
return len({"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection(self.pipeline_class._lora_loadable_modules)) != 0
2193+
return (
2194+
len(
2195+
{"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection(
2196+
self.pipeline_class._lora_loadable_modules
2197+
)
2198+
)
2199+
!= 0
2200+
)
21942201

21952202
def test_layerwise_casting_inference_denoiser(self):
21962203
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
@@ -2249,7 +2256,9 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
22492256
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
22502257
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
22512258

2252-
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
2259+
pipe_float8_e4m3_bf16 = initialize_pipeline(
2260+
storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16
2261+
)
22532262
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
22542263

22552264
@require_peft_version_greater("0.14.0")

0 commit comments

Comments
 (0)