From f26edf51b36c7cf6cdb1a9bd703f151dbf6921b9 Mon Sep 17 00:00:00 2001 From: Jinay Jain Date: Wed, 28 Feb 2024 22:54:36 -0500 Subject: [PATCH 1/2] [bug] Fix float/int guidance scale not working in `StableVideoDiffusionPipeline` --- .../stable_video_diffusion/pipeline_stable_video_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index f53ebbafee2e..d882c708f3b9 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -299,7 +299,7 @@ def guidance_scale(self): @property def do_classifier_free_guidance(self): if isinstance(self.guidance_scale, (int, float)): - return self.guidance_scale + return self.guidance_scale > 1 return self.guidance_scale.max() > 1 @property From 8409bc0790d35f026d1bead9658a3c95314c5aea Mon Sep 17 00:00:00 2001 From: Jinay Jain Date: Mon, 4 Mar 2024 19:41:24 -0500 Subject: [PATCH 2/2] Add test to disable CFG on SVD --- .../test_stable_video_diffusion.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py index 5a3c79422c2b..33cf4c72863b 100644 --- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py +++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py @@ -496,6 +496,22 @@ def test_xformers_attention_forwardGenerator_pass(self): max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results") + def test_disable_cfg(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + inputs["max_guidance_scale"] = 1.0 + output = pipe(**inputs).frames + self.assertEqual(len(output.shape), 5) + @slow @require_torch_gpu