Skip to content

Commit fc71e97

Browse files
committed
init PipelineEfficiencyFunctionTesterMixin
1 parent cc4f805 commit fc71e97

File tree

3 files changed

+151
-35
lines changed

3 files changed

+151
-35
lines changed

tests/pipelines/stable_diffusion_2/test_stable_diffusion.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,23 @@
5252
TEXT_TO_IMAGE_IMAGE_PARAMS,
5353
TEXT_TO_IMAGE_PARAMS,
5454
)
55-
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
55+
from ..test_pipelines_common import (
56+
PipelineEfficiencyFunctionTesterMixin,
57+
PipelineKarrasSchedulerTesterMixin,
58+
PipelineLatentTesterMixin,
59+
PipelineTesterMixin,
60+
)
5661

5762

5863
enable_full_determinism()
5964

6065

6166
class StableDiffusion2PipelineFastTests(
62-
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
67+
PipelineEfficiencyFunctionTesterMixin,
68+
PipelineLatentTesterMixin,
69+
PipelineKarrasSchedulerTesterMixin,
70+
PipelineTesterMixin,
71+
unittest.TestCase,
6372
):
6473
pipeline_class = StableDiffusionPipeline
6574
params = TEXT_TO_IMAGE_PARAMS

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,23 @@
4949
TEXT_TO_IMAGE_IMAGE_PARAMS,
5050
TEXT_TO_IMAGE_PARAMS,
5151
)
52-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin
52+
from ..test_pipelines_common import (
53+
PipelineEfficiencyFunctionTesterMixin,
54+
PipelineLatentTesterMixin,
55+
PipelineTesterMixin,
56+
SDXLOptionalComponentsTesterMixin,
57+
)
5358

5459

5560
enable_full_determinism()
5661

5762

5863
class StableDiffusionXLPipelineFastTests(
59-
PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
64+
PipelineEfficiencyFunctionTesterMixin,
65+
PipelineLatentTesterMixin,
66+
PipelineTesterMixin,
67+
SDXLOptionalComponentsTesterMixin,
68+
unittest.TestCase,
6069
):
6170
pipeline_class = StableDiffusionXLPipeline
6271
params = TEXT_TO_IMAGE_PARAMS
@@ -939,37 +948,6 @@ def test_stable_diffusion_xl_save_from_pretrained(self):
939948

940949
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
941950

942-
def test_stable_diffusion_xl_with_fused_qkv_projections(self):
943-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
944-
components = self.get_dummy_components()
945-
sd_pipe = StableDiffusionXLPipeline(**components)
946-
sd_pipe = sd_pipe.to(device)
947-
sd_pipe.set_progress_bar_config(disable=None)
948-
949-
inputs = self.get_dummy_inputs(device)
950-
image = sd_pipe(**inputs).images
951-
original_image_slice = image[0, -3:, -3:, -1]
952-
953-
sd_pipe.fuse_qkv_projections()
954-
inputs = self.get_dummy_inputs(device)
955-
image = sd_pipe(**inputs).images
956-
image_slice_fused = image[0, -3:, -3:, -1]
957-
958-
sd_pipe.unfuse_qkv_projections()
959-
inputs = self.get_dummy_inputs(device)
960-
image = sd_pipe(**inputs).images
961-
image_slice_disabled = image[0, -3:, -3:, -1]
962-
963-
assert np.allclose(
964-
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
965-
), "Fusion of QKV projections shouldn't affect the outputs."
966-
assert np.allclose(
967-
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
968-
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
969-
assert np.allclose(
970-
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
971-
), "Original outputs should match when fused QKV projections are disabled."
972-
973951
def test_pipeline_interrupt(self):
974952
components = self.get_dummy_components()
975953
sd_pipe = StableDiffusionXLPipeline(**components)

tests/pipelines/test_pipelines_common.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,135 @@ def check_same_shape(tensor_list):
5959
return all(shape == shapes[0] for shape in shapes[1:])
6060

6161

62+
class PipelineEfficiencyFunctionTesterMixin:
63+
"""
64+
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
65+
It provides a set of common tests for PyTorch pipeline that inherit from EfficiencyMixin, e.g. vae_slicing, vae_tiling, freeu, etc.
66+
"""
67+
68+
def test_vae_slicing(self):
69+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
70+
components = self.get_dummy_components()
71+
# components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)
72+
pipe = self.pipeline_class(**components)
73+
pipe = pipe.to(device)
74+
pipe.set_progress_bar_config(disable=None)
75+
76+
image_count = 4
77+
78+
inputs = self.get_dummy_inputs(device)
79+
inputs["prompt"] = [inputs["prompt"]] * image_count
80+
output_1 = pipe(**inputs)
81+
82+
# make sure sliced vae decode yields the same result
83+
pipe.enable_vae_slicing()
84+
inputs = self.get_dummy_inputs(device)
85+
inputs["prompt"] = [inputs["prompt"]] * image_count
86+
output_2 = pipe(**inputs)
87+
88+
# there is a small discrepancy at image borders vs. full batch decode
89+
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3
90+
91+
def test_vae_tiling(self):
92+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
93+
components = self.get_dummy_components()
94+
95+
# make sure here that pndm scheduler skips prk
96+
components["safety_checker"] = None
97+
pipe = self.pipeline_class(**components)
98+
pipe = pipe.to(device)
99+
pipe.set_progress_bar_config(disable=None)
100+
101+
prompt = "A painting of a squirrel eating a burger"
102+
103+
# Test that tiled decode at 512x512 yields the same result as the non-tiled decode
104+
generator = torch.Generator(device=device).manual_seed(0)
105+
output_1 = pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
106+
107+
# make sure tiled vae decode yields the same result
108+
pipe.enable_vae_tiling()
109+
generator = torch.Generator(device=device).manual_seed(0)
110+
output_2 = pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
111+
112+
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 5e-1
113+
114+
# test that tiled decode works with various shapes
115+
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
116+
for shape in shapes:
117+
zeros = torch.zeros(shape).to(device)
118+
pipe.vae.decode(zeros)
119+
120+
def test_freeu_enabled(self):
121+
components = self.get_dummy_components()
122+
pipe = self.pipeline_class(**components)
123+
pipe = pipe.to(torch_device)
124+
pipe.set_progress_bar_config(disable=None)
125+
126+
prompt = "hey"
127+
output = pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
128+
129+
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
130+
output_freeu = pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
131+
132+
assert not np.allclose(
133+
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
134+
), "Enabling of FreeU should lead to different results."
135+
136+
def test_freeu_disabled(self):
137+
components = self.get_dummy_components()
138+
pipe = self.pipeline_class(**components)
139+
pipe = pipe.to(torch_device)
140+
pipe.set_progress_bar_config(disable=None)
141+
142+
prompt = "hey"
143+
output = pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
144+
145+
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
146+
pipe.disable_freeu()
147+
148+
freeu_keys = {"s1", "s2", "b1", "b2"}
149+
for upsample_block in pipe.unet.up_blocks:
150+
for key in freeu_keys:
151+
assert getattr(upsample_block, key) is None, f"Disabling of FreeU should have set {key} to None."
152+
153+
output_no_freeu = pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
154+
155+
assert np.allclose(
156+
output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
157+
), "Disabling of FreeU should lead to results similar to the default pipeline results."
158+
159+
def test_fused_qkv_projections(self):
160+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
161+
components = self.get_dummy_components()
162+
pipe = self.pipeline_class(**components)
163+
pipe = pipe.to(device)
164+
pipe.set_progress_bar_config(disable=None)
165+
166+
inputs = self.get_dummy_inputs(device)
167+
image = pipe(**inputs).images
168+
original_image_slice = image[0, -3:, -3:, -1]
169+
170+
pipe.fuse_qkv_projections()
171+
inputs = self.get_dummy_inputs(device)
172+
image = pipe(**inputs).images
173+
image_slice_fused = image[0, -3:, -3:, -1]
174+
175+
pipe.unfuse_qkv_projections()
176+
inputs = self.get_dummy_inputs(device)
177+
image = pipe(**inputs).images
178+
image_slice_disabled = image[0, -3:, -3:, -1]
179+
180+
assert np.allclose(
181+
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
182+
), "Fusion of QKV projections shouldn't affect the outputs."
183+
assert np.allclose(
184+
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
185+
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
186+
assert np.allclose(
187+
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
188+
), "Original outputs should match when fused QKV projections are disabled."
189+
190+
62191
class PipelineLatentTesterMixin:
63192
"""
64193
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.

0 commit comments

Comments
 (0)