From fe5bb7c1bd9ae2ef5d350790b09348a882022fec Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 12 Jan 2025 18:04:17 +0530 Subject: [PATCH 1/2] improve flux true cfg condition --- src/diffusers/pipelines/flux/pipeline_flux.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 33154db54c73..f5716dc9c8ea 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -790,7 +790,10 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( prompt_embeds, pooled_prompt_embeds, From f5b23030d363949f5e173b2f602d45c2ff5409e1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 12 Jan 2025 18:14:18 +0530 Subject: [PATCH 2/2] add test --- tests/pipelines/flux/test_pipeline_flux.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index ab36333c4056..addc29e14670 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -209,6 +209,17 @@ def test_flux_image_output_shape(self): output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + def test_flux_true_cfg(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("generator") + + no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + inputs["negative_prompt"] = "bad quality" + inputs["true_cfg_scale"] = 2.0 + true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + assert not np.allclose(no_true_cfg_out, true_cfg_out) + @nightly @require_big_gpu_with_torch_cuda